more progress on TOTP/2FA

This commit is contained in:
2025-10-10 01:45:02 +01:00
parent b13cb5086a
commit 4a6c3bc49c
12 changed files with 189 additions and 197 deletions
+74 -32
View File
@@ -5,9 +5,10 @@ use rocket::{
Request,
fs::NamedFile,
http::{CookieJar, Status},
outcome::Outcome,
outcome::{Outcome, try_outcome},
post,
request::{self, FromRequest},
response::Redirect,
serde::json::Json,
};
use rocket_db_pools::{
@@ -38,7 +39,7 @@ pub async fn signup(
cred: Json<UserCredentials>,
jar: &CookieJar<'_>,
mut db: Connection<DbConn>,
) -> Result<Json<String>, String> {
) -> Result<Redirect, Status> {
let result = sqlx::query!(
"INSERT INTO users (username, password) VALUES ($1, $2) RETURNING id",
cred.username,
@@ -46,18 +47,18 @@ pub async fn signup(
)
.fetch_one(&mut **db)
.await
.map_err(|e| e.to_string())?;
.map_err(|e| Status::InternalServerError)?;
let session = Session::new(result.id as usize);
if let Err(e) = session.commit(&mut db).await {
eprintln!("Failed to create session: {}", e);
return Err(e.to_string());
return Err(Status::InternalServerError);
}
jar.add_private(("session", session.token));
println!("Signup successful");
Ok(Json("Signup successful".to_string()))
return Ok(Redirect::to("/chat"));
}
#[get("/login")]
@@ -70,7 +71,7 @@ pub async fn login(
mut db: Connection<DbConn>,
jar: &CookieJar<'_>,
cred: Json<UserCredentials>,
) -> Result<Json<String>, String> {
) -> Result<Redirect, Status> {
if let Ok(row) = sqlx::query!(
"SELECT id FROM users WHERE username = $1 AND password = $2",
cred.username,
@@ -82,41 +83,47 @@ pub async fn login(
let session = Session::new(row.id as usize);
if let Err(e) = session.commit(&mut db).await {
eprintln!("Failed to create session: {}", e);
return Err(e.to_string());
return Err(Status::InternalServerError);
}
jar.add_private(("session", session.token));
return Ok(Json("Signup successful".to_string()));
return Ok(Redirect::to("/chat"));
}
// TODO: implement actual login logic, e.g. verify password and generate token
Err("login failed".to_string())
Err(Status::Unauthorized)
}
#[get("/totp")]
pub async fn mfa_page(session: Session) -> Template {
pub async fn mfa_page(_session: Session) -> Template {
Template::render("2fa", context!())
}
#[get("/api/totp.jpg")]
pub async fn get_totp(s: Session) -> Option<QrCodeImage> {
#[derive(Serialize)]
pub struct QrResponse {
qr_code: String,
}
#[get("/totp.jpg")]
pub async fn get_totp(totp: TOTPCode) -> Option<Json<QrResponse>> {
let totp = TOTP::new(
Algorithm::SHA1,
6,
1,
30,
Secret::generate_secret().to_bytes().unwrap(),
Some("Github".to_string()),
format!("{}", s.user_id),
totp.secret.as_bytes().into(),
Some("chat.zxq5.dev".to_string()),
format!("{}", totp.user_id),
)
.unwrap();
let qr = totp.get_qr_base64().unwrap();
let data_uri = format!("data:image/png;base64,{}", qr);
Some(QrCodeImage(qr.into()))
Some(Json(QrResponse { qr_code: data_uri }))
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Session {
pub token: String,
pub user_id: usize,
@@ -145,6 +152,56 @@ impl Session {
}
}
pub struct TOTPCode {
user_id: usize,
secret: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for TOTPCode {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
let user = try_outcome!(request.guard::<Session>().await);
let mut pool = match request.guard::<Connection<DbConn>>().await {
Outcome::Success(pool) => pool,
_ => return Outcome::Error((Status::Unauthorized, ())),
};
let (enabled, mut secret) = match sqlx::query!(
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
user.user_id as i32,
)
.fetch_one(&mut **pool)
.await
{
Ok(row) => (row.twofa_enabled, row.totp_secret),
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
};
if !enabled || secret.is_none() {
secret = Some(Secret::generate_secret().to_string());
match sqlx::query!(
"UPDATE users SET totp_secret = $1, twofa_enabled = true WHERE id = $2",
secret.as_ref().unwrap(),
user.user_id as i32,
)
.execute(&mut **pool)
.await
{
Ok(_) => (),
Err(_) => return Outcome::Error((Status::InternalServerError, ())),
}
}
Outcome::Success(TOTPCode {
user_id: user.user_id,
secret: secret.unwrap(),
})
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Session {
type Error = ();
@@ -178,18 +235,3 @@ impl<'r> FromRequest<'r> for Session {
}
}
}
use rocket::http::ContentType;
use rocket::response::{self, Responder, Response};
use std::io::Cursor;
pub struct QrCodeImage(Vec<u8>);
impl<'r> Responder<'r, 'static> for QrCodeImage {
fn respond_to(self, _: &'r rocket::Request<'_>) -> response::Result<'static> {
Response::build()
.header(ContentType::PNG)
.sized_body(self.0.len(), Cursor::new(self.0))
.ok()
}
}
+2 -2
View File
@@ -60,10 +60,10 @@ impl LlmWorker {
.map_err(|_| String::from("Failed to make request to LLM server"))?;
Ok(ChatMsg {
display_name: message.display_name.clone(),
display_name: Some(String::from("lmstudio")),
user_id: 0,
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now().timestamp() as usize,
timestamp: chrono::Utc::now().timestamp_millis() as usize,
})
}
}
+10 -7
View File
@@ -34,12 +34,15 @@ async fn users(_ag: Session, mut db: Connection<DbConn>) -> Json<Vec<i32>> {
}
#[get("/users/<id>", rank = 1)]
async fn username_for_id(id: usize, _ag: Session, mut db: Connection<DbConn>) -> String {
sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
.fetch_one(&mut **db)
.await
.map(|row| row.username)
.unwrap_or_else(|_| "User not found".to_string())
async fn display_name(id: usize, _ag: Session, mut db: Connection<DbConn>) -> String {
sqlx::query!(
"SELECT display_name, username FROM users WHERE id = $1",
id as i32
)
.fetch_one(&mut **db)
.await
.map(|row| row.display_name.unwrap_or(row.username))
.unwrap_or_else(|_| "User not found".to_string())
}
#[launch]
@@ -80,7 +83,7 @@ fn rocket() -> Rocket<Build> {
messages::post_message,
messages::event_stream,
users,
username_for_id,
display_name,
auth::signup,
auth::login,
auth::get_totp,
+35 -17
View File
@@ -1,6 +1,7 @@
use std::sync::Arc;
use rocket::{
Shutdown,
response::stream::{Event, EventStream},
serde::json::Json,
time::OffsetDateTime,
@@ -9,9 +10,9 @@ use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use sqlx::prelude::FromRow;
use tokio::sync::broadcast;
use tokio::{select, sync::broadcast};
use crate::{auth::Session, db::DbConn, llm::LlmWorker};
use crate::{auth::Session, db::DbConn, display_name, llm::LlmWorker};
/// ---------- shared broadcaster ----------
pub struct ChatBroadcaster {
@@ -54,14 +55,25 @@ pub async fn post_message(
let chat = chat.inner().clone();
let display_name = sqlx::query!(
"SELECT display_name, username FROM users WHERE id = $1",
session.user_id as i32
)
.fetch_one(&mut **db)
.await
.map(|row| row.display_name.unwrap_or(row.username))
.unwrap_or_else(|_| "Unknown".to_string());
msg.user_id = session.user_id;
msg.display_name = Some(display_name);
chat.publish(msg.clone().into_inner()).await;
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content) VALUES ($1, $2, $3)",
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
CHANNEL_ID,
msg.user_id as i32,
msg.text
msg.text,
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
)
.execute(&mut **db)
.await
@@ -71,15 +83,15 @@ pub async fn post_message(
tokio::spawn(async move {
let response = LlmWorker::new(LMSTUDIO_URI.to_string()).query(&msg).await;
if let Ok(message) = response {
chat.publish(message.clone()).await;
if let Ok(reply) = response {
chat.publish(reply.clone()).await;
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
CHANNEL_ID,
message.user_id as i32,
message.text,
OffsetDateTime::from_unix_timestamp(message.timestamp as i64).unwrap()
reply.user_id as i32,
reply.text,
OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
)
.execute(&mut **db)
.await
@@ -106,7 +118,7 @@ pub async fn get_messages(mut db: Connection<DbConn>, _session: Session) -> Json
display_name: Some(msg.display_name.unwrap_or(msg.username)),
user_id: msg.id as usize,
text: msg.content,
timestamp: msg.created_at.unwrap().unix_timestamp() as usize,
timestamp: (msg.created_at.unwrap().unix_timestamp_nanos() / 1_000_000) as usize,
})
.collect(),
)
@@ -118,6 +130,7 @@ pub async fn event_stream(
chat: &rocket::State<Arc<ChatBroadcaster>>,
db: Connection<DbConn>,
ag: Session,
mut shutdown: Shutdown,
) -> EventStream![] {
let mut rx = chat.subscribe();
@@ -128,18 +141,23 @@ pub async fn event_stream(
}
loop {
match rx.recv().await {
Ok(msg) => yield Event::json(&msg),
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => break,
select!{
// exit early on shutdown
_ = &mut shutdown => break,
msg = rx.recv() => match msg {
Ok(msg) => yield Event::json(&msg),
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => break,
},
}
}
}
}
#[get("/")]
#[get("/chat")]
pub async fn chat_page(session: Session) -> Template {
Template::render("chat", context!(user_id: session.user_id))
}