added login

This commit is contained in:
2025-10-09 10:05:59 +01:00
parent edc7567d15
commit 4f72cc6abe
6 changed files with 228 additions and 41 deletions
+53 -28
View File
@@ -16,6 +16,7 @@ use rocket_db_pools::{
use rocket_dyn_templates::{Template, context}; use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use sqlx::postgres::PgQueryResult;
use crate::db::DbConn; use crate::db::DbConn;
@@ -25,6 +26,11 @@ pub struct UserCredentials {
pub password: String, pub password: String,
} }
#[get("/signup")]
pub async fn signup_page() -> Template {
Template::render("signup", context!())
}
#[post("/signup", data = "<cred>")] #[post("/signup", data = "<cred>")]
pub async fn signup( pub async fn signup(
cred: Json<UserCredentials>, cred: Json<UserCredentials>,
@@ -40,16 +46,8 @@ pub async fn signup(
.await .await
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
let session = SessionToken::new(result.id as usize); let session = Session::new(result.id as usize);
let result = sqlx::query!( if let Err(e) = session.commit(&mut db).await {
"INSERT INTO sessions (user_id, token) VALUES ($1, $2)",
result.id,
session.token,
)
.execute(&mut **db)
.await;
if let Err(e) = result {
eprintln!("Failed to create session: {}", e); eprintln!("Failed to create session: {}", e);
return Err(e.to_string()); return Err(e.to_string());
} }
@@ -60,45 +58,70 @@ pub async fn signup(
Ok(Json("Signup successful".to_string())) Ok(Json("Signup successful".to_string()))
} }
#[get("/signup")] #[get("/login")]
pub async fn signup_page() -> Template { pub async fn login_page() -> Template {
Template::render("signup", context!()) Template::render("login", context!())
} }
#[post("/login", data = "<cred>")] #[post("/login", data = "<cred>")]
pub async fn login( pub async fn login(
conn: Connection<DbConn>, mut db: Connection<DbConn>,
jar: &CookieJar<'_>,
cred: Json<UserCredentials>, cred: Json<UserCredentials>,
) -> Result<Json<String>, String> { ) -> Result<Json<String>, String> {
if let Ok(row) = sqlx::query!(
"SELECT id FROM users WHERE username = $1 AND password = $2",
cred.username,
cred.password,
)
.fetch_one(&mut **db)
.await
{
let session = Session::new(row.id as usize);
if let Err(e) = session.commit(&mut db).await {
eprintln!("Failed to create session: {}", e);
return Err(e.to_string());
}
jar.add_private(("session", session.token));
return Ok(Json("Signup successful".to_string()));
}
// TODO: implement actual login logic, e.g. verify password and generate token // TODO: implement actual login logic, e.g. verify password and generate token
Ok(Json("Login successful".to_string())) Err("login failed".to_string())
} }
pub struct SessionToken { #[derive(Debug)]
pub struct Session {
pub token: String, pub token: String,
pub user_id: usize, pub user_id: usize,
} }
impl SessionToken { impl Session {
pub fn new(user_id: usize) -> Self { pub fn new(user_id: usize) -> Self {
let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
let random: u32 = rand::rng().random(); let random: u32 = rand::rng().random();
let token = format!("{}-{}", current_time.as_secs(), random); let token = format!("{}-{}", current_time.as_secs(), random);
let hashed = format!("{:x}", Sha256::digest(token.as_bytes())); let hashed = format!("{:x}", Sha256::digest(token.as_bytes()));
SessionToken { Self {
token: hashed, token: hashed,
user_id, user_id,
} }
} }
pub async fn commit(&self, db: &mut Connection<DbConn>) -> Result<PgQueryResult, sqlx::Error> {
sqlx::query!(
"INSERT INTO sessions (user_id, token) VALUES ($1, $2)",
self.user_id as i32,
self.token,
)
.execute(&mut ***db)
.await
}
} }
type UserID = usize;
#[derive(Debug)]
pub struct AuthGuard(pub UserID);
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> FromRequest<'r> for AuthGuard { impl<'r> FromRequest<'r> for Session {
type Error = (); type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
@@ -110,16 +133,18 @@ impl<'r> FromRequest<'r> for AuthGuard {
let value = c.value(); let value = c.value();
let result = sqlx::query!( let result = sqlx::query!(
"SELECT user_id FROM sessions WHERE token = $1 AND expires_at > NOW()", "SELECT user_id, token FROM sessions WHERE token = $1 AND expires_at > NOW()",
value value
) )
.fetch_optional(&mut **pool) .fetch_optional(&mut **pool)
.await .await
.expect("query failed!"); .expect("query failed!");
if let Some(token) = result { if let Some(session) = result {
let user_id = token.user_id; Outcome::Success(Self {
Outcome::Success(AuthGuard(user_id as usize)) user_id: session.user_id as usize,
token: session.token,
})
} else { } else {
Outcome::Error((Status::Unauthorized, ())) Outcome::Error((Status::Unauthorized, ()))
} }
+5 -2
View File
@@ -42,7 +42,7 @@ impl LlmWorker {
.json(&payload) .json(&payload)
.send() .send()
.await .await
.unwrap(); .map_err(|_| String::from("Failed to make request to LLM server"))?;
// The API returns a JSON with `choices[].message.content` // The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -54,7 +54,10 @@ impl LlmWorker {
message: Message, message: Message,
} }
let llm_resp: LlmResponse = resp.json().await.unwrap(); let llm_resp: LlmResponse = resp
.json()
.await
.map_err(|_| String::from("Failed to make request to LLM server"))?;
Ok(ChatMsg { Ok(ChatMsg {
display_name: message.display_name.clone(), display_name: message.display_name.clone(),
+4 -3
View File
@@ -11,7 +11,7 @@ use rocket_db_pools::{Connection, Database};
use rocket_dyn_templates::Template; use rocket_dyn_templates::Template;
use std::sync::Arc; use std::sync::Arc;
use crate::auth::AuthGuard; use crate::auth::Session;
use crate::db::DbConn; use crate::db::DbConn;
use crate::messages::ChatBroadcaster; use crate::messages::ChatBroadcaster;
@@ -22,7 +22,7 @@ pub mod llm;
pub mod messages; pub mod messages;
#[get("/users", rank = 2)] #[get("/users", rank = 2)]
async fn users(_ag: AuthGuard, mut db: Connection<DbConn>) -> Json<Vec<i32>> { async fn users(_ag: Session, mut db: Connection<DbConn>) -> Json<Vec<i32>> {
sqlx::query!("SELECT id FROM users") sqlx::query!("SELECT id FROM users")
.fetch_all(&mut **db) .fetch_all(&mut **db)
.await .await
@@ -34,7 +34,7 @@ async fn users(_ag: AuthGuard, mut db: Connection<DbConn>) -> Json<Vec<i32>> {
} }
#[get("/users/<id>", rank = 1)] #[get("/users/<id>", rank = 1)]
async fn username_for_id(id: usize, _ag: AuthGuard, mut db: Connection<DbConn>) -> String { async fn username_for_id(id: usize, _ag: Session, mut db: Connection<DbConn>) -> String {
sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32) sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
.fetch_one(&mut **db) .fetch_one(&mut **db)
.await .await
@@ -74,6 +74,7 @@ fn rocket() -> Rocket<Build> {
messages::event_stream, messages::event_stream,
auth::signup, auth::signup,
auth::signup_page, auth::signup_page,
auth::login_page,
auth::login auth::login
], ],
) )
+7 -7
View File
@@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};
use sqlx::prelude::FromRow; use sqlx::prelude::FromRow;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use crate::{auth::AuthGuard, db::DbConn, llm::LlmWorker}; use crate::{auth::Session, db::DbConn, llm::LlmWorker};
/// ---------- shared broadcaster ---------- /// ---------- shared broadcaster ----------
pub struct ChatBroadcaster { pub struct ChatBroadcaster {
@@ -47,14 +47,14 @@ pub async fn post_message(
mut msg: Json<ChatMsg>, mut msg: Json<ChatMsg>,
chat: &rocket::State<Arc<ChatBroadcaster>>, chat: &rocket::State<Arc<ChatBroadcaster>>,
mut db: Connection<DbConn>, mut db: Connection<DbConn>,
ag: AuthGuard, session: Session,
) -> Result<(), String> { ) -> Result<(), String> {
const CHANNEL_ID: i32 = 1; const CHANNEL_ID: i32 = 1;
const LMSTUDIO_URI: &'static str = "http://127.0.0.1:1234/v1/chat/completions"; const LMSTUDIO_URI: &'static str = "http://127.0.0.1:1234/v1/chat/completions";
let chat = chat.inner().clone(); let chat = chat.inner().clone();
msg.user_id = ag.0; msg.user_id = session.user_id;
chat.publish(msg.clone().into_inner()).await; chat.publish(msg.clone().into_inner()).await;
sqlx::query!( sqlx::query!(
@@ -92,7 +92,7 @@ pub async fn post_message(
} }
#[get("/messages")] #[get("/messages")]
pub async fn get_messages(mut db: Connection<DbConn>, _ag: AuthGuard) -> Json<Vec<ChatMsg>> { pub async fn get_messages(mut db: Connection<DbConn>, _session: Session) -> Json<Vec<ChatMsg>> {
Json( Json(
sqlx::query!( sqlx::query!(
"SELECT u.username, u.display_name, u.id, m.content, m.created_at FROM messages m JOIN users u ON m.user_id = u.id ORDER BY m.created_at DESC LIMIT 100" "SELECT u.username, u.display_name, u.id, m.content, m.created_at FROM messages m JOIN users u ON m.user_id = u.id ORDER BY m.created_at DESC LIMIT 100"
@@ -117,7 +117,7 @@ pub async fn get_messages(mut db: Connection<DbConn>, _ag: AuthGuard) -> Json<Ve
pub async fn event_stream( pub async fn event_stream(
chat: &rocket::State<Arc<ChatBroadcaster>>, chat: &rocket::State<Arc<ChatBroadcaster>>,
db: Connection<DbConn>, db: Connection<DbConn>,
ag: AuthGuard, ag: Session,
) -> EventStream![] { ) -> EventStream![] {
let mut rx = chat.subscribe(); let mut rx = chat.subscribe();
@@ -140,6 +140,6 @@ pub async fn event_stream(
} }
#[get("/")] #[get("/")]
pub async fn chat_page(ag: AuthGuard) -> Template { pub async fn chat_page(session: Session) -> Template {
Template::render("chat", context!(user_id: ag.0)) Template::render("chat", context!(user_id: session.user_id))
} }
-1
View File
@@ -382,7 +382,6 @@ body {
background-size: cover; background-size: cover;
border: 2px solid #252525; border: 2px solid #252525;
flex-shrink: 0; flex-shrink: 0;
background-image: url("static/profile_pics/default.jpg");
} }
.user-avatar.blue { .user-avatar.blue {
+159
View File
@@ -0,0 +1,159 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Discord Clone - Sign Up</title>
<link rel="stylesheet" href="static/css/index.css"/>
</head>
<body>
<div class="signup-container">
<div class="signup-header">
<div class="logo">DC</div>
<h1>Login</h1>
<p>Enter the chat</p>
</div>
<div class="signup-form">
<div class="success-message" id="successMessage">
Login successful, Redirecting
</div>
<div class="form-group">
<label for="username">Username</label>
<input
type="text"
id="username"
name="username"
placeholder="Enter your username"
required
/>
<div class="error-message">Username is required</div>
</div>
<div class="form-group">
<label for="password">Password</label>
<div class="input-wrapper">
<input
type="password"
id="password"
name="password"
placeholder="Enter your password"
required
/>
<button
type="button"
class="password-toggle"
id="passwordToggle"
>
SHOW
</button>
</div>
<div class="error-message">
Password must be at least 8 characters
</div>
</div>
<button type="button" class="submit-button" id="submitButton">
Login
</button>
</div>
<div class="signup-footer">
Already have an account? <a href="/login">Log in</a>
</div>
</div>
<script>
const form = {
username: document.getElementById("username"),
password: document.getElementById("password"),
};
const submitButton = document.getElementById("submitButton");
const successMessage = document.getElementById("successMessage");
// Password toggle functionality
const passwordToggle = document.getElementById("passwordToggle");
passwordToggle.addEventListener("click", function () {
const type =
form.password.type === "password" ? "text" : "password";
form.password.type = type;
this.textContent = type === "password" ? "SHOW" : "HIDE";
});
function toggleError(input, hasError) {
const formGroup = input.closest(".form-group");
if (hasError) {
formGroup.classList.add("error");
} else {
formGroup.classList.remove("error");
}
}
// Form submission
submitButton.addEventListener("click", async function () {
// Disable button and show loading
submitButton.disabled = true;
submitButton.innerHTML =
'<span class="loading-spinner"></span>Creating Account...';
// Prepare data
const formData = {
username: form.username.value.trim(),
password: form.password.value,
};
try {
// Replace with your actual backend endpoint
const response = await fetch(
"http://localhost:8000/login",
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(formData),
},
);
if (response.ok) {
// Show success message
successMessage.classList.add("show");
submitButton.innerHTML = "Logged in!!";
// Optional: Redirect after success
setTimeout(() => {
// window.location.href = '/chat';
console.log("Redirecting to chat...");
}, 2000);
} else {
const error = await response.json();
throw new Error(error.message || "Login failed");
}
} catch (error) {
console.error("Login error:", error);
alert(
error.message ||
"Failed to login. Please try again.",
);
submitButton.disabled = false;
submitButton.innerHTML = "Login";
}
});
// Allow Enter key to submit
Object.values(form).forEach((input) => {
if (input.tagName === "INPUT") {
input.addEventListener("keypress", function (e) {
if (e.key === "Enter") {
submitButton.click();
}
});
}
});
</script>
</body>
</html>