diff --git a/backend/Cargo.toml b/backend/Cargo.toml index c83e5f8..1f1fd71 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -8,10 +8,11 @@ argon2 = "0.5.3" chrono = { version = "0.4.42", features = ["serde"] } futures-util = "0.3.31" rand = "0.9.2" +redis = { version = "0.25.4", features = ["tokio-comp"] } reqwest = { version = "0.12.23", features = ["json"] } rocket = { version = "0.5.1", features = ["json", "secrets"] } rocket_cors = "0.6.0" -rocket_db_pools = { version = "0.2.0", features = ["sqlx_macros", "sqlx_postgres"] } +rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] } rocket_dyn_templates = { version = "0.2.0", features = ["tera"] } serde = { version = "1.0.228", features = ["derive"] } sha2 = "0.10.9" diff --git a/backend/Rocket.toml b/backend/Rocket.toml index 66a779b..edba168 100644 --- a/backend/Rocket.toml +++ b/backend/Rocket.toml @@ -6,6 +6,9 @@ port = 8000 [default.databases.postgres_db] url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp" +[default.databases.redis_cache] +url = "redis://chatapp_redis:6379" + [default] # run inside a docker container or pod address = "0.0.0.0" port = 8000 diff --git a/backend/src/auth/account.rs b/backend/src/auth/account.rs index 96cfc1c..212bbbc 100644 --- a/backend/src/auth/account.rs +++ b/backend/src/auth/account.rs @@ -9,7 +9,7 @@ use rocket_dyn_templates::{Template, context}; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{auth::session::Session, db::DbConn}; +use crate::{auth::session::Session, db::Postgres}; #[derive(Serialize, Deserialize)] pub struct SignupCredentials { @@ -34,7 +34,7 @@ pub async fn signup_page() -> Template { pub async fn signup( cred: Json, jar: &CookieJar<'_>, - mut db: Connection, + mut db: Connection, ) -> Result> { println!("phase 1 {}", cred.access_token); let token_id = AccessToken::validate(&cred.access_token, &mut db).await?; @@ -73,7 +73,7 @@ pub async fn login_page() -> Template { #[post("/login", data = "")] pub async fn login( - mut db: Connection, + mut db: Connection, jar: &CookieJar<'_>, cred: Json, ) -> Result { @@ -115,7 +115,7 @@ pub async fn invite_page(_s: Session) -> Template { #[post("/invite", data = "
")] pub async fn generate_invite( session: Session, - mut db: Connection, + mut db: Connection, form: Json, ) -> Result { if form.start_date > form.expiry_date { @@ -148,7 +148,7 @@ pub struct AccessToken { impl AccessToken { pub async fn validate( token: &str, - db: &mut Connection, + db: &mut Connection, ) -> Result> { match sqlx::query!( "SELECT id FROM access_codes @@ -169,7 +169,7 @@ impl AccessToken { } } - pub async fn use_token(&self, db: &mut Connection) -> Result<(), BadRequest> { + pub async fn use_token(&self, db: &mut Connection) -> Result<(), BadRequest> { sqlx::query!( "UPDATE access_codes SET uses = uses + 1 WHERE id = $1", self.id diff --git a/backend/src/auth/session.rs b/backend/src/auth/session.rs index 95d5601..1525cfd 100644 --- a/backend/src/auth/session.rs +++ b/backend/src/auth/session.rs @@ -10,7 +10,7 @@ use rocket_db_pools::Connection; use sha2::{Digest, Sha256}; use sqlx::postgres::PgQueryResult; -use crate::db::DbConn; +use crate::db::Postgres; #[derive(Debug, Clone)] pub struct Session { @@ -30,7 +30,10 @@ impl Session { } } - pub async fn commit(&self, db: &mut Connection) -> Result { + pub async fn commit( + &self, + db: &mut Connection, + ) -> Result { sqlx::query!( "INSERT INTO sessions (user_id, token) VALUES ($1, $2)", self.user_id as i32, @@ -47,7 +50,7 @@ impl<'r> FromRequest<'r> for Session { async fn from_request(request: &'r Request<'_>) -> request::Outcome { if let Some(c) = request.cookies().get_private("session") { - let mut pool = match request.guard::>().await { + let mut pool = match request.guard::>().await { Outcome::Success(pool) => pool, _ => return Outcome::Error((Status::Unauthorized, ())), }; diff --git a/backend/src/auth/two_factor.rs b/backend/src/auth/two_factor.rs index 80a050a..038cc6c 100644 --- a/backend/src/auth/two_factor.rs +++ b/backend/src/auth/two_factor.rs @@ -3,6 +3,7 @@ use rocket::{ http::Status, outcome::{Outcome, try_outcome}, request::{self, FromRequest}, + response::status::{self, BadRequest}, serde::json::Json, }; use rocket_db_pools::Connection; @@ -10,7 +11,7 @@ use rocket_dyn_templates::{Template, context}; use serde::{Deserialize, Serialize}; use totp_rs::{Algorithm, Secret, TOTP}; -use crate::{auth::session::Session, db::DbConn}; +use crate::{auth::session::Session, db::Postgres}; // Utility methods @@ -36,36 +37,42 @@ pub async fn mfa_page(_session: Session) -> Template { // api -#[post("/totp", data = "")] +#[post("/totp", data = "")] pub async fn confirm_totp( mfa: TOTPSecret, - totp: Json, - mut db: Connection, -) -> Status { - if totp.code.len() == 6 - && let Ok(code) = totp.code.parse::() - { - let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes()).unwrap(); - - println!("input: {} {}", code, totp.generate_current().unwrap()); - - if totp.check_current(&format!("{}", mfa.user_id)).unwrap() { - if sqlx::query!( - "UPDATE users SET twofa_enabled = true WHERE id = $1", - mfa.user_id as i32 - ) - .execute(&mut **db) - .await - .is_err() - { - return Status::InternalServerError; - }; - } + form: Json, + mut db: Connection, +) -> Result<(), status::Custom<&'static str>> { + if form.code.len() != 6 && form.code.parse::().is_err() { + return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code")); } - println!("ok!"); + println!("valid"); - return Status::Ok; + let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes()).unwrap(); + if !totp.check_current(&format!("{}", form.code)).unwrap() { + return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code")); + } + + println!("correct"); + + if sqlx::query!( + "UPDATE users SET twofa_enabled = true WHERE id = $1", + mfa.user_id as i32 + ) + .execute(&mut **db) + .await + .is_err() + { + return Err(status::Custom( + Status::InternalServerError, + "unable to enable 2fa", + )); + }; + + println!("enabled"); + + return Ok(()); } #[get("/totp.jpg")] @@ -101,7 +108,7 @@ impl<'r> FromRequest<'r> for TOTPSecret { async fn from_request(request: &'r Request<'_>) -> request::Outcome { let user = try_outcome!(request.guard::().await); - let mut pool = match request.guard::>().await { + let mut pool = match request.guard::>().await { Outcome::Success(pool) => pool, _ => return Outcome::Error((Status::Unauthorized, ())), }; @@ -141,7 +148,7 @@ impl<'r> FromRequest<'r> for TOTPSecret { } impl TOTPSecret { - pub async fn enable(&self, db: &mut Connection) -> Result<(), ()> { + pub async fn enable(&self, db: &mut Connection) -> Result<(), ()> { match sqlx::query!( "UPDATE users SET twofa_enabled = true WHERE id = $1", self.user_id as i32, diff --git a/backend/src/db.rs b/backend/src/db.rs index 2fbb5da..c863e21 100644 --- a/backend/src/db.rs +++ b/backend/src/db.rs @@ -1,5 +1,9 @@ -use rocket_db_pools::Database; +use rocket_db_pools::{Database, deadpool_redis}; #[derive(Database)] #[database("postgres_db")] -pub struct DbConn(sqlx::PgPool); +pub struct Postgres(sqlx::PgPool); + +#[derive(Database)] +#[database("redis_cache")] +pub struct Redis(deadpool_redis::Pool); diff --git a/backend/src/main.rs b/backend/src/main.rs index 49ca698..733757a 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -12,7 +12,7 @@ use rocket_dyn_templates::Template; use std::sync::Arc; use crate::auth::Session; -use crate::db::DbConn; +use crate::db::{Postgres, Redis}; use crate::messages::ChatBroadcaster; pub mod auth; @@ -23,7 +23,7 @@ pub mod llm; pub mod messages; #[get("/users", rank = 2)] -async fn users(_ag: Session, mut db: Connection) -> Json> { +async fn users(_ag: Session, mut db: Connection) -> Json> { sqlx::query!("SELECT id FROM users") .fetch_all(&mut **db) .await @@ -35,7 +35,7 @@ async fn users(_ag: Session, mut db: Connection) -> Json> { } #[get("/users/", rank = 1)] -async fn display_name(id: usize, _ag: Session, mut db: Connection) -> String { +async fn display_name(id: usize, _ag: Session, mut db: Connection) -> String { sqlx::query!( "SELECT display_name, username FROM users WHERE id = $1", id as i32 @@ -63,7 +63,8 @@ fn rocket() -> Rocket { rocket::build() .manage(chat) .attach(cors.to_cors().unwrap()) - .attach(DbConn::init()) + .attach(Postgres::init()) + .attach(Redis::init()) .attach(Template::fairing()) .mount("/static", FileServer::from("static")) .mount("/cdn", cdn::routes()) diff --git a/backend/src/messages.rs b/backend/src/messages.rs index 840add5..0a35fd8 100644 --- a/backend/src/messages.rs +++ b/backend/src/messages.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use redis::cmd; use rocket::{ Shutdown, response::stream::{Event, EventStream}, @@ -12,7 +13,11 @@ use serde::{Deserialize, Serialize}; use sqlx::prelude::FromRow; use tokio::{select, sync::broadcast}; -use crate::{auth::Session, db::DbConn, llm::LlmWorker}; +use crate::{ + auth::Session, + db::{Postgres, Redis}, + llm::LlmWorker, +}; /// ---------- shared broadcaster ---------- pub struct ChatBroadcaster { @@ -47,7 +52,7 @@ pub struct ChatMsg { pub async fn post_message( mut msg: Json, chat: &rocket::State>, - mut db: Connection, + mut db: Connection, session: Session, ) -> Result<(), String> { const CHANNEL_ID: i32 = 1; @@ -104,7 +109,7 @@ pub async fn post_message( } #[get("/messages")] -pub async fn get_messages(mut db: Connection, _session: Session) -> Json> { +pub async fn get_messages(mut db: Connection, _session: Session) -> Json> { Json( sqlx::query!( "SELECT u.username, u.display_name, u.id, m.content, m.created_at FROM messages m JOIN users u ON m.user_id = u.id ORDER BY m.created_at DESC LIMIT 100" @@ -128,7 +133,7 @@ pub async fn get_messages(mut db: Connection, _session: Session) -> Json #[get("/events")] pub async fn event_stream( chat: &rocket::State>, - db: Connection, + db: Connection, ag: Session, mut shutdown: Shutdown, ) -> EventStream![] { @@ -166,3 +171,25 @@ pub async fn chat_page(session: Session) -> Template { pub async fn chat_page_preview(session: Session) -> Template { Template::render("chatpreview", context!(user_id: session.user_id)) } + +pub struct UserCache {} + +impl UserCache { + pub async fn username(&mut self, id: usize, redis_conn: &mut Connection) -> String { + if let Ok(val) = cmd("GET") + .arg(&[format!("users:{id}")]) + .query_async(&mut **redis_conn) + .await + { + return val; + } + } + + pub async fn insert(id: usize, username: &str, conn: &mut Connection) { + cmd("SET") + .arg(&[format!("users:{id}"), username.to_owned()]) + .query_async(&mut **conn) + .await + .expect("failed to insert key") + } +} diff --git a/docker-compose.yml b/docker-compose.yml index a80cb18..e8a85bd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,6 @@ services: backend: + container_name: chatapp_backend image: git.zxq5.dev/zxq5/chatapp-backend:latest ports: - "8080:8000" @@ -9,6 +10,7 @@ services: - ROCKET_SECRET_KEY=${ROCKET_SECRET_KEY} - DATABASE_URL=${DATABASE_URL} redis: + container_name: chatapp_redis image: docker.io/library/redis:alpine ports: - "6379:6379"