From edc7567d15b8c34ecee663e1760583613cbf014c Mon Sep 17 00:00:00 2001 From: zxq5 Date: Thu, 9 Oct 2025 01:12:08 +0100 Subject: [PATCH] progress --- backend/src/auth.rs | 6 +-- backend/src/db.rs | 5 +++ backend/src/llm.rs | 3 +- backend/src/main.rs | 16 ++++---- backend/src/messages.rs | 37 +++++++++++------- backend/templates/chat.html.tera | 66 ++++++++++++++++---------------- 6 files changed, 73 insertions(+), 60 deletions(-) create mode 100644 backend/src/db.rs diff --git a/backend/src/auth.rs b/backend/src/auth.rs index 890f77d..0d91ee6 100644 --- a/backend/src/auth.rs +++ b/backend/src/auth.rs @@ -10,16 +10,14 @@ use rocket::{ serde::json::Json, }; use rocket_db_pools::{ - Connection, Database, + Connection, sqlx::{self}, }; use rocket_dyn_templates::{Template, context}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -#[derive(Database)] -#[database("postgres_db")] -pub struct DbConn(sqlx::PgPool); +use crate::db::DbConn; #[derive(Serialize, Deserialize)] pub struct UserCredentials { diff --git a/backend/src/db.rs b/backend/src/db.rs new file mode 100644 index 0000000..2fbb5da --- /dev/null +++ b/backend/src/db.rs @@ -0,0 +1,5 @@ +use rocket_db_pools::Database; + +#[derive(Database)] +#[database("postgres_db")] +pub struct DbConn(sqlx::PgPool); diff --git a/backend/src/llm.rs b/backend/src/llm.rs index 792f2c4..52d04a2 100644 --- a/backend/src/llm.rs +++ b/backend/src/llm.rs @@ -57,9 +57,10 @@ impl LlmWorker { let llm_resp: LlmResponse = resp.json().await.unwrap(); Ok(ChatMsg { + display_name: message.display_name.clone(), user_id: 0, text: llm_resp.choices[0].message.content.clone(), - timestamp: chrono::Local::now().timestamp() as usize, + timestamp: chrono::Utc::now().timestamp() as usize, }) } } diff --git a/backend/src/main.rs b/backend/src/main.rs index a24bb90..4b7faf8 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -2,25 +2,22 @@ #[macro_use] extern crate rocket; -use rocket::fairing::Fairing; use rocket::fs::FileServer; use rocket::http::Method; -use rocket::response::stream::{Event, EventStream}; use rocket::serde::json::Json; use rocket::{Build, Rocket}; use rocket_cors::{AllowedOrigins, CorsOptions}; use rocket_db_pools::{Connection, Database}; -use rocket_dyn_templates::{Template, context}; -use serde::{Deserialize, Serialize}; +use rocket_dyn_templates::Template; use std::sync::Arc; -use tokio::sync::broadcast; -use crate::auth::{AuthGuard, DbConn}; -use crate::llm::LlmWorker; +use crate::auth::AuthGuard; +use crate::db::DbConn; use crate::messages::ChatBroadcaster; pub mod auth; pub mod cdn; +pub mod db; pub mod llm; pub mod messages; @@ -29,8 +26,10 @@ async fn users(_ag: AuthGuard, mut db: Connection) -> Json> { sqlx::query!("SELECT id FROM users") .fetch_all(&mut **db) .await - .map(|rows| rows.into_iter().map(|row| row.id).collect()) .unwrap_or_else(|_| Vec::new()) + .into_iter() + .map(|row| row.id) + .collect::>() .into() } @@ -43,7 +42,6 @@ async fn username_for_id(id: usize, _ag: AuthGuard, mut db: Connection) .unwrap_or_else(|_| "User not found".to_string()) } -/// ---------- launch ---------- #[launch] fn rocket() -> Rocket { let chat = Arc::new(ChatBroadcaster::new(32)); diff --git a/backend/src/messages.rs b/backend/src/messages.rs index dd54d09..8796553 100644 --- a/backend/src/messages.rs +++ b/backend/src/messages.rs @@ -3,16 +3,15 @@ use std::sync::Arc; use rocket::{ response::stream::{Event, EventStream}, serde::json::Json, + time::OffsetDateTime, }; use rocket_db_pools::Connection; use rocket_dyn_templates::{Template, context}; use serde::{Deserialize, Serialize}; +use sqlx::prelude::FromRow; use tokio::sync::broadcast; -use crate::{ - auth::{AuthGuard, DbConn}, - llm::LlmWorker, -}; +use crate::{auth::AuthGuard, db::DbConn, llm::LlmWorker}; /// ---------- shared broadcaster ---------- pub struct ChatBroadcaster { @@ -35,8 +34,9 @@ impl ChatBroadcaster { } /// ---------- Rocket routes ---------- -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, FromRow)] pub struct ChatMsg { + pub display_name: Option, pub user_id: usize, pub text: String, pub timestamp: usize, @@ -75,10 +75,11 @@ pub async fn post_message( chat.publish(message.clone()).await; sqlx::query!( - "INSERT INTO messages (channel_id, user_id, content) VALUES ($1, $2, $3)", + "INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)", CHANNEL_ID, message.user_id as i32, - message.text + message.text, + OffsetDateTime::from_unix_timestamp(message.timestamp as i64).unwrap() ) .execute(&mut **db) .await @@ -94,35 +95,43 @@ pub async fn post_message( pub async fn get_messages(mut db: Connection, _ag: AuthGuard) -> Json> { Json( sqlx::query!( - "SELECT user_id, content, created_at FROM messages ORDER BY created_at DESC LIMIT 100" + "SELECT u.username, u.display_name, u.id, m.content, m.created_at FROM messages m JOIN users u ON m.user_id = u.id ORDER BY m.created_at DESC LIMIT 100" ) .fetch_all(&mut **db) .await .unwrap_or_else(|_| Vec::new()) .into_iter() .rev() - .map(|row| ChatMsg { - user_id: row.user_id as usize, - text: row.content, - timestamp: row.created_at.unwrap().unix_timestamp() as usize, + .map(|msg| ChatMsg { + display_name: Some(msg.display_name.unwrap_or(msg.username)), + user_id: msg.id as usize, + text: msg.content, + timestamp: msg.created_at.unwrap().unix_timestamp() as usize, }) .collect(), ) + // Json(vec![]) } #[get("/events")] pub async fn event_stream( chat: &rocket::State>, - _ag: AuthGuard, + db: Connection, + ag: AuthGuard, ) -> EventStream![] { let mut rx = chat.subscribe(); EventStream! { + // Initialize the stream with the last 100 messages + for msg in get_messages(db, ag).await.0 { + yield Event::json(&msg); + } + loop { match rx.recv().await { Ok(msg) => yield Event::json(&msg), Err(broadcast::error::RecvError::Lagged(_)) => { - yield Event::comment("lagged"); + yield Event::comment("RecvError::Lagged"); } Err(broadcast::error::RecvError::Closed) => break, } diff --git a/backend/templates/chat.html.tera b/backend/templates/chat.html.tera index 27aa422..dca6654 100644 --- a/backend/templates/chat.html.tera +++ b/backend/templates/chat.html.tera @@ -137,28 +137,21 @@ const messagesContainer = document.querySelector( ".messages-container", ); - const messageSource = new EventSource("http://localhost:8000/events"); function insertMessage(message) { - const date = new Date(message.timestamp).toLocaleTimeString("en-US", { + const date = new Date(message.timestamp).toLocaleTimeString("en-GB", { hour: "numeric", minute: "2-digit", hour12: true, }); - console.log(users, message); - - const uid = message.user_id; - const uname = users[`${uid}`]; - - console.log(users, uid, uname); const messageEl = document.createElement("div"); messageEl.className = "message"; messageEl.innerHTML = ` - +
- ${uname} + ${message.display_name} ${date}
${message.text}
@@ -169,11 +162,9 @@ messagesContainer.scrollHeight; } - messageSource.onmessage = (event) => insertMessage(JSON.parse(event.data)); - function getCurrentTime() { const now = new Date(); - return now.toLocaleTimeString("en-US", { + return now.toLocaleTimeString("en-GB", { hour: "numeric", minute: "2-digit", hour12: true, @@ -186,7 +177,7 @@ fetch("http://localhost:8000/chat", { method: "POST", body: JSON.stringify({ - userid: user_id, + user_id: user_id, text: message, timestamp: new Date().getTime(), }), @@ -205,26 +196,37 @@ } }); - // get previous messages - fetch("http://localhost:8000/messages") - .then(response => response.json()) - .then(messages => { - messages.forEach(message => { - insertMessage(message); - }); - }); + async function loadData() { + try { + const userIds = await fetch("http://localhost:8000/users/") + .then(r => r.json()); - fetch("http://localhost:8000/users/") - .then(response => response.json()) - .then(items => { - items.forEach(user => { - fetch(`http://localhost:8000/users/${user}`) - .then(response => response.text()) - .then(username => { - users[user] = username; - }); + const userPromises = userIds.map(userId => + fetch(`http://localhost:8000/users/${userId}`) + .then(r => r.text()) + .then(username => ({ userId, username })) + ); + + const userData = await Promise.all(userPromises); + + userData.forEach(({ userId, username }) => { + users[userId] = username; }); - }); + + console.log('Users loaded:', users); + + const messageSource = new EventSource("http://localhost:8000/events"); + messageSource.onmessage = (event) => insertMessage(JSON.parse(event.data)); + messageSource.onerror = (error) => { + console.error('EventSource error:', error); + }; + + } catch (error) { + console.error('Error loading data:', error); + } + } + + loadData();