refactored auth for JWTs, completed TOTP/2FA implementation, added endpoints to change display name and password
This commit is contained in:
@@ -9,6 +9,7 @@ chrono = { version = "0.4.42", features = ["serde"] }
|
||||
dotenv = "0.15.0"
|
||||
futures-util = "0.3.31"
|
||||
image = "0.25.8"
|
||||
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
|
||||
rand = "0.9.2"
|
||||
redis = { version = "0.25.4", features = ["tokio-comp"] }
|
||||
reqwest = { version = "0.12.23", features = ["json"] }
|
||||
@@ -22,4 +23,5 @@ sha2 = "0.10.9"
|
||||
sqlx = { version = "0.7.4", features = ["macros", "time"] }
|
||||
tokio = { version = "1.47.1", features = ["full"] }
|
||||
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
|
||||
tracing = "0.1.44"
|
||||
uuid = { version = "1.18.1", features = ["v4"] }
|
||||
|
||||
+6
-4
@@ -3,15 +3,17 @@ secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU="
|
||||
address = "0.0.0.0"
|
||||
port = 8000
|
||||
|
||||
[default.databases.postgres_db]
|
||||
url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp"
|
||||
[debug.databases.postgres_db]
|
||||
url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev"
|
||||
|
||||
[default.databases.redis_cache]
|
||||
url = "redis://chatapp_redis:6379"
|
||||
[release.databases.postgres_db]
|
||||
url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_prod"
|
||||
|
||||
[debug.databases.redis_cache]
|
||||
url = "redis://localhost:6379"
|
||||
|
||||
[release.databases.redis_cache]
|
||||
url = "redis://chatapp_redis:6379"
|
||||
|
||||
[default] # run inside a docker container or pod
|
||||
address = "0.0.0.0"
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
-- Add migration script here
|
||||
TRUNCATE TABLE users CASCADE;
|
||||
|
||||
ALTER TABLE users DROP COLUMN password;
|
||||
ALTER TABLE users ADD COLUMN pass_hash VARCHAR(255) NOT NULL;
|
||||
ALTER TABLE users ADD CONSTRAINT users_username_unique UNIQUE (username);
|
||||
@@ -0,0 +1,13 @@
|
||||
-- Add migration script here
|
||||
CREATE TYPE status AS ENUM ('pending', 'accepted', 'blocked');
|
||||
|
||||
CREATE TABLE relationships (
|
||||
id SERIAL PRIMARY KEY,
|
||||
from_user INTEGER NOT NULL REFERENCES users(id),
|
||||
to_user INTEGER NOT NULL REFERENCES users(id),
|
||||
status status NOT NULL DEFAULT 'pending',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
|
||||
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
|
||||
);
|
||||
+69
-40
@@ -1,3 +1,8 @@
|
||||
use argon2::{
|
||||
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
|
||||
password_hash::{SaltString, rand_core::OsRng},
|
||||
};
|
||||
use jsonwebtoken::{EncodingKey, Header, encode};
|
||||
use rocket::{
|
||||
http::{CookieJar, Status},
|
||||
response::{Redirect, status::BadRequest},
|
||||
@@ -9,7 +14,11 @@ use rocket_dyn_templates::{Template, context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{auth::session::Session, db::Postgres};
|
||||
use crate::{
|
||||
auth::session::{Claims, Session, TokenScope},
|
||||
db::Postgres,
|
||||
user::User,
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SignupCredentials {
|
||||
@@ -25,6 +34,12 @@ pub struct LoginCredentials {
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AuthResponse {
|
||||
pub token: String,
|
||||
pub totp_required: bool,
|
||||
}
|
||||
|
||||
#[get("/signup")]
|
||||
pub async fn signup_page() -> Template {
|
||||
Template::render("signup", context!())
|
||||
@@ -35,35 +50,38 @@ pub async fn signup(
|
||||
cred: Json<SignupCredentials>,
|
||||
jar: &CookieJar<'_>,
|
||||
mut db: Connection<Postgres>,
|
||||
) -> Result<Redirect, BadRequest<String>> {
|
||||
println!("phase 1 {}", cred.access_token);
|
||||
let token_id = AccessToken::validate(&cred.access_token, &mut db).await?;
|
||||
) -> Result<Json<AuthResponse>, Status> {
|
||||
let token_id = AccessToken::validate(&cred.access_token, &mut db)
|
||||
.await
|
||||
.map_err(|_| Status::Unauthorized)?;
|
||||
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let hashed = Argon2::default()
|
||||
.hash_password(cred.password.as_bytes(), &salt)
|
||||
.map_err(|_| Status::InternalServerError)?
|
||||
.to_string();
|
||||
|
||||
println!("phase 2");
|
||||
let result = sqlx::query!(
|
||||
"INSERT INTO users (email, username, password) VALUES ($1, $2, $3) RETURNING id",
|
||||
"INSERT INTO users (email, username, pass_hash) VALUES ($1, $2, $3) RETURNING id",
|
||||
cred.email,
|
||||
cred.username,
|
||||
cred.password
|
||||
hashed,
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map_err(|_| BadRequest(String::from("Failed to create user")))?;
|
||||
.map_err(|_| Status::InternalServerError)?;
|
||||
|
||||
println!("phase 3");
|
||||
let session = Session::new(result.id as usize);
|
||||
if let Err(e) = session.commit(&mut db).await {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return Err(BadRequest(String::from("Failed to create session")));
|
||||
}
|
||||
let jwt = Claims::new(result.id as usize, TokenScope::Full).encode();
|
||||
|
||||
println!("phase 4");
|
||||
jar.add_private(("session", session.token));
|
||||
token_id
|
||||
.use_token(&mut db)
|
||||
.await
|
||||
.expect("unable to use access code");
|
||||
|
||||
token_id.use_token(&mut db).await?;
|
||||
|
||||
println!("phase 5");
|
||||
Ok(Redirect::to("/chat"))
|
||||
Ok(Json(AuthResponse {
|
||||
token: jwt,
|
||||
totp_required: false,
|
||||
}))
|
||||
}
|
||||
|
||||
#[get("/login")]
|
||||
@@ -74,29 +92,40 @@ pub async fn login_page() -> Template {
|
||||
#[post("/login", data = "<cred>")]
|
||||
pub async fn login(
|
||||
mut db: Connection<Postgres>,
|
||||
jar: &CookieJar<'_>,
|
||||
cred: Json<LoginCredentials>,
|
||||
) -> Result<Redirect, Status> {
|
||||
if let Ok(row) = sqlx::query!(
|
||||
"SELECT id FROM users WHERE username = $1 AND password = $2",
|
||||
) -> Result<Json<AuthResponse>, Status> {
|
||||
println!("e");
|
||||
let row = sqlx::query!(
|
||||
"SELECT id, pass_hash, twofa_enabled FROM users WHERE username = $1",
|
||||
cred.username,
|
||||
cred.password,
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
{
|
||||
let session = Session::new(row.id as usize);
|
||||
if let Err(e) = session.commit(&mut db).await {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return Err(Status::InternalServerError);
|
||||
}
|
||||
.map_err(|_| Status::Unauthorized)?;
|
||||
|
||||
jar.add_private(("session", session.token));
|
||||
return Ok(Redirect::to("/chat"));
|
||||
}
|
||||
println!("ok");
|
||||
|
||||
// TODO: implement actual login logic, e.g. verify password and generate token
|
||||
Err(Status::Unauthorized)
|
||||
// verify password as before
|
||||
let parsed_hash = PasswordHash::new(&row.pass_hash).map_err(|_| Status::InternalServerError)?;
|
||||
Argon2::default()
|
||||
.verify_password(cred.password.as_bytes(), &parsed_hash)
|
||||
.map_err(|_| Status::Unauthorized)?;
|
||||
|
||||
println!("ok2");
|
||||
|
||||
let user_id = row.id as usize;
|
||||
|
||||
// issue either a partial or full token depending on 2FA status
|
||||
let (session, totp_required) = if row.twofa_enabled {
|
||||
(Claims::new(user_id, TokenScope::TotpPending), true)
|
||||
} else {
|
||||
(Claims::new(user_id, TokenScope::Full), false)
|
||||
};
|
||||
|
||||
Ok(Json(AuthResponse {
|
||||
token: session.encode(),
|
||||
totp_required,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -149,7 +178,7 @@ impl AccessToken {
|
||||
pub async fn validate(
|
||||
token: &str,
|
||||
db: &mut Connection<Postgres>,
|
||||
) -> Result<AccessToken, BadRequest<String>> {
|
||||
) -> Result<AccessToken, String> {
|
||||
match sqlx::query!(
|
||||
"SELECT id FROM access_codes
|
||||
WHERE code = $1
|
||||
@@ -165,18 +194,18 @@ impl AccessToken {
|
||||
id: row.id,
|
||||
_code: token.to_string(),
|
||||
}),
|
||||
Err(_) => Err(BadRequest(String::from("Invalid or Expired token!"))),
|
||||
Err(_) => Err(String::from("Invalid or Expired token!")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn use_token(&self, db: &mut Connection<Postgres>) -> Result<(), BadRequest<String>> {
|
||||
pub async fn use_token(&self, db: &mut Connection<Postgres>) -> Result<(), String> {
|
||||
sqlx::query!(
|
||||
"UPDATE access_codes SET uses = uses + 1 WHERE id = $1",
|
||||
self.id
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await
|
||||
.map_err(|_| BadRequest(String::from("Invalid or Expired token!")))?;
|
||||
.map_err(|_| String::from("Invalid or Expired token!"))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
pub mod account;
|
||||
pub mod profile;
|
||||
pub mod session;
|
||||
pub mod two_factor;
|
||||
|
||||
pub use session::Session;
|
||||
|
||||
pub use account::{generate_invite, invite_page, login, login_page, signup, signup_page};
|
||||
pub use two_factor::{confirm_totp, get_totp, mfa_page};
|
||||
pub use profile::{change_display_name, change_password};
|
||||
pub use two_factor::{
|
||||
confirm_totp, disable_totp, get_totp, get_totp_status, mfa_page, verify_totp,
|
||||
};
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
use argon2::{
|
||||
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
|
||||
password_hash::{SaltString, rand_core::OsRng},
|
||||
};
|
||||
use rocket::{http::Status, serde::json::Json};
|
||||
use rocket_db_pools::Connection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{auth::Session, db::Postgres, user::User};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PasswordForm {
|
||||
old_password: String,
|
||||
new_password: String,
|
||||
}
|
||||
|
||||
#[post("/settings/password", data = "<form>")]
|
||||
pub async fn change_password(
|
||||
session: Session,
|
||||
mut db: Connection<Postgres>,
|
||||
form: Json<PasswordForm>,
|
||||
) -> Result<(), Status> {
|
||||
let mut user = User::get_by_id(session.user_id, &mut db)
|
||||
.await
|
||||
.ok_or(Status::NotFound)
|
||||
.inspect_err(|_| {
|
||||
tracing::error!(
|
||||
"Valid session does not have a valid user. ID: {}",
|
||||
session.user_id
|
||||
)
|
||||
})?;
|
||||
|
||||
let parsed_hash = PasswordHash::new(&user.pass_hash)
|
||||
.inspect_err(|e| tracing::error!("Failed to parse hash for password! uid:{} {e}", user.id))
|
||||
.map_err(|_| Status::InternalServerError)?;
|
||||
|
||||
Argon2::default()
|
||||
.verify_password(form.old_password.as_bytes(), &parsed_hash)
|
||||
.map_err(|_| Status::Unauthorized)?;
|
||||
|
||||
// old password is correct, so new one can be set.
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let hashed = Argon2::default()
|
||||
.hash_password(form.new_password.as_bytes(), &salt)
|
||||
.inspect_err(|e| tracing::error!("failed to hash password! {e}"))
|
||||
.map_err(|_| Status::InternalServerError)?
|
||||
.to_string();
|
||||
|
||||
user.set_pass_hash(hashed, &mut db)
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!("{e}"))
|
||||
.map_err(|_| Status::InternalServerError)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct DisplayNameForm {
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
#[post("/settings/display_name", data = "<new>")]
|
||||
pub async fn change_display_name(
|
||||
session: Session,
|
||||
mut db: Connection<Postgres>,
|
||||
new: Json<DisplayNameForm>,
|
||||
) -> Result<(), Status> {
|
||||
let mut user = User::get_by_id(session.user_id, &mut db)
|
||||
.await
|
||||
.ok_or(Status::NotFound)
|
||||
.inspect_err(|_| {
|
||||
tracing::error!(
|
||||
"Valid session does not have a valid user. ID: {}",
|
||||
session.user_id
|
||||
)
|
||||
})?;
|
||||
|
||||
user.set_display_name(new.display_name.clone(), &mut db)
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!("{e}"))
|
||||
.map_err(|_| Status::InternalServerError)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
+85
-55
@@ -1,5 +1,9 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use std::{
|
||||
sync::LazyLock,
|
||||
time::{SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||||
use rand::Rng;
|
||||
use rocket::{
|
||||
Request,
|
||||
@@ -7,73 +11,99 @@ use rocket::{
|
||||
request::{self, FromRequest, Outcome},
|
||||
};
|
||||
use rocket_db_pools::Connection;
|
||||
use sha2::{Digest, Sha256};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256, digest::block_buffer::Lazy};
|
||||
use sqlx::postgres::PgQueryResult;
|
||||
|
||||
use crate::db::Postgres;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Session {
|
||||
pub token: String,
|
||||
pub user_id: usize,
|
||||
static JWT_SECRET: LazyLock<String> = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap());
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TokenScope {
|
||||
Full,
|
||||
TotpPending,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(user_id: usize) -> Self {
|
||||
let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
||||
let random: u32 = rand::rng().random();
|
||||
let token = format!("{}-{}", current_time.as_secs(), random);
|
||||
let hashed = format!("{:x}", Sha256::digest(token.as_bytes()));
|
||||
Self {
|
||||
token: hashed,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn commit(
|
||||
&self,
|
||||
db: &mut Connection<Postgres>,
|
||||
) -> Result<PgQueryResult, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
"INSERT INTO sessions (user_id, token) VALUES ($1, $2)",
|
||||
self.user_id as i32,
|
||||
self.token,
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await
|
||||
}
|
||||
pub struct Session {
|
||||
pub user_id: usize,
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for Session {
|
||||
type Error = ();
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
if let Some(c) = request.cookies().get_private("session") {
|
||||
let mut pool = match request.guard::<Connection<Postgres>>().await {
|
||||
Outcome::Success(pool) => pool,
|
||||
_ => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
let value = c.value();
|
||||
let result = sqlx::query!(
|
||||
"SELECT user_id, token FROM sessions WHERE token = $1 AND expires_at > NOW()",
|
||||
value
|
||||
)
|
||||
.fetch_optional(&mut **pool)
|
||||
.await
|
||||
.expect("query failed!");
|
||||
|
||||
if let Some(session) = result {
|
||||
Outcome::Success(Self {
|
||||
user_id: session.user_id as usize,
|
||||
token: session.token,
|
||||
})
|
||||
} else {
|
||||
Outcome::Error((Status::Unauthorized, ()))
|
||||
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
match Claims::from_request(req).await {
|
||||
Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session {
|
||||
user_id: user.sub as usize,
|
||||
}),
|
||||
Outcome::Success(_) => {
|
||||
eprintln!("warning: user with scope other than Full attempted to access session");
|
||||
Outcome::Error((Status::Forbidden, ()))
|
||||
}
|
||||
Outcome::Error(err) => {
|
||||
eprintln!("Session request guard failed: {:?}", err);
|
||||
Outcome::Error(err)
|
||||
}
|
||||
_ => unreachable!("forward should never be called"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: i32,
|
||||
pub exp: usize,
|
||||
pub scope: TokenScope,
|
||||
}
|
||||
|
||||
impl Claims {
|
||||
pub fn new(user_id: usize, scope: TokenScope) -> Self {
|
||||
Self {
|
||||
sub: user_id as i32,
|
||||
exp: (SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
+ 3600) as usize,
|
||||
scope,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode(&self) -> String {
|
||||
encode(
|
||||
&Header::default(),
|
||||
self,
|
||||
&EncodingKey::from_secret(JWT_SECRET.as_bytes()),
|
||||
)
|
||||
.expect("unable to encode jwt")
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for Claims {
|
||||
type Error = ();
|
||||
|
||||
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
let token = req
|
||||
.headers()
|
||||
.get_one("Authorization")
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
match token {
|
||||
None => Outcome::Error((Status::Unauthorized, ())),
|
||||
Some(t) => {
|
||||
match decode::<Claims>(
|
||||
t,
|
||||
&DecodingKey::from_secret(JWT_SECRET.as_bytes()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(data) => Outcome::Success(data.claims),
|
||||
Err(_) => Outcome::Error((Status::Unauthorized, ())),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Outcome::Error((Status::Unauthorized, ()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+146
-23
@@ -11,7 +11,13 @@ use rocket_dyn_templates::{Template, context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use totp_rs::{Algorithm, Secret, TOTP};
|
||||
|
||||
use crate::{auth::session::Session, db::Postgres};
|
||||
use crate::{
|
||||
auth::{
|
||||
account::AuthResponse,
|
||||
session::{Claims, Session, TokenScope},
|
||||
},
|
||||
db::Postgres,
|
||||
};
|
||||
|
||||
// Utility methods
|
||||
|
||||
@@ -35,25 +41,23 @@ pub async fn mfa_page(_session: Session) -> Template {
|
||||
Template::render("2fa", context!())
|
||||
}
|
||||
|
||||
// api
|
||||
|
||||
#[post("/totp", data = "<form>")]
|
||||
pub async fn confirm_totp(
|
||||
mfa: TOTPSecret,
|
||||
form: Json<TOTPSixDigitCode>,
|
||||
mut db: Connection<Postgres>,
|
||||
) -> Result<(), status::Custom<&'static str>> {
|
||||
if form.code.len() != 6 && form.code.parse::<usize>().is_err() {
|
||||
if form.code.len() != 6 || form.code.parse::<u32>().is_err() {
|
||||
return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code"));
|
||||
}
|
||||
|
||||
println!("valid");
|
||||
|
||||
let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes()).unwrap();
|
||||
if !totp.check_current(&form.code.to_string()).unwrap() {
|
||||
return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code"));
|
||||
let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes())
|
||||
.map_err(|_| status::Custom(Status::InternalServerError, "TOTP Error"))?;
|
||||
if !totp.check_current(&form.code).unwrap_or(false) {
|
||||
return Err(status::Custom(Status::BadRequest, "Incorrect code"));
|
||||
}
|
||||
|
||||
println!("correct");
|
||||
|
||||
if sqlx::query!(
|
||||
@@ -92,6 +96,13 @@ pub struct TOTPSixDigitCode {
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TotpStatus {
|
||||
Enabled,
|
||||
Disabled,
|
||||
}
|
||||
|
||||
pub struct TOTPSecret {
|
||||
user_id: usize,
|
||||
secret: String,
|
||||
@@ -107,37 +118,53 @@ impl<'r> FromRequest<'r> for TOTPSecret {
|
||||
type Error = ();
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
let auth_header = request.headers().get_one("Authorization");
|
||||
println!(
|
||||
"TOTPSecret guard - Auth header present: {}",
|
||||
auth_header.is_some()
|
||||
);
|
||||
|
||||
let user = try_outcome!(request.guard::<Claims>().await);
|
||||
println!(
|
||||
"TOTPSecret guard - Claims ok, user: {}, scope: {:?}",
|
||||
user.sub, user.scope
|
||||
);
|
||||
|
||||
// only allow full tokens for TOTP setup
|
||||
if user.scope != TokenScope::Full {
|
||||
println!("TOTPSecret guard - rejected, scope is {:?}", user.scope);
|
||||
return Outcome::Error((Status::Forbidden, ()));
|
||||
}
|
||||
|
||||
let user = try_outcome!(request.guard::<Session>().await);
|
||||
let mut pool = match request.guard::<Connection<Postgres>>().await {
|
||||
Outcome::Success(pool) => pool,
|
||||
_ => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
let (enabled, mut secret) = match sqlx::query!(
|
||||
let row = sqlx::query!(
|
||||
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
|
||||
user.user_id as i32,
|
||||
user.user_id as i32
|
||||
)
|
||||
.fetch_one(&mut **pool)
|
||||
.await
|
||||
{
|
||||
Ok(row) => (row.twofa_enabled, row.totp_secret),
|
||||
.await;
|
||||
|
||||
let (enabled, mut secret) = match row {
|
||||
Ok(r) => (r.twofa_enabled, r.totp_secret),
|
||||
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
if !enabled || secret.is_none() {
|
||||
secret = Some(Secret::generate_secret().to_string());
|
||||
|
||||
match sqlx::query!(
|
||||
if secret.is_none() {
|
||||
let new_secret = Secret::generate_secret().to_encoded().to_string();
|
||||
sqlx::query!(
|
||||
"UPDATE users SET totp_secret = $1 WHERE id = $2",
|
||||
secret.as_ref().unwrap(),
|
||||
user.user_id as i32,
|
||||
new_secret,
|
||||
user.user_id as i32
|
||||
)
|
||||
.execute(&mut **pool)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(_) => return Outcome::Error((Status::InternalServerError, ())),
|
||||
}
|
||||
.ok();
|
||||
secret = Some(new_secret);
|
||||
}
|
||||
|
||||
Outcome::Success(TOTPSecret {
|
||||
@@ -161,3 +188,99 @@ impl TOTPSecret {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct TotpVerifyRequest {
|
||||
pub code: String,
|
||||
}
|
||||
|
||||
#[get("/totp/status")]
|
||||
pub async fn get_totp_status(
|
||||
user: Session,
|
||||
mut db: Connection<Postgres>,
|
||||
) -> Result<Json<TotpStatus>, Status> {
|
||||
Ok(Json(
|
||||
if sqlx::query!(
|
||||
"SELECT twofa_enabled FROM users WHERE id = $1",
|
||||
user.user_id as i32,
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map_err(|_| Status::NotFound)?
|
||||
.twofa_enabled
|
||||
{
|
||||
TotpStatus::Enabled
|
||||
} else {
|
||||
TotpStatus::Disabled
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
#[delete("/totp")]
|
||||
pub async fn disable_totp(
|
||||
user: Session,
|
||||
mut db: Connection<Postgres>,
|
||||
) -> Result<Json<AuthResponse>, Status> {
|
||||
sqlx::query!(
|
||||
"UPDATE users SET twofa_enabled = false, totp_secret = NULL WHERE id = $1",
|
||||
user.user_id as i32,
|
||||
)
|
||||
.execute(&mut **db)
|
||||
.await
|
||||
.map_err(|_| Status::NotFound)?;
|
||||
|
||||
Ok(Json(AuthResponse {
|
||||
token: Claims::new(user.user_id, TokenScope::Full).encode(),
|
||||
totp_required: false,
|
||||
}))
|
||||
}
|
||||
|
||||
#[post("/totp/verify", data = "<body>")]
|
||||
pub async fn verify_totp(
|
||||
user: Claims, // request guard checks token validity
|
||||
mut db: Connection<Postgres>,
|
||||
body: Json<TotpVerifyRequest>,
|
||||
) -> Result<Json<AuthResponse>, Status> {
|
||||
println!("reached 1");
|
||||
|
||||
// reject if they somehow got here with a full token
|
||||
if user.scope != TokenScope::TotpPending {
|
||||
return Err(Status::Forbidden);
|
||||
}
|
||||
|
||||
println!("reached 2");
|
||||
|
||||
let row = sqlx::query!(
|
||||
"SELECT totp_secret FROM users WHERE id = $1 AND twofa_enabled = TRUE",
|
||||
user.sub
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map_err(|_| Status::Unauthorized)?;
|
||||
|
||||
println!("reached 3");
|
||||
|
||||
let totp = totp_gen(
|
||||
user.sub as usize,
|
||||
row.totp_secret
|
||||
.expect("user with 2fa enabled has no totp secret")
|
||||
.as_bytes(),
|
||||
)
|
||||
.map_err(|_| Status::InternalServerError)?;
|
||||
|
||||
if !totp
|
||||
.check_current(&body.code)
|
||||
.map_err(|_| Status::InternalServerError)?
|
||||
{
|
||||
return Err(Status::Unauthorized);
|
||||
}
|
||||
|
||||
println!("reached 5");
|
||||
|
||||
let claims = Claims::new(user.sub as usize, TokenScope::Full);
|
||||
|
||||
Ok(Json(AuthResponse {
|
||||
token: claims.encode(),
|
||||
totp_required: false,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -73,6 +73,11 @@ fn rocket() -> Rocket<Build> {
|
||||
auth::get_totp,
|
||||
auth::confirm_totp,
|
||||
auth::generate_invite,
|
||||
auth::verify_totp,
|
||||
auth::disable_totp,
|
||||
auth::get_totp_status,
|
||||
auth::change_password,
|
||||
auth::change_display_name
|
||||
],
|
||||
)
|
||||
.register(
|
||||
|
||||
@@ -96,6 +96,8 @@ pub async fn post_message(
|
||||
.await
|
||||
.map_err(|_| "Failed".to_string())?;
|
||||
|
||||
println!("gisfujdeghnjuisdfjngiosdfgjkosdf gnojdfsg nmodfsg");
|
||||
|
||||
if let Some(ref mut cache) = cache {
|
||||
messenger::cache::insert(cache, channel_id, &msg)
|
||||
.await
|
||||
|
||||
+58
-1
@@ -1,5 +1,5 @@
|
||||
use redis::AsyncCommands;
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::{serde::json::Json, time::OffsetDateTime};
|
||||
use rocket_db_pools::Connection;
|
||||
|
||||
use crate::{
|
||||
@@ -7,6 +7,63 @@ use crate::{
|
||||
db::{Postgres, Redis},
|
||||
};
|
||||
|
||||
pub struct User {
|
||||
pub id: i32,
|
||||
pub email: Option<String>,
|
||||
pub username: String,
|
||||
pub display_name: Option<String>,
|
||||
pub pass_hash: String,
|
||||
pub twofa_enabled: bool,
|
||||
pub totp_secret: Option<String>,
|
||||
pub created_at: Option<OffsetDateTime>,
|
||||
pub updated_at: Option<OffsetDateTime>,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub async fn get_by_id(id: usize, db: &mut Connection<Postgres>) -> Option<Self> {
|
||||
sqlx::query_as!(
|
||||
Self,
|
||||
"SELECT id, email, username, display_name, pass_hash, twofa_enabled, totp_secret, created_at, updated_at FROM users WHERE id = $1",
|
||||
id as i32
|
||||
)
|
||||
.fetch_optional(&mut ***db)
|
||||
.await
|
||||
.unwrap_or(None)
|
||||
}
|
||||
|
||||
pub async fn set_display_name(
|
||||
&mut self,
|
||||
display_name: Option<String>,
|
||||
db: &mut Connection<Postgres>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
self.display_name = display_name;
|
||||
sqlx::query!(
|
||||
"UPDATE users SET display_name = $1 WHERE id = $2",
|
||||
self.display_name,
|
||||
self.id
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_pass_hash(
|
||||
&mut self,
|
||||
pass_hash: String,
|
||||
db: &mut Connection<Postgres>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
self.pass_hash = pass_hash;
|
||||
sqlx::query!(
|
||||
"UPDATE users SET pass_hash = $1 WHERE id = $2",
|
||||
self.pass_hash,
|
||||
self.id
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/users", rank = 2)]
|
||||
pub async fn users(_ag: Session, mut db: Connection<Postgres>) -> Json<Vec<i32>> {
|
||||
sqlx::query!("SELECT id FROM users")
|
||||
|
||||
@@ -124,9 +124,9 @@
|
||||
successMessage.classList.add("show");
|
||||
submitButton.innerHTML = "Logged in!!";
|
||||
|
||||
setTimeout(() => {
|
||||
window.location.replace('/chat');
|
||||
}, 1000);
|
||||
// setTimeout(() => {
|
||||
// window.location.replace('/chat');
|
||||
// }, 1000);
|
||||
} else {
|
||||
const error = await response.text();
|
||||
throw new Error(error || "Login failed");
|
||||
|
||||
Reference in New Issue
Block a user