more progress on TOTP/2FA

This commit is contained in:
2025-10-10 01:45:02 +01:00
parent b13cb5086a
commit 4a6c3bc49c
12 changed files with 189 additions and 197 deletions
+2 -2
View File
@@ -1,6 +1,6 @@
[debug] [debug]
secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU=" secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU="
address = "127.0.0.1" address = "0.0.0.0"
port = 8000 port = 8000
[default.databases.postgres_db] [default.databases.postgres_db]
@@ -8,4 +8,4 @@ url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp"
[default] # run inside a docker container or pod [default] # run inside a docker container or pod
address = "0.0.0.0" address = "0.0.0.0"
port = 8082 port = 8000
+4 -14
View File
@@ -1,15 +1,5 @@
-- Add migration script here -- Add migration script here
DROP TABLE users; ALTER TABLE users ADD COLUMN email VARCHAR(100);
ALTER TABLE users ADD COLUMN twofa_enabled BOOLEAN DEFAULT FALSE;
CREATE TABLE users ( ALTER TABLE users ADD COLUMN totp_secret VARCHAR(64);
id SERIAL PRIMARY KEY, ALTER TABLE users ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP;
username VARCHAR(50) UNIQUE NOT NULL,
display_name VARCHAR(50),
email VARCHAR(100) NOT NULL,
password VARCHAR(100) NOT NULL,
totp_secret VARCHAR(100),
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
@@ -0,0 +1,2 @@
-- Add migration script here
ALTER TABLE users ALTER COLUMN twofa_enabled SET NOT NULL;
+74 -32
View File
@@ -5,9 +5,10 @@ use rocket::{
Request, Request,
fs::NamedFile, fs::NamedFile,
http::{CookieJar, Status}, http::{CookieJar, Status},
outcome::Outcome, outcome::{Outcome, try_outcome},
post, post,
request::{self, FromRequest}, request::{self, FromRequest},
response::Redirect,
serde::json::Json, serde::json::Json,
}; };
use rocket_db_pools::{ use rocket_db_pools::{
@@ -38,7 +39,7 @@ pub async fn signup(
cred: Json<UserCredentials>, cred: Json<UserCredentials>,
jar: &CookieJar<'_>, jar: &CookieJar<'_>,
mut db: Connection<DbConn>, mut db: Connection<DbConn>,
) -> Result<Json<String>, String> { ) -> Result<Redirect, Status> {
let result = sqlx::query!( let result = sqlx::query!(
"INSERT INTO users (username, password) VALUES ($1, $2) RETURNING id", "INSERT INTO users (username, password) VALUES ($1, $2) RETURNING id",
cred.username, cred.username,
@@ -46,18 +47,18 @@ pub async fn signup(
) )
.fetch_one(&mut **db) .fetch_one(&mut **db)
.await .await
.map_err(|e| e.to_string())?; .map_err(|e| Status::InternalServerError)?;
let session = Session::new(result.id as usize); let session = Session::new(result.id as usize);
if let Err(e) = session.commit(&mut db).await { if let Err(e) = session.commit(&mut db).await {
eprintln!("Failed to create session: {}", e); eprintln!("Failed to create session: {}", e);
return Err(e.to_string()); return Err(Status::InternalServerError);
} }
jar.add_private(("session", session.token)); jar.add_private(("session", session.token));
println!("Signup successful"); println!("Signup successful");
Ok(Json("Signup successful".to_string())) return Ok(Redirect::to("/chat"));
} }
#[get("/login")] #[get("/login")]
@@ -70,7 +71,7 @@ pub async fn login(
mut db: Connection<DbConn>, mut db: Connection<DbConn>,
jar: &CookieJar<'_>, jar: &CookieJar<'_>,
cred: Json<UserCredentials>, cred: Json<UserCredentials>,
) -> Result<Json<String>, String> { ) -> Result<Redirect, Status> {
if let Ok(row) = sqlx::query!( if let Ok(row) = sqlx::query!(
"SELECT id FROM users WHERE username = $1 AND password = $2", "SELECT id FROM users WHERE username = $1 AND password = $2",
cred.username, cred.username,
@@ -82,41 +83,47 @@ pub async fn login(
let session = Session::new(row.id as usize); let session = Session::new(row.id as usize);
if let Err(e) = session.commit(&mut db).await { if let Err(e) = session.commit(&mut db).await {
eprintln!("Failed to create session: {}", e); eprintln!("Failed to create session: {}", e);
return Err(e.to_string()); return Err(Status::InternalServerError);
} }
jar.add_private(("session", session.token)); jar.add_private(("session", session.token));
return Ok(Json("Signup successful".to_string())); return Ok(Redirect::to("/chat"));
} }
// TODO: implement actual login logic, e.g. verify password and generate token // TODO: implement actual login logic, e.g. verify password and generate token
Err("login failed".to_string()) Err(Status::Unauthorized)
} }
#[get("/totp")] #[get("/totp")]
pub async fn mfa_page(session: Session) -> Template { pub async fn mfa_page(_session: Session) -> Template {
Template::render("2fa", context!()) Template::render("2fa", context!())
} }
#[get("/api/totp.jpg")] #[derive(Serialize)]
pub async fn get_totp(s: Session) -> Option<QrCodeImage> { pub struct QrResponse {
qr_code: String,
}
#[get("/totp.jpg")]
pub async fn get_totp(totp: TOTPCode) -> Option<Json<QrResponse>> {
let totp = TOTP::new( let totp = TOTP::new(
Algorithm::SHA1, Algorithm::SHA1,
6, 6,
1, 1,
30, 30,
Secret::generate_secret().to_bytes().unwrap(), totp.secret.as_bytes().into(),
Some("Github".to_string()), Some("chat.zxq5.dev".to_string()),
format!("{}", s.user_id), format!("{}", totp.user_id),
) )
.unwrap(); .unwrap();
let qr = totp.get_qr_base64().unwrap(); let qr = totp.get_qr_base64().unwrap();
let data_uri = format!("data:image/png;base64,{}", qr);
Some(QrCodeImage(qr.into())) Some(Json(QrResponse { qr_code: data_uri }))
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Session { pub struct Session {
pub token: String, pub token: String,
pub user_id: usize, pub user_id: usize,
@@ -145,6 +152,56 @@ impl Session {
} }
} }
pub struct TOTPCode {
user_id: usize,
secret: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for TOTPCode {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
let user = try_outcome!(request.guard::<Session>().await);
let mut pool = match request.guard::<Connection<DbConn>>().await {
Outcome::Success(pool) => pool,
_ => return Outcome::Error((Status::Unauthorized, ())),
};
let (enabled, mut secret) = match sqlx::query!(
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
user.user_id as i32,
)
.fetch_one(&mut **pool)
.await
{
Ok(row) => (row.twofa_enabled, row.totp_secret),
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
};
if !enabled || secret.is_none() {
secret = Some(Secret::generate_secret().to_string());
match sqlx::query!(
"UPDATE users SET totp_secret = $1, twofa_enabled = true WHERE id = $2",
secret.as_ref().unwrap(),
user.user_id as i32,
)
.execute(&mut **pool)
.await
{
Ok(_) => (),
Err(_) => return Outcome::Error((Status::InternalServerError, ())),
}
}
Outcome::Success(TOTPCode {
user_id: user.user_id,
secret: secret.unwrap(),
})
}
}
#[rocket::async_trait] #[rocket::async_trait]
impl<'r> FromRequest<'r> for Session { impl<'r> FromRequest<'r> for Session {
type Error = (); type Error = ();
@@ -178,18 +235,3 @@ impl<'r> FromRequest<'r> for Session {
} }
} }
} }
use rocket::http::ContentType;
use rocket::response::{self, Responder, Response};
use std::io::Cursor;
pub struct QrCodeImage(Vec<u8>);
impl<'r> Responder<'r, 'static> for QrCodeImage {
fn respond_to(self, _: &'r rocket::Request<'_>) -> response::Result<'static> {
Response::build()
.header(ContentType::PNG)
.sized_body(self.0.len(), Cursor::new(self.0))
.ok()
}
}
+2 -2
View File
@@ -60,10 +60,10 @@ impl LlmWorker {
.map_err(|_| String::from("Failed to make request to LLM server"))?; .map_err(|_| String::from("Failed to make request to LLM server"))?;
Ok(ChatMsg { Ok(ChatMsg {
display_name: message.display_name.clone(), display_name: Some(String::from("lmstudio")),
user_id: 0, user_id: 0,
text: llm_resp.choices[0].message.content.clone(), text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now().timestamp() as usize, timestamp: chrono::Utc::now().timestamp_millis() as usize,
}) })
} }
} }
+10 -7
View File
@@ -34,12 +34,15 @@ async fn users(_ag: Session, mut db: Connection<DbConn>) -> Json<Vec<i32>> {
} }
#[get("/users/<id>", rank = 1)] #[get("/users/<id>", rank = 1)]
async fn username_for_id(id: usize, _ag: Session, mut db: Connection<DbConn>) -> String { async fn display_name(id: usize, _ag: Session, mut db: Connection<DbConn>) -> String {
sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32) sqlx::query!(
.fetch_one(&mut **db) "SELECT display_name, username FROM users WHERE id = $1",
.await id as i32
.map(|row| row.username) )
.unwrap_or_else(|_| "User not found".to_string()) .fetch_one(&mut **db)
.await
.map(|row| row.display_name.unwrap_or(row.username))
.unwrap_or_else(|_| "User not found".to_string())
} }
#[launch] #[launch]
@@ -80,7 +83,7 @@ fn rocket() -> Rocket<Build> {
messages::post_message, messages::post_message,
messages::event_stream, messages::event_stream,
users, users,
username_for_id, display_name,
auth::signup, auth::signup,
auth::login, auth::login,
auth::get_totp, auth::get_totp,
+35 -17
View File
@@ -1,6 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use rocket::{ use rocket::{
Shutdown,
response::stream::{Event, EventStream}, response::stream::{Event, EventStream},
serde::json::Json, serde::json::Json,
time::OffsetDateTime, time::OffsetDateTime,
@@ -9,9 +10,9 @@ use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context}; use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::prelude::FromRow; use sqlx::prelude::FromRow;
use tokio::sync::broadcast; use tokio::{select, sync::broadcast};
use crate::{auth::Session, db::DbConn, llm::LlmWorker}; use crate::{auth::Session, db::DbConn, display_name, llm::LlmWorker};
/// ---------- shared broadcaster ---------- /// ---------- shared broadcaster ----------
pub struct ChatBroadcaster { pub struct ChatBroadcaster {
@@ -54,14 +55,25 @@ pub async fn post_message(
let chat = chat.inner().clone(); let chat = chat.inner().clone();
let display_name = sqlx::query!(
"SELECT display_name, username FROM users WHERE id = $1",
session.user_id as i32
)
.fetch_one(&mut **db)
.await
.map(|row| row.display_name.unwrap_or(row.username))
.unwrap_or_else(|_| "Unknown".to_string());
msg.user_id = session.user_id; msg.user_id = session.user_id;
msg.display_name = Some(display_name);
chat.publish(msg.clone().into_inner()).await; chat.publish(msg.clone().into_inner()).await;
sqlx::query!( sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content) VALUES ($1, $2, $3)", "INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
CHANNEL_ID, CHANNEL_ID,
msg.user_id as i32, msg.user_id as i32,
msg.text msg.text,
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
) )
.execute(&mut **db) .execute(&mut **db)
.await .await
@@ -71,15 +83,15 @@ pub async fn post_message(
tokio::spawn(async move { tokio::spawn(async move {
let response = LlmWorker::new(LMSTUDIO_URI.to_string()).query(&msg).await; let response = LlmWorker::new(LMSTUDIO_URI.to_string()).query(&msg).await;
if let Ok(message) = response { if let Ok(reply) = response {
chat.publish(message.clone()).await; chat.publish(reply.clone()).await;
sqlx::query!( sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)", "INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
CHANNEL_ID, CHANNEL_ID,
message.user_id as i32, reply.user_id as i32,
message.text, reply.text,
OffsetDateTime::from_unix_timestamp(message.timestamp as i64).unwrap() OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
) )
.execute(&mut **db) .execute(&mut **db)
.await .await
@@ -106,7 +118,7 @@ pub async fn get_messages(mut db: Connection<DbConn>, _session: Session) -> Json
display_name: Some(msg.display_name.unwrap_or(msg.username)), display_name: Some(msg.display_name.unwrap_or(msg.username)),
user_id: msg.id as usize, user_id: msg.id as usize,
text: msg.content, text: msg.content,
timestamp: msg.created_at.unwrap().unix_timestamp() as usize, timestamp: (msg.created_at.unwrap().unix_timestamp_nanos() / 1_000_000) as usize,
}) })
.collect(), .collect(),
) )
@@ -118,6 +130,7 @@ pub async fn event_stream(
chat: &rocket::State<Arc<ChatBroadcaster>>, chat: &rocket::State<Arc<ChatBroadcaster>>,
db: Connection<DbConn>, db: Connection<DbConn>,
ag: Session, ag: Session,
mut shutdown: Shutdown,
) -> EventStream![] { ) -> EventStream![] {
let mut rx = chat.subscribe(); let mut rx = chat.subscribe();
@@ -128,18 +141,23 @@ pub async fn event_stream(
} }
loop { loop {
match rx.recv().await { select!{
Ok(msg) => yield Event::json(&msg), // exit early on shutdown
Err(broadcast::error::RecvError::Lagged(_)) => { _ = &mut shutdown => break,
yield Event::comment("RecvError::Lagged");
} msg = rx.recv() => match msg {
Err(broadcast::error::RecvError::Closed) => break, Ok(msg) => yield Event::json(&msg),
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => break,
},
} }
} }
} }
} }
#[get("/")] #[get("/chat")]
pub async fn chat_page(session: Session) -> Template { pub async fn chat_page(session: Session) -> Template {
Template::render("chat", context!(user_id: session.user_id)) Template::render("chat", context!(user_id: session.user_id))
} }
+4 -4
View File
@@ -8,9 +8,9 @@ body {
font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif; font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
background: #0a0a0a; background: #0a0a0a;
color: #e0e0e0; color: #e0e0e0;
min-height: 100vh; max-height: 100dvh;
min-width: 100vw; min-width: 100dvw;
height: 100vh; height: 100dvh;
overflow: hidden; overflow: hidden;
} }
@@ -33,7 +33,7 @@ body {
.chat-container { .chat-container {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
height: 100vh; height: 100dvh;
min-width: 100vw; min-width: 100vw;
margin: 0 0; margin: 0 0;
background: #121212; background: #121212;
+30 -38
View File
@@ -3,54 +3,46 @@
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Discord Clone - Group Chat</title> <title>Discord Clone - Sign Up</title>
<link rel="stylesheet" href="static/css/index.css"/> <link rel="stylesheet" href="static/css/index.css"/>
</head> </head>
<body> <body>
<div class="chat-container"> <div class="signup-container">
<!--<div class="chat-container" style="background-image: url('cdn/background.png'); backdrop-filter: blur(10px); background-size: cover; background-position: center; background-repeat: no-repeat;">--> <div class="signup-header">
<!-- Chat Header --> <div class="logo">DC</div>
<div class="chat-header"> <h1>2FA Setup</h1>
<div class="chat-title">
<img class="user-avatar" src="cdn/profile/0"></img>
<h1>Chat title</h1>
</div>
</div> </div>
<!-- Live Location Notification Bubble --> <div class="signup-form">
<div class="notification-container"> <div class="success-message" id="successMessage">
<div class="live-location-bubble" id="locationBubble"> Login successful, Redirecting
<div class="map-container">
<img src="cdn/map.png" alt="Map" />
</div>
<div class="location-content">
<div class="location-icon">
<img src="cdn/icons/location.svg" alt="Location"></img>
</div>
<button class="join-button" id="joinButton">
Join
</button>
<div class="location-text">Live Location</div>
<div class="location-users" id="locationUsers">
<!-- Users will be added dynamically -->
</div>
</div>
</div> </div>
</div>
<!-- Messages Container --> <img id="qr-code" alt="QR Code" style="width: 100%; height: auto; filter: brightness(0.925) invert(1);">
<!--<div class="messages-container" style="background-image: url('cdn/background.png'); backdrop-filter: blur(10px); background-size: cover; background-position: center; background-repeat: no-repeat;">-->
<div class="messages-container"></div>
<!-- Input Container --> <div class="form-group">
<div class="input-container"> <input
<div class="input-wrapper"> type="text"
<input type="text" placeholder="Start Typing..." /> inputmode="numeric"
<button class="send-button"> pattern="[0-9]*"
<img src="cdn/icons/send.svg" alt="Send" /> maxlength="6"
</button> placeholder="000000"
style="font-size: 24px; letter-spacing: 0.5em; text-align: center; width: 100%;"
>
</div> </div>
<button type="button" class="submit-button" id="submitButton">
Confirm!
</button>
</div> </div>
</div> </div>
<script>
fetch('/api/totp.jpg')
.then(response => response.json())
.then(data => {
document.getElementById('qr-code').src = data.qr_code;
});
</script>
</body> </body>
</html> </html>
+22 -77
View File
@@ -5,6 +5,9 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Discord Clone - Group Chat</title> <title>Discord Clone - Group Chat</title>
<link rel="stylesheet" href="static/css/index.css"/> <link rel="stylesheet" href="static/css/index.css"/>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.11.1/styles/default.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.11.1/highlight.min.js"></script>
</head> </head>
<body> <body>
<div class="chat-container"> <div class="chat-container">
@@ -13,7 +16,7 @@
<div class="chat-header"> <div class="chat-header">
<div class="chat-title"> <div class="chat-title">
<img class="user-avatar" src="cdn/profile/0"></img> <img class="user-avatar" src="cdn/profile/0"></img>
<h1>Chat title</h1> <h1>Wish.com Discord frfr</h1>
</div> </div>
</div> </div>
@@ -53,84 +56,26 @@
</div> </div>
</div> </div>
<script> <script type="module">
import markdownit from 'https://cdn.jsdelivr.net/npm/markdown-it@14.1.0/+esm';
const md = markdownit({
html: true,
linkify: true,
typographer: true,
highlight: function (str, lang) {
if (lang && hljs.getLanguage(lang)) {
try {
return hljs.highlight(str, { language: lang }).value;
} catch (__) {}
}
return ''; // use external default escaping
}
})
const user_id = {{ user_id }}; const user_id = {{ user_id }};
var users = {}; var users = {};
// Location tracker state
const locationUsers = [
{ id: 1, color: "linear-gradient(135deg, #ff6b6b, #ff8e8e)" },
{ id: 2, color: "linear-gradient(135deg, #4ecdc4, #7fdbda)" },
{ id: 3, color: "linear-gradient(135deg, #45b7d1, #6cc5e0)" },
];
let hiddenUsersCount = 5; // Users not shown in the visible stack
let currentUserInLocation = false;
function updateLocationUsers() {
const container = document.getElementById("locationUsers");
container.innerHTML = "";
const maxVisible = 3;
const visibleUsers = locationUsers.slice(0, maxVisible);
visibleUsers.forEach((user) => {
const pic = document.createElement("div");
pic.className = "location-user-pic";
pic.style.background = user.color;
container.appendChild(pic);
});
// Calculate total hidden users (hidden users + current user if tracking)
const totalHidden =
hiddenUsersCount + (currentUserInLocation ? 1 : 0);
// Add count indicator if there are more users
if (totalHidden > 0) {
const count = document.createElement("div");
count.className = "location-count";
count.textContent = `+${totalHidden}`;
container.appendChild(count);
}
}
// Toggle location tracking
// const locationBubble = document.getElementById("locationBubble");
// const joinButton = document.getElementById("joinButton");
// let isExpanded = false;
// locationBubble.addEventListener("click", function (e) {
// // Don't toggle expanded state if clicking the join button
// if (e.target.id === "joinButton") return;
// isExpanded = !isExpanded;
// this.classList.toggle("expanded", isExpanded);
// });
// joinButton.addEventListener("click", function (e) {
// e.stopPropagation();
// currentUserInLocation = !currentUserInLocation;
// this.classList.toggle("active", currentUserInLocation);
// this.textContent = currentUserInLocation ? "Leave" : "Join";
// locationBubble.classList.toggle(
// "active",
// currentUserInLocation,
// );
// // Animate click
// this.style.transform = "scale(0.95)";
// setTimeout(() => {
// this.style.transform = "";
// }, 150);
// updateLocationUsers();
// });
// // Initialize location users
// updateLocationUsers();
// Handle message sending // Handle message sending
const input = document.querySelector("input"); const input = document.querySelector("input");
const sendButton = document.querySelector(".send-button"); const sendButton = document.querySelector(".send-button");
@@ -154,7 +99,7 @@
<span class="username">${message.display_name}</span> <span class="username">${message.display_name}</span>
<span class="timestamp">${date}</span> <span class="timestamp">${date}</span>
</div> </div>
<div class="message-text">${message.text}</div> <div class="message-text">${md.render(message.text)}</div>
</div> </div>
`; `;
messagesContainer.appendChild(messageEl); messagesContainer.appendChild(messageEl);
+2 -2
View File
@@ -126,8 +126,8 @@
// Optional: Redirect after success // Optional: Redirect after success
setTimeout(() => { setTimeout(() => {
// window.location.href = '/chat'; window.location.href = '/chat';
console.log("Redirecting to chat..."); // console.log("Redirecting to chat...");
}, 2000); }, 2000);
} else { } else {
const error = await response.json(); const error = await response.json();
+2 -2
View File
@@ -242,8 +242,8 @@
// Optional: Redirect after success // Optional: Redirect after success
setTimeout(() => { setTimeout(() => {
// window.location.href = '/chat'; window.location.href = '/chat';
console.log("Redirecting to chat..."); // console.log("Redirecting to chat...");
}, 2000); }, 2000);
} else { } else {
const error = await response.json(); const error = await response.json();