diff --git a/backend/.idea/.gitignore b/backend/.idea/.gitignore new file mode 100644 index 0000000..ab1f416 --- /dev/null +++ b/backend/.idea/.gitignore @@ -0,0 +1,10 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Ignored default folder with query files +/queries/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 6115099..c2318ee 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -10,7 +10,7 @@ dotenv = "0.15.0" futures-util = "0.3.31" image = "0.25.8" jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] } -rand = "0.9.2" +rand = "0.8" redis = { version = "0.25.4", features = ["tokio-comp"] } reqwest = { version = "0.12.23", features = ["json"] } rocket = { version = "0.5.1", features = ["json", "secrets"] } @@ -20,8 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" sha2 = "0.10.9" -sqlx = { version = "0.7.4", features = ["macros", "time"] } +sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] } tokio = { version = "1.47.1", features = ["full"] } totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] } tracing = "0.1.44" uuid = { version = "1.18.1", features = ["v4"] } +thiserror = "1.0.69" +utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] } +clap = { version = "4.5", features = ["derive"] } diff --git a/backend/migrations/20251001212022_prototype-v1.sql b/backend/migrations/20251001212022_prototype-v1.sql deleted file mode 100644 index 857f28d..0000000 --- a/backend/migrations/20251001212022_prototype-v1.sql +++ /dev/null @@ -1,49 +0,0 @@ --- Add migration script here -CREATE TABLE users ( - id SERIAL PRIMARY KEY, - username VARCHAR(50) UNIQUE NOT NULL, - password VARCHAR(50) NOT NULL, - display_name VARCHAR(50), - created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP -); - -CREATE TABLE channels ( - id SERIAL PRIMARY KEY, - name VARCHAR(50) NOT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP -); - -CREATE TABLE messages ( - id SERIAL PRIMARY KEY, - channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, - user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - content TEXT NOT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, - is_edited BOOLEAN DEFAULT FALSE -); - -create table attachments ( - id SERIAL PRIMARY KEY, - message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE, - path TEXT NOT NULL -); - -CREATE INDEX idx_messages_channel_id ON messages (channel_id, id DESC); -CREATE INDEX idx_new_messages ON messages(created_at DESC); - --- Create a function to update the updated_at timestamp -CREATE OR REPLACE FUNCTION update_updated_at_column() -RETURNS TRIGGER AS $$ -BEGIN - NEW.updated_at = CURRENT_TIMESTAMP; - RETURN NEW; -END; -$$ language 'plpgsql'; - --- Create trigger for messages table -CREATE TRIGGER update_messages_updated_at - BEFORE UPDATE ON messages - FOR EACH ROW - EXECUTE FUNCTION update_updated_at_column(); diff --git a/backend/migrations/20251005014201_session-tokens.sql b/backend/migrations/20251005014201_session-tokens.sql deleted file mode 100644 index 42a2be4..0000000 --- a/backend/migrations/20251005014201_session-tokens.sql +++ /dev/null @@ -1,20 +0,0 @@ --- Add migration script here -CREATE TABLE sessions ( - id SERIAL PRIMARY KEY, - user_id INTEGER NOT NULL REFERENCES users(id), - token TEXT NOT NULL UNIQUE, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '7 days' -); - -CREATE OR REPLACE FUNCTION cleanup_expired_sessions() -RETURNS TRIGGER AS $$ -BEGIN - DELETE FROM sessions WHERE expires_at < NOW(); - RETURN NULL; -END; -$$ LANGUAGE plpgsql; - -CREATE TRIGGER trigger_cleanup_sessions - AFTER INSERT ON sessions - EXECUTE FUNCTION cleanup_expired_sessions(); diff --git a/backend/migrations/20251009091103_auth-v2.sql b/backend/migrations/20251009091103_auth-v2.sql deleted file mode 100644 index 5174a46..0000000 --- a/backend/migrations/20251009091103_auth-v2.sql +++ /dev/null @@ -1,5 +0,0 @@ --- Add migration script here -ALTER TABLE users ADD COLUMN email VARCHAR(100); -ALTER TABLE users ADD COLUMN twofa_enabled BOOLEAN DEFAULT FALSE; -ALTER TABLE users ADD COLUMN totp_secret VARCHAR(64); -ALTER TABLE users ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP; diff --git a/backend/migrations/20251010002253_auth-v3.sql b/backend/migrations/20251010002253_auth-v3.sql deleted file mode 100644 index a5831ac..0000000 --- a/backend/migrations/20251010002253_auth-v3.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add migration script here -ALTER TABLE users ALTER COLUMN twofa_enabled SET NOT NULL; diff --git a/backend/migrations/20251010231008_invite-codes.sql b/backend/migrations/20251010231008_invite-codes.sql deleted file mode 100644 index 9aabb63..0000000 --- a/backend/migrations/20251010231008_invite-codes.sql +++ /dev/null @@ -1,18 +0,0 @@ --- Add migration script here -CREATE TABLE access_codes ( - -- identifiers - id SERIAL PRIMARY KEY, - creator_id INTEGER NOT NULL REFERENCES users(id), - - -- code data - code VARCHAR(255) NOT NULL, - name VARCHAR(255) NOT NULL, - - -- uses - uses INTEGER NOT NULL DEFAULT 0, - max_uses INTEGER NOT NULL DEFAULT 1, - - -- time data - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '1 day' -); diff --git a/backend/migrations/20251011002049_invite-codes-v2.sql b/backend/migrations/20251011002049_invite-codes-v2.sql deleted file mode 100644 index 3ae502b..0000000 --- a/backend/migrations/20251011002049_invite-codes-v2.sql +++ /dev/null @@ -1,10 +0,0 @@ --- Add migration script here -ALTER TABLE access_codes -ALTER COLUMN created_at -TYPE TIMESTAMP WITH TIME ZONE -USING created_at AT TIME ZONE 'UTC'; - -ALTER TABLE access_codes -ALTER COLUMN expires_at -TYPE TIMESTAMP WITH TIME ZONE -USING expires_at AT TIME ZONE 'UTC'; diff --git a/backend/migrations/20260331174422_password-hashing.sql b/backend/migrations/20260331174422_password-hashing.sql deleted file mode 100644 index d301a06..0000000 --- a/backend/migrations/20260331174422_password-hashing.sql +++ /dev/null @@ -1,6 +0,0 @@ --- Add migration script here -TRUNCATE TABLE users CASCADE; - -ALTER TABLE users DROP COLUMN password; -ALTER TABLE users ADD COLUMN pass_hash VARCHAR(255) NOT NULL; -ALTER TABLE users ADD CONSTRAINT users_username_unique UNIQUE (username); diff --git a/backend/migrations/20260331222425_friends.sql b/backend/migrations/20260331222425_friends.sql deleted file mode 100644 index c1d2b05..0000000 --- a/backend/migrations/20260331222425_friends.sql +++ /dev/null @@ -1,13 +0,0 @@ --- Add migration script here -CREATE TYPE status AS ENUM ('pending', 'accepted', 'blocked'); - -CREATE TABLE relationships ( - id SERIAL PRIMARY KEY, - from_user INTEGER NOT NULL REFERENCES users(id), - to_user INTEGER NOT NULL REFERENCES users(id), - status status NOT NULL DEFAULT 'pending', - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - CONSTRAINT no_self_relationship CHECK (from_user != to_user), - CONSTRAINT unique_relationship UNIQUE (from_user, to_user) -); diff --git a/backend/migrations/20260405234237_v0.4.0.sql b/backend/migrations/20260405234237_v0.4.0.sql new file mode 100644 index 0000000..0467948 --- /dev/null +++ b/backend/migrations/20260405234237_v0.4.0.sql @@ -0,0 +1,183 @@ +-- Add migration script here +-- Add migration script here +CREATE TYPE user_role AS ENUM ('user', 'admin'); +CREATE TYPE totp_status AS ENUM ('disabled', 'pending', 'enabled'); + +CREATE TABLE users ( + id BIGSERIAL PRIMARY KEY NOT NULL, + role user_role NOT NULL DEFAULT 'user', + + -- profile + nickname VARCHAR(255), + + -- basic auth + username VARCHAR(255) UNIQUE NOT NULL, + passhash VARCHAR(255) NOT NULL, + + -- email + email VARCHAR(255), + email_verified BOOLEAN DEFAULT FALSE, + + -- 2fa + totp_secret VARCHAR(255), + totp_status totp_status NOT NULL DEFAULT 'disabled', + + -- update tracking + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ +); + +CREATE TABLE access_tokens ( + id BIGSERIAL PRIMARY KEY NOT NULL, + creator_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + + code VARCHAR(255) NOT NULL, + name VARCHAR(255) NOT NULL, + + uses INTEGER NOT NULL DEFAULT 0, + max_uses INTEGER NOT NULL DEFAULT 1, + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '24 hours', + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE refresh_tokens ( + id BIGSERIAL PRIMARY KEY NOT NULL, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + token_hash VARCHAR(255) NOT NULL, + + revoked BOOLEAN NOT NULL DEFAULT FALSE, + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '7 days' +); + + +CREATE TABLE spaces ( + id BIGSERIAL PRIMARY KEY NOT NULL, + name VARCHAR(255) NOT NULL, + description TEXT, + + owner_id BIGINT NOT NULL REFERENCES users(id), + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE channels ( + id BIGSERIAL PRIMARY KEY NOT NULL, + name VARCHAR(255) NOT NULL, + description TEXT, + + space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE, + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + + +CREATE TABLE space_members ( + space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + + joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + role user_role DEFAULT 'user', + + PRIMARY KEY (space_id, user_id) +); + +CREATE TABLE messages ( + id BIGSERIAL PRIMARY KEY NOT NULL, + channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + content TEXT NOT NULL, + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + is_edited BOOLEAN DEFAULT FALSE +); + +CREATE TABLE attachments ( + id BIGSERIAL PRIMARY KEY NOT NULL, + message_id BIGINT NOT NULL REFERENCES messages(id) ON DELETE CASCADE, + filename VARCHAR(255) NOT NULL, + content_type VARCHAR(100) NOT NULL, -- mime type e.g. image/png, video/mp4 + size_bytes BIGINT NOT NULL, + url TEXT NOT NULL, -- path to file on your CDN/storage + width INTEGER, -- null for non-image/video + height INTEGER, -- null for non-image/video + duration_ms INTEGER, -- null for non-audio/video + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TYPE relationship_status AS ENUM ('pending', 'accepted', 'blocked'); + +CREATE TABLE relationships ( + id BIGSERIAL PRIMARY KEY NOT NULL, + + from_user BIGINT NOT NULL REFERENCES users(id), + to_user BIGINT NOT NULL REFERENCES users(id), + status relationship_status NOT NULL DEFAULT 'pending', + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT no_self_relationship CHECK (from_user != to_user), + CONSTRAINT unique_relationship UNIQUE (from_user, to_user) +); + +CREATE INDEX idx_messages_channel_id ON messages(channel_id, created_at DESC); +CREATE INDEX idx_messages_user_id ON messages(user_id); +CREATE INDEX idx_attachments_message ON attachments(message_id); +CREATE INDEX idx_channels_space_id ON channels(space_id); +CREATE INDEX idx_space_members_user ON space_members(user_id); +CREATE INDEX idx_refresh_tokens_hash ON refresh_tokens(token_hash); +CREATE INDEX idx_relationships_from ON relationships(from_user, to_user); +CREATE INDEX idx_relationships_to ON relationships(to_user); +CREATE INDEX idx_access_tokens_code ON access_tokens(code); + +CREATE OR REPLACE FUNCTION update_updated_at() + RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER users_updated_at + BEFORE UPDATE ON users + FOR EACH ROW EXECUTE FUNCTION update_updated_at(); + +CREATE TRIGGER spaces_updated_at + BEFORE UPDATE ON spaces + FOR EACH ROW EXECUTE FUNCTION update_updated_at(); + +CREATE TRIGGER channels_updated_at + BEFORE UPDATE ON channels + FOR EACH ROW EXECUTE FUNCTION update_updated_at(); + +CREATE TRIGGER messages_updated_at + BEFORE UPDATE ON messages + FOR EACH ROW EXECUTE FUNCTION update_updated_at(); + +CREATE TRIGGER relationships_updated_at + BEFORE UPDATE ON relationships + FOR EACH ROW EXECUTE FUNCTION update_updated_at(); + +CREATE TRIGGER access_tokens_updated_at + BEFORE UPDATE ON access_tokens + FOR EACH ROW EXECUTE FUNCTION update_updated_at(); + +CREATE OR REPLACE FUNCTION add_owner_to_space() + RETURNS TRIGGER AS $$ +BEGIN + INSERT INTO space_members (space_id, user_id, role, joined_at) + VALUES (NEW.id, NEW.owner_id, 'admin', NOW()); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER space_owner_becomes_member + AFTER INSERT ON spaces + FOR EACH ROW EXECUTE FUNCTION add_owner_to_space(); \ No newline at end of file diff --git a/backend/src/auth/session.rs b/backend/src/api/auth.rs similarity index 63% rename from backend/src/auth/session.rs rename to backend/src/api/auth.rs index 0997744..bc31739 100644 --- a/backend/src/auth/session.rs +++ b/backend/src/api/auth.rs @@ -1,21 +1,46 @@ -use std::{ - sync::LazyLock, - time::{SystemTime, UNIX_EPOCH}, -}; +use crate::error::ApiResult; +use crate::model::auth::{AccessTokenForm, AuthResponse, LoginCredentials, SignupCredentials}; +use crate::svc::auth_svc::AuthService; +use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; +use rocket::http::Status; +use rocket::request::{FromRequest, Outcome}; +use rocket::serde::json::Json; +use rocket::serde::{Deserialize, Serialize}; +use rocket::{Request, State}; +use std::sync::LazyLock; +use std::time::{SystemTime, UNIX_EPOCH}; +use crate::svc::access_token_svc::AccessTokenService; -use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode}; -use rand::Rng; -use rocket::{ - Request, - http::Status, - request::{self, FromRequest, Outcome}, -}; -use rocket_db_pools::Connection; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256, digest::block_buffer::Lazy}; -use sqlx::postgres::PgQueryResult; +#[post("/signup", data = "")] +pub async fn signup( + cred: Json, + svc: &State +) -> ApiResult> { + let response = svc + .signup( + &cred.email, &cred.username, &cred.password, &cred.access_token, + ).await?; + Ok(Json(response)) +} -use crate::db::Postgres; +#[post("/login", data = "")] +pub async fn login( + cred: Json, + svc: &State +) -> ApiResult> { + Ok(Json(svc.login(&cred.username, &cred.password).await?)) +} + +#[post("/invite", data = "
")] +pub async fn generate_invite( + session: Session, + form: Json, + svc: &State +) -> ApiResult { + svc.create( + session.uid, &form.name, form.max_uses, + form.start_date, form.expiry_date).await +} static JWT_SECRET: LazyLock = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap()); @@ -27,7 +52,7 @@ pub enum TokenScope { } pub struct Session { - pub user_id: usize, + pub uid: i64, } #[rocket::async_trait] @@ -37,7 +62,7 @@ impl<'r> FromRequest<'r> for Session { async fn from_request(req: &'r Request<'_>) -> Outcome { match Claims::from_request(req).await { Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session { - user_id: user.sub as usize, + uid: user.sub as i64, }), Outcome::Success(_) => { eprintln!("warning: user with scope other than Full attempted to access session"); @@ -106,4 +131,4 @@ impl<'r> FromRequest<'r> for Claims { } } } -} +} \ No newline at end of file diff --git a/backend/src/cdn.rs b/backend/src/api/cdn.rs similarity index 98% rename from backend/src/cdn.rs rename to backend/src/api/cdn.rs index 82ada9b..84fa347 100644 --- a/backend/src/cdn.rs +++ b/backend/src/api/cdn.rs @@ -26,7 +26,7 @@ pub async fn profile_pic(user_id: usize) -> Option { Some(image) } else { Some( - NamedFile::open("./cdn/profiles/full/default.svg") + NamedFile::open("../../cdn/profiles/full/default.svg") .await .ok()?, ) diff --git a/backend/src/api/chat.rs b/backend/src/api/chat.rs new file mode 100644 index 0000000..63cda73 --- /dev/null +++ b/backend/src/api/chat.rs @@ -0,0 +1,70 @@ +use crate::api::auth::Session; +use crate::error::ApiResult; +use crate::svc::chat_svc::ChatService; +use chrono::{DateTime, Utc}; +use rocket::response::stream::Event; +use rocket::serde::json::Json; +use rocket::serde::{Deserialize, Serialize}; +use rocket::{Shutdown, State, ___internal_EventStream as EventStream}; +use sqlx::FromRow; +use tokio::select; +use tokio::sync::broadcast; + +/// ---------- Rocket routes ---------- +#[derive(Debug, Serialize, Deserialize, Clone, FromRow)] +pub struct ChatMsg { + pub display_name: Option, + pub user_id: i64, + pub text: String, + pub timestamp: DateTime, +} + +#[post("/chat/", format = "json", data = "")] +pub async fn post_message( + msg: Json, + chat: &State, + session: Session, + channel_id: i64, +) -> ApiResult<()> { + chat.send(channel_id, session.uid, &msg.text, Utc::now()).await +} + +#[get("/events/")] +pub async fn event_stream( + chat: &State, + s: Session, + mut shutdown: Shutdown, + channel_id: i64, +) -> ApiResult { + let messages = chat.get_messages(channel_id, 100) + .await?; // if get message returned err, inform user. + + let mut rx = chat.subscribe(channel_id).await; + let id = s.uid; + + Ok(EventStream! { + for msg in messages { + yield Event::json(&msg); + } + + loop { + select!{ + _ = &mut shutdown => break, // exit early on shutdown + msg = rx.recv() => match msg { + Ok(msg) => { + tracing::info!("yielding message!"); + yield Event::json(&msg) + }, + Err(broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",); + yield Event::comment("RecvError::Lagged"); + } + Err(broadcast::error::RecvError::Closed) => { + tracing::info!("Broadcaster hung up on channel {channel_id}!"); + break + }, + }, + } + } + }) +} \ No newline at end of file diff --git a/backend/src/api/mod.rs b/backend/src/api/mod.rs new file mode 100644 index 0000000..444c985 --- /dev/null +++ b/backend/src/api/mod.rs @@ -0,0 +1,7 @@ +pub mod auth; +pub mod chat; +pub mod totp; +pub mod settings; +pub mod cdn; +pub mod profile; +pub mod space; \ No newline at end of file diff --git a/backend/src/api/profile.rs b/backend/src/api/profile.rs new file mode 100644 index 0000000..e44e9b2 --- /dev/null +++ b/backend/src/api/profile.rs @@ -0,0 +1,13 @@ +use rocket::State; +use crate::api::auth::Session; +use crate::error::ApiResult; +use crate::svc::user_svc::UserService; + +#[get("/users/")] +pub async fn display_name( + id: i64, + _ag: Session, + svc: &State, +) -> ApiResult { + svc.get_username(id).await +} diff --git a/backend/src/api/settings.rs b/backend/src/api/settings.rs new file mode 100644 index 0000000..662b9af --- /dev/null +++ b/backend/src/api/settings.rs @@ -0,0 +1,68 @@ +use crate::api::auth::Session; +use crate::error::ApiResult; +use crate::svc::settings_svc::SettingsService; +use rocket::serde::json::Json; +use rocket::serde::{Deserialize, Serialize}; +use rocket::State; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PasswordForm { + pub old_password: String, + pub new_password: String, +} + +#[post("/settings/password", data = "")] +pub async fn change_password( + session: Session, + form: Json, + settings: &State +) -> ApiResult<()> { + settings.change_password( + session.uid, &form.old_password, &form.new_password + ).await +} + +#[derive(Deserialize, Debug, Clone)] +pub struct DisplayNameForm { + pub display_name: Option, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct PasswordAnd2faForm { + pub password: String, + pub totp_code: Option, +} + +#[delete("/settings", data = "")] +pub async fn delete_account( + session: Session, + data: Json, + settings: &State +) -> ApiResult<()> { + settings.delete_account( + session.uid, &data.password, &data.totp_code + ).await +} + +#[patch("/settings/display_name", data = "")] +pub async fn change_display_name( + session: Session, + new: Json, + settings: &State +) -> ApiResult<()> { + settings.change_display_name(session.uid, new.display_name.clone()).await +} + +#[derive(Deserialize)] +pub struct UsernameForm { + pub username: String, +} + +#[patch("/settings/username", data = "")] +pub async fn change_username( + session: Session, + new: Json, + settings: &State +) -> ApiResult<()> { + settings.change_username(session.uid, &new.username).await +} \ No newline at end of file diff --git a/backend/src/api/space.rs b/backend/src/api/space.rs new file mode 100644 index 0000000..85dca5c --- /dev/null +++ b/backend/src/api/space.rs @@ -0,0 +1,36 @@ +use crate::error::ApiResult; +use crate::model::space::{Space, SpaceDto}; +use crate::model::space::Channel; +use crate::repo::{SpaceRepo, ChannelRepo}; +use rocket::serde::json::Json; +use rocket::State; +use std::sync::Arc; +use crate::api::auth::Session; +use crate::svc::chat_svc::ChatService; + +#[get("/spaces")] +pub async fn list_spaces( + space_repo: &State> +) -> ApiResult>> { + let spaces = space_repo.get_all().await?; + Ok(Json(spaces)) +} + +#[get("/spaces//channels")] +pub async fn list_channels( + space_id: i64, + channel_repo: &State> +) -> ApiResult>> { + let channels = channel_repo.get_by_space_id(space_id).await?; + Ok(Json(channels)) +} + +#[get("/accessible_channels")] +pub async fn get_accessible_channels( + session: Session, + svc: &State +) -> ApiResult>> { + let space = svc.get_accessible_channels(session.uid).await?; + println!("{:?}", space); + Ok(Json(space)) +} \ No newline at end of file diff --git a/backend/src/api/totp.rs b/backend/src/api/totp.rs new file mode 100644 index 0000000..50400b4 --- /dev/null +++ b/backend/src/api/totp.rs @@ -0,0 +1,120 @@ +use crate::api::auth::{Claims, Session, TokenScope}; +use crate::error::{ApiResult, AppError}; +use crate::model::auth::AuthResponse; +use crate::svc::auth_svc::AuthService; +use rocket::serde::json::Json; +use rocket::serde::{Deserialize, Serialize}; +use rocket::State; +use totp_rs::{Algorithm, TOTP}; + +#[derive(Debug, Deserialize)] +pub struct TOTPSixDigitCode { + code: String, +} + +#[derive(Debug, sqlx::Type, Clone, Copy, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +#[sqlx(type_name = "totp_status", rename_all = "lowercase")] +pub enum TotpStatus { + Enabled, + Pending, + Disabled, +} + +#[derive(Serialize)] +pub struct QrResponse { + qr_code: String, +} + +#[derive(Deserialize)] +pub struct TotpVerifyRequest { + pub code: String, +} + +#[derive(Deserialize)] +pub struct PasswordConfirmation { + password: String, +} + +#[derive(Deserialize)] +pub struct PasswordAnd2fa { + pub password: String, + pub totp_code: String, +} + +pub fn totp_gen(user_id: i64, secret: &[u8]) -> ApiResult { + TOTP::new( + Algorithm::SHA1, + 6, + 1, + 30, + secret.to_owned(), + Some("chat.zxq5.dev".to_string()), + format!("{}", user_id), + ) + .map_err(|_| AppError::internal("failed to generate totp")) +} + +#[post("/totp", data = "")] +pub async fn confirm_totp( + user: Session, + form: Json, + svc: &State, +) -> ApiResult<()> { + svc.confirm_totp(user.uid, &form.code).await +} + +#[post("/totp.jpg", data = "")] +pub async fn get_totp( + user: Session, + form: Json, + svc: &State, +) -> ApiResult> { + let secret = svc.get_or_create_totp_secret(user.uid, &form.password).await?; + + let qr_b64 = totp_gen(user.uid, secret.as_bytes()) + .map_err(|_| AppError::internal("invalid totp secret"))? + .get_qr_base64() + .map_err(|_| AppError::internal("failed to generate qr code"))?; + + Ok(Json(QrResponse { + qr_code: format!("data:image/png;base64,{}", qr_b64), + })) +} + +#[get("/totp/status")] +pub async fn get_totp_status( + user: Session, + svc: &State, +) -> ApiResult> { + Ok(Json( + svc.get_totp_status(user.uid).await? + .then_some(TotpStatus::Enabled) + .unwrap_or(TotpStatus::Disabled), + )) +} + +#[delete("/totp", data = "")] +pub async fn disable_totp( + user: Session, + form: Json, + svc: &State, +) -> ApiResult> { + let response = svc.disable_totp(user.uid, &form.password, &form.totp_code).await?; + Ok(Json(response)) +} + +#[post("/totp/verify", data = "")] +pub async fn verify_totp( + claims: Claims, + body: Json, + svc: &State, +) -> ApiResult> { + // reject if they somehow got here with a full token + if claims.scope != TokenScope::TotpPending { + return Err(AppError::Forbidden); + } + + let response = svc.login_totp(claims.sub as i64, &body.code).await?; + Ok(Json(response)) +} \ No newline at end of file diff --git a/backend/src/auth/account.rs b/backend/src/auth/account.rs deleted file mode 100644 index 29bdef1..0000000 --- a/backend/src/auth/account.rs +++ /dev/null @@ -1,211 +0,0 @@ -use argon2::{ - Argon2, PasswordHash, PasswordHasher, PasswordVerifier, - password_hash::{SaltString, rand_core::OsRng}, -}; -use jsonwebtoken::{EncodingKey, Header, encode}; -use rocket::{ - http::{CookieJar, Status}, - response::{Redirect, status::BadRequest}, - serde::json::Json, - time::OffsetDateTime, -}; -use rocket_db_pools::Connection; -use rocket_dyn_templates::{Template, context}; -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -use crate::{ - auth::session::{Claims, Session, TokenScope}, - db::Postgres, - user::User, -}; - -#[derive(Serialize, Deserialize)] -pub struct SignupCredentials { - pub email: String, - pub username: String, - pub password: String, - pub access_token: String, -} - -#[derive(Serialize, Deserialize)] -pub struct LoginCredentials { - pub username: String, - pub password: String, -} - -#[derive(Serialize, Deserialize)] -pub struct AuthResponse { - pub token: String, - pub totp_required: bool, -} - -#[get("/signup")] -pub async fn signup_page() -> Template { - Template::render("signup", context!()) -} - -#[post("/signup", data = "")] -pub async fn signup( - cred: Json, - jar: &CookieJar<'_>, - mut db: Connection, -) -> Result, Status> { - let token_id = AccessToken::validate(&cred.access_token, &mut db) - .await - .map_err(|_| Status::Unauthorized)?; - - let salt = SaltString::generate(&mut OsRng); - let hashed = Argon2::default() - .hash_password(cred.password.as_bytes(), &salt) - .map_err(|_| Status::InternalServerError)? - .to_string(); - - let result = sqlx::query!( - "INSERT INTO users (email, username, pass_hash) VALUES ($1, $2, $3) RETURNING id", - cred.email, - cred.username, - hashed, - ) - .fetch_one(&mut **db) - .await - .map_err(|_| Status::InternalServerError)?; - - let jwt = Claims::new(result.id as usize, TokenScope::Full).encode(); - - token_id - .use_token(&mut db) - .await - .expect("unable to use access code"); - - Ok(Json(AuthResponse { - token: jwt, - totp_required: false, - })) -} - -#[get("/login")] -pub async fn login_page() -> Template { - Template::render("login", context!()) -} - -#[post("/login", data = "")] -pub async fn login( - mut db: Connection, - cred: Json, -) -> Result, Status> { - println!("e"); - let row = sqlx::query!( - "SELECT id, pass_hash, twofa_enabled FROM users WHERE username = $1", - cred.username, - ) - .fetch_one(&mut **db) - .await - .map_err(|_| Status::Unauthorized)?; - - println!("ok"); - - // verify password as before - let parsed_hash = PasswordHash::new(&row.pass_hash).map_err(|_| Status::InternalServerError)?; - Argon2::default() - .verify_password(cred.password.as_bytes(), &parsed_hash) - .map_err(|_| Status::Unauthorized)?; - - println!("ok2"); - - let user_id = row.id as usize; - - // issue either a partial or full token depending on 2FA status - let (session, totp_required) = if row.twofa_enabled { - (Claims::new(user_id, TokenScope::TotpPending), true) - } else { - (Claims::new(user_id, TokenScope::Full), false) - }; - - Ok(Json(AuthResponse { - token: session.encode(), - totp_required, - })) -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AccessTokenForm { - pub name: String, - pub max_uses: usize, - pub expiry_date: usize, - pub start_date: usize, -} - -#[get("/invite")] -pub async fn invite_page(_s: Session) -> Template { - Template::render("invite", context! {}) -} - -#[post("/invite", data = "")] -pub async fn generate_invite( - session: Session, - mut db: Connection, - form: Json, -) -> Result { - if form.start_date > form.expiry_date { - return Err(Status::BadRequest); - } - - let code = Uuid::new_v4().to_string(); - sqlx::query!( - "INSERT INTO access_codes (name, code, creator_id, max_uses, created_at, expires_at) - VALUES ($1, $2, $3, $4, $5, $6) RETURNING id", - form.name, - code, - session.user_id as i32, - form.max_uses as i32, - OffsetDateTime::from_unix_timestamp_nanos(form.start_date as i128 * 1_000_000).unwrap(), - OffsetDateTime::from_unix_timestamp_nanos(form.expiry_date as i128 * 1_000_000).unwrap() - ) - .fetch_one(&mut **db) - .await - .map_err(|_| Status::InternalServerError)?; - - Ok(code) -} - -pub struct AccessToken { - id: i32, - _code: String, -} - -impl AccessToken { - pub async fn validate( - token: &str, - db: &mut Connection, - ) -> Result { - match sqlx::query!( - "SELECT id FROM access_codes - WHERE code = $1 - AND created_at < NOW() - AND expires_at > NOW() - AND uses < max_uses", - token - ) - .fetch_one(&mut ***db) - .await - { - Ok(row) => Ok(AccessToken { - id: row.id, - _code: token.to_string(), - }), - Err(_) => Err(String::from("Invalid or Expired token!")), - } - } - - pub async fn use_token(&self, db: &mut Connection) -> Result<(), String> { - sqlx::query!( - "UPDATE access_codes SET uses = uses + 1 WHERE id = $1", - self.id - ) - .execute(&mut ***db) - .await - .map_err(|_| String::from("Invalid or Expired token!"))?; - Ok(()) - } -} diff --git a/backend/src/auth/mod.rs b/backend/src/auth/mod.rs deleted file mode 100644 index 6cde613..0000000 --- a/backend/src/auth/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub mod account; -pub mod profile; -pub mod session; -pub mod two_factor; - -pub use session::Session; - -pub use account::{generate_invite, invite_page, login, login_page, signup, signup_page}; -pub use profile::{change_display_name, change_password, change_username, delete_account}; -pub use two_factor::{ - confirm_totp, disable_totp, get_totp, get_totp_status, mfa_page, verify_totp, -}; diff --git a/backend/src/auth/profile.rs b/backend/src/auth/profile.rs deleted file mode 100644 index a3eab8c..0000000 --- a/backend/src/auth/profile.rs +++ /dev/null @@ -1,143 +0,0 @@ -use argon2::{ - Argon2, PasswordHash, PasswordHasher, PasswordVerifier, - password_hash::{SaltString, rand_core::OsRng}, -}; -use rocket::{http::Status, serde::json::Json}; -use rocket_db_pools::Connection; -use serde::{Deserialize, Serialize}; - -use crate::{auth::Session, db::Postgres, user::User}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PasswordForm { - old_password: String, - new_password: String, -} - -#[post("/settings/password", data = "")] -pub async fn change_password( - session: Session, - mut db: Connection, - form: Json, -) -> Result<(), Status> { - let mut user = User::get_by_id(session.user_id, &mut db) - .await - .ok_or(Status::NotFound) - .inspect_err(|_| { - tracing::error!( - "Valid session does not have a valid user. ID: {}", - session.user_id - ) - })?; - - user.verify_password(&form.old_password)?; - - // old password is correct, so new one can be set. - let salt = SaltString::generate(&mut OsRng); - let hashed = Argon2::default() - .hash_password(form.new_password.as_bytes(), &salt) - .inspect_err(|e| tracing::error!("failed to hash password! {e}")) - .map_err(|_| Status::InternalServerError)? - .to_string(); - - user.set_pass_hash(hashed, &mut db) - .await - .inspect_err(|e| tracing::error!("{e}")) - .map_err(|_| Status::InternalServerError)?; - - Ok(()) -} - -#[derive(Deserialize, Debug, Clone)] -pub struct DisplayNameForm { - pub display_name: Option, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct PasswordAnd2fa { - pub password: String, - pub totp_code: Option, -} - -#[delete("/settings", data = "")] -pub async fn delete_account( - session: Session, - mut db: Connection, - data: Json, -) -> Result<(), Status> { - let mut user = User::get_by_id(session.user_id, &mut db) - .await - .ok_or(Status::NotFound) - .inspect_err(|_| { - tracing::error!( - "Valid session does not have a valid user. ID: {}", - session.user_id - ) - })?; - - user.verify_password(&data.password)?; - - if user.twofa_enabled { - user.verify_2fa(data.totp_code.as_deref().unwrap_or(""))?; - } - - user.delete(&mut db) - .await - .inspect_err(|e| tracing::error!("{e}")) - .map_err(|_| Status::InternalServerError)?; - - Ok(()) -} - -#[patch("/settings/display_name", data = "")] -pub async fn change_display_name( - session: Session, - mut db: Connection, - new: Json, -) -> Result<(), Status> { - let mut user = User::get_by_id(session.user_id, &mut db) - .await - .ok_or(Status::NotFound) - .inspect_err(|_| { - tracing::error!( - "Valid session does not have a valid user. ID: {}", - session.user_id - ) - })?; - - user.set_display_name(new.display_name.clone(), &mut db) - .await - .inspect_err(|e| tracing::error!("{e}")) - .map_err(|_| Status::InternalServerError)?; - - Ok(()) -} - -#[derive(Deserialize)] -pub struct UsernameForm { - username: String, -} - -#[patch("/settings/username", data = "")] -pub async fn change_username( - session: Session, - mut db: Connection, - new: Json, -) -> Result<(), Status> { - let mut user = User::get_by_id(session.user_id, &mut db) - .await - .ok_or(Status::NotFound) - .inspect_err(|_| { - tracing::error!( - "Valid session does not have a valid user. ID: {}", - session.user_id - ) - })?; - - user.set_username(new.username.clone(), &mut db) - .await - .inspect_err(|e| tracing::error!("{e}")) - .map_err(|_| Status::InternalServerError)?; - - Ok(()) -} diff --git a/backend/src/auth/two_factor.rs b/backend/src/auth/two_factor.rs deleted file mode 100644 index c0fe8aa..0000000 --- a/backend/src/auth/two_factor.rs +++ /dev/null @@ -1,301 +0,0 @@ -use futures_util::TryFutureExt; -use rocket::{ - Request, - http::Status, - outcome::{Outcome, try_outcome}, - request::{self, FromRequest}, - response::status::{self}, - serde::json::Json, -}; -use rocket_db_pools::Connection; -use rocket_dyn_templates::{Template, context}; -use serde::{Deserialize, Serialize}; -use totp_rs::{Algorithm, Secret, TOTP}; - -use crate::{ - auth::{ - account::AuthResponse, - profile::PasswordAnd2fa, - session::{Claims, Session, TokenScope}, - }, - db::Postgres, - user::User, -}; - -// Utility methods - -pub fn totp_gen(user_id: usize, secret: &[u8]) -> Result { - TOTP::new( - Algorithm::SHA1, - 6, - 1, - 30, - secret.to_owned(), - Some("chat.zxq5.dev".to_string()), - format!("{}", user_id), - ) - .map_err(|_| String::from("Invalid Secret")) -} - -// pages - -#[get("/totp")] -pub async fn mfa_page(_session: Session) -> Template { - Template::render("2fa", context!()) -} - -#[post("/totp", data = "")] -pub async fn confirm_totp( - mfa: TOTPSecret, - form: Json, - mut db: Connection, -) -> Result<(), status::Custom<&'static str>> { - if form.code.len() != 6 || form.code.parse::().is_err() { - return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code")); - } - - println!("valid"); - - let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes()) - .map_err(|_| status::Custom(Status::InternalServerError, "TOTP Error"))?; - if !totp.check_current(&form.code).unwrap_or(false) { - return Err(status::Custom(Status::BadRequest, "Incorrect code")); - } - println!("correct"); - - if sqlx::query!( - "UPDATE users SET twofa_enabled = true WHERE id = $1", - mfa.user_id as i32 - ) - .execute(&mut **db) - .await - .is_err() - { - return Err(status::Custom( - Status::InternalServerError, - "unable to enable 2fa", - )); - }; - - println!("enabled"); - - Ok(()) -} - -#[derive(Deserialize)] -pub struct PasswordConfirmation { - password: String, -} - -#[post("/totp.jpg", data = "")] -pub async fn get_totp( - mfa: TOTPSecret, - form: Json, -) -> Option> { - let qr_b64 = totp_gen(mfa.user_id, mfa.secret.as_bytes()) - .expect("Invalid TOTP") - .get_qr_base64() - .unwrap(); - - Some(Json(QrResponse { - qr_code: format!("data:image/png;base64,{}", qr_b64), - })) -} - -#[derive(Debug, Deserialize)] -pub struct TOTPSixDigitCode { - code: String, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum TotpStatus { - Enabled, - Disabled, -} - -pub struct TOTPSecret { - user_id: usize, - secret: String, -} - -#[derive(Serialize)] -pub struct QrResponse { - qr_code: String, -} - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for TOTPSecret { - type Error = (); - - async fn from_request(request: &'r Request<'_>) -> request::Outcome { - let auth_header = request.headers().get_one("Authorization"); - println!( - "TOTPSecret guard - Auth header present: {}", - auth_header.is_some() - ); - - let user = try_outcome!(request.guard::().await); - println!( - "TOTPSecret guard - Claims ok, user: {}, scope: {:?}", - user.sub, user.scope - ); - - // only allow full tokens for TOTP setup - if user.scope != TokenScope::Full { - println!("TOTPSecret guard - rejected, scope is {:?}", user.scope); - return Outcome::Error((Status::Forbidden, ())); - } - - let user = try_outcome!(request.guard::().await); - let mut pool = match request.guard::>().await { - Outcome::Success(pool) => pool, - _ => return Outcome::Error((Status::Unauthorized, ())), - }; - - let row = sqlx::query!( - "SELECT twofa_enabled, totp_secret FROM users WHERE id = $1", - user.user_id as i32 - ) - .fetch_one(&mut **pool) - .await; - - let (enabled, mut secret) = match row { - Ok(r) => (r.twofa_enabled, r.totp_secret), - Err(_) => return Outcome::Error((Status::Unauthorized, ())), - }; - - if secret.is_none() { - let new_secret = Secret::generate_secret().to_encoded().to_string(); - sqlx::query!( - "UPDATE users SET totp_secret = $1 WHERE id = $2", - new_secret, - user.user_id as i32 - ) - .execute(&mut **pool) - .await - .ok(); - secret = Some(new_secret); - } - - Outcome::Success(TOTPSecret { - user_id: user.user_id, - secret: secret.unwrap(), - }) - } -} - -impl TOTPSecret { - pub async fn enable(&self, db: &mut Connection) -> Result<(), ()> { - match sqlx::query!( - "UPDATE users SET twofa_enabled = true WHERE id = $1", - self.user_id as i32, - ) - .execute(&mut ***db) - .await - { - Ok(_) => Ok(()), - Err(_) => Err(()), - } - } -} - -#[derive(Deserialize)] -pub struct TotpVerifyRequest { - pub code: String, -} - -#[get("/totp/status")] -pub async fn get_totp_status( - user: Session, - mut db: Connection, -) -> Result, Status> { - Ok(Json( - if sqlx::query!( - "SELECT twofa_enabled FROM users WHERE id = $1", - user.user_id as i32, - ) - .fetch_one(&mut **db) - .await - .map_err(|_| Status::NotFound)? - .twofa_enabled - { - TotpStatus::Enabled - } else { - TotpStatus::Disabled - }, - )) -} - -#[delete("/totp", data = "")] -pub async fn disable_totp( - user: Session, - mut db: Connection, - form: Json, -) -> Result, Status> { - let totp_code = form.totp_code.clone().ok_or(Status::BadRequest)?; - let mut user = User::get_by_id(user.user_id, &mut db) - .await - .ok_or(Status::NotFound)?; - - user.verify_password(&form.password)?; - user.verify_2fa(&totp_code)?; - user.set_twofa_enabled(false, &mut db) - .await - .map_err(|_| Status::InternalServerError)?; - - Ok(Json(AuthResponse { - token: Claims::new(user.id as usize, TokenScope::Full).encode(), - totp_required: false, - })) -} - -#[post("/totp/verify", data = "")] -pub async fn verify_totp( - claims: Claims, // request guard checks token validity - mut db: Connection, - body: Json, -) -> Result, Status> { - println!("reached 1"); - - // reject if they somehow got here with a full token - if claims.scope != TokenScope::TotpPending { - return Err(Status::Forbidden); - } - - println!("reached 2"); - - let row = sqlx::query!( - "SELECT totp_secret FROM users WHERE id = $1 AND twofa_enabled = TRUE", - claims.sub - ) - .fetch_one(&mut **db) - .await - .map_err(|_| Status::Unauthorized)?; - - println!("reached 3"); - - let totp = totp_gen( - claims.sub as usize, - row.totp_secret - .expect("user with 2fa enabled has no totp secret") - .as_bytes(), - ) - .map_err(|_| Status::InternalServerError)?; - - if !totp - .check_current(&body.code) - .map_err(|_| Status::InternalServerError)? - { - return Err(Status::Unauthorized); - } - - println!("reached 5"); - - let claims = Claims::new(claims.sub as usize, TokenScope::Full); - - Ok(Json(AuthResponse { - token: claims.encode(), - totp_required: false, - })) -} diff --git a/backend/src/cli.rs b/backend/src/cli.rs new file mode 100644 index 0000000..d27b26c --- /dev/null +++ b/backend/src/cli.rs @@ -0,0 +1,96 @@ +use clap::{Parser, Subcommand}; +use sqlx::postgres::PgPoolOptions; +use std::time::Duration; +use std::sync::Arc; +use crate::repo::user_repo::UserRepository; +use crate::repo::space_repo::SpaceRepository; +use crate::repo::channel_repo::ChannelRepository; +use crate::repo::{UserRepo, SpaceRepo, ChannelRepo}; +use argon2::{ + password_hash::{PasswordHasher, SaltString}, + Argon2, +}; +use rand::rngs::OsRng; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +pub struct Cli { + #[command(subcommand)] + pub command: Option, +} + +#[derive(Subcommand)] +pub enum Commands { + /// First-time setup for the server + Setup { + /// Admin username + #[arg(short, long)] + username: String, + + /// Admin password + #[arg(short, long)] + password: String, + + /// Default space name + #[arg(short, long, default_value = "Default Space")] + space: String, + + /// Default channel name + #[arg(short, long, default_value = "general")] + channel: String, + }, +} + +pub async fn handle_cli() -> bool { + let cli = Cli::parse(); + + match cli.command { + Some(Commands::Setup { username, password, space, channel }) => { + if let Err(e) = run_setup(username, password, space, channel).await { + eprintln!("Setup failed: {}", e); + std::process::exit(1); + } + println!("Setup completed successfully!"); + true + } + None => false, + } +} + +async fn run_setup(username: String, password: String, space_name: String, channel_name: String) -> Result<(), Box> { + dotenv::dotenv().ok(); + let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let pool = PgPoolOptions::new() + .max_connections(1) + .acquire_timeout(Duration::from_secs(5)) + .connect(&db_url) + .await?; + + let user_repo = UserRepository::new(pool.clone()); + let space_repo = SpaceRepository::new(pool.clone()); + let channel_repo = ChannelRepository::new(pool.clone()); + + // 1. Create admin user + println!("Creating admin user: {}...", username); + + let argon2 = Argon2::default(); + let salt = SaltString::generate(&mut OsRng); + let passhash = argon2 + .hash_password(password.as_bytes(), &salt) + .map_err(|e| e.to_string())? + .to_string(); + + let user_id = user_repo.new_user("admin@localhost", &username, &passhash).await?; + user_repo.set_role(user_id, "admin").await?; + + // 2. Create default space + println!("Creating default space: {}...", space_name); + let space_id = space_repo.create(&space_name, Some("Default space created during setup"), user_id).await?; + + // 3. Create default channel + println!("Creating default channel: {}...", channel_name); + channel_repo.create(&channel_name, Some("Default channel"), space_id).await?; + + Ok(()) +} diff --git a/backend/src/db.rs b/backend/src/db.rs deleted file mode 100644 index c863e21..0000000 --- a/backend/src/db.rs +++ /dev/null @@ -1,9 +0,0 @@ -use rocket_db_pools::{Database, deadpool_redis}; - -#[derive(Database)] -#[database("postgres_db")] -pub struct Postgres(sqlx::PgPool); - -#[derive(Database)] -#[database("redis_cache")] -pub struct Redis(deadpool_redis::Pool); diff --git a/backend/src/error.rs b/backend/src/error.rs new file mode 100644 index 0000000..0cb8f6a --- /dev/null +++ b/backend/src/error.rs @@ -0,0 +1,127 @@ +// error.rs +use rocket::{http::Status, response::{self, Responder}, Request, Response}; +use thiserror::Error; +use rocket_dyn_templates::Template; +use rocket::serde::Serialize; + +#[derive(Error, Debug)] +pub enum AppError { + #[error("Not found")] + NotFound, + + #[error("Unauthorized")] + Unauthorised(String), + + #[error("Forbidden")] + Forbidden, + + #[error("Bad request: {0}")] + BadRequest(String), + + #[error("Database error: {0}")] + Database(#[from] sqlx::Error), + + #[error("Internal error: {0}")] + Internal(String), +} + +impl AppError { + pub fn internal(msg: impl Into) -> Self { + Self::Internal(msg.into()) + } + + pub fn bad_request(msg: impl Into) -> Self { + Self::BadRequest(msg.into()) + } + pub fn unauthorised(msg: impl Into) -> Self { + Self::Unauthorised(msg.into()) + } +} + +impl<'r> Responder<'r, 'static> for AppError { + fn respond_to(self, _req: &'r Request<'_>) -> response::Result<'static> { + let status = match &self { + AppError::NotFound => Status::NotFound, + AppError::Unauthorised(_) => Status::Unauthorized, + AppError::Forbidden => Status::Forbidden, + AppError::BadRequest(_) => Status::BadRequest, + AppError::Database(_) => Status::InternalServerError, + AppError::Internal(_) => Status::InternalServerError, + }; + + // log internal errors + if status == Status::InternalServerError { + tracing::error!("Internal Server Error: {}", self); + } + + Response::build() + .status(status) + .header(rocket::http::ContentType::Plain) + .sized_body( + self.to_string().len(), + std::io::Cursor::new(self.to_string()) + ) + .ok() + } +} + +pub type ApiResult = Result; + +#[derive(Serialize)] +struct ErrorContext { + error_code: u16, + error_message: &'static str, + additional_info: &'static str, + redirect: Option, +} + +#[derive(Serialize)] +struct RedirectContext { + url: &'static str, + message: &'static str, +} + +#[catch(404)] +pub async fn handle_404() -> Template { + Template::render( + "error", + ErrorContext { + error_code: 404, + error_message: "Not Found", + additional_info: "There's nothing here.", + redirect: Some(RedirectContext { + url: "/", + message: "Home", + }), + }, + ) +} + +#[catch(401)] +pub async fn handle_401() -> Template { + Template::render( + "error", + ErrorContext { + error_code: 401, + error_message: "Unauthorised", + additional_info: "You are not authorised to access this resource.", + redirect: Some(RedirectContext { + url: "/login", + message: "Login", + }), + }, + ) +} + +#[catch(default)] +pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template { + Template::render( + "error", + ErrorContext { + error_code: status.code, + error_message: "Unknown Error", + additional_info: "I don't know what to do with this error.", + redirect: None, + }, + ) +} \ No newline at end of file diff --git a/backend/src/handlers.rs b/backend/src/handlers.rs deleted file mode 100644 index b8c6938..0000000 --- a/backend/src/handlers.rs +++ /dev/null @@ -1,62 +0,0 @@ -use rocket::{Request, http::Status}; -use rocket_dyn_templates::Template; -use serde::Serialize; - -#[derive(Serialize)] -struct ErrorContext { - error_code: u16, - error_message: &'static str, - additional_info: &'static str, - redirect: Option, -} - -#[derive(Serialize)] -struct RedirectContext { - url: &'static str, - message: &'static str, -} - -#[catch(404)] -pub async fn handle_404() -> Template { - Template::render( - "error", - ErrorContext { - error_code: 404, - error_message: "Not Found", - additional_info: "There's nothing here.", - redirect: Some(RedirectContext { - url: "/", - message: "Home", - }), - }, - ) -} - -#[catch(401)] -pub async fn handle_401() -> Template { - Template::render( - "error", - ErrorContext { - error_code: 401, - error_message: "Unauthorised", - additional_info: "You are not authorised to access this resource.", - redirect: Some(RedirectContext { - url: "/login", - message: "Login", - }), - }, - ) -} - -#[catch(default)] -pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template { - Template::render( - "error", - ErrorContext { - error_code: status.code, - error_message: "Unknown Error", - additional_info: "I don't know what to do with this error.", - redirect: None, - }, - ) -} diff --git a/backend/src/lib.rs b/backend/src/lib.rs new file mode 100644 index 0000000..bfa8fa8 --- /dev/null +++ b/backend/src/lib.rs @@ -0,0 +1,149 @@ +#![deny(clippy::unwrap_used)] +#![warn(clippy::all, clippy::nursery, clippy::cargo, clippy::pedantic)] + +#[macro_use] +extern crate rocket; + +pub mod messenger; +pub mod api; +pub mod repo; +pub mod error; +pub mod svc; +pub mod model; +pub mod cli; + +use crate::repo::{access_token_repo::AccessTokenRepo, Repo}; +use crate::repo::message_repo::MessageRepository; +use crate::repo::user_repo::UserRepository; +use crate::repo::space_repo::SpaceRepository; +use crate::repo::channel_repo::ChannelRepository; +use crate::svc::auth_svc::AuthService; +use crate::svc::chat_svc::ChatService; +use crate::svc::settings_svc::SettingsService; +use crate::svc::user_svc::UserService; +use rocket::fs::{FileServer, NamedFile}; +use rocket::http::Method; +use rocket_cors::{AllowedOrigins, CorsOptions}; +use rocket_dyn_templates::Template; +use sqlx::postgres::PgPoolOptions; +use std::env; +use std::sync::{Arc, LazyLock}; +use std::time::Duration; +use api::cdn; +use crate::svc::access_token_svc::AccessTokenService; +use crate::svc::llm_service::LlmService; + +pub fn rocket() -> rocket::Rocket { + if std::env::var("RELEASE_MODE").unwrap_or_default() != "1" { + dotenv::dotenv().ok(); + } + let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + + let pool = PgPoolOptions::new() + .max_connections(25) + .min_connections(5) + .acquire_timeout(Duration::from_secs(5)) + .connect_lazy(&db_url) + .expect("Failed to create database pool"); + + let user_repo = Arc::new(UserRepository::new(pool.clone())); + let message_repo = MessageRepository::new(pool.clone()); + let token_repo = Arc::new(AccessTokenRepo::new(pool.clone())); + let space_repo: Arc = Arc::new(SpaceRepository::new(pool.clone())); + let channel_repo: Arc = Arc::new(ChannelRepository::new(pool.clone())); + let llm_service = LlmService::new(); + let chat_service = ChatService::new(32, llm_service.clone(), message_repo.clone(), user_repo.clone(), channel_repo.clone(), space_repo.clone()); + + rocket_builder(user_repo, token_repo, space_repo, channel_repo, chat_service) +} + +pub fn rocket_builder( + user_repo: Arc, + token_repo: Arc, + space_repo: Arc, + channel_repo: Arc, + chat_service: ChatService +) -> rocket::Rocket { + + + let cors = CorsOptions::default() + .allowed_origins(AllowedOrigins::all()) + .allowed_methods( + vec![Method::Get, Method::Post, Method::Patch] + .into_iter() + .map(From::from) + .collect(), + ) + .allow_credentials(true); + + let access_token_svc = AccessTokenService::new(token_repo.clone()); + let auth_service = AuthService::new(user_repo.clone(), access_token_svc.clone()); + let settings_service = SettingsService::new(auth_service.clone(), user_repo.clone()); + let user_service = UserService::new(user_repo.clone()); + + rocket::build() + .manage(chat_service) + .manage(auth_service) + .manage(settings_service) + .manage(user_service) + .manage(space_repo) + .manage(channel_repo) + .attach(cors.to_cors().unwrap()) + .attach(Template::fairing()) + .mount("/static", FileServer::from("static")) + .mount("/cdn", cdn::routes()) + .mount( + "/", + routes![ + favicon, + + ], + ) + .mount( + "/api", + routes![ + cdn::upload_profile_pic, + api::profile::display_name, + + // basic auth + api::auth::login, + api::auth::signup, + + // 2fa + api::totp::confirm_totp, + api::totp::disable_totp, + api::totp::get_totp, + api::totp::get_totp_status, + api::totp::verify_totp, + + // chat + api::chat::event_stream, + api::chat::post_message, + + // user settings + api::settings::change_display_name, + api::settings::change_password, + api::settings::change_username, + api::settings::delete_account, + + // spaces + api::space::list_spaces, + api::space::list_channels, + api::space::get_accessible_channels + ], + ) + .register( + "/", + catchers![ + error::handle_401, + error::handle_404, + error::handle_default, + ], + ) +} + +#[get("/favicon.ico")] +pub async fn favicon() -> NamedFile { + NamedFile::open("static/favicon.ico").await.unwrap() +} diff --git a/backend/src/llm.rs b/backend/src/llm.rs deleted file mode 100644 index ea79cb4..0000000 --- a/backend/src/llm.rs +++ /dev/null @@ -1,69 +0,0 @@ -// src/llm.rs -use serde::{Deserialize, Serialize}; - -use crate::messenger::ChatMsg; - -#[derive(Serialize)] -struct LlmRequest { - model: String, - messages: Vec, -} - -#[derive(Serialize, Deserialize)] -struct Message { - role: String, // "user" or "assistant" - content: String, -} - -pub struct LlmWorker { - uri: String, -} - -impl LlmWorker { - pub fn new(uri: String) -> Self { - Self { uri } - } - - pub async fn query(&self, message: &ChatMsg) -> Result { - let client = reqwest::Client::new(); - - // Build the request body - let payload = LlmRequest { - model: "gpt-oss-20b".into(), // whatever model you run locally - messages: vec![Message { - role: "user".into(), - content: message.text.clone(), - }], - }; - - // POST to lm‑studio (default 127.0.0.1:1234) - let resp = client - .post(self.uri.clone()) - .json(&payload) - .send() - .await - .map_err(|_| String::from("Failed to make request to LLM server"))?; - - // The API returns a JSON with `choices[].message.content` - #[derive(Deserialize)] - struct LlmResponse { - choices: Vec, - } - #[derive(Deserialize)] - struct Choice { - message: Message, - } - - let llm_resp: LlmResponse = resp - .json() - .await - .map_err(|_| String::from("Failed to make request to LLM server"))?; - - Ok(ChatMsg { - display_name: Some(String::from("lmstudio")), - user_id: 0, - text: llm_resp.choices[0].message.content.clone(), - timestamp: chrono::Utc::now().timestamp_millis() as usize, - }) - } -} diff --git a/backend/src/main.rs b/backend/src/main.rs index 22de1c1..9f6980e 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,98 +1,11 @@ -// src/main.rs -#[macro_use] -extern crate rocket; +use backend::rocket; +use backend::cli::handle_cli; -use rocket::fs::{FileServer, NamedFile}; -use rocket::http::Method; -use rocket::{Build, Rocket}; -use rocket_cors::{AllowedOrigins, CorsOptions}; -use rocket_db_pools::Database; -use rocket_dyn_templates::Template; -use std::env; -use std::sync::{Arc, LazyLock}; - -use crate::db::{Postgres, Redis}; - -pub mod auth; -pub mod cdn; -pub mod db; -pub mod handlers; -pub mod llm; -pub mod messenger; -pub mod user; - -static LMSTUDIO_URL: LazyLock = - LazyLock::new(|| env::var("LMSTUDIO_URL").expect("Ensure LMSTUDIO_URL is set!")); - -#[launch] -fn rocket() -> Rocket { - // make sure the env is loaded - dotenv::dotenv().expect("Failed to load env! aborting launch!"); - - let chat = Arc::new(crate::messenger::ChatBroadcaster::new(32)); - - let cors = CorsOptions::default() - .allowed_origins(AllowedOrigins::all()) - .allowed_methods( - vec![Method::Get, Method::Post, Method::Patch] - .into_iter() - .map(From::from) - .collect(), - ) - .allow_credentials(true); - - rocket::build() - .manage(chat) - .attach(cors.to_cors().unwrap()) - .attach(Postgres::init()) - .attach(Redis::init()) - .attach(Template::fairing()) - .mount("/static", FileServer::from("static")) - .mount("/cdn", cdn::routes()) - .mount( - "/", - routes![ - favicon, - messenger::chat_page, - auth::signup_page, - auth::login_page, - auth::mfa_page, - auth::invite_page, - ], - ) - .mount( - "/api", - routes![ - cdn::upload_profile_pic, - messenger::post_message, - messenger::event_stream, - user::users, - user::display_name, - auth::signup, - auth::login, - auth::get_totp, - auth::confirm_totp, - auth::generate_invite, - auth::verify_totp, - auth::disable_totp, - auth::get_totp_status, - auth::change_password, - auth::change_display_name, - auth::change_username, - auth::delete_account, - ], - ) - .register( - "/", - catchers![ - handlers::handle_404, - handlers::handle_401, - handlers::handle_default - ], - ) -} - -#[get("/favicon.ico")] -async fn favicon() -> NamedFile { - NamedFile::open("static/favicon.ico").await.unwrap() +#[rocket::main] +async fn main() -> Result<(), rocket::Error> { + if handle_cli().await { + return Ok(()); + } + rocket().launch().await?; + Ok(()) } diff --git a/backend/src/messenger/cache.rs b/backend/src/messenger/cache.rs index ae0443f..810dfd3 100644 --- a/backend/src/messenger/cache.rs +++ b/backend/src/messenger/cache.rs @@ -1,10 +1,8 @@ use redis::AsyncCommands; use rocket_db_pools::Connection; -use crate::{ - db::{Postgres, Redis}, - messenger::ChatMsg, -}; +use crate::api::chat::ChatMsg; +use crate::db::{Postgres, Redis}; // Helper function to cache message in Redis pub async fn insert( diff --git a/backend/src/messenger/messages.rs b/backend/src/messenger/messages.rs deleted file mode 100644 index eafff4a..0000000 --- a/backend/src/messenger/messages.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::sync::Arc; - -use rocket::{ - Shutdown, - response::stream::{Event, EventStream}, - serde::json::Json, - time::OffsetDateTime, -}; -use rocket_db_pools::Connection; -use rocket_dyn_templates::{Template, context}; -use serde::{Deserialize, Serialize}; -use sqlx::prelude::FromRow; -use tokio::{select, sync::broadcast}; - -use crate::{ - auth::Session, - db::{Postgres, Redis}, - llm::LlmWorker, - messenger, -}; - -/// ---------- shared broadcaster ---------- -pub struct ChatBroadcaster { - buffer_size: usize, - senders: std::sync::Mutex>>, -} - -impl ChatBroadcaster { - pub fn new(buffer_size: usize) -> Self { - Self { - buffer_size, - senders: std::sync::Mutex::new(std::collections::HashMap::new()), - } - } - - /// Publish a message to the specified channel. - pub async fn publish(&self, channel_id: i32, msg: ChatMsg) { - let mut map = self.senders.lock().unwrap(); - let sender = map - .entry(channel_id) - .or_insert_with(|| broadcast::channel::(self.buffer_size).0); - let _ = sender.send(msg); - } - - /// Subscribe to the specified channel. - pub fn subscribe(&self, channel_id: i32) -> broadcast::Receiver { - let mut map = self.senders.lock().unwrap(); - let sender = map - .entry(channel_id) - .or_insert_with(|| broadcast::channel::(self.buffer_size).0); - sender.subscribe() - } -} - -/// ---------- Rocket routes ---------- -#[derive(Debug, Serialize, Deserialize, Clone, FromRow)] -pub struct ChatMsg { - pub display_name: Option, - pub user_id: usize, - pub text: String, - pub timestamp: usize, -} - -#[post("/chat/", format = "json", data = "")] -pub async fn post_message( - mut msg: Json, - chat: &rocket::State>, - mut postgres: Connection, - mut cache: Option>, - session: Session, - channel_id: i32, -) -> Result<(), String> { - let chat = chat.inner().clone(); - - let display_name = sqlx::query!( - "SELECT display_name, username FROM users WHERE id = $1", - session.user_id as i32 - ) - .fetch_one(&mut **postgres) - .await - .map(|row| row.display_name.unwrap_or(row.username)) - .unwrap_or_else(|_| "Unknown".to_string()); - - msg.user_id = session.user_id; - msg.display_name = Some(display_name); - chat.publish(channel_id, msg.clone().into_inner()).await; - - sqlx::query!( - "INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)", - channel_id, - msg.user_id as i32, - msg.text, - OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap() - ) - .execute(&mut **postgres) - .await - .map_err(|_| "Failed".to_string())?; - - println!("gisfujdeghnjuisdfjngiosdfgjkosdf gnojdfsg nmodfsg"); - - if let Some(ref mut cache) = cache { - messenger::cache::insert(cache, channel_id, &msg) - .await - .map_err(|_| "Redis cache failed".to_string())?; - } - - // get response - tokio::spawn(async move { - let response = LlmWorker::new(crate::LMSTUDIO_URL.to_string()) - .query(&msg) - .await; - - if let Ok(reply) = response { - chat.publish(channel_id, reply.clone()).await; - - if let Some(ref mut cache) = cache { - messenger::cache::insert(cache, channel_id, &reply) - .await - .ok(); - } - - sqlx::query!( - "INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)", - channel_id, - reply.user_id as i32, - reply.text, - OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap() - ) - .execute(&mut **postgres) - .await - .map_err(|_| "Failed".to_string()) - .unwrap(); - } - }); - - Ok(()) -} - -pub async fn get_messages( - mut db: Connection, - mut redis: Connection, - channel_id: i32, -) -> Json> { - if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await - && !messages.is_empty() - { - return Json(messages); - }; - - if let Err(x) = messenger::cache::initialise(&mut redis, &mut db, channel_id).await { - eprintln!("WARN: {x:?}"); - } - - if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await - && !messages.is_empty() - { - return Json(messages); - }; - - let res = sqlx::query!( - "SELECT u.username, u.display_name, u.id, m.content, m.created_at - FROM messages m - JOIN users u ON m.user_id = u.id - WHERE m.channel_id = $1 - ORDER BY m.created_at DESC LIMIT 100", - channel_id - ) - .fetch_all(&mut **db) - .await - .unwrap_or_else(|_| Vec::new()) - .into_iter() - .rev() - .map(|msg| ChatMsg { - display_name: Some(msg.display_name.unwrap_or(msg.username)), - user_id: msg.id as usize, - text: msg.content, - timestamp: (msg.created_at.unwrap().unix_timestamp_nanos() / 1_000_000) as usize, - }) - .collect(); - Json(res) -} - -#[get("/events/")] -pub async fn event_stream( - chat: &rocket::State>, - postgres: Connection, - cache: Connection, - _session: Session, - mut shutdown: Shutdown, - channel_id: i32, -) -> EventStream![] { - let mut rx = chat.subscribe(channel_id); - - EventStream! { - // Initialize the stream with the last 100 messages - for msg in get_messages(postgres, cache, channel_id).await.0 { - yield Event::json(&msg); - } - - loop { - select!{ - // exit early on shutdown - _ = &mut shutdown => break, - - msg = rx.recv() => match msg { - Ok(msg) => yield Event::json(&msg), - Err(broadcast::error::RecvError::Lagged(_)) => { - yield Event::comment("RecvError::Lagged"); - } - Err(broadcast::error::RecvError::Closed) => break, - }, - } - } - } -} - -#[get("/chat")] -pub async fn chat_page(session: Session) -> Template { - Template::render("chat", context!(user_id: session.user_id)) -} diff --git a/backend/src/messenger/mod.rs b/backend/src/messenger/mod.rs index 556635c..81bcd8f 100644 --- a/backend/src/messenger/mod.rs +++ b/backend/src/messenger/mod.rs @@ -1,4 +1 @@ -mod cache; -mod messages; - -pub use messages::{ChatBroadcaster, ChatMsg, chat_page, event_stream, get_messages, post_message}; +// mod cache; diff --git a/backend/src/model/auth.rs b/backend/src/model/auth.rs new file mode 100644 index 0000000..5f0a088 --- /dev/null +++ b/backend/src/model/auth.rs @@ -0,0 +1,35 @@ +use rocket::serde::{Deserialize, Serialize}; +use chrono::{DateTime, Utc}; + +#[derive(Serialize, Deserialize)] +pub struct SignupCredentials { + pub email: String, + pub username: String, + pub password: String, + pub access_token: String, +} + +#[derive(Serialize, Deserialize)] +pub struct LoginCredentials { + pub username: String, + pub password: String, +} + +#[derive(Serialize, Deserialize)] +pub struct AuthResponse { + pub token: String, + pub totp_required: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessTokenForm { + pub name: String, + pub max_uses: i32, + pub expiry_date: DateTime, + pub start_date: DateTime, +} + +pub struct AccessToken { + pub id: i64, + pub code: String, +} \ No newline at end of file diff --git a/backend/src/model/mod.rs b/backend/src/model/mod.rs new file mode 100644 index 0000000..23732a7 --- /dev/null +++ b/backend/src/model/mod.rs @@ -0,0 +1,3 @@ +pub mod auth; +pub mod user; +pub mod space; \ No newline at end of file diff --git a/backend/src/model/space.rs b/backend/src/model/space.rs new file mode 100644 index 0000000..701cd63 --- /dev/null +++ b/backend/src/model/space.rs @@ -0,0 +1,35 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use chrono::{DateTime, Utc}; + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct Space { + pub id: i64, + pub name: String, + pub description: Option, + pub owner_id: i64, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct Channel { + pub id: i64, + pub name: String, + pub description: Option, + pub space_id: i64, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SpaceDto { + pub channels: Vec, + pub id: i64, + pub owner_id: i64, + pub name: String, + pub description: Option, + + pub created_at: DateTime, + pub updated_at: DateTime, +} \ No newline at end of file diff --git a/backend/src/model/user.rs b/backend/src/model/user.rs new file mode 100644 index 0000000..63c0220 --- /dev/null +++ b/backend/src/model/user.rs @@ -0,0 +1,52 @@ +use crate::api::auth::Session; +use crate::error::ApiResult; +use crate::svc::user_svc::UserService; +use chrono::{DateTime, Utc}; +use rocket::State; +use sqlx::FromRow; +use crate::api::totp::TotpStatus; + +#[derive(Clone)] +#[derive(FromRow)] +pub struct User { + pub id: i64, + pub email: Option, + pub username: String, + pub nickname: Option, + pub passhash: String, + pub totp_status: TotpStatus, + pub totp_secret: Option, + pub created_at: Option>, + pub updated_at: Option>, +} + +// pub struct UserCache {} +// +// impl UserCache { +// pub async fn username( +// id: usize, +// redis_conn: &mut Connection, +// pgsql_conn: &mut Connection, +// ) -> String { +// if let Ok(val) = redis_conn.get(format!("users:{id}")).await { +// return val; +// } +// +// if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32) +// .fetch_one(&mut ***pgsql_conn) +// .await +// { +// let username = v.username; +// Self::insert(id, &username, redis_conn).await; +// username +// } else { +// unimplemented!() +// } +// } +// +// pub async fn insert(id: usize, username: &str, conn: &mut Connection) { +// conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800) +// .await +// .expect("failed to insert key"); +// } +// } diff --git a/backend/src/repo/access_token_repo.rs b/backend/src/repo/access_token_repo.rs new file mode 100644 index 0000000..d40333c --- /dev/null +++ b/backend/src/repo/access_token_repo.rs @@ -0,0 +1,67 @@ +use crate::repo::{Repo, AccessTokenRepoTrait}; +use chrono::{DateTime, Utc}; +use sqlx::PgPool; +use crate::model::auth::AccessToken; + +#[derive(Clone)] +pub struct AccessTokenRepo { + pool: PgPool +} + +impl Repo for AccessTokenRepo { + type Target = AccessToken; + + fn new(pool: PgPool) -> Self { + Self { pool } + } + + async fn get_by_id(&self, id: i64) -> Option { + sqlx::query_as!(AccessToken, "SELECT id, code FROM access_tokens WHERE id = $1", id) + .fetch_optional(&self.pool) + .await + .unwrap_or(None) + } +} + +#[async_trait] +impl AccessTokenRepoTrait for AccessTokenRepo { + async fn get_by_id(&self, id: i64) -> Option { + Repo::get_by_id(self, id).await + } + + async fn create_new(&self, + uid: i64, name: &str, code: &str, max_uses: i32, + start_date: DateTime, expiry_date: DateTime + ) -> Result { + sqlx::query!( + "INSERT INTO access_tokens (name, code, creator_id, max_uses, created_at, expires_at) + VALUES ($1, $2, $3, $4, $5, $6) RETURNING id", + name, + code, + uid, + max_uses, + start_date, + expiry_date + ).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound)) + } + + async fn use_token(&self, id: i64) -> Result<(), sqlx::Error> { + sqlx::query!("UPDATE access_tokens SET uses = uses + 1 WHERE id = $1", id) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn get_code_not_expired(&self, code: &str) -> Result, sqlx::Error> { + sqlx::query_as!(AccessToken, + "SELECT id, code FROM access_tokens + WHERE code = $1 + AND created_at < NOW() + AND expires_at > NOW() + AND uses < max_uses", + code + ) + .fetch_optional(&self.pool) + .await + } +} \ No newline at end of file diff --git a/backend/src/repo/channel_repo.rs b/backend/src/repo/channel_repo.rs new file mode 100644 index 0000000..6970dec --- /dev/null +++ b/backend/src/repo/channel_repo.rs @@ -0,0 +1,39 @@ +use crate::repo::ChannelRepo; +use crate::model::space::Channel; +use sqlx::PgPool; + +#[derive(Clone)] +pub struct ChannelRepository { + pool: PgPool, +} + +impl ChannelRepository { + pub fn new(pool: PgPool) -> Self { + Self { pool } + } +} + +#[rocket::async_trait] +impl ChannelRepo for ChannelRepository { + async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result { + let row = sqlx::query!( + "INSERT INTO channels (name, description, space_id) VALUES ($1, $2, $3) RETURNING id", + name, + description, + space_id + ) + .fetch_one(&self.pool) + .await?; + Ok(row.id) + } + + async fn get_by_space_id(&self, space_id: i64) -> Result, sqlx::Error> { + sqlx::query_as!( + Channel, + "SELECT id, name, description, space_id, created_at as \"created_at!\", updated_at as \"updated_at!\" FROM channels WHERE space_id = $1", + space_id + ) + .fetch_all(&self.pool) + .await + } +} diff --git a/backend/src/repo/message_repo.rs b/backend/src/repo/message_repo.rs new file mode 100644 index 0000000..ea9b4a8 --- /dev/null +++ b/backend/src/repo/message_repo.rs @@ -0,0 +1,95 @@ +use crate::api::chat::ChatMsg; +use crate::repo::Repo; +use chrono::{DateTime, Utc}; +use sqlx::PgPool; + +#[derive(Clone)] +pub struct MessageRepository { + pool: PgPool +} + +impl Repo for MessageRepository { + type Target = ChatMsg; + + fn new(pool: PgPool) -> Self { + Self { pool } + } + + // TODO: caching with redis + async fn get_by_id(&self, id: i64) -> Option { + sqlx::query!( + "SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at + FROM messages m + JOIN users u ON m.user_id = u.id + WHERE m.id = $1", + id + ).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg { + display_name: Some(row.nickname.unwrap_or(row.username)), + user_id: row.user_id, + text: row.content, + timestamp: row.created_at, + }) + } +} + +impl MessageRepository { + + // TODO! caching with redis + pub async fn create_new( + &self, uid: i64, channel_id: i64, + text: &str, created_at: DateTime + ) -> Result { + sqlx::query!( + "INSERT INTO messages (channel_id, user_id, content, created_at) + VALUES ($1, $2, $3, $4) RETURNING id", + channel_id, + uid, + text, + created_at + ).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound)) + } + + /// TODO: caching with redis + pub async fn get_by_channel(&self, channel_id: i64, limit: usize) + -> Result, sqlx::Error> { + sqlx::query!( + "SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at + FROM messages m + JOIN users u ON m.user_id = u.id + WHERE m.channel_id = $1 + ORDER BY m.created_at DESC LIMIT $2", + channel_id, + limit as i64 + ).fetch_all(&self.pool).await.map(|messages| { + messages.into_iter().rev().map(|msg| { + ChatMsg { + display_name: Some(msg.nickname.unwrap_or(msg.username)), + user_id: msg.user_id, + text: msg.content, + timestamp: msg.created_at, + } + }).collect::>() + }) + } +} + + + + + + + + + + + + + + + + + + + + + diff --git a/backend/src/repo/mock.rs b/backend/src/repo/mock.rs new file mode 100644 index 0000000..34daf6a --- /dev/null +++ b/backend/src/repo/mock.rs @@ -0,0 +1,153 @@ +use crate::repo::{UserRepo, AccessTokenRepoTrait}; +use crate::model::user::User; +use crate::model::auth::AccessToken; +use rocket::async_trait; +use std::sync::Mutex; +use chrono::Utc; +use std::sync::Arc; +use sqlx::Error; +use crate::api::totp::TotpStatus; +use crate::api::totp::TotpStatus::Disabled; + +pub struct MockAccessTokenRepo { + pub tokens: Mutex>, +} + +#[async_trait] +impl AccessTokenRepoTrait for MockAccessTokenRepo { + async fn get_by_id(&self, id: i64) -> Option { + self.tokens.lock().unwrap().iter().find(|t| t.id == id).map(|t| AccessToken { id: t.id, code: t.code.clone() }) + } + async fn create_new(&self, _uid: i64, _name: &str, code: &str, _max_uses: i32, _start_date: chrono::DateTime, _expiry_date: chrono::DateTime) -> Result { + let mut tokens = self.tokens.lock().unwrap(); + let id = tokens.len() as i64 + 1; + tokens.push(AccessToken { id, code: code.to_string() }); + Ok(id) + } + + async fn use_token(&self, id: i64) -> Result<(), Error> { + // let mut tokens = self.tokens.lock().unwrap(); + // if let Some(pos) = tokens.iter().position(|t| t.id == id) { + // tokens.get_mut(pos).uses = + // } + Ok(()) + } + + async fn get_code_not_expired(&self, code: &str) -> Result, Error> { + Ok(self.tokens.lock().unwrap() + .iter().find(|t| t.code == code) + .map(|t| AccessToken { id: t.id, code: t.code.clone() })) + } +} + +pub struct MockUserRepo { + pub users: Mutex>, +} + +#[async_trait] +impl UserRepo for MockUserRepo { + fn pool(&self) -> &sqlx::PgPool { + unimplemented!("MockUserRepo does not have a real pool") + } + async fn get_by_id(&self, id: i64) -> Option { + self.users.lock().unwrap().iter().find(|u| u.id == id).cloned() + } + async fn save(&self, user: &User) -> Result<(), sqlx::Error> { + let mut users = self.users.lock().unwrap(); + if let Some(pos) = users.iter().position(|u| u.id == user.id) { + users[pos] = user.clone(); + } + Ok(()) + } + async fn new_user(&self, email: &str, username: &str, pass_hash: &str) -> Result { + let mut users = self.users.lock().unwrap(); + let id = users.len() as i64 + 1; + users.push(User { + id, + email: Some(email.to_string()), + username: username.to_string(), + nickname: None, + passhash: pass_hash.to_string(), + totp_status: Disabled, + totp_secret: None, + created_at: Some(Utc::now()), + updated_at: Some(Utc::now()), + }); + Ok(id) + } + async fn get_by_username(&self, username: &str) -> Result, sqlx::Error> { + Ok(self.users.lock().unwrap().iter().find(|u| u.username == username).cloned()) + } + async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error> { + self.users.lock().unwrap().retain(|u| u.id != id); + Ok(()) + } + async fn set_display_name(&self, id: i64, display_name: Option) -> Result<(), sqlx::Error> { + if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) { + u.nickname = display_name; + } + Ok(()) + } + async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error> { + if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) { + u.username = username.to_string(); + } + Ok(()) + } + async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error> { + if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) { + u.totp_status = *enabled; + } + Ok(()) + } + async fn get_totp_secret(&self, id: i64) -> Result, sqlx::Error> { + Ok(self.users.lock().unwrap().iter().find(|u| u.id == id).and_then(|u| u.totp_secret.clone())) + } + async fn set_totp_secret(&self, id: i64, secret: Option) -> Result<(), sqlx::Error> { + if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) { + u.totp_secret = secret; + } + Ok(()) + } + async fn get_pass_hash(&self, id: i64) -> Result { + Ok(self.users.lock().unwrap().iter().find(|u| u.id == id).map(|u| u.passhash.clone()).unwrap()) + } + async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error> { + if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) { + u.passhash = pass_hash.to_string(); + } + Ok(()) + } + async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error> { + if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) { + u.email = Some(email.to_string()); + } + Ok(()) + } + async fn set_role(&self, _id: i64, _role: &str) -> Result<(), sqlx::Error> { + Ok(()) + } +} + +pub struct MockTokenRepo { + pub tokens: Mutex>, +} + +#[async_trait] +impl AccessTokenRepoTrait for MockTokenRepo { + async fn get_by_id(&self, id: i64) -> Option { + self.tokens.lock().unwrap().iter().find(|t| t.id == id).map(|t| AccessToken { id: t.id, code: t.code.clone() }) + } + async fn create_new(&self, _uid: i64, _name: &str, code: &str, _max_uses: i32, _start_date: chrono::DateTime, _expiry_date: chrono::DateTime) -> Result { + let mut tokens = self.tokens.lock().unwrap(); + let id = tokens.len() as i64 + 1; + tokens.push(AccessToken { id, code: code.to_string() }); + Ok(id) + } + async fn use_token(&self, _id: i64) -> Result<(), sqlx::Error> { + Ok(()) + } + async fn get_code_not_expired(&self, code: &str) -> Result, sqlx::Error> { + Ok(self.tokens.lock().unwrap().iter().find(|t| t.code == code).map(|t| AccessToken { id: t.id, code: t.code.clone() })) + } +} diff --git a/backend/src/repo/mod.rs b/backend/src/repo/mod.rs new file mode 100644 index 0000000..2e48f94 --- /dev/null +++ b/backend/src/repo/mod.rs @@ -0,0 +1,64 @@ +use crate::model::auth::AccessToken; +use crate::model::user::User; +use chrono::{DateTime, Utc}; +use crate::api::totp::TotpStatus; +use crate::model::space::Space; + +pub mod user_repo; +pub mod message_repo; +pub mod access_token_repo; +pub mod space_repo; +pub mod channel_repo; +pub mod mock; + +pub trait Repo: Clone + Send + Sync { + type Target; + + fn new(pool: sqlx::PgPool) -> Self; + + async fn get_by_id(&self, id: i64) -> Option; +} + +#[rocket::async_trait] +pub trait UserRepo: Send + Sync { + fn pool(&self) -> &sqlx::PgPool; + async fn get_by_id(&self, id: i64) -> Option; + async fn save(&self, user: &User) -> Result<(), sqlx::Error>; + async fn new_user(&self, email: &str, username: &str, pass_hash: &str) -> Result; + async fn get_by_username(&self, username: &str) -> Result, sqlx::Error>; + async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error>; + async fn set_display_name(&self, id: i64, display_name: Option) -> Result<(), sqlx::Error>; + async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error>; + async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error>; + async fn get_totp_secret(&self, id: i64) -> Result, sqlx::Error>; + async fn set_totp_secret(&self, id: i64, secret: Option) -> Result<(), sqlx::Error>; + async fn get_pass_hash(&self, id: i64) -> Result; + async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error>; + async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error>; + async fn set_role(&self, id: i64, role: &str) -> Result<(), sqlx::Error>; +} + +#[rocket::async_trait] +pub trait SpaceRepo: Send + Sync { + async fn create(&self, name: &str, description: Option<&str>, owner_id: i64) -> Result; + async fn get_all(&self) -> Result, sqlx::Error>; + async fn get_by_member(&self, uid: i64) -> Result, sqlx::Error>; + async fn get_by_id(&self, id: i64) -> Result, sqlx::Error>; +} + +#[rocket::async_trait] +pub trait ChannelRepo: Send + Sync { + async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result; + async fn get_by_space_id(&self, space_id: i64) -> Result, sqlx::Error>; +} + +#[async_trait] +pub trait AccessTokenRepoTrait: Send + Sync { + async fn get_by_id(&self, id: i64) -> Option; + async fn create_new(&self, + uid: i64, name: &str, code: &str, max_uses: i32, + start_date: DateTime, expiry_date: DateTime + ) -> Result; + async fn use_token(&self, id: i64) -> Result<(), sqlx::Error>; + async fn get_code_not_expired(&self, code: &str) -> Result, sqlx::Error>; +} \ No newline at end of file diff --git a/backend/src/repo/space_repo.rs b/backend/src/repo/space_repo.rs new file mode 100644 index 0000000..50cf4a6 --- /dev/null +++ b/backend/src/repo/space_repo.rs @@ -0,0 +1,56 @@ +use crate::repo::SpaceRepo; +use crate::model::space::Space; +use sqlx::PgPool; + +#[derive(Clone)] +pub struct SpaceRepository { + pool: PgPool, +} + +impl SpaceRepository { + pub fn new(pool: PgPool) -> Self { + Self { pool } + } +} + +#[rocket::async_trait] +impl SpaceRepo for SpaceRepository { + async fn create(&self, name: &str, description: Option<&str>, owner_id: i64) -> Result { + let row = sqlx::query!( + "INSERT INTO spaces (name, description, owner_id) VALUES ($1, $2, $3) RETURNING id", + name, + description, + owner_id + ) + .fetch_one(&self.pool) + .await?; + Ok(row.id) + } + + async fn get_all(&self) -> Result, sqlx::Error> { + sqlx::query_as!(Space, + "SELECT id, name, description, owner_id, created_at, updated_at FROM spaces" + ) + .fetch_all(&self.pool) + .await + } + + async fn get_by_member(&self, uid: i64) -> Result, sqlx::Error> { + sqlx::query_as!(Space, + "SELECT s.id, s.name, s.description, s.created_at, s.updated_at, s.owner_id + FROM spaces s JOIN space_members sm ON s.id = sm.space_id + WHERE sm.user_id = $1", + uid + ).fetch_all(&self.pool) + .await + } + + async fn get_by_id(&self, id: i64) -> Result, sqlx::Error> { + sqlx::query_as!(Space, + "SELECT id, name, description, owner_id, created_at, updated_at FROM spaces WHERE id = $1", + id + ) + .fetch_optional(&self.pool) + .await + } +} diff --git a/backend/src/repo/user_repo.rs b/backend/src/repo/user_repo.rs new file mode 100644 index 0000000..3794b40 --- /dev/null +++ b/backend/src/repo/user_repo.rs @@ -0,0 +1,212 @@ +use crate::repo::{Repo, UserRepo}; +use crate::model::user::User; +use sqlx::PgPool; +use crate::api::totp::TotpStatus; + +#[derive(Clone)] +pub struct UserRepository { + pool: PgPool +} + +impl UserRepository { + pub fn new(pool: PgPool) -> Self { + Self { pool } + } + + pub fn pool(&self) -> &PgPool { + &self.pool + } +} + +impl Repo for UserRepository { + type Target = User; + fn new(pool: PgPool) -> Self { + Self::new(pool) + } + + async fn get_by_id(&self, id: i64) -> Option { + sqlx::query_as!( + User, + "SELECT id, email, username, nickname, passhash, totp_status as \"totp_status!: TotpStatus\", totp_secret, created_at, updated_at FROM users WHERE id = $1", + id + ) + .fetch_optional(&self.pool) + .await + .map_err(|e| { + tracing::error!("Database error in get_by_id: {}", e); + e + }) + .ok()? + } +} + +#[async_trait] +impl UserRepo for UserRepository { + fn pool(&self) -> &sqlx::PgPool { + &self.pool + } + + async fn get_by_id(&self, id: i64) -> Option { + Repo::get_by_id(self, id).await + } + + async fn save(&self, user: &User) -> Result<(), sqlx::Error> { + sqlx::query!( + "UPDATE users SET email = $1, username = $2, nickname = $3, passhash = $4, totp_status = $5, totp_secret = $6, created_at = $7, updated_at = $8 WHERE id = $9", + user.email, + user.username, + user.nickname, + user.passhash, + user.totp_status as TotpStatus, + user.totp_secret, + user.created_at, + user.updated_at, + user.id + ).execute(&self.pool).await?; + Ok(()) + } + + async fn new_user(&self, email: &str, username: &str, passhash: &str) -> Result { + sqlx::query!( + "INSERT INTO users (email, username, passhash) VALUES ($1, $2, $3) RETURNING id", + email, + username, + passhash + ) + .fetch_optional(&self.pool) + .await + .and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound)) + } + + async fn get_by_username(&self, username: &str) -> Result, sqlx::Error> { + sqlx::query_as!( + User, + "SELECT id, email, username, nickname, passhash, totp_status as \"totp_status!: TotpStatus\", totp_secret, created_at, updated_at FROM users WHERE username = $1", + username + ) + .fetch_optional(&self.pool) + .await + } + + async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error> { + sqlx::query!("DELETE FROM users WHERE id = $1", id) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn set_display_name(&self, id: i64, display_name: Option) -> Result<(), sqlx::Error> { + sqlx::query!( + "UPDATE users SET nickname = $1 WHERE id = $2", + display_name, + id + ) + .execute(&self.pool).await?; + Ok(()) + } + + async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error> { + sqlx::query!( + "UPDATE users SET username = $1 WHERE id = $2", + username, + id + ) + .execute(&self.pool).await?; + Ok(()) + } + + async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error> { + sqlx::query!( + "UPDATE users SET totp_status = $1 WHERE id = $2", + enabled as &TotpStatus, + id + ) + .execute(&self.pool).await?; + Ok(()) + } + + async fn get_totp_secret(&self, id: i64) -> Result, sqlx::Error> { + sqlx::query!( + "SELECT totp_secret FROM users WHERE id = $1", + id + ) + .fetch_optional(&self.pool) + .await + .map(|opt| opt.and_then(|row| row.totp_secret)) + } + + async fn set_totp_secret(&self, id: i64, secret: Option) -> Result<(), sqlx::Error> { + sqlx::query!( + "UPDATE users SET totp_secret = $1 WHERE id = $2", + secret, + id + ) + .execute(&self.pool).await?; + Ok(()) + } + + async fn get_pass_hash(&self, id: i64) -> Result { + sqlx::query!( + "SELECT passhash FROM users WHERE id = $1", + id + ) + .fetch_optional(&self.pool) + .await + .and_then(|row| row.map(|r| r.passhash).ok_or(sqlx::Error::RowNotFound)) + } + + async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error> { + sqlx::query!( + "UPDATE users SET passhash = $1 WHERE id = $2", + pass_hash, + id + ) + .execute(&self.pool).await?; + Ok(()) + } + + async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error> { + sqlx::query!( + "UPDATE users SET email = $1 WHERE id = $2", + email, + id + ) + .execute(&self.pool).await?; + Ok(()) + } + + async fn set_role(&self, id: i64, role: &str) -> Result<(), sqlx::Error> { + sqlx::query( + "UPDATE users SET role = $1::user_role WHERE id = $2" + ) + .bind(role) + .bind(id) + .execute(&self.pool).await?; + Ok(()) + } +} + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/backend/src/svc/access_token_svc.rs b/backend/src/svc/access_token_svc.rs new file mode 100644 index 0000000..47a5ab8 --- /dev/null +++ b/backend/src/svc/access_token_svc.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; +use chrono::{DateTime, Utc}; +use uuid::Uuid; +use crate::error::{ApiResult, AppError}; +use crate::model::auth::AccessToken; +use crate::repo::access_token_repo::AccessTokenRepo; +use crate::repo::AccessTokenRepoTrait; + + +#[derive(Clone)] +pub struct AccessTokenService { + repo: Arc +} + +impl AccessTokenService { + pub fn new(repo: Arc) -> Self { + Self { repo } + } + + pub async fn create(&self, + uid: i64, name: &str, max_uses: i32, + valid_from: DateTime, valid_until: DateTime + ) -> ApiResult { + if valid_from > valid_until { + return Err(AppError::bad_request("start date must be before end date")) + } + + if valid_until < Utc::now() { + return Err(AppError::bad_request("expiry date must be after current date")) + } + + let code = Uuid::new_v4().to_string(); + self.repo.create_new(uid, name, &code, max_uses, valid_from, valid_until).await?; + + Ok(code) + } + + pub async fn get_valid_token(&self, token: &str) -> ApiResult { + self.repo.get_code_not_expired(token).await? + .ok_or(AppError::unauthorised("invalid access token")) + } + + pub async fn use_token(&self, id: i64) -> ApiResult<()> { + self.repo.use_token(id).await?; + Ok(()) + } +} \ No newline at end of file diff --git a/backend/src/svc/auth_svc.rs b/backend/src/svc/auth_svc.rs new file mode 100644 index 0000000..07a2636 --- /dev/null +++ b/backend/src/svc/auth_svc.rs @@ -0,0 +1,259 @@ +use crate::api::auth::{Claims, TokenScope}; +use crate::api::totp::totp_gen; +use crate::error::{ApiResult, AppError}; +use crate::model::auth::AuthResponse; +use crate::repo::{UserRepo, AccessTokenRepoTrait}; +use std::sync::Arc; +use argon2::password_hash::rand_core::OsRng; +use argon2::password_hash::SaltString; +use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; +use chrono::{DateTime, Utc}; +use uuid::Uuid; +use crate::api::totp::TotpStatus::{Disabled, Enabled}; +use crate::svc::access_token_svc::AccessTokenService; + +#[derive(Clone)] +pub struct AuthService { + users: Arc, + tokens: AccessTokenService, +} + +impl AuthService { + pub fn new(users: Arc, tokens: AccessTokenService) -> Self { + Self { users, tokens } + } + + pub async fn signup(&self, + email: &str, username: &str, + password: &str, access_token: &str + ) -> ApiResult { + let tok_id = self.tokens.get_valid_token(access_token).await?.id; + + let pass = password.to_string(); + let svc = self.clone(); + let hashed = tokio::task::spawn_blocking(move || svc.hash_password(&pass)) + .await + .map_err(|_| AppError::internal("blocking task panicked"))??; + + let uid = self.users + .new_user(email, username, &hashed).await?; + + self.tokens.use_token(tok_id).await?; + + let jwt = Claims::new(uid as usize, TokenScope::Full).encode(); + Ok(AuthResponse { + token: jwt, + totp_required: false + }) + } + + pub async fn login(&self, username: &str, password: &str) -> ApiResult { + let user = self.users + .get_by_username(username).await? + .ok_or(AppError::unauthorised("invalid username"))?; + + let pass = password.to_string(); + let user_hash = user.passhash.clone(); + let svc = self.clone(); + tokio::task::spawn_blocking(move || svc.verify_password(&user_hash, &pass)) + .await + .map_err(|_| AppError::internal("blocking task panicked"))??; + + let scope = if user.totp_status == Enabled { TokenScope::TotpPending } else { TokenScope::Full }; + let jwt = Claims::new(user.id as usize, scope).encode(); + + Ok(AuthResponse { + token: jwt, + totp_required: user.totp_status == Enabled + }) + } + + pub async fn login_totp(&self, uid: i64, code: &str) -> ApiResult { + let secret = self.users.get_totp_secret(uid).await? + .ok_or(AppError::unauthorised("2fa not enabled"))?; + + self.verify_2fa(uid, &secret, code)?; + + let jwt = Claims::new(uid as usize, TokenScope::Full).encode(); + + Ok(AuthResponse { + token: jwt, + totp_required: false + }) + } + + pub async fn disable_totp(&self, uid: i64, password: &str, totp_code: &str) -> ApiResult { + let mut user = self.users.get_by_id(uid).await + .ok_or(AppError::internal("user not found"))?; + + let Some(secret) = user.totp_secret else { + return Err(AppError::bad_request("2fa not enabled")); + }; + + self.verify_password(&user.passhash, password)?; + self.verify_2fa(uid, &secret, totp_code)?; + + user.totp_secret = None; + user.totp_status = Disabled; + self.users.save(&user).await?; + + Ok(AuthResponse { + token: Claims::new(uid as usize, TokenScope::Full).encode(), + totp_required: false + }) + } + + pub async fn get_totp_status(&self, uid: i64) -> ApiResult { + Ok( + self.users.get_totp_secret(uid).await?.is_some() + ) + } + + pub async fn confirm_totp(&self, uid: i64, totp_code: &str) -> ApiResult<()> { + let secret = self.users.get_totp_secret(uid).await? + .ok_or(AppError::bad_request("2fa setup not initialised"))?; + + self.verify_2fa(uid, &secret, totp_code)?; + + self.users.set_twofa_enabled(uid, &Enabled).await?; + + Ok(()) + } + + pub async fn get_or_create_totp_secret( + &self, uid: i64, password: &str, + ) -> ApiResult { + let user = self.users.get_by_id(uid).await + .ok_or(AppError::internal("user not found"))?; + + let pass = password.to_string(); + let user_hash = user.passhash.clone(); + let svc = self.clone(); + tokio::task::spawn_blocking(move || svc.verify_password(&user_hash, &pass)) + .await + .map_err(|_| AppError::internal("blocking task panicked"))??; + + if let Some(secret) = user.totp_secret { + return Ok(secret); + } + + let new_secret = totp_rs::Secret::generate_secret() + .to_encoded() + .to_string(); + + self.users.set_totp_secret(uid, Some(new_secret.clone())).await?; + + Ok(new_secret) + } + + pub async fn verify_user_password(&self, uid: i64, password: &str) -> ApiResult<()> { + let hash = self.users.get_pass_hash(uid).await + .map_err(|_| AppError::internal("user not found"))?; + + let pass = password.to_string(); + let svc = self.clone(); + tokio::task::spawn_blocking(move || svc.verify_password(&hash, &pass)) + .await + .map_err(|_| AppError::internal("blocking task panicked"))??; + + Ok(()) + } + + pub async fn verify_user_totp(&self, uid: i64, totp_code: &str) -> ApiResult<()> { + let secret = self.users.get_totp_secret(uid).await? + .ok_or(AppError::internal("user not found"))?; + + self.verify_2fa(uid, &secret, totp_code) + } + + + pub fn hash_password(&self, password: &str) -> ApiResult { + let salt = SaltString::generate(&mut OsRng); + Argon2::default() + .hash_password(password.as_bytes(), &salt) + .map_err(|_| AppError::internal("failed to hash password")) + .map(|hash| hash.to_string()) + } + + // Private helpers + fn verify_password(&self, pass_hash: &str, password: &str) -> ApiResult<()> { + let parsed_hash = PasswordHash::new(&pass_hash) + .map_err(|_| AppError::internal("invalid password hash"))?; + + Argon2::default() + .verify_password(password.as_bytes(), &parsed_hash) + .map_err(|_| AppError::unauthorised("incorrect password"))?; + + Ok(()) + } + + pub fn verify_2fa(&self, uid: i64, totp_secret: &str, totp_code: &str) -> ApiResult<()> { + if totp_gen(uid, totp_secret.as_bytes()) + .map_err(|_| AppError::internal("invalid totp secret"))? + .check_current(totp_code) + .map_err(|_| AppError::internal("invalid totp code"))? { + Ok(()) + } else { + Err(AppError::unauthorised("incorrect totp code")) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::repo::mock::{MockUserRepo, MockTokenRepo}; + use std::sync::Mutex; + + fn setup() -> AuthService { + unsafe { + std::env::set_var("JWT_SECRET", "test_secret"); + } + let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) }); + let tok_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) }); + let tokens = AccessTokenService::new(tok_repo); + AuthService::new(users, tokens) + } + + #[tokio::test] + async fn test_signup_and_login() { + let auth = setup(); + let code = auth.tokens.create(1, "test", 1, Utc::now(), Utc::now()).await.unwrap(); + + let signup_res = auth.signup("test@example.com", "tester", "password123", &code).await; + assert!(signup_res.is_ok()); + + let login_res = auth.login("tester", "password123").await; + assert!(login_res.is_ok()); + let login_data = login_res.unwrap(); + assert!(!login_data.totp_required); + assert!(!login_data.token.is_empty()); + } + + #[tokio::test] + async fn test_login_invalid_password() { + let auth = setup(); + let token_code = auth.tokens.create(1, "test", 1, Utc::now(), Utc::now()).await.unwrap(); + auth.signup("test@example.com", "tester", "password123", &token_code).await.unwrap(); + + let login_res = auth.login("tester", "wrong_password").await; + assert!(login_res.is_err()); + if let Err(AppError::Unauthorised(msg)) = login_res { + assert_eq!(msg, "incorrect password"); + } else { + panic!("Expected Unauthorised error"); + } + } + + #[tokio::test] + async fn test_invite() { + let auth = setup(); + let res = auth.tokens.create(1, "invite", 1, Utc::now(), Utc::now() + chrono::Duration::days(1)).await; + assert!(res.is_ok()); + let code = res.unwrap(); + assert!(!code.is_empty()); + + let token = auth.tokens.get_valid_token(&code).await; + assert!(token.is_ok()); + } +} \ No newline at end of file diff --git a/backend/src/svc/chat_svc.rs b/backend/src/svc/chat_svc.rs new file mode 100644 index 0000000..0572f8a --- /dev/null +++ b/backend/src/svc/chat_svc.rs @@ -0,0 +1,187 @@ +use crate::api::chat::ChatMsg; +use crate::error::{ApiResult, AppError}; +use crate::repo::message_repo::MessageRepository; +use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo}; +use chrono::{DateTime, Utc}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::broadcast::Sender; +use tokio::sync::{broadcast, Mutex}; +use crate::model::space::SpaceDto; +use crate::svc::llm_service::LlmService; + +/// ---------- shared broadcaster ---------- + +#[derive(Clone)] +pub struct ChatService { + users: Arc, + channels: Arc, + spaces: Arc, + messages: MessageRepository, + + llm: LlmService, + buffer_size: usize, + senders: Arc>>>, + + +} + +impl ChatService { + pub fn new( + buffer_size: usize, llm: LlmService, + messages: MessageRepository, users: Arc, + channels: Arc, spaces: Arc, + ) -> Self { + Self { + channels, + spaces, + llm, + users, + messages, + buffer_size, + senders: Arc::new(Mutex::new(std::collections::HashMap::new())), + } + } + + pub async fn get_accessible_channels(&self, uid: i64) -> ApiResult> { + // let spaces = self.spaces.get_by_member(uid).await?; + // TODO! UNCOMMENT THIS ^^^^^^ + let spaces = self.spaces.get_all().await?; + + let mut result = Vec::new(); + for space in spaces { + let channels = self.channels.get_by_space_id(space.id).await?; + result.push(SpaceDto { + channels, + id: space.id, + owner_id: space.owner_id, + name: space.name, + description: space.description, + created_at: space.created_at, + updated_at: space.updated_at, + }); + } + + Ok(result) + } + + pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult> { + let messages = self.messages.get_by_channel(channel_id, limit).await?; + Ok(messages) + } + + /// Sends a chat message to the specified channel, persists it to the database, + /// and handles potential AI-generated replies asynchronously. + /// + /// # Parameters + /// - `channel_id`: The ID of the channel to which the message will be sent. + /// - `uid`: The user ID of the sender. + /// - `text`: The content of the message to be sent. + /// - `created_at`: The timestamp at which the message was created. + /// + /// # Returns + /// - `ApiResult<()>`: Indicates success or failure of the operation. + /// + /// # Behavior + /// 1. Fetches the user by their `uid`. Returns an error if the user is not found. + /// 2. Constructs a `ChatMsg` object with the sender's `display_name` or `username`, + /// and the specified message content and timestamp. + /// 3. Publishes the constructed message to the given channel. + /// 4. Persists the message in the database. + /// 5. Spawns an asynchronous task to generate an LLM-powered (language model) reply: + /// - Sends the original message to the LLM worker for a potential reply. + /// - Publishes the LLM's reply to the same channel if successful. + /// - Persists the LLM's reply to the database. + /// + /// # Notes + /// - Caching with Redis is planned for both message persistence and AI replies, but + /// is not implemented in the current version. + /// - The spawned asynchronous task does not block the main execution flow. + /// + /// # Potential Errors + /// - Returns `AppError::NotFound` if the `uid` does not map to an existing user. + /// - Returns an error wrapped in `ApiResult` if the database operations fail. + /// + /// # TODO + /// - Implement caching for both user-supplied messages and LLM-generated replies + /// using Redis at the repository or service layer. + pub async fn send(&self, + channel_id: i64, uid: i64, + text: &str, created_at: DateTime + ) -> ApiResult<()> { + let user = self.users.get_by_id(uid).await + .ok_or(AppError::NotFound)?; + + let message = ChatMsg { + display_name: Some(user + .nickname.clone() + .unwrap_or_else(|| user.username.clone())), + user_id: uid, + text: text.to_string(), + timestamp: created_at, + }; + + self.publish(channel_id, message.clone()).await; + + let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?; + // TODO: caching w redis at repository layer + + let svc_instance = self.clone(); + + let Some(text) = text.strip_prefix("/ask ") else { + return Ok(()) + }; + + if !svc_instance.llm.enabled() { + return Ok(()) + } + + tokio::spawn(async move { + let response = svc_instance.llm + .query(&message) + .await; + + if let Ok(reply) = response { + + tracing::info!("LLM reply: {}", reply.text); + + svc_instance.publish(channel_id, reply.clone()).await; + // TODO: cache response (or do with redis!) + if let Err(e) = svc_instance.messages + .create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await { + tracing::error!("Failed to persist LLM reply: {}", e); + } + + tracing::info!("LLM reply persisted"); + + } else { + tracing::warn!("Error contacting LLM: {:?}", response); + } + }); + + Ok(()) + } + + /// Subscribe to the specified channel. + pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver { + let mut map = self.senders.lock().await; + let sender = map + .entry(channel_id) + .or_insert_with(|| broadcast::channel::(self.buffer_size).0); + sender.subscribe() + } + + // Private helper methods + + /// Publish a message to the specified channel. + async fn publish(&self, channel_id: i64, msg: ChatMsg) { + let mut map = self.senders.lock().await; + let sender = map + .entry(channel_id) + .or_insert_with(|| broadcast::channel::(self.buffer_size).0); + let _ = sender.send(msg); + } + + + +} \ No newline at end of file diff --git a/backend/src/svc/llm_service.rs b/backend/src/svc/llm_service.rs new file mode 100644 index 0000000..9b0cace --- /dev/null +++ b/backend/src/svc/llm_service.rs @@ -0,0 +1,89 @@ +#[derive(Clone)] +pub struct LlmService; + + +static LMSTUDIO_URL: LazyLock> = LazyLock::new(|| env::var("LMSTUDIO_URL").ok()); +static LMSTUDIO_MODEL: LazyLock> = LazyLock::new(|| env::var("LMSTUDIO_MODEL").ok()); + +impl LlmService { + + pub fn new() -> Self { + Self {} + } + + pub fn enabled(&self) -> bool { + LMSTUDIO_URL.is_some() + } + + pub async fn query(&self, message: &ChatMsg) -> ApiResult { + let Some(url) = LMSTUDIO_URL.clone() else { + return Err(AppError::internal("AI not enabled!")) + }; + + let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into()); + + let client = reqwest::Client::new(); + + // Build the request body + let payload = LlmRequest { + model, // whatever model you run locally + messages: vec![Message { + role: "user".into(), + content: message.text.clone(), + }], + }; + + // POST to lm‑studio (default 127.0.0.1:1234) + let resp = client + .post(url) + .json(&payload) + .send() + .await + .map_err(|_| AppError::internal("Failed to make request to LLM server"))?; + + // The API returns a JSON with `choices[].message.content` + #[derive(Deserialize)] + struct LlmResponse { + choices: Vec, + } + #[derive(Deserialize)] + struct Choice { + message: Message, + } + + let llm_resp: LlmResponse = resp + .json() + .await + .map_err(|_| AppError::internal("Failed to parse LLM response"))?; + + Ok(ChatMsg { + display_name: Some(String::from("llm")), + user_id: 0, + text: llm_resp.choices[0].message.content.clone(), + timestamp: chrono::Utc::now(), + }) + } +} + +use std::env; +use std::sync::LazyLock; +// src/llm.rs +use serde::{Deserialize, Serialize}; + +use crate::api::chat::ChatMsg; +use crate::error::{ApiResult, AppError}; +use crate::svc::chat_svc::ChatService; + +#[derive(Serialize)] +struct LlmRequest { + model: String, + messages: Vec, +} + +#[derive(Serialize, Deserialize)] +struct Message { + role: String, // "user" or "assistant" + content: String, +} + + diff --git a/backend/src/svc/mod.rs b/backend/src/svc/mod.rs new file mode 100644 index 0000000..7c366a3 --- /dev/null +++ b/backend/src/svc/mod.rs @@ -0,0 +1,6 @@ +pub mod auth_svc; +pub mod chat_svc; +pub mod settings_svc; +pub mod user_svc; +pub mod access_token_svc; +pub mod llm_service; \ No newline at end of file diff --git a/backend/src/svc/settings_svc.rs b/backend/src/svc/settings_svc.rs new file mode 100644 index 0000000..a5127f9 --- /dev/null +++ b/backend/src/svc/settings_svc.rs @@ -0,0 +1,116 @@ +//! The `SettingsService` is responsible for managing user account settings, allowing users to +//! update their username, password, display name, email, and delete their account. +//! It interacts with the `AuthService` to handle authentication and password-related functionality +//! and the `UserRepository` to perform updates to user accounts in the data store. + +use crate::error::{ApiResult, AppError}; +use crate::repo::UserRepo; +use crate::svc::auth_svc::AuthService; +use std::sync::Arc; + +#[derive(Clone)] +pub struct SettingsService { + auth: AuthService, + users: Arc, +} + +impl SettingsService { + pub fn new(auth: AuthService, users: Arc) -> Self { + Self { auth, users } + } + + pub async fn change_username(&self, uid: i64, new: &str) -> ApiResult<()> { + self.users.set_username(uid, new).await?; + Ok(()) + } + + pub async fn change_password(&self, uid: i64, old: &str, new: &str) -> ApiResult<()> { + self.auth.verify_user_password(uid, old).await?; + let hashed = self.auth.hash_password(new)?; + self.users.set_pass_hash(uid, &hashed).await?; + Ok(()) + } + + pub async fn change_display_name(&self, uid: i64, new: Option) -> ApiResult<()> { + self.users.set_display_name(uid, new).await?; + Ok(()) + } + + pub async fn change_email(&self, uid: i64, new: &str) -> ApiResult<()> { + self.users.set_email(uid, new).await?; + Ok(()) + } + + pub async fn delete_account(&self, uid: i64, password: &str, totp_code: &Option) -> ApiResult<()> { + self.auth.verify_user_password(uid, password).await?; + + // check 2fa code is correct if enabled + if self.auth.get_totp_status(uid).await? { + + let Some(totp_code) = totp_code else { + return Err(AppError::unauthorised("2fa code is required")) + }; + + self.auth.verify_user_totp(uid, totp_code).await?; + } + + self.users.delete_by_id(uid).await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::repo::mock::{MockUserRepo, MockTokenRepo}; + use std::sync::Mutex; + use chrono::Utc; + use crate::svc::access_token_svc::AccessTokenService; + + fn setup() -> SettingsService { + unsafe { + std::env::set_var("JWT_SECRET", "test_secret"); + } + let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) }); + let token_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) }); + let tokens_svc = AccessTokenService::new(token_repo); + let auth = AuthService::new(users.clone(), tokens_svc.clone()); + SettingsService::new(auth, users) + } + + #[tokio::test] + async fn test_change_username() { + let settings = setup(); + let uid = settings.users.new_user("test@example.com", "old", "pass").await.unwrap(); + + settings.change_username(uid, "new").await.unwrap(); + let user = settings.users.get_by_id(uid).await.unwrap(); + assert_eq!(user.username, "new"); + } + + #[tokio::test] + async fn test_change_password() { + let settings = setup(); + let pass = "old_pass"; + let hashed = settings.auth.hash_password(pass).unwrap(); + let uid = settings.users.new_user("test@example.com", "user", &hashed).await.unwrap(); + + settings.change_password(uid, pass, "new_pass").await.unwrap(); + let _user = settings.users.get_by_id(uid).await.unwrap(); + assert!(settings.auth.verify_user_password(uid, "new_pass").await.is_ok()); + } + + #[tokio::test] + async fn test_delete_account() { + let settings = setup(); + let pass = "password"; + let hashed = settings.auth.hash_password(pass).unwrap(); + let uid = settings.users.new_user("test@example.com", "user", &hashed).await.unwrap(); + + let res = settings.delete_account(uid, pass, &None).await; + assert!(res.is_ok()); + + let user = settings.users.get_by_id(uid).await; + assert!(user.is_none()); + } +} \ No newline at end of file diff --git a/backend/src/svc/user_svc.rs b/backend/src/svc/user_svc.rs new file mode 100644 index 0000000..c94ac3c --- /dev/null +++ b/backend/src/svc/user_svc.rs @@ -0,0 +1,28 @@ +use crate::error::ApiResult; +use crate::repo::UserRepo; +use std::sync::Arc; + +pub struct UserService { + repo: Arc +} + +impl UserService { + pub fn new(repo: Arc) -> Self { + Self { repo } + } + + pub async fn get_display_name(&self, uid: i64) -> ApiResult { + // TODO: redis caching for display names + + let user = self.repo.get_by_id(uid) + .await.ok_or(crate::error::AppError::NotFound)?; + + Ok(user.nickname.unwrap_or_else(|| user.username)) + } + + pub async fn get_username(&self, uid: i64) -> ApiResult { + self.repo.get_by_id(uid) + .await.ok_or(crate::error::AppError::NotFound) + .map(|u| u.username) + } +} \ No newline at end of file diff --git a/backend/src/user.rs b/backend/src/user.rs deleted file mode 100644 index 7934936..0000000 --- a/backend/src/user.rs +++ /dev/null @@ -1,188 +0,0 @@ -use argon2::{Argon2, PasswordHash, PasswordVerifier}; -use redis::AsyncCommands; -use rocket::{http::Status, serde::json::Json, time::OffsetDateTime}; -use rocket_db_pools::Connection; - -use crate::{ - auth::{Session, two_factor::totp_gen}, - db::{Postgres, Redis}, -}; - -pub struct User { - pub id: i32, - pub email: Option, - pub username: String, - pub display_name: Option, - pub pass_hash: String, - pub twofa_enabled: bool, - pub totp_secret: Option, - pub created_at: Option, - pub updated_at: Option, -} - -impl User { - pub async fn get_by_id(id: usize, db: &mut Connection) -> Option { - sqlx::query_as!( - Self, - "SELECT id, email, username, display_name, pass_hash, twofa_enabled, totp_secret, created_at, updated_at FROM users WHERE id = $1", - id as i32 - ) - .fetch_optional(&mut ***db) - .await - .unwrap_or(None) - } - - pub async fn delete(&mut self, db: &mut Connection) -> Result<(), sqlx::Error> { - sqlx::query!("DELETE FROM users WHERE id = $1", self.id) - .execute(&mut ***db) - .await?; - Ok(()) - } - - pub fn verify_2fa(&self, code: &str) -> Result<(), Status> { - if totp_gen( - self.id as usize, - self.totp_secret - .clone() - .expect("user with 2fa enabled has no totp secret") - .as_bytes(), - ) - .map_err(|_| Status::InternalServerError)? - .check_current(code) - .map_err(|_| Status::InternalServerError)? - { - Ok(()) - } else { - Err(Status::Unauthorized) - } - } - - pub fn verify_password(&self, password: &str) -> Result<(), Status> { - let parsed_hash = PasswordHash::new(&self.pass_hash) - .inspect_err(|e| { - tracing::error!("Failed to parse hash for password! uid:{} {e}", self.id) - }) - .map_err(|_| Status::InternalServerError)?; - - Argon2::default() - .verify_password(password.as_bytes(), &parsed_hash) - .map_err(|_| Status::Unauthorized) - } - - pub async fn set_display_name( - &mut self, - display_name: Option, - db: &mut Connection, - ) -> Result<(), sqlx::Error> { - self.display_name = display_name; - sqlx::query!( - "UPDATE users SET display_name = $1 WHERE id = $2", - self.display_name, - self.id - ) - .execute(&mut ***db) - .await?; - Ok(()) - } - - pub async fn set_username( - &mut self, - username: String, - db: &mut Connection, - ) -> Result<(), sqlx::Error> { - self.username = username; - sqlx::query!( - "UPDATE users SET username = $1 WHERE id = $2", - self.username, - self.id - ) - .execute(&mut ***db) - .await?; - Ok(()) - } - - pub async fn set_twofa_enabled( - &mut self, - enabled: bool, - db: &mut Connection, - ) -> Result<(), sqlx::Error> { - self.twofa_enabled = enabled; - sqlx::query!( - "UPDATE users SET twofa_enabled = $1 WHERE id = $2", - self.twofa_enabled, - self.id - ) - .execute(&mut ***db) - .await?; - Ok(()) - } - - pub async fn set_pass_hash( - &mut self, - pass_hash: String, - db: &mut Connection, - ) -> Result<(), sqlx::Error> { - self.pass_hash = pass_hash; - sqlx::query!( - "UPDATE users SET pass_hash = $1 WHERE id = $2", - self.pass_hash, - self.id - ) - .execute(&mut ***db) - .await?; - Ok(()) - } -} - -#[get("/users", rank = 2)] -pub async fn users(_ag: Session, mut db: Connection) -> Json> { - sqlx::query!("SELECT id FROM users") - .fetch_all(&mut **db) - .await - .unwrap_or_else(|_| Vec::new()) - .into_iter() - .map(|row| row.id) - .collect::>() - .into() -} - -#[get("/users/", rank = 1)] -pub async fn display_name( - id: usize, - _ag: Session, - mut pgsql_conn: Connection, - mut redis_conn: Connection, -) -> String { - UserCache::username(id, &mut redis_conn, &mut pgsql_conn).await -} - -pub struct UserCache {} - -impl UserCache { - pub async fn username( - id: usize, - redis_conn: &mut Connection, - pgsql_conn: &mut Connection, - ) -> String { - if let Ok(val) = redis_conn.get(format!("users:{id}")).await { - return val; - } - - if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32) - .fetch_one(&mut ***pgsql_conn) - .await - { - let username = v.username; - Self::insert(id, &username, redis_conn).await; - username - } else { - unimplemented!() - } - } - - pub async fn insert(id: usize, username: &str, conn: &mut Connection) { - conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800) - .await - .expect("failed to insert key"); - } -} diff --git a/backend/tests/auth_integration.rs b/backend/tests/auth_integration.rs new file mode 100644 index 0000000..5dab21f --- /dev/null +++ b/backend/tests/auth_integration.rs @@ -0,0 +1,198 @@ +use backend::rocket_builder; +use backend::repo::mock::{MockUserRepo, MockTokenRepo}; +use backend::repo::message_repo::MessageRepository; +use backend::svc::chat_svc::ChatService; +use backend::repo::user_repo::UserRepository; +use backend::repo::{Repo, AccessTokenRepoTrait}; +use rocket::local::asynchronous::Client; +use rocket::http::{Status, ContentType}; +use serde_json::json; +use std::sync::{Arc, Mutex}; +use sqlx::PgPool; +use chrono::Utc; +use backend::svc::llm_service::LlmService; + +async fn test_rocket() -> rocket::Rocket { + let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) }); + let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) }); + + let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap(); + let messages = MessageRepository::new(pool.clone()); + let user_repo = Arc::new(UserRepository::new(pool)); + let llm_service = LlmService::new(); + let chat_service = ChatService::new(32, messages, user_repo, llm_service); + + rocket_builder(users, tokens, chat_service) +} + +#[rocket::async_test] +async fn test_unauthorized_access() { + let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance"); + + // Attempt to access a protected endpoint without authentication + let response = client.patch("/api/settings/display_name").dispatch().await; + assert_eq!(response.status(), Status::Unauthorized); + + let response = client.post("/api/settings/password").dispatch().await; + assert_eq!(response.status(), Status::Unauthorized); + + let response = client.delete("/api/settings").dispatch().await; + assert_eq!(response.status(), Status::Unauthorized); +} + +#[rocket::async_test] +async fn test_signup_invalid_token() { + let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance"); + + let signup_data = json!({ + "email": "test@example.com", + "username": "testuser", + "password": "password123", + "access_token": "invalid-token" + }); + + let response = client.post("/api/signup") + .header(ContentType::JSON) + .body(signup_data.to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Unauthorized); +} + +#[rocket::async_test] +async fn test_login_invalid_credentials() { + let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance"); + + let login_data = json!({ + "username": "nonexistent", + "password": "wrongpassword" + }); + + let response = client.post("/api/login") + .header(ContentType::JSON) + .body(login_data.to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Unauthorized); +} + +#[rocket::async_test] +async fn test_full_auth_flow() { + let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) }); + let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) }); + let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap(); + let messages = MessageRepository::new(pool.clone()); + let user_repo = Arc::new(UserRepository::new(pool)); + let llm_service = LlmService::new(); + let chat_service = ChatService::new(32, messages, user_repo, llm_service); + + let token_code = "valid-token"; + tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap(); + + let client = Client::tracked(rocket_builder(users, tokens, chat_service)).await.expect("valid rocket instance"); + + // 1. Signup + let signup_data = json!({ + "email": "test@example.com", + "username": "testuser", + "password": "password123", + "access_token": token_code + }); + + let response = client.post("/api/signup") + .header(ContentType::JSON) + .body(signup_data.to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Ok); + let body = response.into_string().await.unwrap(); + assert!(body.contains("token")); + + // 2. Login + let login_data = json!({ + "username": "testuser", + "password": "password123" + }); + + let response = client.post("/api/login") + .header(ContentType::JSON) + .body(login_data.to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Ok); + let body = response.into_string().await.unwrap(); + assert!(body.contains("token")); +} + +#[rocket::async_test] +async fn test_delete_account_security() { + let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) }); + let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) }); + let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap(); + let messages = MessageRepository::new(pool.clone()); + let user_repo = Arc::new(UserRepository::new(pool)); + let llm_service = LlmService::new(); + let chat_service = ChatService::new(32, messages, user_repo, llm_service); + + let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance"); + + let token_code = "valid-token"; + tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap(); + + client.post("/api/signup") + .header(ContentType::JSON) + .body(json!({ + "email": "test@example.com", + "username": "testuser", + "password": "password123", + "access_token": token_code + }).to_string()) + .dispatch() + .await; + + // Login to get JWT + let login_res = client.post("/api/login") + .header(ContentType::JSON) + .body(json!({ + "username": "testuser", + "password": "password123" + }).to_string()) + .dispatch() + .await; + + let auth_resp: serde_json::Value = serde_json::from_str(&login_res.into_string().await.unwrap()).unwrap(); + let jwt = auth_resp["token"].as_str().unwrap(); + + // 1. Delete with WRONG password + let response = client.delete("/api/settings") + .header(ContentType::JSON) + .header(rocket::http::Header::new("Authorization", format!("Bearer {}", jwt))) + .body(json!({ + "password": "wrongpassword", + "totp_code": null + }).to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Unauthorized); + + // 2. Delete with CORRECT password + let response = client.delete("/api/settings") + .header(ContentType::JSON) + .header(rocket::http::Header::new("Authorization", format!("Bearer {}", jwt))) + .body(json!({ + "password": "password123", + "totp_code": null + }).to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Ok); + + // Verify user is gone + assert!(users.users.lock().unwrap().is_empty()); +} diff --git a/backend/tests/chat_integration.rs b/backend/tests/chat_integration.rs new file mode 100644 index 0000000..10578ac --- /dev/null +++ b/backend/tests/chat_integration.rs @@ -0,0 +1,142 @@ +use backend::rocket_builder; +use backend::repo::mock::{MockUserRepo, MockTokenRepo}; +use backend::repo::message_repo::MessageRepository; +use backend::svc::chat_svc::ChatService; +use backend::repo::{Repo, AccessTokenRepoTrait}; +use rocket::local::asynchronous::Client; +use rocket::http::{Status, ContentType, Header}; +use serde_json::{json, Value}; +use std::sync::{Arc, Mutex}; +use sqlx::PgPool; +use chrono::Utc; +use backend::svc::llm_service::LlmService; + +async fn setup_client_with_svc(chat_service: ChatService, users: Arc, tokens: Arc) -> (Client, String) { + let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance"); + + // Create a user and get JWT + let token_code = "valid-token"; + tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap(); + + let jwt = { + let signup_res = client.post("/api/signup") + .header(ContentType::JSON) + .body(json!({ + "email": "test@example.com", + "username": "testuser", + "password": "password123", + "access_token": token_code + }).to_string()) + .dispatch() + .await; + assert_eq!(signup_res.status(), Status::Ok); + + let login_res = client.post("/api/login") + .header(ContentType::JSON) + .body(json!({ + "username": "testuser", + "password": "password123" + }).to_string()) + .dispatch() + .await; + + assert_eq!(login_res.status(), Status::Ok, "Login failed"); + + let body = login_res.into_string().await.expect("login body"); + let auth_resp: serde_json::Value = serde_json::from_str(&body).unwrap(); + auth_resp["token"].as_str().unwrap().to_string() + }; + + (client, jwt) +} + +#[rocket::async_test] +async fn test_chat_event_stream_consistency() { + unsafe { std::env::set_var("JWT_SECRET", "test_secret"); } + let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap(); + let messages = ::new(pool.clone()); + let users_repo = Arc::new(MockUserRepo { users: Mutex::new(vec![]) }); + let tokens_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) }); + let llm_service = LlmService::new(); + let chat_service = ChatService::new(1024, messages, users_repo.clone(), llm_service); + + let (client, jwt) = setup_client_with_svc(chat_service.clone(), users_repo.clone(), tokens_repo.clone()).await; + + // Use the same client for sender but with a different user (or the same, doesn't matter for broadcast) + // Actually, to simulate another user, we should sign up another user. + let jwt_sender = { + let token_code = "valid-token-2"; + tokens_repo.create_new(1, "test2", token_code, 1, Utc::now(), Utc::now() + chrono::Duration::days(1)).await.unwrap(); + let signup_res = client.post("/api/signup") + .header(ContentType::JSON) + .body(json!({ + "email": "test2@example.com", + "username": "testuser2", + "password": "password123", + "access_token": token_code + }).to_string()) + .dispatch() + .await; + assert_eq!(signup_res.status(), Status::Ok); + let login_res = client.post("/api/login") + .header(ContentType::JSON) + .body(json!({ + "username": "testuser2", + "password": "password123" + }).to_string()) + .dispatch() + .await; + let body = login_res.into_string().await.unwrap(); + let auth_resp: serde_json::Value = serde_json::from_str(&body).unwrap(); + auth_resp["token"].as_str().unwrap().to_string() + }; + + let channel_id = 1; + + // Start listening to the event stream + let mut response = client.get(format!("/api/events/{}", channel_id)) + .header(Header::new("Authorization", format!("Bearer {}", jwt))) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Ok); + + let num_messages = 5; // Reduced for faster debugging + let mut received_count = 0; + + let jwt_clone = jwt.clone(); + + tokio::spawn(async move { + for i in 0..num_messages { + let msg = format!("Message {}", i); + let res = sender_client.post(format!("/api/chat/{}", channel_id)) + .header(ContentType::JSON) + .header(Header::new("Authorization", format!("Bearer {}", jwt_clone))) + .body(json!({ + "display_name": "testuser", + "user_id": 1, + "text": msg, + "timestamp": Utc::now() + }).to_string()) + .dispatch() + .await; + assert_eq!(res.status(), Status::Ok); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + } + }); + + // Wait a bit for messages to be posted + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Consume the stream + let text = response.into_string().await.unwrap(); + println!("Received chunk: {}", text); + let mut received_count = 0; + for line in text.lines() { + if line.starts_with("data:") { + received_count += 1; + } + } + + assert_eq!(received_count, num_messages, "Should receive all posted messages. Received: {}. Full text: {}", received_count, text); +} diff --git a/backend/tests/settings_integration.rs b/backend/tests/settings_integration.rs new file mode 100644 index 0000000..c74d200 --- /dev/null +++ b/backend/tests/settings_integration.rs @@ -0,0 +1,121 @@ +use backend::rocket_builder; +use backend::repo::mock::{MockUserRepo, MockTokenRepo}; +use backend::repo::message_repo::MessageRepository; +use backend::svc::chat_svc::ChatService; +use backend::repo::user_repo::UserRepository; +use backend::repo::{Repo, AccessTokenRepoTrait}; +use rocket::local::asynchronous::Client; +use rocket::http::{Status, ContentType, Header}; +use serde_json::json; +use std::sync::{Arc, Mutex}; +use sqlx::PgPool; +use chrono::Utc; +use backend::svc::llm_service::LlmService; + +async fn setup_client() -> (Client, Arc, String) { + let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) }); + let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) }); + let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap(); + let messages = MessageRepository::new(pool.clone()); + let user_repo = Arc::new(UserRepository::new(pool)); + let llm_service = LlmService::new(); + let chat_service = ChatService::new(32, messages, user_repo, llm_service); + + let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance"); + + // Create a user and get JWT + let token_code = "valid-token"; + tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap(); + + client.post("/api/signup") + .header(ContentType::JSON) + .body(json!({ + "email": "test@example.com", + "username": "testuser", + "password": "password123", + "access_token": token_code + }).to_string()) + .dispatch() + .await; + + let login_res = client.post("/api/login") + .header(ContentType::JSON) + .body(json!({ + "username": "testuser", + "password": "password123" + }).to_string()) + .dispatch() + .await; + + let auth_resp: serde_json::Value = serde_json::from_str(&login_res.into_string().await.unwrap()).unwrap(); + let jwt = auth_resp["token"].as_str().unwrap().to_string(); + + (client, users, jwt) +} + +#[rocket::async_test] +async fn test_change_display_name() { + let (client, users, jwt) = setup_client().await; + + let response = client.patch("/api/settings/display_name") + .header(ContentType::JSON) + .header(Header::new("Authorization", format!("Bearer {}", jwt))) + .body(json!({ + "display_name": "New Display Name" + }).to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Ok); + + let user = users.users.lock().unwrap()[0].clone(); + assert_eq!(user.nickname, Some("New Display Name".to_string())); +} + +#[rocket::async_test] +async fn test_change_username() { + let (client, users, jwt) = setup_client().await; + + let response = client.patch("/api/settings/username") + .header(ContentType::JSON) + .header(Header::new("Authorization", format!("Bearer {}", jwt))) + .body(json!({ + "username": "newusername" + }).to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Ok); + + let user = users.users.lock().unwrap()[0].clone(); + assert_eq!(user.username, "newusername"); +} + +#[rocket::async_test] +async fn test_change_password() { + let (client, _, jwt) = setup_client().await; + + let response = client.post("/api/settings/password") + .header(ContentType::JSON) + .header(Header::new("Authorization", format!("Bearer {}", jwt))) + .body(json!({ + "old_password": "password123", + "new_password": "newpassword456" + }).to_string()) + .dispatch() + .await; + + assert_eq!(response.status(), Status::Ok); + + // Verify login with new password + let login_res = client.post("/api/login") + .header(ContentType::JSON) + .body(json!({ + "username": "testuser", + "password": "newpassword456" + }).to_string()) + .dispatch() + .await; + + assert_eq!(login_res.status(), Status::Ok); +}