diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 22207d7..14ae531 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -6,6 +6,7 @@ edition = "2024" [dependencies] argon2 = "0.5.3" chrono = { version = "0.4.42", features = ["serde"] } +dotenv = "0.15.0" futures-util = "0.3.31" image = "0.25.8" rand = "0.9.2" diff --git a/backend/src/llm.rs b/backend/src/llm.rs index 5a5388a..be47d9f 100644 --- a/backend/src/llm.rs +++ b/backend/src/llm.rs @@ -1,7 +1,7 @@ // src/llm.rs use serde::{Deserialize, Serialize}; -use crate::messages::ChatMsg; +use crate::messenger::ChatMsg; #[derive(Serialize)] struct LlmRequest { diff --git a/backend/src/main.rs b/backend/src/main.rs index a1a298c..90f9bbf 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -10,7 +10,8 @@ use rocket::{Build, Rocket}; use rocket_cors::{AllowedOrigins, CorsOptions}; use rocket_db_pools::{Connection, Database}; use rocket_dyn_templates::Template; -use std::sync::Arc; +use std::env; +use std::sync::{Arc, LazyLock}; use crate::auth::Session; use crate::db::{Postgres, Redis}; @@ -20,33 +21,18 @@ pub mod cdn; pub mod db; pub mod handlers; pub mod llm; -pub mod messages; +pub mod messenger; +pub mod user; -#[get("/users", rank = 2)] -async fn users(_ag: Session, mut db: Connection) -> Json> { - sqlx::query!("SELECT id FROM users") - .fetch_all(&mut **db) - .await - .unwrap_or_else(|_| Vec::new()) - .into_iter() - .map(|row| row.id) - .collect::>() - .into() -} - -#[get("/users/", rank = 1)] -async fn display_name( - id: usize, - _ag: Session, - mut pgsql_conn: Connection, - mut redis_conn: Connection, -) -> String { - UserCache::username(id, &mut redis_conn, &mut pgsql_conn).await -} +static LMSTUDIO_URL: LazyLock = + LazyLock::new(|| env::var("LMSTUDIO_URL").expect("Ensure LMSTUDIO_URL is set!")); #[launch] fn rocket() -> Rocket { - let chat = Arc::new(crate::messages::ChatBroadcaster::new(32)); + // make sure the env is loaded + dotenv::dotenv().expect("Failed to load env! aborting launch!"); + + let chat = Arc::new(crate::messenger::ChatBroadcaster::new(32)); let cors = CorsOptions::default() .allowed_origins(AllowedOrigins::all()) @@ -70,7 +56,7 @@ fn rocket() -> Rocket { "/", routes![ favicon, - messages::chat_page, + messenger::chat_page, auth::signup_page, auth::login_page, auth::mfa_page, @@ -81,11 +67,11 @@ fn rocket() -> Rocket { "/api", routes![ cdn::upload_profile_pic, - messages::get_messages, - messages::post_message, - messages::event_stream, - users, - display_name, + messenger::get_messages, + messenger::post_message, + messenger::event_stream, + user::users, + user::display_name, auth::signup, auth::login, auth::get_totp, @@ -107,45 +93,3 @@ fn rocket() -> Rocket { async fn favicon() -> NamedFile { NamedFile::open("static/favicon.ico").await.unwrap() } - -pub struct UserCache {} - -impl UserCache { - pub async fn username( - id: usize, - redis_conn: &mut Connection, - pgsql_conn: &mut Connection, - ) -> String { - if let Ok(val) = cmd("GET") - .arg(&[format!("users:{id}")]) - .query_async(&mut **redis_conn) - .await - { - return val; - } - - if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32) - .fetch_one(&mut ***pgsql_conn) - .await - { - let username = v.username; - Self::insert(id, &username, redis_conn).await; - username - } else { - unimplemented!() - } - } - - pub async fn insert(id: usize, username: &str, conn: &mut Connection) { - cmd("SET") - .arg(&[ - format!("users:{id}"), - username.to_string(), - "EX".to_string(), - "1800".to_string(), - ]) - .query_async(&mut **conn) - .await - .expect("failed to insert key") - } -} diff --git a/backend/src/messages/cache.rs b/backend/src/messenger/cache.rs similarity index 99% rename from backend/src/messages/cache.rs rename to backend/src/messenger/cache.rs index 445f06f..68c2d5e 100644 --- a/backend/src/messages/cache.rs +++ b/backend/src/messenger/cache.rs @@ -3,7 +3,7 @@ use rocket_db_pools::Connection; use crate::{ db::{Postgres, Redis}, - messages::ChatMsg, + messenger::ChatMsg, }; // Helper function to cache message in Redis diff --git a/backend/src/messages/messages.rs b/backend/src/messenger/messages.rs similarity index 68% rename from backend/src/messages/messages.rs rename to backend/src/messenger/messages.rs index 3723c70..bf0fffb 100644 --- a/backend/src/messages/messages.rs +++ b/backend/src/messenger/messages.rs @@ -1,13 +1,11 @@ use std::sync::Arc; -use redis::{AsyncCommands, cmd}; use rocket::{ Shutdown, response::stream::{Event, EventStream}, serde::json::Json, time::OffsetDateTime, }; -use rocket_cors::CorsOptions; use rocket_db_pools::Connection; use rocket_dyn_templates::{Template, context}; use serde::{Deserialize, Serialize}; @@ -22,21 +20,34 @@ use crate::{ /// ---------- shared broadcaster ---------- pub struct ChatBroadcaster { - sender: broadcast::Sender, + buffer_size: usize, + senders: std::sync::Mutex>>, } impl ChatBroadcaster { pub fn new(buffer_size: usize) -> Self { - let (sender, _rx) = broadcast::channel::(buffer_size); - Self { sender } + Self { + buffer_size, + senders: std::sync::Mutex::new(std::collections::HashMap::new()), + } } - pub async fn publish(&self, msg: ChatMsg) { - let _ = self.sender.send(msg); + /// Publish a message to the specified channel. + pub async fn publish(&self, channel_id: i32, msg: ChatMsg) { + let mut map = self.senders.lock().unwrap(); + let sender = map + .entry(channel_id) + .or_insert_with(|| broadcast::channel::(self.buffer_size).0); + let _ = sender.send(msg); } - pub fn subscribe(&self) -> broadcast::Receiver { - self.sender.subscribe() + /// Subscribe to the specified channel. + pub fn subscribe(&self, channel_id: i32) -> broadcast::Receiver { + let mut map = self.senders.lock().unwrap(); + let sender = map + .entry(channel_id) + .or_insert_with(|| broadcast::channel::(self.buffer_size).0); + sender.subscribe() } } @@ -49,18 +60,15 @@ pub struct ChatMsg { pub timestamp: usize, } -#[post("/chat", format = "json", data = "")] +#[post("/chat/", format = "json", data = "")] pub async fn post_message( mut msg: Json, chat: &rocket::State>, mut postgres: Connection, - mut cache: Connection, + mut cache: Option>, session: Session, + channel_id: i32, ) -> Result<(), String> { - const CHANNEL_ID: i32 = 1; - let channel_id = CHANNEL_ID; - const LMSTUDIO_URI: &'static str = "http://127.0.0.1:1234/v1/chat/completions"; - let chat = chat.inner().clone(); let display_name = sqlx::query!( @@ -74,11 +82,11 @@ pub async fn post_message( msg.user_id = session.user_id; msg.display_name = Some(display_name); - chat.publish(msg.clone().into_inner()).await; + chat.publish(channel_id, msg.clone().into_inner()).await; sqlx::query!( "INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)", - CHANNEL_ID, + channel_id, msg.user_id as i32, msg.text, OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap() @@ -87,22 +95,30 @@ pub async fn post_message( .await .map_err(|_| "Failed".to_string())?; - super::cache::insert(&mut cache, channel_id, &msg) - .await - .map_err(|_| "Redis cache failed".to_string())?; + if let Some(ref mut cache) = cache { + messenger::cache::insert(cache, channel_id, &msg) + .await + .map_err(|_| "Redis cache failed".to_string())?; + } + // get response tokio::spawn(async move { - let response = LlmWorker::new(LMSTUDIO_URI.to_string()).query(&msg).await; + let response = LlmWorker::new(crate::LMSTUDIO_URL.to_string()) + .query(&msg) + .await; if let Ok(reply) = response { - chat.publish(reply.clone()).await; - super::cache::insert(&mut cache, CHANNEL_ID, &reply) - .await - .ok(); + chat.publish(channel_id, reply.clone()).await; + + if let Some(ref mut cache) = cache { + messenger::cache::insert(cache, channel_id, &reply) + .await + .ok(); + } sqlx::query!( "INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)", - CHANNEL_ID, + channel_id, reply.user_id as i32, reply.text, OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap() @@ -126,17 +142,17 @@ pub async fn get_messages( const CHANNEL_ID: i32 = 1; let channel_id = CHANNEL_ID; - if let Ok(messages) = super::cache::get(&mut redis, channel_id).await + if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await && !messages.is_empty() { return Json(messages); }; - if let Err(x) = super::cache::initialise(&mut redis, &mut db, channel_id).await { + if let Err(x) = messenger::cache::initialise(&mut redis, &mut db, channel_id).await { eprintln!("WARN: {x:?}"); } - if let Ok(messages) = super::cache::get(&mut redis, channel_id).await + if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await && !messages.is_empty() { return Json(messages); @@ -165,15 +181,16 @@ pub async fn get_messages( Json(res) } -#[get("/events")] +#[get("/events/")] pub async fn event_stream( chat: &rocket::State>, postgres: Connection, cache: Connection, ag: Session, mut shutdown: Shutdown, + channel_id: i32, ) -> EventStream![] { - let mut rx = chat.subscribe(); + let mut rx = chat.subscribe(channel_id); EventStream! { // Initialize the stream with the last 100 messages @@ -202,8 +219,3 @@ pub async fn event_stream( pub async fn chat_page(session: Session) -> Template { Template::render("chat", context!(user_id: session.user_id)) } - -#[get("/chatpreview")] -pub async fn chat_page_preview(session: Session) -> Template { - Template::render("chatpreview", context!(user_id: session.user_id)) -} diff --git a/backend/src/messages/mod.rs b/backend/src/messenger/mod.rs similarity index 100% rename from backend/src/messages/mod.rs rename to backend/src/messenger/mod.rs diff --git a/backend/src/user.rs b/backend/src/user.rs new file mode 100644 index 0000000..bdb6d91 --- /dev/null +++ b/backend/src/user.rs @@ -0,0 +1,72 @@ +use redis::cmd; +use rocket::serde::json::Json; +use rocket_db_pools::Connection; + +use crate::{ + auth::Session, + db::{Postgres, Redis}, +}; + +#[get("/users", rank = 2)] +pub async fn users(_ag: Session, mut db: Connection) -> Json> { + sqlx::query!("SELECT id FROM users") + .fetch_all(&mut **db) + .await + .unwrap_or_else(|_| Vec::new()) + .into_iter() + .map(|row| row.id) + .collect::>() + .into() +} + +#[get("/users/", rank = 1)] +pub async fn display_name( + id: usize, + _ag: Session, + mut pgsql_conn: Connection, + mut redis_conn: Connection, +) -> String { + UserCache::username(id, &mut redis_conn, &mut pgsql_conn).await +} + +pub struct UserCache {} + +impl UserCache { + pub async fn username( + id: usize, + redis_conn: &mut Connection, + pgsql_conn: &mut Connection, + ) -> String { + if let Ok(val) = cmd("GET") + .arg(&[format!("users:{id}")]) + .query_async(&mut **redis_conn) + .await + { + return val; + } + + if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32) + .fetch_one(&mut ***pgsql_conn) + .await + { + let username = v.username; + Self::insert(id, &username, redis_conn).await; + username + } else { + unimplemented!() + } + } + + pub async fn insert(id: usize, username: &str, conn: &mut Connection) { + cmd("SET") + .arg(&[ + format!("users:{id}"), + username.to_string(), + "EX".to_string(), + "1800".to_string(), + ]) + .query_async(&mut **conn) + .await + .expect("failed to insert key") + } +}