use crate::error::ApiResult; use crate::model::auth::{AccessTokenForm, AuthResponse, LoginCredentials, SignupCredentials}; use crate::svc::access_token_svc::AccessTokenService; use crate::svc::auth_svc::AuthService; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode}; use rocket::http::Status; use rocket::request::{FromRequest, Outcome}; use rocket::serde::json::Json; use rocket::serde::{Deserialize, Serialize}; use rocket::{Request, State}; use std::sync::LazyLock; use std::time::{SystemTime, UNIX_EPOCH}; #[post("/signup", data = "")] pub async fn signup( cred: Json, svc: &State, ) -> ApiResult> { let response = svc .signup( &cred.email, &cred.username, &cred.password, &cred.access_token, ) .await?; Ok(Json(response)) } #[post("/login", data = "")] pub async fn login( cred: Json, svc: &State, ) -> ApiResult> { Ok(Json(svc.login(&cred.username, &cred.password).await?)) } #[post("/invite", data = "
")] pub async fn generate_invite( session: AdminSession, form: Json, svc: &State, ) -> ApiResult { svc.create( session.uid, &form.name, form.max_uses, form.start_date, form.expiry_date, ) .await } static JWT_SECRET: LazyLock = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap()); #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)] #[serde(rename_all = "snake_case")] pub enum TokenScope { Full, TotpPending, } pub struct Session { pub uid: i64, } #[rocket::async_trait] impl<'r> FromRequest<'r> for Session { type Error = (); async fn from_request(req: &'r Request<'_>) -> Outcome { match Claims::from_request(req).await { Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session { uid: user.sub as i64, }), Outcome::Success(_) => { eprintln!("warning: user with scope other than Full attempted to access session"); Outcome::Error((Status::Forbidden, ())) } Outcome::Error(err) => { eprintln!("Session request guard failed: {:?}", err); Outcome::Error(err) } _ => unreachable!("forward should never be called"), } } } pub struct AdminSession { pub uid: i64, } #[rocket::async_trait] impl<'r> FromRequest<'r> for AdminSession { type Error = (); async fn from_request(req: &'r Request<'_>) -> Outcome { // First verify the session is valid match Claims::from_request(req).await { Outcome::Success(user) if user.scope == TokenScope::Full => { let uid = user.sub as i64; // Get AuthService from Rocket state let auth_svc = match req.guard::<&State>().await { Outcome::Success(svc) => svc, Outcome::Error(err) => { tracing::error!("AdminSession: Failed to get AuthService from state"); return Outcome::Error(err); } _ => unreachable!("forward should never be called"), }; // Check if user is admin match auth_svc.is_admin(uid).await { Ok(true) => Outcome::Success(AdminSession { uid }), Ok(false) => { tracing::debug!("non-admin user attempted to access admin session"); Outcome::Error((Status::Forbidden, ())) } Err(err) => { tracing::error!("AdminSession: is_admin check failed: {:?}", err); Outcome::Error((Status::InternalServerError, ())) } } } Outcome::Success(_) => { tracing::debug!("warning: user with scope other than Full attempted to access admin session"); Outcome::Error((Status::Forbidden, ())) } Outcome::Error(err) => { tracing::debug!("AdminSession request guard failed: {:?}", err); Outcome::Error(err) } _ => unreachable!("forward should never be called"), } } } #[derive(Serialize, Deserialize)] pub struct Claims { pub sub: i32, pub exp: usize, pub scope: TokenScope, } impl Claims { pub fn new(user_id: usize, scope: TokenScope) -> Self { Self { sub: user_id as i32, exp: (SystemTime::now() .duration_since(UNIX_EPOCH) .expect("Failed to get time") .as_secs() + 60 * 60 * 24 * 7) as usize, scope, } } pub fn encode(&self) -> String { encode( &Header::default(), self, &EncodingKey::from_secret(JWT_SECRET.as_bytes()), ) .expect("unable to encode jwt") } } #[rocket::async_trait] impl<'r> FromRequest<'r> for Claims { type Error = (); async fn from_request(req: &'r Request<'_>) -> Outcome { let token = req .headers() .get_one("Authorization") .and_then(|v| v.strip_prefix("Bearer ")); match token { None => Outcome::Error((Status::Unauthorized, ())), Some(t) => { match decode::( t, &DecodingKey::from_secret(JWT_SECRET.as_bytes()), &Validation::default(), ) { Ok(data) => Outcome::Success(data.claims), Err(_) => Outcome::Error((Status::Unauthorized, ())), } } } } }