more progress on TOTP/2FA
This commit is contained in:
+74
-32
@@ -5,9 +5,10 @@ use rocket::{
|
||||
Request,
|
||||
fs::NamedFile,
|
||||
http::{CookieJar, Status},
|
||||
outcome::Outcome,
|
||||
outcome::{Outcome, try_outcome},
|
||||
post,
|
||||
request::{self, FromRequest},
|
||||
response::Redirect,
|
||||
serde::json::Json,
|
||||
};
|
||||
use rocket_db_pools::{
|
||||
@@ -38,7 +39,7 @@ pub async fn signup(
|
||||
cred: Json<UserCredentials>,
|
||||
jar: &CookieJar<'_>,
|
||||
mut db: Connection<DbConn>,
|
||||
) -> Result<Json<String>, String> {
|
||||
) -> Result<Redirect, Status> {
|
||||
let result = sqlx::query!(
|
||||
"INSERT INTO users (username, password) VALUES ($1, $2) RETURNING id",
|
||||
cred.username,
|
||||
@@ -46,18 +47,18 @@ pub async fn signup(
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
.map_err(|e| Status::InternalServerError)?;
|
||||
|
||||
let session = Session::new(result.id as usize);
|
||||
if let Err(e) = session.commit(&mut db).await {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return Err(e.to_string());
|
||||
return Err(Status::InternalServerError);
|
||||
}
|
||||
|
||||
jar.add_private(("session", session.token));
|
||||
|
||||
println!("Signup successful");
|
||||
Ok(Json("Signup successful".to_string()))
|
||||
return Ok(Redirect::to("/chat"));
|
||||
}
|
||||
|
||||
#[get("/login")]
|
||||
@@ -70,7 +71,7 @@ pub async fn login(
|
||||
mut db: Connection<DbConn>,
|
||||
jar: &CookieJar<'_>,
|
||||
cred: Json<UserCredentials>,
|
||||
) -> Result<Json<String>, String> {
|
||||
) -> Result<Redirect, Status> {
|
||||
if let Ok(row) = sqlx::query!(
|
||||
"SELECT id FROM users WHERE username = $1 AND password = $2",
|
||||
cred.username,
|
||||
@@ -82,41 +83,47 @@ pub async fn login(
|
||||
let session = Session::new(row.id as usize);
|
||||
if let Err(e) = session.commit(&mut db).await {
|
||||
eprintln!("Failed to create session: {}", e);
|
||||
return Err(e.to_string());
|
||||
return Err(Status::InternalServerError);
|
||||
}
|
||||
|
||||
jar.add_private(("session", session.token));
|
||||
return Ok(Json("Signup successful".to_string()));
|
||||
return Ok(Redirect::to("/chat"));
|
||||
}
|
||||
|
||||
// TODO: implement actual login logic, e.g. verify password and generate token
|
||||
Err("login failed".to_string())
|
||||
Err(Status::Unauthorized)
|
||||
}
|
||||
|
||||
#[get("/totp")]
|
||||
pub async fn mfa_page(session: Session) -> Template {
|
||||
pub async fn mfa_page(_session: Session) -> Template {
|
||||
Template::render("2fa", context!())
|
||||
}
|
||||
|
||||
#[get("/api/totp.jpg")]
|
||||
pub async fn get_totp(s: Session) -> Option<QrCodeImage> {
|
||||
#[derive(Serialize)]
|
||||
pub struct QrResponse {
|
||||
qr_code: String,
|
||||
}
|
||||
|
||||
#[get("/totp.jpg")]
|
||||
pub async fn get_totp(totp: TOTPCode) -> Option<Json<QrResponse>> {
|
||||
let totp = TOTP::new(
|
||||
Algorithm::SHA1,
|
||||
6,
|
||||
1,
|
||||
30,
|
||||
Secret::generate_secret().to_bytes().unwrap(),
|
||||
Some("Github".to_string()),
|
||||
format!("{}", s.user_id),
|
||||
totp.secret.as_bytes().into(),
|
||||
Some("chat.zxq5.dev".to_string()),
|
||||
format!("{}", totp.user_id),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let qr = totp.get_qr_base64().unwrap();
|
||||
let data_uri = format!("data:image/png;base64,{}", qr);
|
||||
|
||||
Some(QrCodeImage(qr.into()))
|
||||
Some(Json(QrResponse { qr_code: data_uri }))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Session {
|
||||
pub token: String,
|
||||
pub user_id: usize,
|
||||
@@ -145,6 +152,56 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TOTPCode {
|
||||
user_id: usize,
|
||||
secret: String,
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for TOTPCode {
|
||||
type Error = ();
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||
let user = try_outcome!(request.guard::<Session>().await);
|
||||
let mut pool = match request.guard::<Connection<DbConn>>().await {
|
||||
Outcome::Success(pool) => pool,
|
||||
_ => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
let (enabled, mut secret) = match sqlx::query!(
|
||||
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
|
||||
user.user_id as i32,
|
||||
)
|
||||
.fetch_one(&mut **pool)
|
||||
.await
|
||||
{
|
||||
Ok(row) => (row.twofa_enabled, row.totp_secret),
|
||||
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
|
||||
};
|
||||
|
||||
if !enabled || secret.is_none() {
|
||||
secret = Some(Secret::generate_secret().to_string());
|
||||
|
||||
match sqlx::query!(
|
||||
"UPDATE users SET totp_secret = $1, twofa_enabled = true WHERE id = $2",
|
||||
secret.as_ref().unwrap(),
|
||||
user.user_id as i32,
|
||||
)
|
||||
.execute(&mut **pool)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(_) => return Outcome::Error((Status::InternalServerError, ())),
|
||||
}
|
||||
}
|
||||
|
||||
Outcome::Success(TOTPCode {
|
||||
user_id: user.user_id,
|
||||
secret: secret.unwrap(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for Session {
|
||||
type Error = ();
|
||||
@@ -178,18 +235,3 @@ impl<'r> FromRequest<'r> for Session {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use rocket::http::ContentType;
|
||||
use rocket::response::{self, Responder, Response};
|
||||
use std::io::Cursor;
|
||||
|
||||
pub struct QrCodeImage(Vec<u8>);
|
||||
|
||||
impl<'r> Responder<'r, 'static> for QrCodeImage {
|
||||
fn respond_to(self, _: &'r rocket::Request<'_>) -> response::Result<'static> {
|
||||
Response::build()
|
||||
.header(ContentType::PNG)
|
||||
.sized_body(self.0.len(), Cursor::new(self.0))
|
||||
.ok()
|
||||
}
|
||||
}
|
||||
|
||||
+2
-2
@@ -60,10 +60,10 @@ impl LlmWorker {
|
||||
.map_err(|_| String::from("Failed to make request to LLM server"))?;
|
||||
|
||||
Ok(ChatMsg {
|
||||
display_name: message.display_name.clone(),
|
||||
display_name: Some(String::from("lmstudio")),
|
||||
user_id: 0,
|
||||
text: llm_resp.choices[0].message.content.clone(),
|
||||
timestamp: chrono::Utc::now().timestamp() as usize,
|
||||
timestamp: chrono::Utc::now().timestamp_millis() as usize,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+10
-7
@@ -34,12 +34,15 @@ async fn users(_ag: Session, mut db: Connection<DbConn>) -> Json<Vec<i32>> {
|
||||
}
|
||||
|
||||
#[get("/users/<id>", rank = 1)]
|
||||
async fn username_for_id(id: usize, _ag: Session, mut db: Connection<DbConn>) -> String {
|
||||
sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map(|row| row.username)
|
||||
.unwrap_or_else(|_| "User not found".to_string())
|
||||
async fn display_name(id: usize, _ag: Session, mut db: Connection<DbConn>) -> String {
|
||||
sqlx::query!(
|
||||
"SELECT display_name, username FROM users WHERE id = $1",
|
||||
id as i32
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map(|row| row.display_name.unwrap_or(row.username))
|
||||
.unwrap_or_else(|_| "User not found".to_string())
|
||||
}
|
||||
|
||||
#[launch]
|
||||
@@ -80,7 +83,7 @@ fn rocket() -> Rocket<Build> {
|
||||
messages::post_message,
|
||||
messages::event_stream,
|
||||
users,
|
||||
username_for_id,
|
||||
display_name,
|
||||
auth::signup,
|
||||
auth::login,
|
||||
auth::get_totp,
|
||||
|
||||
+35
-17
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use rocket::{
|
||||
Shutdown,
|
||||
response::stream::{Event, EventStream},
|
||||
serde::json::Json,
|
||||
time::OffsetDateTime,
|
||||
@@ -9,9 +10,9 @@ use rocket_db_pools::Connection;
|
||||
use rocket_dyn_templates::{Template, context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::prelude::FromRow;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::{select, sync::broadcast};
|
||||
|
||||
use crate::{auth::Session, db::DbConn, llm::LlmWorker};
|
||||
use crate::{auth::Session, db::DbConn, display_name, llm::LlmWorker};
|
||||
|
||||
/// ---------- shared broadcaster ----------
|
||||
pub struct ChatBroadcaster {
|
||||
@@ -54,14 +55,25 @@ pub async fn post_message(
|
||||
|
||||
let chat = chat.inner().clone();
|
||||
|
||||
let display_name = sqlx::query!(
|
||||
"SELECT display_name, username FROM users WHERE id = $1",
|
||||
session.user_id as i32
|
||||
)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map(|row| row.display_name.unwrap_or(row.username))
|
||||
.unwrap_or_else(|_| "Unknown".to_string());
|
||||
|
||||
msg.user_id = session.user_id;
|
||||
msg.display_name = Some(display_name);
|
||||
chat.publish(msg.clone().into_inner()).await;
|
||||
|
||||
sqlx::query!(
|
||||
"INSERT INTO messages (channel_id, user_id, content) VALUES ($1, $2, $3)",
|
||||
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
|
||||
CHANNEL_ID,
|
||||
msg.user_id as i32,
|
||||
msg.text
|
||||
msg.text,
|
||||
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
|
||||
)
|
||||
.execute(&mut **db)
|
||||
.await
|
||||
@@ -71,15 +83,15 @@ pub async fn post_message(
|
||||
tokio::spawn(async move {
|
||||
let response = LlmWorker::new(LMSTUDIO_URI.to_string()).query(&msg).await;
|
||||
|
||||
if let Ok(message) = response {
|
||||
chat.publish(message.clone()).await;
|
||||
if let Ok(reply) = response {
|
||||
chat.publish(reply.clone()).await;
|
||||
|
||||
sqlx::query!(
|
||||
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
|
||||
CHANNEL_ID,
|
||||
message.user_id as i32,
|
||||
message.text,
|
||||
OffsetDateTime::from_unix_timestamp(message.timestamp as i64).unwrap()
|
||||
reply.user_id as i32,
|
||||
reply.text,
|
||||
OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
|
||||
)
|
||||
.execute(&mut **db)
|
||||
.await
|
||||
@@ -106,7 +118,7 @@ pub async fn get_messages(mut db: Connection<DbConn>, _session: Session) -> Json
|
||||
display_name: Some(msg.display_name.unwrap_or(msg.username)),
|
||||
user_id: msg.id as usize,
|
||||
text: msg.content,
|
||||
timestamp: msg.created_at.unwrap().unix_timestamp() as usize,
|
||||
timestamp: (msg.created_at.unwrap().unix_timestamp_nanos() / 1_000_000) as usize,
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
@@ -118,6 +130,7 @@ pub async fn event_stream(
|
||||
chat: &rocket::State<Arc<ChatBroadcaster>>,
|
||||
db: Connection<DbConn>,
|
||||
ag: Session,
|
||||
mut shutdown: Shutdown,
|
||||
) -> EventStream![] {
|
||||
let mut rx = chat.subscribe();
|
||||
|
||||
@@ -128,18 +141,23 @@ pub async fn event_stream(
|
||||
}
|
||||
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(msg) => yield Event::json(&msg),
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => {
|
||||
yield Event::comment("RecvError::Lagged");
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
select!{
|
||||
// exit early on shutdown
|
||||
_ = &mut shutdown => break,
|
||||
|
||||
msg = rx.recv() => match msg {
|
||||
Ok(msg) => yield Event::json(&msg),
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => {
|
||||
yield Event::comment("RecvError::Lagged");
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/")]
|
||||
#[get("/chat")]
|
||||
pub async fn chat_page(session: Session) -> Template {
|
||||
Template::render("chat", context!(user_id: session.user_id))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user