diff --git a/backend/src/auth.rs b/backend/src/auth.rs index 0d91ee6..11f4e3a 100644 --- a/backend/src/auth.rs +++ b/backend/src/auth.rs @@ -16,6 +16,7 @@ use rocket_db_pools::{ use rocket_dyn_templates::{Template, context}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; +use sqlx::postgres::PgQueryResult; use crate::db::DbConn; @@ -25,6 +26,11 @@ pub struct UserCredentials { pub password: String, } +#[get("/signup")] +pub async fn signup_page() -> Template { + Template::render("signup", context!()) +} + #[post("/signup", data = "")] pub async fn signup( cred: Json, @@ -40,16 +46,8 @@ pub async fn signup( .await .map_err(|e| e.to_string())?; - let session = SessionToken::new(result.id as usize); - let result = sqlx::query!( - "INSERT INTO sessions (user_id, token) VALUES ($1, $2)", - result.id, - session.token, - ) - .execute(&mut **db) - .await; - - if let Err(e) = result { + let session = Session::new(result.id as usize); + if let Err(e) = session.commit(&mut db).await { eprintln!("Failed to create session: {}", e); return Err(e.to_string()); } @@ -60,45 +58,70 @@ pub async fn signup( Ok(Json("Signup successful".to_string())) } -#[get("/signup")] -pub async fn signup_page() -> Template { - Template::render("signup", context!()) +#[get("/login")] +pub async fn login_page() -> Template { + Template::render("login", context!()) } #[post("/login", data = "")] pub async fn login( - conn: Connection, + mut db: Connection, + jar: &CookieJar<'_>, cred: Json, ) -> Result, String> { + if let Ok(row) = sqlx::query!( + "SELECT id FROM users WHERE username = $1 AND password = $2", + cred.username, + cred.password, + ) + .fetch_one(&mut **db) + .await + { + let session = Session::new(row.id as usize); + if let Err(e) = session.commit(&mut db).await { + eprintln!("Failed to create session: {}", e); + return Err(e.to_string()); + } + + jar.add_private(("session", session.token)); + return Ok(Json("Signup successful".to_string())); + } + // TODO: implement actual login logic, e.g. verify password and generate token - Ok(Json("Login successful".to_string())) + Err("login failed".to_string()) } -pub struct SessionToken { +#[derive(Debug)] +pub struct Session { pub token: String, pub user_id: usize, } -impl SessionToken { +impl Session { pub fn new(user_id: usize) -> Self { let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); let random: u32 = rand::rng().random(); let token = format!("{}-{}", current_time.as_secs(), random); let hashed = format!("{:x}", Sha256::digest(token.as_bytes())); - SessionToken { + Self { token: hashed, user_id, } } + + pub async fn commit(&self, db: &mut Connection) -> Result { + sqlx::query!( + "INSERT INTO sessions (user_id, token) VALUES ($1, $2)", + self.user_id as i32, + self.token, + ) + .execute(&mut ***db) + .await + } } -type UserID = usize; - -#[derive(Debug)] -pub struct AuthGuard(pub UserID); - #[rocket::async_trait] -impl<'r> FromRequest<'r> for AuthGuard { +impl<'r> FromRequest<'r> for Session { type Error = (); async fn from_request(request: &'r Request<'_>) -> request::Outcome { @@ -110,16 +133,18 @@ impl<'r> FromRequest<'r> for AuthGuard { let value = c.value(); let result = sqlx::query!( - "SELECT user_id FROM sessions WHERE token = $1 AND expires_at > NOW()", + "SELECT user_id, token FROM sessions WHERE token = $1 AND expires_at > NOW()", value ) .fetch_optional(&mut **pool) .await .expect("query failed!"); - if let Some(token) = result { - let user_id = token.user_id; - Outcome::Success(AuthGuard(user_id as usize)) + if let Some(session) = result { + Outcome::Success(Self { + user_id: session.user_id as usize, + token: session.token, + }) } else { Outcome::Error((Status::Unauthorized, ())) } diff --git a/backend/src/llm.rs b/backend/src/llm.rs index 52d04a2..7fda239 100644 --- a/backend/src/llm.rs +++ b/backend/src/llm.rs @@ -42,7 +42,7 @@ impl LlmWorker { .json(&payload) .send() .await - .unwrap(); + .map_err(|_| String::from("Failed to make request to LLM server"))?; // The API returns a JSON with `choices[].message.content` #[derive(Deserialize)] @@ -54,7 +54,10 @@ impl LlmWorker { message: Message, } - let llm_resp: LlmResponse = resp.json().await.unwrap(); + let llm_resp: LlmResponse = resp + .json() + .await + .map_err(|_| String::from("Failed to make request to LLM server"))?; Ok(ChatMsg { display_name: message.display_name.clone(), diff --git a/backend/src/main.rs b/backend/src/main.rs index 4b7faf8..0643a54 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -11,7 +11,7 @@ use rocket_db_pools::{Connection, Database}; use rocket_dyn_templates::Template; use std::sync::Arc; -use crate::auth::AuthGuard; +use crate::auth::Session; use crate::db::DbConn; use crate::messages::ChatBroadcaster; @@ -22,7 +22,7 @@ pub mod llm; pub mod messages; #[get("/users", rank = 2)] -async fn users(_ag: AuthGuard, mut db: Connection) -> Json> { +async fn users(_ag: Session, mut db: Connection) -> Json> { sqlx::query!("SELECT id FROM users") .fetch_all(&mut **db) .await @@ -34,7 +34,7 @@ async fn users(_ag: AuthGuard, mut db: Connection) -> Json> { } #[get("/users/", rank = 1)] -async fn username_for_id(id: usize, _ag: AuthGuard, mut db: Connection) -> String { +async fn username_for_id(id: usize, _ag: Session, mut db: Connection) -> String { sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32) .fetch_one(&mut **db) .await @@ -74,6 +74,7 @@ fn rocket() -> Rocket { messages::event_stream, auth::signup, auth::signup_page, + auth::login_page, auth::login ], ) diff --git a/backend/src/messages.rs b/backend/src/messages.rs index 8796553..8062a68 100644 --- a/backend/src/messages.rs +++ b/backend/src/messages.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; use sqlx::prelude::FromRow; use tokio::sync::broadcast; -use crate::{auth::AuthGuard, db::DbConn, llm::LlmWorker}; +use crate::{auth::Session, db::DbConn, llm::LlmWorker}; /// ---------- shared broadcaster ---------- pub struct ChatBroadcaster { @@ -47,14 +47,14 @@ pub async fn post_message( mut msg: Json, chat: &rocket::State>, mut db: Connection, - ag: AuthGuard, + session: Session, ) -> Result<(), String> { const CHANNEL_ID: i32 = 1; const LMSTUDIO_URI: &'static str = "http://127.0.0.1:1234/v1/chat/completions"; let chat = chat.inner().clone(); - msg.user_id = ag.0; + msg.user_id = session.user_id; chat.publish(msg.clone().into_inner()).await; sqlx::query!( @@ -92,7 +92,7 @@ pub async fn post_message( } #[get("/messages")] -pub async fn get_messages(mut db: Connection, _ag: AuthGuard) -> Json> { +pub async fn get_messages(mut db: Connection, _session: Session) -> Json> { Json( sqlx::query!( "SELECT u.username, u.display_name, u.id, m.content, m.created_at FROM messages m JOIN users u ON m.user_id = u.id ORDER BY m.created_at DESC LIMIT 100" @@ -117,7 +117,7 @@ pub async fn get_messages(mut db: Connection, _ag: AuthGuard) -> Json>, db: Connection, - ag: AuthGuard, + ag: Session, ) -> EventStream![] { let mut rx = chat.subscribe(); @@ -140,6 +140,6 @@ pub async fn event_stream( } #[get("/")] -pub async fn chat_page(ag: AuthGuard) -> Template { - Template::render("chat", context!(user_id: ag.0)) +pub async fn chat_page(session: Session) -> Template { + Template::render("chat", context!(user_id: session.user_id)) } diff --git a/backend/static/css/index.css b/backend/static/css/index.css index e9e4efa..4a7ed54 100644 --- a/backend/static/css/index.css +++ b/backend/static/css/index.css @@ -382,7 +382,6 @@ body { background-size: cover; border: 2px solid #252525; flex-shrink: 0; - background-image: url("static/profile_pics/default.jpg"); } .user-avatar.blue { diff --git a/backend/templates/login.html.tera b/backend/templates/login.html.tera new file mode 100644 index 0000000..a4a1c0e --- /dev/null +++ b/backend/templates/login.html.tera @@ -0,0 +1,159 @@ + + + + + + Discord Clone - Sign Up + + + + + + + +