megacommit
This commit is contained in:
@@ -1,282 +0,0 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use rand::Rng;
|
||||
use rocket::{
|
||||
Request,
|
||||
http::{CookieJar, Status},
|
||||
outcome::{Outcome, try_outcome},
|
||||
post,
|
||||
request::{self, FromRequest},
|
||||
response::{Redirect, status},
|
||||
serde::json::Json,
|
||||
};
|
||||
use rocket_db_pools::{
|
||||
Connection,
|
||||
sqlx::{self},
|
||||
};
|
||||
use rocket_dyn_templates::{Template, context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::postgres::PgQueryResult;
|
||||
use totp_rs::{Algorithm, Secret, TOTP};
|
||||
|
||||
use crate::db::DbConn;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct UserCredentials {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[get("/signup")]
|
||||
pub async fn signup_page() -> Template {
|
||||
Template::render("signup", context!())
|
||||
}
|
||||
|
||||
#[post("/signup", data = "<cred>")]
|
||||
pub async fn signup(
|
||||
cred: Json<UserCredentials>,
|
||||
jar: &CookieJar<'_>,
|
||||
mut db: Connection<DbConn>,
|
||||
) -> Result<Redirect, Status> {
|
||||
let result = sqlx::query!(
|
||||
"INSERT INTO users (username, password) VALUES ($1, $2) RETURNING id",
|
||||
cred.username,
|
||||
cred.password
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map_err(|e| Status::InternalServerError)?;
|
||||
|
||||
let session = Session::new(result.id as usize);
|
||||
if let Err(e) = session.commit(&mut db).await {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return Err(Status::InternalServerError);
|
||||
}
|
||||
|
||||
jar.add_private(("session", session.token));
|
||||
|
||||
println!("Signup successful");
|
||||
return Ok(Redirect::to("/chat"));
|
||||
}
|
||||
|
||||
#[get("/login")]
|
||||
pub async fn login_page() -> Template {
|
||||
Template::render("login", context!())
|
||||
}
|
||||
|
||||
#[post("/login", data = "<cred>")]
|
||||
pub async fn login(
|
||||
mut db: Connection<DbConn>,
|
||||
jar: &CookieJar<'_>,
|
||||
cred: Json<UserCredentials>,
|
||||
) -> Result<Redirect, Status> {
|
||||
if let Ok(row) = sqlx::query!(
|
||||
"SELECT id FROM users WHERE username = $1 AND password = $2",
|
||||
cred.username,
|
||||
cred.password,
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
{
|
||||
let session = Session::new(row.id as usize);
|
||||
if let Err(e) = session.commit(&mut db).await {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return Err(Status::InternalServerError);
|
||||
}
|
||||
|
||||
jar.add_private(("session", session.token));
|
||||
return Ok(Redirect::to("/chat"));
|
||||
}
|
||||
|
||||
// TODO: implement actual login logic, e.g. verify password and generate token
|
||||
Err(Status::Unauthorized)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Session {
|
||||
pub token: String,
|
||||
pub user_id: usize,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(user_id: usize) -> Self {
|
||||
let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
||||
let random: u32 = rand::rng().random();
|
||||
let token = format!("{}-{}", current_time.as_secs(), random);
|
||||
let hashed = format!("{:x}", Sha256::digest(token.as_bytes()));
|
||||
Self {
|
||||
token: hashed,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn commit(&self, db: &mut Connection<DbConn>) -> Result<PgQueryResult, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
"INSERT INTO sessions (user_id, token) VALUES ($1, $2)",
|
||||
self.user_id as i32,
|
||||
self.token,
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for Session {
|
||||
type Error = ();
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
if let Some(c) = request.cookies().get_private("session") {
|
||||
let mut pool = match request.guard::<Connection<DbConn>>().await {
|
||||
Outcome::Success(pool) => pool,
|
||||
_ => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
let value = c.value();
|
||||
let result = sqlx::query!(
|
||||
"SELECT user_id, token FROM sessions WHERE token = $1 AND expires_at > NOW()",
|
||||
value
|
||||
)
|
||||
.fetch_optional(&mut **pool)
|
||||
.await
|
||||
.expect("query failed!");
|
||||
|
||||
if let Some(session) = result {
|
||||
Outcome::Success(Self {
|
||||
user_id: session.user_id as usize,
|
||||
token: session.token,
|
||||
})
|
||||
} else {
|
||||
Outcome::Error((Status::Unauthorized, ()))
|
||||
}
|
||||
} else {
|
||||
Outcome::Error((Status::Unauthorized, ()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------- TOTP 2FA Auth -----------------------
|
||||
|
||||
#[get("/totp")]
|
||||
pub async fn mfa_page(_session: Session) -> Template {
|
||||
Template::render("2fa", context!())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Totp {
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[post("/totp", data = "<totp>")]
|
||||
pub async fn confirm_totp(
|
||||
mfa: MultiFactorEnabled,
|
||||
totp: Json<Totp>,
|
||||
mut db: Connection<DbConn>,
|
||||
) -> Status {
|
||||
if totp.code.len() == 6
|
||||
&& let Ok(code) = totp.code.parse::<usize>()
|
||||
{
|
||||
let secret = match sqlx::query!(
|
||||
"SELECT totp_secret FROM users WHERE id = $1",
|
||||
mfa.user_id as i32
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
{
|
||||
Err(_) => return Status::InternalServerError,
|
||||
Ok(user) => user.totp_secret,
|
||||
};
|
||||
}
|
||||
|
||||
return Status::BadRequest;
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct QrResponse {
|
||||
qr_code: String,
|
||||
}
|
||||
|
||||
#[get("/totp.jpg")]
|
||||
pub async fn get_totp(mfa: MultiFactorEnabled) -> Option<Json<QrResponse>> {
|
||||
let totp = TOTP::new(
|
||||
Algorithm::SHA1,
|
||||
6,
|
||||
1,
|
||||
30,
|
||||
mfa.secret.as_bytes().into(),
|
||||
Some("chat.zxq5.dev".to_string()),
|
||||
format!("{}", mfa.user_id),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let qr = totp.get_qr_base64().unwrap();
|
||||
let data_uri = format!("data:image/png;base64,{}", qr);
|
||||
|
||||
Some(Json(QrResponse { qr_code: data_uri }))
|
||||
}
|
||||
|
||||
pub struct MultiFactorEnabled {
|
||||
user_id: usize,
|
||||
secret: String,
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for MultiFactorEnabled {
|
||||
type Error = ();
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
let user = try_outcome!(request.guard::<Session>().await);
|
||||
let mut pool = match request.guard::<Connection<DbConn>>().await {
|
||||
Outcome::Success(pool) => pool,
|
||||
_ => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
let (enabled, mut secret) = match sqlx::query!(
|
||||
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
|
||||
user.user_id as i32,
|
||||
)
|
||||
.fetch_one(&mut **pool)
|
||||
.await
|
||||
{
|
||||
Ok(row) => (row.twofa_enabled, row.totp_secret),
|
||||
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
if !enabled || secret.is_none() {
|
||||
secret = Some(Secret::generate_secret().to_string());
|
||||
|
||||
match sqlx::query!(
|
||||
"UPDATE users SET totp_secret = $1 WHERE id = $2",
|
||||
secret.as_ref().unwrap(),
|
||||
user.user_id as i32,
|
||||
)
|
||||
.execute(&mut **pool)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(_) => return Outcome::Error((Status::InternalServerError, ())),
|
||||
}
|
||||
}
|
||||
|
||||
Outcome::Success(MultiFactorEnabled {
|
||||
user_id: user.user_id,
|
||||
secret: secret.unwrap(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiFactorEnabled {
|
||||
pub async fn enable(&self, db: &mut Connection<DbConn>) -> Result<(), ()> {
|
||||
match sqlx::query!(
|
||||
"UPDATE users SET twofa_enabled = true WHERE id = $1",
|
||||
self.user_id as i32,
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
use rocket::{
|
||||
http::{CookieJar, Status},
|
||||
response::{Redirect, status::BadRequest},
|
||||
serde::json::Json,
|
||||
time::{OffsetDateTime, PrimitiveDateTime},
|
||||
};
|
||||
use rocket_db_pools::Connection;
|
||||
use rocket_dyn_templates::{Template, context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{auth::session::Session, db::DbConn};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SignupCredentials {
|
||||
pub email: String,
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub access_token: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct LoginCredentials {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[get("/signup")]
|
||||
pub async fn signup_page() -> Template {
|
||||
Template::render("signup", context!())
|
||||
}
|
||||
|
||||
#[post("/signup", data = "<cred>")]
|
||||
pub async fn signup(
|
||||
cred: Json<SignupCredentials>,
|
||||
jar: &CookieJar<'_>,
|
||||
mut db: Connection<DbConn>,
|
||||
) -> Result<Redirect, BadRequest<String>> {
|
||||
println!("phase 1 {}", cred.access_token);
|
||||
let token_id = AccessToken::validate(&cred.access_token, &mut db).await?;
|
||||
|
||||
println!("phase 2");
|
||||
let result = sqlx::query!(
|
||||
"INSERT INTO users (email, username, password) VALUES ($1, $2, $3) RETURNING id",
|
||||
cred.email,
|
||||
cred.username,
|
||||
cred.password
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map_err(|e| BadRequest(String::from("Failed to create user")))?;
|
||||
|
||||
println!("phase 3");
|
||||
let session = Session::new(result.id as usize);
|
||||
if let Err(e) = session.commit(&mut db).await {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return Err(BadRequest(String::from("Failed to create session")));
|
||||
}
|
||||
|
||||
println!("phase 4");
|
||||
jar.add_private(("session", session.token));
|
||||
|
||||
token_id.use_token(&mut db).await?;
|
||||
|
||||
println!("phase 5");
|
||||
return Ok(Redirect::to("/chat"));
|
||||
}
|
||||
|
||||
#[get("/login")]
|
||||
pub async fn login_page() -> Template {
|
||||
Template::render("login", context!())
|
||||
}
|
||||
|
||||
#[post("/login", data = "<cred>")]
|
||||
pub async fn login(
|
||||
mut db: Connection<DbConn>,
|
||||
jar: &CookieJar<'_>,
|
||||
cred: Json<LoginCredentials>,
|
||||
) -> Result<Redirect, Status> {
|
||||
if let Ok(row) = sqlx::query!(
|
||||
"SELECT id FROM users WHERE username = $1 AND password = $2",
|
||||
cred.username,
|
||||
cred.password,
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
{
|
||||
let session = Session::new(row.id as usize);
|
||||
if let Err(e) = session.commit(&mut db).await {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return Err(Status::InternalServerError);
|
||||
}
|
||||
|
||||
jar.add_private(("session", session.token));
|
||||
return Ok(Redirect::to("/chat"));
|
||||
}
|
||||
|
||||
// TODO: implement actual login logic, e.g. verify password and generate token
|
||||
Err(Status::Unauthorized)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessTokenForm {
|
||||
pub name: String,
|
||||
pub max_uses: usize,
|
||||
pub expiry_date: usize,
|
||||
pub start_date: usize,
|
||||
}
|
||||
|
||||
#[get("/invite")]
|
||||
pub async fn invite_page(s: Session) -> Template {
|
||||
Template::render("invite", context! {})
|
||||
}
|
||||
|
||||
#[post("/invite", data = "<form>")]
|
||||
pub async fn generate_invite(
|
||||
session: Session,
|
||||
mut db: Connection<DbConn>,
|
||||
form: Json<AccessTokenForm>,
|
||||
) -> Result<String, Status> {
|
||||
if form.start_date > form.expiry_date {
|
||||
return Err(Status::BadRequest);
|
||||
}
|
||||
|
||||
let code = Uuid::new_v4().to_string();
|
||||
let row = sqlx::query!(
|
||||
"INSERT INTO access_codes (name, code, creator_id, max_uses, created_at, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
|
||||
form.name,
|
||||
code,
|
||||
session.user_id as i32,
|
||||
form.max_uses as i32,
|
||||
OffsetDateTime::from_unix_timestamp_nanos(form.start_date as i128 * 1_000_000).unwrap(),
|
||||
OffsetDateTime::from_unix_timestamp_nanos(form.expiry_date as i128 * 1_000_000).unwrap()
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map_err(|_| Status::InternalServerError)?;
|
||||
|
||||
Ok(code)
|
||||
}
|
||||
|
||||
pub struct AccessToken {
|
||||
id: i32,
|
||||
code: String,
|
||||
}
|
||||
|
||||
impl AccessToken {
|
||||
pub async fn validate(
|
||||
token: &str,
|
||||
db: &mut Connection<DbConn>,
|
||||
) -> Result<AccessToken, BadRequest<String>> {
|
||||
match sqlx::query!(
|
||||
"SELECT id FROM access_codes
|
||||
WHERE code = $1
|
||||
AND created_at < NOW()
|
||||
AND expires_at > NOW()
|
||||
AND uses < max_uses",
|
||||
token
|
||||
)
|
||||
.fetch_one(&mut ***db)
|
||||
.await
|
||||
{
|
||||
Ok(row) => Ok(AccessToken {
|
||||
id: row.id,
|
||||
code: token.to_string(),
|
||||
}),
|
||||
Err(_) => Err(BadRequest(String::from("Invalid or Expired token!"))),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn use_token(&self, db: &mut Connection<DbConn>) -> Result<(), BadRequest<String>> {
|
||||
sqlx::query!(
|
||||
"UPDATE access_codes SET uses = uses + 1 WHERE id = $1",
|
||||
self.id
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await
|
||||
.map_err(|_| BadRequest(String::from("Invalid or Expired token!")))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
pub mod account;
|
||||
pub mod session;
|
||||
pub mod two_factor;
|
||||
|
||||
pub use session::Session;
|
||||
|
||||
pub use account::{generate_invite, invite_page, login, login_page, signup, signup_page};
|
||||
pub use two_factor::{confirm_totp, get_totp, mfa_page};
|
||||
@@ -0,0 +1,76 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use rand::Rng;
|
||||
use rocket::{
|
||||
Request,
|
||||
http::Status,
|
||||
request::{self, FromRequest, Outcome},
|
||||
};
|
||||
use rocket_db_pools::Connection;
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::postgres::PgQueryResult;
|
||||
|
||||
use crate::db::DbConn;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Session {
|
||||
pub token: String,
|
||||
pub user_id: usize,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(user_id: usize) -> Self {
|
||||
let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
|
||||
let random: u32 = rand::rng().random();
|
||||
let token = format!("{}-{}", current_time.as_secs(), random);
|
||||
let hashed = format!("{:x}", Sha256::digest(token.as_bytes()));
|
||||
Self {
|
||||
token: hashed,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn commit(&self, db: &mut Connection<DbConn>) -> Result<PgQueryResult, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
"INSERT INTO sessions (user_id, token) VALUES ($1, $2)",
|
||||
self.user_id as i32,
|
||||
self.token,
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for Session {
|
||||
type Error = ();
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
if let Some(c) = request.cookies().get_private("session") {
|
||||
let mut pool = match request.guard::<Connection<DbConn>>().await {
|
||||
Outcome::Success(pool) => pool,
|
||||
_ => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
let value = c.value();
|
||||
let result = sqlx::query!(
|
||||
"SELECT user_id, token FROM sessions WHERE token = $1 AND expires_at > NOW()",
|
||||
value
|
||||
)
|
||||
.fetch_optional(&mut **pool)
|
||||
.await
|
||||
.expect("query failed!");
|
||||
|
||||
if let Some(session) = result {
|
||||
Outcome::Success(Self {
|
||||
user_id: session.user_id as usize,
|
||||
token: session.token,
|
||||
})
|
||||
} else {
|
||||
Outcome::Error((Status::Unauthorized, ()))
|
||||
}
|
||||
} else {
|
||||
Outcome::Error((Status::Unauthorized, ()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
use rocket::{
|
||||
Request,
|
||||
http::Status,
|
||||
outcome::{Outcome, try_outcome},
|
||||
request::{self, FromRequest},
|
||||
serde::json::Json,
|
||||
};
|
||||
use rocket_db_pools::Connection;
|
||||
use rocket_dyn_templates::{Template, context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use totp_rs::{Algorithm, Secret, TOTP};
|
||||
|
||||
use crate::{auth::session::Session, db::DbConn};
|
||||
|
||||
// Utility methods
|
||||
|
||||
pub fn totp_gen(user_id: usize, secret: &[u8]) -> Result<TOTP, String> {
|
||||
TOTP::new(
|
||||
Algorithm::SHA1,
|
||||
6,
|
||||
1,
|
||||
30,
|
||||
secret.to_owned(),
|
||||
Some("chat.zxq5.dev".to_string()),
|
||||
format!("{}", user_id),
|
||||
)
|
||||
.map_err(|_| String::from("Invalid Secret"))
|
||||
}
|
||||
|
||||
// pages
|
||||
|
||||
#[get("/totp")]
|
||||
pub async fn mfa_page(_session: Session) -> Template {
|
||||
Template::render("2fa", context!())
|
||||
}
|
||||
|
||||
// api
|
||||
|
||||
#[post("/totp", data = "<totp>")]
|
||||
pub async fn confirm_totp(
|
||||
mfa: TOTPSecret,
|
||||
totp: Json<TOTPSixDigitCode>,
|
||||
mut db: Connection<DbConn>,
|
||||
) -> Status {
|
||||
if totp.code.len() == 6
|
||||
&& let Ok(code) = totp.code.parse::<usize>()
|
||||
{
|
||||
let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes()).unwrap();
|
||||
|
||||
println!("input: {} {}", code, totp.generate_current().unwrap());
|
||||
|
||||
if totp.check_current(&format!("{}", mfa.user_id)).unwrap() {
|
||||
if sqlx::query!(
|
||||
"UPDATE users SET twofa_enabled = true WHERE id = $1",
|
||||
mfa.user_id as i32
|
||||
)
|
||||
.execute(&mut **db)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Status::InternalServerError;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
println!("ok!");
|
||||
|
||||
return Status::Ok;
|
||||
}
|
||||
|
||||
#[get("/totp.jpg")]
|
||||
pub async fn get_totp(mfa: TOTPSecret) -> Option<Json<QrResponse>> {
|
||||
let qr_b64 = totp_gen(mfa.user_id, mfa.secret.as_bytes())
|
||||
.expect("Invalid TOTP")
|
||||
.get_qr_base64()
|
||||
.unwrap();
|
||||
|
||||
Some(Json(QrResponse {
|
||||
qr_code: format!("data:image/png;base64,{}", qr_b64),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct TOTPSixDigitCode {
|
||||
code: String,
|
||||
}
|
||||
|
||||
pub struct TOTPSecret {
|
||||
user_id: usize,
|
||||
secret: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct QrResponse {
|
||||
qr_code: String,
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for TOTPSecret {
|
||||
type Error = ();
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
let user = try_outcome!(request.guard::<Session>().await);
|
||||
let mut pool = match request.guard::<Connection<DbConn>>().await {
|
||||
Outcome::Success(pool) => pool,
|
||||
_ => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
let (enabled, mut secret) = match sqlx::query!(
|
||||
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
|
||||
user.user_id as i32,
|
||||
)
|
||||
.fetch_one(&mut **pool)
|
||||
.await
|
||||
{
|
||||
Ok(row) => (row.twofa_enabled, row.totp_secret),
|
||||
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
if !enabled || secret.is_none() {
|
||||
secret = Some(Secret::generate_secret().to_string());
|
||||
|
||||
match sqlx::query!(
|
||||
"UPDATE users SET totp_secret = $1 WHERE id = $2",
|
||||
secret.as_ref().unwrap(),
|
||||
user.user_id as i32,
|
||||
)
|
||||
.execute(&mut **pool)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(_) => return Outcome::Error((Status::InternalServerError, ())),
|
||||
}
|
||||
}
|
||||
|
||||
Outcome::Success(TOTPSecret {
|
||||
user_id: user.user_id,
|
||||
secret: secret.unwrap(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TOTPSecret {
|
||||
pub async fn enable(&self, db: &mut Connection<DbConn>) -> Result<(), ()> {
|
||||
match sqlx::query!(
|
||||
"UPDATE users SET twofa_enabled = true WHERE id = $1",
|
||||
self.user_id as i32,
|
||||
)
|
||||
.execute(&mut ***db)
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
use rocket::{Request, http::Status};
|
||||
use rocket_dyn_templates::{Template, context};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorContext {
|
||||
error_code: u16,
|
||||
error_message: &'static str,
|
||||
additional_info: &'static str,
|
||||
redirect: Option<RedirectContext>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RedirectContext {
|
||||
url: &'static str,
|
||||
message: &'static str,
|
||||
}
|
||||
|
||||
#[catch(404)]
|
||||
pub async fn handle_404() -> Template {
|
||||
Template::render(
|
||||
"error",
|
||||
ErrorContext {
|
||||
error_code: 404,
|
||||
error_message: "Not Found",
|
||||
additional_info: "There's nothing here.",
|
||||
redirect: Some(RedirectContext {
|
||||
url: "/",
|
||||
message: "Home",
|
||||
}),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[catch(401)]
|
||||
pub async fn handle_401() -> Template {
|
||||
Template::render(
|
||||
"error",
|
||||
ErrorContext {
|
||||
error_code: 401,
|
||||
error_message: "Unauthorised",
|
||||
additional_info: "You are not authorised to access this resource.",
|
||||
redirect: Some(RedirectContext {
|
||||
url: "/login",
|
||||
message: "Login",
|
||||
}),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[catch(default)]
|
||||
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
|
||||
Template::render(
|
||||
"error",
|
||||
ErrorContext {
|
||||
error_code: status.code,
|
||||
error_message: "Unknown Error",
|
||||
additional_info: "I don't know what to do with this error.",
|
||||
redirect: None,
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -18,6 +18,7 @@ use crate::messages::ChatBroadcaster;
|
||||
pub mod auth;
|
||||
pub mod cdn;
|
||||
pub mod db;
|
||||
pub mod handlers;
|
||||
pub mod llm;
|
||||
pub mod messages;
|
||||
|
||||
@@ -71,9 +72,11 @@ fn rocket() -> Rocket<Build> {
|
||||
routes![
|
||||
favicon,
|
||||
messages::chat_page,
|
||||
messages::chat_page_preview,
|
||||
auth::signup_page,
|
||||
auth::login_page,
|
||||
auth::mfa_page,
|
||||
auth::invite_page,
|
||||
],
|
||||
)
|
||||
.mount(
|
||||
@@ -87,6 +90,16 @@ fn rocket() -> Rocket<Build> {
|
||||
auth::signup,
|
||||
auth::login,
|
||||
auth::get_totp,
|
||||
auth::confirm_totp,
|
||||
auth::generate_invite,
|
||||
],
|
||||
)
|
||||
.register(
|
||||
"/",
|
||||
catchers![
|
||||
handlers::handle_404,
|
||||
handlers::handle_401,
|
||||
handlers::handle_default
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
@@ -161,3 +161,8 @@ pub async fn event_stream(
|
||||
pub async fn chat_page(session: Session) -> Template {
|
||||
Template::render("chat", context!(user_id: session.user_id))
|
||||
}
|
||||
|
||||
#[get("/chatpreview")]
|
||||
pub async fn chat_page_preview(session: Session) -> Template {
|
||||
Template::render("chatpreview", context!(user_id: session.user_id))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user