full backend rewrite.
calling this v0.4.0
This commit is contained in:
Generated
+10
@@ -0,0 +1,10 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Ignored default folder with query files
|
||||||
|
/queries/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
+5
-2
@@ -10,7 +10,7 @@ dotenv = "0.15.0"
|
|||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
image = "0.25.8"
|
image = "0.25.8"
|
||||||
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
|
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
|
||||||
rand = "0.9.2"
|
rand = "0.8"
|
||||||
redis = { version = "0.25.4", features = ["tokio-comp"] }
|
redis = { version = "0.25.4", features = ["tokio-comp"] }
|
||||||
reqwest = { version = "0.12.23", features = ["json"] }
|
reqwest = { version = "0.12.23", features = ["json"] }
|
||||||
rocket = { version = "0.5.1", features = ["json", "secrets"] }
|
rocket = { version = "0.5.1", features = ["json", "secrets"] }
|
||||||
@@ -20,8 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] }
|
|||||||
serde = { version = "1.0.228", features = ["derive"] }
|
serde = { version = "1.0.228", features = ["derive"] }
|
||||||
serde_json = "1.0.145"
|
serde_json = "1.0.145"
|
||||||
sha2 = "0.10.9"
|
sha2 = "0.10.9"
|
||||||
sqlx = { version = "0.7.4", features = ["macros", "time"] }
|
sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] }
|
||||||
tokio = { version = "1.47.1", features = ["full"] }
|
tokio = { version = "1.47.1", features = ["full"] }
|
||||||
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
|
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
|
||||||
tracing = "0.1.44"
|
tracing = "0.1.44"
|
||||||
uuid = { version = "1.18.1", features = ["v4"] }
|
uuid = { version = "1.18.1", features = ["v4"] }
|
||||||
|
thiserror = "1.0.69"
|
||||||
|
utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] }
|
||||||
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
CREATE TABLE users (
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
username VARCHAR(50) UNIQUE NOT NULL,
|
|
||||||
password VARCHAR(50) NOT NULL,
|
|
||||||
display_name VARCHAR(50),
|
|
||||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE channels (
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
name VARCHAR(50) NOT NULL,
|
|
||||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE messages (
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
|
||||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
|
||||||
content TEXT NOT NULL,
|
|
||||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
is_edited BOOLEAN DEFAULT FALSE
|
|
||||||
);
|
|
||||||
|
|
||||||
create table attachments (
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
|
|
||||||
path TEXT NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX idx_messages_channel_id ON messages (channel_id, id DESC);
|
|
||||||
CREATE INDEX idx_new_messages ON messages(created_at DESC);
|
|
||||||
|
|
||||||
-- Create a function to update the updated_at timestamp
|
|
||||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
|
||||||
RETURNS TRIGGER AS $$
|
|
||||||
BEGIN
|
|
||||||
NEW.updated_at = CURRENT_TIMESTAMP;
|
|
||||||
RETURN NEW;
|
|
||||||
END;
|
|
||||||
$$ language 'plpgsql';
|
|
||||||
|
|
||||||
-- Create trigger for messages table
|
|
||||||
CREATE TRIGGER update_messages_updated_at
|
|
||||||
BEFORE UPDATE ON messages
|
|
||||||
FOR EACH ROW
|
|
||||||
EXECUTE FUNCTION update_updated_at_column();
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
CREATE TABLE sessions (
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
|
||||||
token TEXT NOT NULL UNIQUE,
|
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '7 days'
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
|
|
||||||
RETURNS TRIGGER AS $$
|
|
||||||
BEGIN
|
|
||||||
DELETE FROM sessions WHERE expires_at < NOW();
|
|
||||||
RETURN NULL;
|
|
||||||
END;
|
|
||||||
$$ LANGUAGE plpgsql;
|
|
||||||
|
|
||||||
CREATE TRIGGER trigger_cleanup_sessions
|
|
||||||
AFTER INSERT ON sessions
|
|
||||||
EXECUTE FUNCTION cleanup_expired_sessions();
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
ALTER TABLE users ADD COLUMN email VARCHAR(100);
|
|
||||||
ALTER TABLE users ADD COLUMN twofa_enabled BOOLEAN DEFAULT FALSE;
|
|
||||||
ALTER TABLE users ADD COLUMN totp_secret VARCHAR(64);
|
|
||||||
ALTER TABLE users ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP;
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
ALTER TABLE users ALTER COLUMN twofa_enabled SET NOT NULL;
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
CREATE TABLE access_codes (
|
|
||||||
-- identifiers
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
creator_id INTEGER NOT NULL REFERENCES users(id),
|
|
||||||
|
|
||||||
-- code data
|
|
||||||
code VARCHAR(255) NOT NULL,
|
|
||||||
name VARCHAR(255) NOT NULL,
|
|
||||||
|
|
||||||
-- uses
|
|
||||||
uses INTEGER NOT NULL DEFAULT 0,
|
|
||||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
|
||||||
|
|
||||||
-- time data
|
|
||||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '1 day'
|
|
||||||
);
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
ALTER TABLE access_codes
|
|
||||||
ALTER COLUMN created_at
|
|
||||||
TYPE TIMESTAMP WITH TIME ZONE
|
|
||||||
USING created_at AT TIME ZONE 'UTC';
|
|
||||||
|
|
||||||
ALTER TABLE access_codes
|
|
||||||
ALTER COLUMN expires_at
|
|
||||||
TYPE TIMESTAMP WITH TIME ZONE
|
|
||||||
USING expires_at AT TIME ZONE 'UTC';
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
TRUNCATE TABLE users CASCADE;
|
|
||||||
|
|
||||||
ALTER TABLE users DROP COLUMN password;
|
|
||||||
ALTER TABLE users ADD COLUMN pass_hash VARCHAR(255) NOT NULL;
|
|
||||||
ALTER TABLE users ADD CONSTRAINT users_username_unique UNIQUE (username);
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
CREATE TYPE status AS ENUM ('pending', 'accepted', 'blocked');
|
|
||||||
|
|
||||||
CREATE TABLE relationships (
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
from_user INTEGER NOT NULL REFERENCES users(id),
|
|
||||||
to_user INTEGER NOT NULL REFERENCES users(id),
|
|
||||||
status status NOT NULL DEFAULT 'pending',
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
|
|
||||||
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
|
|
||||||
);
|
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
-- Add migration script here
|
||||||
|
-- Add migration script here
|
||||||
|
CREATE TYPE user_role AS ENUM ('user', 'admin');
|
||||||
|
CREATE TYPE totp_status AS ENUM ('disabled', 'pending', 'enabled');
|
||||||
|
|
||||||
|
CREATE TABLE users (
|
||||||
|
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||||
|
role user_role NOT NULL DEFAULT 'user',
|
||||||
|
|
||||||
|
-- profile
|
||||||
|
nickname VARCHAR(255),
|
||||||
|
|
||||||
|
-- basic auth
|
||||||
|
username VARCHAR(255) UNIQUE NOT NULL,
|
||||||
|
passhash VARCHAR(255) NOT NULL,
|
||||||
|
|
||||||
|
-- email
|
||||||
|
email VARCHAR(255),
|
||||||
|
email_verified BOOLEAN DEFAULT FALSE,
|
||||||
|
|
||||||
|
-- 2fa
|
||||||
|
totp_secret VARCHAR(255),
|
||||||
|
totp_status totp_status NOT NULL DEFAULT 'disabled',
|
||||||
|
|
||||||
|
-- update tracking
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
deleted_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE access_tokens (
|
||||||
|
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||||
|
creator_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
|
||||||
|
code VARCHAR(255) NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
|
||||||
|
uses INTEGER NOT NULL DEFAULT 0,
|
||||||
|
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||||
|
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '24 hours',
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE refresh_tokens (
|
||||||
|
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||||
|
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
token_hash VARCHAR(255) NOT NULL,
|
||||||
|
|
||||||
|
revoked BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '7 days'
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE spaces (
|
||||||
|
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
|
||||||
|
owner_id BIGINT NOT NULL REFERENCES users(id),
|
||||||
|
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE channels (
|
||||||
|
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
|
||||||
|
space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
|
||||||
|
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE space_members (
|
||||||
|
space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
|
||||||
|
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
|
||||||
|
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
role user_role DEFAULT 'user',
|
||||||
|
|
||||||
|
PRIMARY KEY (space_id, user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE messages (
|
||||||
|
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||||
|
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||||
|
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
is_edited BOOLEAN DEFAULT FALSE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE attachments (
|
||||||
|
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||||
|
message_id BIGINT NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
|
||||||
|
filename VARCHAR(255) NOT NULL,
|
||||||
|
content_type VARCHAR(100) NOT NULL, -- mime type e.g. image/png, video/mp4
|
||||||
|
size_bytes BIGINT NOT NULL,
|
||||||
|
url TEXT NOT NULL, -- path to file on your CDN/storage
|
||||||
|
width INTEGER, -- null for non-image/video
|
||||||
|
height INTEGER, -- null for non-image/video
|
||||||
|
duration_ms INTEGER, -- null for non-audio/video
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TYPE relationship_status AS ENUM ('pending', 'accepted', 'blocked');
|
||||||
|
|
||||||
|
CREATE TABLE relationships (
|
||||||
|
id BIGSERIAL PRIMARY KEY NOT NULL,
|
||||||
|
|
||||||
|
from_user BIGINT NOT NULL REFERENCES users(id),
|
||||||
|
to_user BIGINT NOT NULL REFERENCES users(id),
|
||||||
|
status relationship_status NOT NULL DEFAULT 'pending',
|
||||||
|
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
|
||||||
|
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_messages_channel_id ON messages(channel_id, created_at DESC);
|
||||||
|
CREATE INDEX idx_messages_user_id ON messages(user_id);
|
||||||
|
CREATE INDEX idx_attachments_message ON attachments(message_id);
|
||||||
|
CREATE INDEX idx_channels_space_id ON channels(space_id);
|
||||||
|
CREATE INDEX idx_space_members_user ON space_members(user_id);
|
||||||
|
CREATE INDEX idx_refresh_tokens_hash ON refresh_tokens(token_hash);
|
||||||
|
CREATE INDEX idx_relationships_from ON relationships(from_user, to_user);
|
||||||
|
CREATE INDEX idx_relationships_to ON relationships(to_user);
|
||||||
|
CREATE INDEX idx_access_tokens_code ON access_tokens(code);
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION update_updated_at()
|
||||||
|
RETURNS TRIGGER AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW.updated_at = NOW();
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
CREATE TRIGGER users_updated_at
|
||||||
|
BEFORE UPDATE ON users
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
|
||||||
|
|
||||||
|
CREATE TRIGGER spaces_updated_at
|
||||||
|
BEFORE UPDATE ON spaces
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
|
||||||
|
|
||||||
|
CREATE TRIGGER channels_updated_at
|
||||||
|
BEFORE UPDATE ON channels
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
|
||||||
|
|
||||||
|
CREATE TRIGGER messages_updated_at
|
||||||
|
BEFORE UPDATE ON messages
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
|
||||||
|
|
||||||
|
CREATE TRIGGER relationships_updated_at
|
||||||
|
BEFORE UPDATE ON relationships
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
|
||||||
|
|
||||||
|
CREATE TRIGGER access_tokens_updated_at
|
||||||
|
BEFORE UPDATE ON access_tokens
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION add_owner_to_space()
|
||||||
|
RETURNS TRIGGER AS $$
|
||||||
|
BEGIN
|
||||||
|
INSERT INTO space_members (space_id, user_id, role, joined_at)
|
||||||
|
VALUES (NEW.id, NEW.owner_id, 'admin', NOW());
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
CREATE TRIGGER space_owner_becomes_member
|
||||||
|
AFTER INSERT ON spaces
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION add_owner_to_space();
|
||||||
@@ -1,21 +1,46 @@
|
|||||||
use std::{
|
use crate::error::ApiResult;
|
||||||
sync::LazyLock,
|
use crate::model::auth::{AccessTokenForm, AuthResponse, LoginCredentials, SignupCredentials};
|
||||||
time::{SystemTime, UNIX_EPOCH},
|
use crate::svc::auth_svc::AuthService;
|
||||||
};
|
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||||
|
use rocket::http::Status;
|
||||||
|
use rocket::request::{FromRequest, Outcome};
|
||||||
|
use rocket::serde::json::Json;
|
||||||
|
use rocket::serde::{Deserialize, Serialize};
|
||||||
|
use rocket::{Request, State};
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
use crate::svc::access_token_svc::AccessTokenService;
|
||||||
|
|
||||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
#[post("/signup", data = "<cred>")]
|
||||||
use rand::Rng;
|
pub async fn signup(
|
||||||
use rocket::{
|
cred: Json<SignupCredentials>,
|
||||||
Request,
|
svc: &State<AuthService>
|
||||||
http::Status,
|
) -> ApiResult<Json<AuthResponse>> {
|
||||||
request::{self, FromRequest, Outcome},
|
let response = svc
|
||||||
};
|
.signup(
|
||||||
use rocket_db_pools::Connection;
|
&cred.email, &cred.username, &cred.password, &cred.access_token,
|
||||||
use serde::{Deserialize, Serialize};
|
).await?;
|
||||||
use sha2::{Digest, Sha256, digest::block_buffer::Lazy};
|
Ok(Json(response))
|
||||||
use sqlx::postgres::PgQueryResult;
|
}
|
||||||
|
|
||||||
use crate::db::Postgres;
|
#[post("/login", data = "<cred>")]
|
||||||
|
pub async fn login(
|
||||||
|
cred: Json<LoginCredentials>,
|
||||||
|
svc: &State<AuthService>
|
||||||
|
) -> ApiResult<Json<AuthResponse>> {
|
||||||
|
Ok(Json(svc.login(&cred.username, &cred.password).await?))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/invite", data = "<form>")]
|
||||||
|
pub async fn generate_invite(
|
||||||
|
session: Session,
|
||||||
|
form: Json<AccessTokenForm>,
|
||||||
|
svc: &State<AccessTokenService>
|
||||||
|
) -> ApiResult<String> {
|
||||||
|
svc.create(
|
||||||
|
session.uid, &form.name, form.max_uses,
|
||||||
|
form.start_date, form.expiry_date).await
|
||||||
|
}
|
||||||
|
|
||||||
static JWT_SECRET: LazyLock<String> = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap());
|
static JWT_SECRET: LazyLock<String> = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap());
|
||||||
|
|
||||||
@@ -27,7 +52,7 @@ pub enum TokenScope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
pub user_id: usize,
|
pub uid: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[rocket::async_trait]
|
#[rocket::async_trait]
|
||||||
@@ -37,7 +62,7 @@ impl<'r> FromRequest<'r> for Session {
|
|||||||
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match Claims::from_request(req).await {
|
match Claims::from_request(req).await {
|
||||||
Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session {
|
Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session {
|
||||||
user_id: user.sub as usize,
|
uid: user.sub as i64,
|
||||||
}),
|
}),
|
||||||
Outcome::Success(_) => {
|
Outcome::Success(_) => {
|
||||||
eprintln!("warning: user with scope other than Full attempted to access session");
|
eprintln!("warning: user with scope other than Full attempted to access session");
|
||||||
@@ -106,4 +131,4 @@ impl<'r> FromRequest<'r> for Claims {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -26,7 +26,7 @@ pub async fn profile_pic(user_id: usize) -> Option<NamedFile> {
|
|||||||
Some(image)
|
Some(image)
|
||||||
} else {
|
} else {
|
||||||
Some(
|
Some(
|
||||||
NamedFile::open("./cdn/profiles/full/default.svg")
|
NamedFile::open("../../cdn/profiles/full/default.svg")
|
||||||
.await
|
.await
|
||||||
.ok()?,
|
.ok()?,
|
||||||
)
|
)
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
use crate::api::auth::Session;
|
||||||
|
use crate::error::ApiResult;
|
||||||
|
use crate::svc::chat_svc::ChatService;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use rocket::response::stream::Event;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
|
use rocket::serde::{Deserialize, Serialize};
|
||||||
|
use rocket::{Shutdown, State, ___internal_EventStream as EventStream};
|
||||||
|
use sqlx::FromRow;
|
||||||
|
use tokio::select;
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
|
||||||
|
/// ---------- Rocket routes ----------
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
|
||||||
|
pub struct ChatMsg {
|
||||||
|
pub display_name: Option<String>,
|
||||||
|
pub user_id: i64,
|
||||||
|
pub text: String,
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
|
||||||
|
pub async fn post_message(
|
||||||
|
msg: Json<ChatMsg>,
|
||||||
|
chat: &State<ChatService>,
|
||||||
|
session: Session,
|
||||||
|
channel_id: i64,
|
||||||
|
) -> ApiResult<()> {
|
||||||
|
chat.send(channel_id, session.uid, &msg.text, Utc::now()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/events/<channel_id>")]
|
||||||
|
pub async fn event_stream(
|
||||||
|
chat: &State<ChatService>,
|
||||||
|
s: Session,
|
||||||
|
mut shutdown: Shutdown,
|
||||||
|
channel_id: i64,
|
||||||
|
) -> ApiResult<EventStream![]> {
|
||||||
|
let messages = chat.get_messages(channel_id, 100)
|
||||||
|
.await?; // if get message returned err, inform user.
|
||||||
|
|
||||||
|
let mut rx = chat.subscribe(channel_id).await;
|
||||||
|
let id = s.uid;
|
||||||
|
|
||||||
|
Ok(EventStream! {
|
||||||
|
for msg in messages {
|
||||||
|
yield Event::json(&msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
select!{
|
||||||
|
_ = &mut shutdown => break, // exit early on shutdown
|
||||||
|
msg = rx.recv() => match msg {
|
||||||
|
Ok(msg) => {
|
||||||
|
tracing::info!("yielding message!");
|
||||||
|
yield Event::json(&msg)
|
||||||
|
},
|
||||||
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||||
|
tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",);
|
||||||
|
yield Event::comment("RecvError::Lagged");
|
||||||
|
}
|
||||||
|
Err(broadcast::error::RecvError::Closed) => {
|
||||||
|
tracing::info!("Broadcaster hung up on channel {channel_id}!");
|
||||||
|
break
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
pub mod auth;
|
||||||
|
pub mod chat;
|
||||||
|
pub mod totp;
|
||||||
|
pub mod settings;
|
||||||
|
pub mod cdn;
|
||||||
|
pub mod profile;
|
||||||
|
pub mod space;
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
use rocket::State;
|
||||||
|
use crate::api::auth::Session;
|
||||||
|
use crate::error::ApiResult;
|
||||||
|
use crate::svc::user_svc::UserService;
|
||||||
|
|
||||||
|
#[get("/users/<id>")]
|
||||||
|
pub async fn display_name(
|
||||||
|
id: i64,
|
||||||
|
_ag: Session,
|
||||||
|
svc: &State<UserService>,
|
||||||
|
) -> ApiResult<String> {
|
||||||
|
svc.get_username(id).await
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
use crate::api::auth::Session;
|
||||||
|
use crate::error::ApiResult;
|
||||||
|
use crate::svc::settings_svc::SettingsService;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
|
use rocket::serde::{Deserialize, Serialize};
|
||||||
|
use rocket::State;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PasswordForm {
|
||||||
|
pub old_password: String,
|
||||||
|
pub new_password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/settings/password", data = "<form>")]
|
||||||
|
pub async fn change_password(
|
||||||
|
session: Session,
|
||||||
|
form: Json<PasswordForm>,
|
||||||
|
settings: &State<SettingsService>
|
||||||
|
) -> ApiResult<()> {
|
||||||
|
settings.change_password(
|
||||||
|
session.uid, &form.old_password, &form.new_password
|
||||||
|
).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
|
pub struct DisplayNameForm {
|
||||||
|
pub display_name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
|
pub struct PasswordAnd2faForm {
|
||||||
|
pub password: String,
|
||||||
|
pub totp_code: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[delete("/settings", data = "<data>")]
|
||||||
|
pub async fn delete_account(
|
||||||
|
session: Session,
|
||||||
|
data: Json<PasswordAnd2faForm>,
|
||||||
|
settings: &State<SettingsService>
|
||||||
|
) -> ApiResult<()> {
|
||||||
|
settings.delete_account(
|
||||||
|
session.uid, &data.password, &data.totp_code
|
||||||
|
).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[patch("/settings/display_name", data = "<new>")]
|
||||||
|
pub async fn change_display_name(
|
||||||
|
session: Session,
|
||||||
|
new: Json<DisplayNameForm>,
|
||||||
|
settings: &State<SettingsService>
|
||||||
|
) -> ApiResult<()> {
|
||||||
|
settings.change_display_name(session.uid, new.display_name.clone()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct UsernameForm {
|
||||||
|
pub username: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[patch("/settings/username", data = "<new>")]
|
||||||
|
pub async fn change_username(
|
||||||
|
session: Session,
|
||||||
|
new: Json<UsernameForm>,
|
||||||
|
settings: &State<SettingsService>
|
||||||
|
) -> ApiResult<()> {
|
||||||
|
settings.change_username(session.uid, &new.username).await
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
use crate::error::ApiResult;
|
||||||
|
use crate::model::space::{Space, SpaceDto};
|
||||||
|
use crate::model::space::Channel;
|
||||||
|
use crate::repo::{SpaceRepo, ChannelRepo};
|
||||||
|
use rocket::serde::json::Json;
|
||||||
|
use rocket::State;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use crate::api::auth::Session;
|
||||||
|
use crate::svc::chat_svc::ChatService;
|
||||||
|
|
||||||
|
#[get("/spaces")]
|
||||||
|
pub async fn list_spaces(
|
||||||
|
space_repo: &State<Arc<dyn SpaceRepo>>
|
||||||
|
) -> ApiResult<Json<Vec<Space>>> {
|
||||||
|
let spaces = space_repo.get_all().await?;
|
||||||
|
Ok(Json(spaces))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/spaces/<space_id>/channels")]
|
||||||
|
pub async fn list_channels(
|
||||||
|
space_id: i64,
|
||||||
|
channel_repo: &State<Arc<dyn ChannelRepo>>
|
||||||
|
) -> ApiResult<Json<Vec<Channel>>> {
|
||||||
|
let channels = channel_repo.get_by_space_id(space_id).await?;
|
||||||
|
Ok(Json(channels))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/accessible_channels")]
|
||||||
|
pub async fn get_accessible_channels(
|
||||||
|
session: Session,
|
||||||
|
svc: &State<ChatService>
|
||||||
|
) -> ApiResult<Json<Vec<SpaceDto>>> {
|
||||||
|
let space = svc.get_accessible_channels(session.uid).await?;
|
||||||
|
println!("{:?}", space);
|
||||||
|
Ok(Json(space))
|
||||||
|
}
|
||||||
@@ -0,0 +1,120 @@
|
|||||||
|
use crate::api::auth::{Claims, Session, TokenScope};
|
||||||
|
use crate::error::{ApiResult, AppError};
|
||||||
|
use crate::model::auth::AuthResponse;
|
||||||
|
use crate::svc::auth_svc::AuthService;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
|
use rocket::serde::{Deserialize, Serialize};
|
||||||
|
use rocket::State;
|
||||||
|
use totp_rs::{Algorithm, TOTP};
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct TOTPSixDigitCode {
|
||||||
|
code: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, sqlx::Type, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
#[sqlx(type_name = "totp_status", rename_all = "lowercase")]
|
||||||
|
pub enum TotpStatus {
|
||||||
|
Enabled,
|
||||||
|
Pending,
|
||||||
|
Disabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct QrResponse {
|
||||||
|
qr_code: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct TotpVerifyRequest {
|
||||||
|
pub code: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct PasswordConfirmation {
|
||||||
|
password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct PasswordAnd2fa {
|
||||||
|
pub password: String,
|
||||||
|
pub totp_code: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn totp_gen(user_id: i64, secret: &[u8]) -> ApiResult<TOTP> {
|
||||||
|
TOTP::new(
|
||||||
|
Algorithm::SHA1,
|
||||||
|
6,
|
||||||
|
1,
|
||||||
|
30,
|
||||||
|
secret.to_owned(),
|
||||||
|
Some("chat.zxq5.dev".to_string()),
|
||||||
|
format!("{}", user_id),
|
||||||
|
)
|
||||||
|
.map_err(|_| AppError::internal("failed to generate totp"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/totp", data = "<form>")]
|
||||||
|
pub async fn confirm_totp(
|
||||||
|
user: Session,
|
||||||
|
form: Json<TOTPSixDigitCode>,
|
||||||
|
svc: &State<AuthService>,
|
||||||
|
) -> ApiResult<()> {
|
||||||
|
svc.confirm_totp(user.uid, &form.code).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/totp.jpg", data = "<form>")]
|
||||||
|
pub async fn get_totp(
|
||||||
|
user: Session,
|
||||||
|
form: Json<PasswordConfirmation>,
|
||||||
|
svc: &State<AuthService>,
|
||||||
|
) -> ApiResult<Json<QrResponse>> {
|
||||||
|
let secret = svc.get_or_create_totp_secret(user.uid, &form.password).await?;
|
||||||
|
|
||||||
|
let qr_b64 = totp_gen(user.uid, secret.as_bytes())
|
||||||
|
.map_err(|_| AppError::internal("invalid totp secret"))?
|
||||||
|
.get_qr_base64()
|
||||||
|
.map_err(|_| AppError::internal("failed to generate qr code"))?;
|
||||||
|
|
||||||
|
Ok(Json(QrResponse {
|
||||||
|
qr_code: format!("data:image/png;base64,{}", qr_b64),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/totp/status")]
|
||||||
|
pub async fn get_totp_status(
|
||||||
|
user: Session,
|
||||||
|
svc: &State<AuthService>,
|
||||||
|
) -> ApiResult<Json<TotpStatus>> {
|
||||||
|
Ok(Json(
|
||||||
|
svc.get_totp_status(user.uid).await?
|
||||||
|
.then_some(TotpStatus::Enabled)
|
||||||
|
.unwrap_or(TotpStatus::Disabled),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[delete("/totp", data = "<form>")]
|
||||||
|
pub async fn disable_totp(
|
||||||
|
user: Session,
|
||||||
|
form: Json<PasswordAnd2fa>,
|
||||||
|
svc: &State<AuthService>,
|
||||||
|
) -> ApiResult<Json<AuthResponse>> {
|
||||||
|
let response = svc.disable_totp(user.uid, &form.password, &form.totp_code).await?;
|
||||||
|
Ok(Json(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/totp/verify", data = "<body>")]
|
||||||
|
pub async fn verify_totp(
|
||||||
|
claims: Claims,
|
||||||
|
body: Json<TotpVerifyRequest>,
|
||||||
|
svc: &State<AuthService>,
|
||||||
|
) -> ApiResult<Json<AuthResponse>> {
|
||||||
|
// reject if they somehow got here with a full token
|
||||||
|
if claims.scope != TokenScope::TotpPending {
|
||||||
|
return Err(AppError::Forbidden);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = svc.login_totp(claims.sub as i64, &body.code).await?;
|
||||||
|
Ok(Json(response))
|
||||||
|
}
|
||||||
@@ -1,211 +0,0 @@
|
|||||||
use argon2::{
|
|
||||||
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
|
|
||||||
password_hash::{SaltString, rand_core::OsRng},
|
|
||||||
};
|
|
||||||
use jsonwebtoken::{EncodingKey, Header, encode};
|
|
||||||
use rocket::{
|
|
||||||
http::{CookieJar, Status},
|
|
||||||
response::{Redirect, status::BadRequest},
|
|
||||||
serde::json::Json,
|
|
||||||
time::OffsetDateTime,
|
|
||||||
};
|
|
||||||
use rocket_db_pools::Connection;
|
|
||||||
use rocket_dyn_templates::{Template, context};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
auth::session::{Claims, Session, TokenScope},
|
|
||||||
db::Postgres,
|
|
||||||
user::User,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct SignupCredentials {
|
|
||||||
pub email: String,
|
|
||||||
pub username: String,
|
|
||||||
pub password: String,
|
|
||||||
pub access_token: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct LoginCredentials {
|
|
||||||
pub username: String,
|
|
||||||
pub password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct AuthResponse {
|
|
||||||
pub token: String,
|
|
||||||
pub totp_required: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/signup")]
|
|
||||||
pub async fn signup_page() -> Template {
|
|
||||||
Template::render("signup", context!())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/signup", data = "<cred>")]
|
|
||||||
pub async fn signup(
|
|
||||||
cred: Json<SignupCredentials>,
|
|
||||||
jar: &CookieJar<'_>,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
) -> Result<Json<AuthResponse>, Status> {
|
|
||||||
let token_id = AccessToken::validate(&cred.access_token, &mut db)
|
|
||||||
.await
|
|
||||||
.map_err(|_| Status::Unauthorized)?;
|
|
||||||
|
|
||||||
let salt = SaltString::generate(&mut OsRng);
|
|
||||||
let hashed = Argon2::default()
|
|
||||||
.hash_password(cred.password.as_bytes(), &salt)
|
|
||||||
.map_err(|_| Status::InternalServerError)?
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
let result = sqlx::query!(
|
|
||||||
"INSERT INTO users (email, username, pass_hash) VALUES ($1, $2, $3) RETURNING id",
|
|
||||||
cred.email,
|
|
||||||
cred.username,
|
|
||||||
hashed,
|
|
||||||
)
|
|
||||||
.fetch_one(&mut **db)
|
|
||||||
.await
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
let jwt = Claims::new(result.id as usize, TokenScope::Full).encode();
|
|
||||||
|
|
||||||
token_id
|
|
||||||
.use_token(&mut db)
|
|
||||||
.await
|
|
||||||
.expect("unable to use access code");
|
|
||||||
|
|
||||||
Ok(Json(AuthResponse {
|
|
||||||
token: jwt,
|
|
||||||
totp_required: false,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/login")]
|
|
||||||
pub async fn login_page() -> Template {
|
|
||||||
Template::render("login", context!())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/login", data = "<cred>")]
|
|
||||||
pub async fn login(
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
cred: Json<LoginCredentials>,
|
|
||||||
) -> Result<Json<AuthResponse>, Status> {
|
|
||||||
println!("e");
|
|
||||||
let row = sqlx::query!(
|
|
||||||
"SELECT id, pass_hash, twofa_enabled FROM users WHERE username = $1",
|
|
||||||
cred.username,
|
|
||||||
)
|
|
||||||
.fetch_one(&mut **db)
|
|
||||||
.await
|
|
||||||
.map_err(|_| Status::Unauthorized)?;
|
|
||||||
|
|
||||||
println!("ok");
|
|
||||||
|
|
||||||
// verify password as before
|
|
||||||
let parsed_hash = PasswordHash::new(&row.pass_hash).map_err(|_| Status::InternalServerError)?;
|
|
||||||
Argon2::default()
|
|
||||||
.verify_password(cred.password.as_bytes(), &parsed_hash)
|
|
||||||
.map_err(|_| Status::Unauthorized)?;
|
|
||||||
|
|
||||||
println!("ok2");
|
|
||||||
|
|
||||||
let user_id = row.id as usize;
|
|
||||||
|
|
||||||
// issue either a partial or full token depending on 2FA status
|
|
||||||
let (session, totp_required) = if row.twofa_enabled {
|
|
||||||
(Claims::new(user_id, TokenScope::TotpPending), true)
|
|
||||||
} else {
|
|
||||||
(Claims::new(user_id, TokenScope::Full), false)
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Json(AuthResponse {
|
|
||||||
token: session.encode(),
|
|
||||||
totp_required,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct AccessTokenForm {
|
|
||||||
pub name: String,
|
|
||||||
pub max_uses: usize,
|
|
||||||
pub expiry_date: usize,
|
|
||||||
pub start_date: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/invite")]
|
|
||||||
pub async fn invite_page(_s: Session) -> Template {
|
|
||||||
Template::render("invite", context! {})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/invite", data = "<form>")]
|
|
||||||
pub async fn generate_invite(
|
|
||||||
session: Session,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
form: Json<AccessTokenForm>,
|
|
||||||
) -> Result<String, Status> {
|
|
||||||
if form.start_date > form.expiry_date {
|
|
||||||
return Err(Status::BadRequest);
|
|
||||||
}
|
|
||||||
|
|
||||||
let code = Uuid::new_v4().to_string();
|
|
||||||
sqlx::query!(
|
|
||||||
"INSERT INTO access_codes (name, code, creator_id, max_uses, created_at, expires_at)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
|
|
||||||
form.name,
|
|
||||||
code,
|
|
||||||
session.user_id as i32,
|
|
||||||
form.max_uses as i32,
|
|
||||||
OffsetDateTime::from_unix_timestamp_nanos(form.start_date as i128 * 1_000_000).unwrap(),
|
|
||||||
OffsetDateTime::from_unix_timestamp_nanos(form.expiry_date as i128 * 1_000_000).unwrap()
|
|
||||||
)
|
|
||||||
.fetch_one(&mut **db)
|
|
||||||
.await
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
Ok(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct AccessToken {
|
|
||||||
id: i32,
|
|
||||||
_code: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AccessToken {
|
|
||||||
pub async fn validate(
|
|
||||||
token: &str,
|
|
||||||
db: &mut Connection<Postgres>,
|
|
||||||
) -> Result<AccessToken, String> {
|
|
||||||
match sqlx::query!(
|
|
||||||
"SELECT id FROM access_codes
|
|
||||||
WHERE code = $1
|
|
||||||
AND created_at < NOW()
|
|
||||||
AND expires_at > NOW()
|
|
||||||
AND uses < max_uses",
|
|
||||||
token
|
|
||||||
)
|
|
||||||
.fetch_one(&mut ***db)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(row) => Ok(AccessToken {
|
|
||||||
id: row.id,
|
|
||||||
_code: token.to_string(),
|
|
||||||
}),
|
|
||||||
Err(_) => Err(String::from("Invalid or Expired token!")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn use_token(&self, db: &mut Connection<Postgres>) -> Result<(), String> {
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE access_codes SET uses = uses + 1 WHERE id = $1",
|
|
||||||
self.id
|
|
||||||
)
|
|
||||||
.execute(&mut ***db)
|
|
||||||
.await
|
|
||||||
.map_err(|_| String::from("Invalid or Expired token!"))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
pub mod account;
|
|
||||||
pub mod profile;
|
|
||||||
pub mod session;
|
|
||||||
pub mod two_factor;
|
|
||||||
|
|
||||||
pub use session::Session;
|
|
||||||
|
|
||||||
pub use account::{generate_invite, invite_page, login, login_page, signup, signup_page};
|
|
||||||
pub use profile::{change_display_name, change_password, change_username, delete_account};
|
|
||||||
pub use two_factor::{
|
|
||||||
confirm_totp, disable_totp, get_totp, get_totp_status, mfa_page, verify_totp,
|
|
||||||
};
|
|
||||||
@@ -1,143 +0,0 @@
|
|||||||
use argon2::{
|
|
||||||
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
|
|
||||||
password_hash::{SaltString, rand_core::OsRng},
|
|
||||||
};
|
|
||||||
use rocket::{http::Status, serde::json::Json};
|
|
||||||
use rocket_db_pools::Connection;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::{auth::Session, db::Postgres, user::User};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct PasswordForm {
|
|
||||||
old_password: String,
|
|
||||||
new_password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/settings/password", data = "<form>")]
|
|
||||||
pub async fn change_password(
|
|
||||||
session: Session,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
form: Json<PasswordForm>,
|
|
||||||
) -> Result<(), Status> {
|
|
||||||
let mut user = User::get_by_id(session.user_id, &mut db)
|
|
||||||
.await
|
|
||||||
.ok_or(Status::NotFound)
|
|
||||||
.inspect_err(|_| {
|
|
||||||
tracing::error!(
|
|
||||||
"Valid session does not have a valid user. ID: {}",
|
|
||||||
session.user_id
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
user.verify_password(&form.old_password)?;
|
|
||||||
|
|
||||||
// old password is correct, so new one can be set.
|
|
||||||
let salt = SaltString::generate(&mut OsRng);
|
|
||||||
let hashed = Argon2::default()
|
|
||||||
.hash_password(form.new_password.as_bytes(), &salt)
|
|
||||||
.inspect_err(|e| tracing::error!("failed to hash password! {e}"))
|
|
||||||
.map_err(|_| Status::InternalServerError)?
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
user.set_pass_hash(hashed, &mut db)
|
|
||||||
.await
|
|
||||||
.inspect_err(|e| tracing::error!("{e}"))
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
|
||||||
pub struct DisplayNameForm {
|
|
||||||
pub display_name: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
|
||||||
pub struct PasswordAnd2fa {
|
|
||||||
pub password: String,
|
|
||||||
pub totp_code: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[delete("/settings", data = "<data>")]
|
|
||||||
pub async fn delete_account(
|
|
||||||
session: Session,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
data: Json<PasswordAnd2fa>,
|
|
||||||
) -> Result<(), Status> {
|
|
||||||
let mut user = User::get_by_id(session.user_id, &mut db)
|
|
||||||
.await
|
|
||||||
.ok_or(Status::NotFound)
|
|
||||||
.inspect_err(|_| {
|
|
||||||
tracing::error!(
|
|
||||||
"Valid session does not have a valid user. ID: {}",
|
|
||||||
session.user_id
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
user.verify_password(&data.password)?;
|
|
||||||
|
|
||||||
if user.twofa_enabled {
|
|
||||||
user.verify_2fa(data.totp_code.as_deref().unwrap_or(""))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
user.delete(&mut db)
|
|
||||||
.await
|
|
||||||
.inspect_err(|e| tracing::error!("{e}"))
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[patch("/settings/display_name", data = "<new>")]
|
|
||||||
pub async fn change_display_name(
|
|
||||||
session: Session,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
new: Json<DisplayNameForm>,
|
|
||||||
) -> Result<(), Status> {
|
|
||||||
let mut user = User::get_by_id(session.user_id, &mut db)
|
|
||||||
.await
|
|
||||||
.ok_or(Status::NotFound)
|
|
||||||
.inspect_err(|_| {
|
|
||||||
tracing::error!(
|
|
||||||
"Valid session does not have a valid user. ID: {}",
|
|
||||||
session.user_id
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
user.set_display_name(new.display_name.clone(), &mut db)
|
|
||||||
.await
|
|
||||||
.inspect_err(|e| tracing::error!("{e}"))
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct UsernameForm {
|
|
||||||
username: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[patch("/settings/username", data = "<new>")]
|
|
||||||
pub async fn change_username(
|
|
||||||
session: Session,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
new: Json<UsernameForm>,
|
|
||||||
) -> Result<(), Status> {
|
|
||||||
let mut user = User::get_by_id(session.user_id, &mut db)
|
|
||||||
.await
|
|
||||||
.ok_or(Status::NotFound)
|
|
||||||
.inspect_err(|_| {
|
|
||||||
tracing::error!(
|
|
||||||
"Valid session does not have a valid user. ID: {}",
|
|
||||||
session.user_id
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
user.set_username(new.username.clone(), &mut db)
|
|
||||||
.await
|
|
||||||
.inspect_err(|e| tracing::error!("{e}"))
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@@ -1,301 +0,0 @@
|
|||||||
use futures_util::TryFutureExt;
|
|
||||||
use rocket::{
|
|
||||||
Request,
|
|
||||||
http::Status,
|
|
||||||
outcome::{Outcome, try_outcome},
|
|
||||||
request::{self, FromRequest},
|
|
||||||
response::status::{self},
|
|
||||||
serde::json::Json,
|
|
||||||
};
|
|
||||||
use rocket_db_pools::Connection;
|
|
||||||
use rocket_dyn_templates::{Template, context};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use totp_rs::{Algorithm, Secret, TOTP};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
auth::{
|
|
||||||
account::AuthResponse,
|
|
||||||
profile::PasswordAnd2fa,
|
|
||||||
session::{Claims, Session, TokenScope},
|
|
||||||
},
|
|
||||||
db::Postgres,
|
|
||||||
user::User,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Utility methods
|
|
||||||
|
|
||||||
pub fn totp_gen(user_id: usize, secret: &[u8]) -> Result<TOTP, String> {
|
|
||||||
TOTP::new(
|
|
||||||
Algorithm::SHA1,
|
|
||||||
6,
|
|
||||||
1,
|
|
||||||
30,
|
|
||||||
secret.to_owned(),
|
|
||||||
Some("chat.zxq5.dev".to_string()),
|
|
||||||
format!("{}", user_id),
|
|
||||||
)
|
|
||||||
.map_err(|_| String::from("Invalid Secret"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// pages
|
|
||||||
|
|
||||||
#[get("/totp")]
|
|
||||||
pub async fn mfa_page(_session: Session) -> Template {
|
|
||||||
Template::render("2fa", context!())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/totp", data = "<form>")]
|
|
||||||
pub async fn confirm_totp(
|
|
||||||
mfa: TOTPSecret,
|
|
||||||
form: Json<TOTPSixDigitCode>,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
) -> Result<(), status::Custom<&'static str>> {
|
|
||||||
if form.code.len() != 6 || form.code.parse::<u32>().is_err() {
|
|
||||||
return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code"));
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("valid");
|
|
||||||
|
|
||||||
let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes())
|
|
||||||
.map_err(|_| status::Custom(Status::InternalServerError, "TOTP Error"))?;
|
|
||||||
if !totp.check_current(&form.code).unwrap_or(false) {
|
|
||||||
return Err(status::Custom(Status::BadRequest, "Incorrect code"));
|
|
||||||
}
|
|
||||||
println!("correct");
|
|
||||||
|
|
||||||
if sqlx::query!(
|
|
||||||
"UPDATE users SET twofa_enabled = true WHERE id = $1",
|
|
||||||
mfa.user_id as i32
|
|
||||||
)
|
|
||||||
.execute(&mut **db)
|
|
||||||
.await
|
|
||||||
.is_err()
|
|
||||||
{
|
|
||||||
return Err(status::Custom(
|
|
||||||
Status::InternalServerError,
|
|
||||||
"unable to enable 2fa",
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("enabled");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct PasswordConfirmation {
|
|
||||||
password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/totp.jpg", data = "<form>")]
|
|
||||||
pub async fn get_totp(
|
|
||||||
mfa: TOTPSecret,
|
|
||||||
form: Json<PasswordConfirmation>,
|
|
||||||
) -> Option<Json<QrResponse>> {
|
|
||||||
let qr_b64 = totp_gen(mfa.user_id, mfa.secret.as_bytes())
|
|
||||||
.expect("Invalid TOTP")
|
|
||||||
.get_qr_base64()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Some(Json(QrResponse {
|
|
||||||
qr_code: format!("data:image/png;base64,{}", qr_b64),
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct TOTPSixDigitCode {
|
|
||||||
code: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum TotpStatus {
|
|
||||||
Enabled,
|
|
||||||
Disabled,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct TOTPSecret {
|
|
||||||
user_id: usize,
|
|
||||||
secret: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct QrResponse {
|
|
||||||
qr_code: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[rocket::async_trait]
|
|
||||||
impl<'r> FromRequest<'r> for TOTPSecret {
|
|
||||||
type Error = ();
|
|
||||||
|
|
||||||
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
|
||||||
let auth_header = request.headers().get_one("Authorization");
|
|
||||||
println!(
|
|
||||||
"TOTPSecret guard - Auth header present: {}",
|
|
||||||
auth_header.is_some()
|
|
||||||
);
|
|
||||||
|
|
||||||
let user = try_outcome!(request.guard::<Claims>().await);
|
|
||||||
println!(
|
|
||||||
"TOTPSecret guard - Claims ok, user: {}, scope: {:?}",
|
|
||||||
user.sub, user.scope
|
|
||||||
);
|
|
||||||
|
|
||||||
// only allow full tokens for TOTP setup
|
|
||||||
if user.scope != TokenScope::Full {
|
|
||||||
println!("TOTPSecret guard - rejected, scope is {:?}", user.scope);
|
|
||||||
return Outcome::Error((Status::Forbidden, ()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let user = try_outcome!(request.guard::<Session>().await);
|
|
||||||
let mut pool = match request.guard::<Connection<Postgres>>().await {
|
|
||||||
Outcome::Success(pool) => pool,
|
|
||||||
_ => return Outcome::Error((Status::Unauthorized, ())),
|
|
||||||
};
|
|
||||||
|
|
||||||
let row = sqlx::query!(
|
|
||||||
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
|
|
||||||
user.user_id as i32
|
|
||||||
)
|
|
||||||
.fetch_one(&mut **pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let (enabled, mut secret) = match row {
|
|
||||||
Ok(r) => (r.twofa_enabled, r.totp_secret),
|
|
||||||
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
|
|
||||||
};
|
|
||||||
|
|
||||||
if secret.is_none() {
|
|
||||||
let new_secret = Secret::generate_secret().to_encoded().to_string();
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE users SET totp_secret = $1 WHERE id = $2",
|
|
||||||
new_secret,
|
|
||||||
user.user_id as i32
|
|
||||||
)
|
|
||||||
.execute(&mut **pool)
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
secret = Some(new_secret);
|
|
||||||
}
|
|
||||||
|
|
||||||
Outcome::Success(TOTPSecret {
|
|
||||||
user_id: user.user_id,
|
|
||||||
secret: secret.unwrap(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TOTPSecret {
|
|
||||||
pub async fn enable(&self, db: &mut Connection<Postgres>) -> Result<(), ()> {
|
|
||||||
match sqlx::query!(
|
|
||||||
"UPDATE users SET twofa_enabled = true WHERE id = $1",
|
|
||||||
self.user_id as i32,
|
|
||||||
)
|
|
||||||
.execute(&mut ***db)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(_) => Err(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct TotpVerifyRequest {
|
|
||||||
pub code: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/totp/status")]
|
|
||||||
pub async fn get_totp_status(
|
|
||||||
user: Session,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
) -> Result<Json<TotpStatus>, Status> {
|
|
||||||
Ok(Json(
|
|
||||||
if sqlx::query!(
|
|
||||||
"SELECT twofa_enabled FROM users WHERE id = $1",
|
|
||||||
user.user_id as i32,
|
|
||||||
)
|
|
||||||
.fetch_one(&mut **db)
|
|
||||||
.await
|
|
||||||
.map_err(|_| Status::NotFound)?
|
|
||||||
.twofa_enabled
|
|
||||||
{
|
|
||||||
TotpStatus::Enabled
|
|
||||||
} else {
|
|
||||||
TotpStatus::Disabled
|
|
||||||
},
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[delete("/totp", data = "<form>")]
|
|
||||||
pub async fn disable_totp(
|
|
||||||
user: Session,
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
form: Json<PasswordAnd2fa>,
|
|
||||||
) -> Result<Json<AuthResponse>, Status> {
|
|
||||||
let totp_code = form.totp_code.clone().ok_or(Status::BadRequest)?;
|
|
||||||
let mut user = User::get_by_id(user.user_id, &mut db)
|
|
||||||
.await
|
|
||||||
.ok_or(Status::NotFound)?;
|
|
||||||
|
|
||||||
user.verify_password(&form.password)?;
|
|
||||||
user.verify_2fa(&totp_code)?;
|
|
||||||
user.set_twofa_enabled(false, &mut db)
|
|
||||||
.await
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
Ok(Json(AuthResponse {
|
|
||||||
token: Claims::new(user.id as usize, TokenScope::Full).encode(),
|
|
||||||
totp_required: false,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/totp/verify", data = "<body>")]
|
|
||||||
pub async fn verify_totp(
|
|
||||||
claims: Claims, // request guard checks token validity
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
body: Json<TotpVerifyRequest>,
|
|
||||||
) -> Result<Json<AuthResponse>, Status> {
|
|
||||||
println!("reached 1");
|
|
||||||
|
|
||||||
// reject if they somehow got here with a full token
|
|
||||||
if claims.scope != TokenScope::TotpPending {
|
|
||||||
return Err(Status::Forbidden);
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("reached 2");
|
|
||||||
|
|
||||||
let row = sqlx::query!(
|
|
||||||
"SELECT totp_secret FROM users WHERE id = $1 AND twofa_enabled = TRUE",
|
|
||||||
claims.sub
|
|
||||||
)
|
|
||||||
.fetch_one(&mut **db)
|
|
||||||
.await
|
|
||||||
.map_err(|_| Status::Unauthorized)?;
|
|
||||||
|
|
||||||
println!("reached 3");
|
|
||||||
|
|
||||||
let totp = totp_gen(
|
|
||||||
claims.sub as usize,
|
|
||||||
row.totp_secret
|
|
||||||
.expect("user with 2fa enabled has no totp secret")
|
|
||||||
.as_bytes(),
|
|
||||||
)
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
if !totp
|
|
||||||
.check_current(&body.code)
|
|
||||||
.map_err(|_| Status::InternalServerError)?
|
|
||||||
{
|
|
||||||
return Err(Status::Unauthorized);
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("reached 5");
|
|
||||||
|
|
||||||
let claims = Claims::new(claims.sub as usize, TokenScope::Full);
|
|
||||||
|
|
||||||
Ok(Json(AuthResponse {
|
|
||||||
token: claims.encode(),
|
|
||||||
totp_required: false,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use sqlx::postgres::PgPoolOptions;
|
||||||
|
use std::time::Duration;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use crate::repo::user_repo::UserRepository;
|
||||||
|
use crate::repo::space_repo::SpaceRepository;
|
||||||
|
use crate::repo::channel_repo::ChannelRepository;
|
||||||
|
use crate::repo::{UserRepo, SpaceRepo, ChannelRepo};
|
||||||
|
use argon2::{
|
||||||
|
password_hash::{PasswordHasher, SaltString},
|
||||||
|
Argon2,
|
||||||
|
};
|
||||||
|
use rand::rngs::OsRng;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
pub struct Cli {
|
||||||
|
#[command(subcommand)]
|
||||||
|
pub command: Option<Commands>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand)]
|
||||||
|
pub enum Commands {
|
||||||
|
/// First-time setup for the server
|
||||||
|
Setup {
|
||||||
|
/// Admin username
|
||||||
|
#[arg(short, long)]
|
||||||
|
username: String,
|
||||||
|
|
||||||
|
/// Admin password
|
||||||
|
#[arg(short, long)]
|
||||||
|
password: String,
|
||||||
|
|
||||||
|
/// Default space name
|
||||||
|
#[arg(short, long, default_value = "Default Space")]
|
||||||
|
space: String,
|
||||||
|
|
||||||
|
/// Default channel name
|
||||||
|
#[arg(short, long, default_value = "general")]
|
||||||
|
channel: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_cli() -> bool {
|
||||||
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
match cli.command {
|
||||||
|
Some(Commands::Setup { username, password, space, channel }) => {
|
||||||
|
if let Err(e) = run_setup(username, password, space, channel).await {
|
||||||
|
eprintln!("Setup failed: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
println!("Setup completed successfully!");
|
||||||
|
true
|
||||||
|
}
|
||||||
|
None => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_setup(username: String, password: String, space_name: String, channel_name: String) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
dotenv::dotenv().ok();
|
||||||
|
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
|
||||||
|
|
||||||
|
let pool = PgPoolOptions::new()
|
||||||
|
.max_connections(1)
|
||||||
|
.acquire_timeout(Duration::from_secs(5))
|
||||||
|
.connect(&db_url)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let user_repo = UserRepository::new(pool.clone());
|
||||||
|
let space_repo = SpaceRepository::new(pool.clone());
|
||||||
|
let channel_repo = ChannelRepository::new(pool.clone());
|
||||||
|
|
||||||
|
// 1. Create admin user
|
||||||
|
println!("Creating admin user: {}...", username);
|
||||||
|
|
||||||
|
let argon2 = Argon2::default();
|
||||||
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
|
let passhash = argon2
|
||||||
|
.hash_password(password.as_bytes(), &salt)
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let user_id = user_repo.new_user("admin@localhost", &username, &passhash).await?;
|
||||||
|
user_repo.set_role(user_id, "admin").await?;
|
||||||
|
|
||||||
|
// 2. Create default space
|
||||||
|
println!("Creating default space: {}...", space_name);
|
||||||
|
let space_id = space_repo.create(&space_name, Some("Default space created during setup"), user_id).await?;
|
||||||
|
|
||||||
|
// 3. Create default channel
|
||||||
|
println!("Creating default channel: {}...", channel_name);
|
||||||
|
channel_repo.create(&channel_name, Some("Default channel"), space_id).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
use rocket_db_pools::{Database, deadpool_redis};
|
|
||||||
|
|
||||||
#[derive(Database)]
|
|
||||||
#[database("postgres_db")]
|
|
||||||
pub struct Postgres(sqlx::PgPool);
|
|
||||||
|
|
||||||
#[derive(Database)]
|
|
||||||
#[database("redis_cache")]
|
|
||||||
pub struct Redis(deadpool_redis::Pool);
|
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
// error.rs
|
||||||
|
use rocket::{http::Status, response::{self, Responder}, Request, Response};
|
||||||
|
use thiserror::Error;
|
||||||
|
use rocket_dyn_templates::Template;
|
||||||
|
use rocket::serde::Serialize;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum AppError {
|
||||||
|
#[error("Not found")]
|
||||||
|
NotFound,
|
||||||
|
|
||||||
|
#[error("Unauthorized")]
|
||||||
|
Unauthorised(String),
|
||||||
|
|
||||||
|
#[error("Forbidden")]
|
||||||
|
Forbidden,
|
||||||
|
|
||||||
|
#[error("Bad request: {0}")]
|
||||||
|
BadRequest(String),
|
||||||
|
|
||||||
|
#[error("Database error: {0}")]
|
||||||
|
Database(#[from] sqlx::Error),
|
||||||
|
|
||||||
|
#[error("Internal error: {0}")]
|
||||||
|
Internal(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppError {
|
||||||
|
pub fn internal(msg: impl Into<String>) -> Self {
|
||||||
|
Self::Internal(msg.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bad_request(msg: impl Into<String>) -> Self {
|
||||||
|
Self::BadRequest(msg.into())
|
||||||
|
}
|
||||||
|
pub fn unauthorised(msg: impl Into<String>) -> Self {
|
||||||
|
Self::Unauthorised(msg.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'r> Responder<'r, 'static> for AppError {
|
||||||
|
fn respond_to(self, _req: &'r Request<'_>) -> response::Result<'static> {
|
||||||
|
let status = match &self {
|
||||||
|
AppError::NotFound => Status::NotFound,
|
||||||
|
AppError::Unauthorised(_) => Status::Unauthorized,
|
||||||
|
AppError::Forbidden => Status::Forbidden,
|
||||||
|
AppError::BadRequest(_) => Status::BadRequest,
|
||||||
|
AppError::Database(_) => Status::InternalServerError,
|
||||||
|
AppError::Internal(_) => Status::InternalServerError,
|
||||||
|
};
|
||||||
|
|
||||||
|
// log internal errors
|
||||||
|
if status == Status::InternalServerError {
|
||||||
|
tracing::error!("Internal Server Error: {}", self);
|
||||||
|
}
|
||||||
|
|
||||||
|
Response::build()
|
||||||
|
.status(status)
|
||||||
|
.header(rocket::http::ContentType::Plain)
|
||||||
|
.sized_body(
|
||||||
|
self.to_string().len(),
|
||||||
|
std::io::Cursor::new(self.to_string())
|
||||||
|
)
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type ApiResult<T> = Result<T, AppError>;
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ErrorContext {
|
||||||
|
error_code: u16,
|
||||||
|
error_message: &'static str,
|
||||||
|
additional_info: &'static str,
|
||||||
|
redirect: Option<RedirectContext>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct RedirectContext {
|
||||||
|
url: &'static str,
|
||||||
|
message: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[catch(404)]
|
||||||
|
pub async fn handle_404() -> Template {
|
||||||
|
Template::render(
|
||||||
|
"error",
|
||||||
|
ErrorContext {
|
||||||
|
error_code: 404,
|
||||||
|
error_message: "Not Found",
|
||||||
|
additional_info: "There's nothing here.",
|
||||||
|
redirect: Some(RedirectContext {
|
||||||
|
url: "/",
|
||||||
|
message: "Home",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[catch(401)]
|
||||||
|
pub async fn handle_401() -> Template {
|
||||||
|
Template::render(
|
||||||
|
"error",
|
||||||
|
ErrorContext {
|
||||||
|
error_code: 401,
|
||||||
|
error_message: "Unauthorised",
|
||||||
|
additional_info: "You are not authorised to access this resource.",
|
||||||
|
redirect: Some(RedirectContext {
|
||||||
|
url: "/login",
|
||||||
|
message: "Login",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[catch(default)]
|
||||||
|
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
|
||||||
|
Template::render(
|
||||||
|
"error",
|
||||||
|
ErrorContext {
|
||||||
|
error_code: status.code,
|
||||||
|
error_message: "Unknown Error",
|
||||||
|
additional_info: "I don't know what to do with this error.",
|
||||||
|
redirect: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
use rocket::{Request, http::Status};
|
|
||||||
use rocket_dyn_templates::Template;
|
|
||||||
use serde::Serialize;
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ErrorContext {
|
|
||||||
error_code: u16,
|
|
||||||
error_message: &'static str,
|
|
||||||
additional_info: &'static str,
|
|
||||||
redirect: Option<RedirectContext>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct RedirectContext {
|
|
||||||
url: &'static str,
|
|
||||||
message: &'static str,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[catch(404)]
|
|
||||||
pub async fn handle_404() -> Template {
|
|
||||||
Template::render(
|
|
||||||
"error",
|
|
||||||
ErrorContext {
|
|
||||||
error_code: 404,
|
|
||||||
error_message: "Not Found",
|
|
||||||
additional_info: "There's nothing here.",
|
|
||||||
redirect: Some(RedirectContext {
|
|
||||||
url: "/",
|
|
||||||
message: "Home",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[catch(401)]
|
|
||||||
pub async fn handle_401() -> Template {
|
|
||||||
Template::render(
|
|
||||||
"error",
|
|
||||||
ErrorContext {
|
|
||||||
error_code: 401,
|
|
||||||
error_message: "Unauthorised",
|
|
||||||
additional_info: "You are not authorised to access this resource.",
|
|
||||||
redirect: Some(RedirectContext {
|
|
||||||
url: "/login",
|
|
||||||
message: "Login",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[catch(default)]
|
|
||||||
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
|
|
||||||
Template::render(
|
|
||||||
"error",
|
|
||||||
ErrorContext {
|
|
||||||
error_code: status.code,
|
|
||||||
error_message: "Unknown Error",
|
|
||||||
additional_info: "I don't know what to do with this error.",
|
|
||||||
redirect: None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
#![deny(clippy::unwrap_used)]
|
||||||
|
#![warn(clippy::all, clippy::nursery, clippy::cargo, clippy::pedantic)]
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
extern crate rocket;
|
||||||
|
|
||||||
|
pub mod messenger;
|
||||||
|
pub mod api;
|
||||||
|
pub mod repo;
|
||||||
|
pub mod error;
|
||||||
|
pub mod svc;
|
||||||
|
pub mod model;
|
||||||
|
pub mod cli;
|
||||||
|
|
||||||
|
use crate::repo::{access_token_repo::AccessTokenRepo, Repo};
|
||||||
|
use crate::repo::message_repo::MessageRepository;
|
||||||
|
use crate::repo::user_repo::UserRepository;
|
||||||
|
use crate::repo::space_repo::SpaceRepository;
|
||||||
|
use crate::repo::channel_repo::ChannelRepository;
|
||||||
|
use crate::svc::auth_svc::AuthService;
|
||||||
|
use crate::svc::chat_svc::ChatService;
|
||||||
|
use crate::svc::settings_svc::SettingsService;
|
||||||
|
use crate::svc::user_svc::UserService;
|
||||||
|
use rocket::fs::{FileServer, NamedFile};
|
||||||
|
use rocket::http::Method;
|
||||||
|
use rocket_cors::{AllowedOrigins, CorsOptions};
|
||||||
|
use rocket_dyn_templates::Template;
|
||||||
|
use sqlx::postgres::PgPoolOptions;
|
||||||
|
use std::env;
|
||||||
|
use std::sync::{Arc, LazyLock};
|
||||||
|
use std::time::Duration;
|
||||||
|
use api::cdn;
|
||||||
|
use crate::svc::access_token_svc::AccessTokenService;
|
||||||
|
use crate::svc::llm_service::LlmService;
|
||||||
|
|
||||||
|
pub fn rocket() -> rocket::Rocket<rocket::Build> {
|
||||||
|
if std::env::var("RELEASE_MODE").unwrap_or_default() != "1" {
|
||||||
|
dotenv::dotenv().ok();
|
||||||
|
}
|
||||||
|
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
|
||||||
|
|
||||||
|
|
||||||
|
let pool = PgPoolOptions::new()
|
||||||
|
.max_connections(25)
|
||||||
|
.min_connections(5)
|
||||||
|
.acquire_timeout(Duration::from_secs(5))
|
||||||
|
.connect_lazy(&db_url)
|
||||||
|
.expect("Failed to create database pool");
|
||||||
|
|
||||||
|
let user_repo = Arc::new(UserRepository::new(pool.clone()));
|
||||||
|
let message_repo = MessageRepository::new(pool.clone());
|
||||||
|
let token_repo = Arc::new(AccessTokenRepo::new(pool.clone()));
|
||||||
|
let space_repo: Arc<dyn repo::SpaceRepo> = Arc::new(SpaceRepository::new(pool.clone()));
|
||||||
|
let channel_repo: Arc<dyn repo::ChannelRepo> = Arc::new(ChannelRepository::new(pool.clone()));
|
||||||
|
let llm_service = LlmService::new();
|
||||||
|
let chat_service = ChatService::new(32, llm_service.clone(), message_repo.clone(), user_repo.clone(), channel_repo.clone(), space_repo.clone());
|
||||||
|
|
||||||
|
rocket_builder(user_repo, token_repo, space_repo, channel_repo, chat_service)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rocket_builder(
|
||||||
|
user_repo: Arc<dyn repo::UserRepo>,
|
||||||
|
token_repo: Arc<dyn repo::AccessTokenRepoTrait>,
|
||||||
|
space_repo: Arc<dyn repo::SpaceRepo>,
|
||||||
|
channel_repo: Arc<dyn repo::ChannelRepo>,
|
||||||
|
chat_service: ChatService
|
||||||
|
) -> rocket::Rocket<rocket::Build> {
|
||||||
|
|
||||||
|
|
||||||
|
let cors = CorsOptions::default()
|
||||||
|
.allowed_origins(AllowedOrigins::all())
|
||||||
|
.allowed_methods(
|
||||||
|
vec![Method::Get, Method::Post, Method::Patch]
|
||||||
|
.into_iter()
|
||||||
|
.map(From::from)
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
.allow_credentials(true);
|
||||||
|
|
||||||
|
let access_token_svc = AccessTokenService::new(token_repo.clone());
|
||||||
|
let auth_service = AuthService::new(user_repo.clone(), access_token_svc.clone());
|
||||||
|
let settings_service = SettingsService::new(auth_service.clone(), user_repo.clone());
|
||||||
|
let user_service = UserService::new(user_repo.clone());
|
||||||
|
|
||||||
|
rocket::build()
|
||||||
|
.manage(chat_service)
|
||||||
|
.manage(auth_service)
|
||||||
|
.manage(settings_service)
|
||||||
|
.manage(user_service)
|
||||||
|
.manage(space_repo)
|
||||||
|
.manage(channel_repo)
|
||||||
|
.attach(cors.to_cors().unwrap())
|
||||||
|
.attach(Template::fairing())
|
||||||
|
.mount("/static", FileServer::from("static"))
|
||||||
|
.mount("/cdn", cdn::routes())
|
||||||
|
.mount(
|
||||||
|
"/",
|
||||||
|
routes![
|
||||||
|
favicon,
|
||||||
|
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.mount(
|
||||||
|
"/api",
|
||||||
|
routes![
|
||||||
|
cdn::upload_profile_pic,
|
||||||
|
api::profile::display_name,
|
||||||
|
|
||||||
|
// basic auth
|
||||||
|
api::auth::login,
|
||||||
|
api::auth::signup,
|
||||||
|
|
||||||
|
// 2fa
|
||||||
|
api::totp::confirm_totp,
|
||||||
|
api::totp::disable_totp,
|
||||||
|
api::totp::get_totp,
|
||||||
|
api::totp::get_totp_status,
|
||||||
|
api::totp::verify_totp,
|
||||||
|
|
||||||
|
// chat
|
||||||
|
api::chat::event_stream,
|
||||||
|
api::chat::post_message,
|
||||||
|
|
||||||
|
// user settings
|
||||||
|
api::settings::change_display_name,
|
||||||
|
api::settings::change_password,
|
||||||
|
api::settings::change_username,
|
||||||
|
api::settings::delete_account,
|
||||||
|
|
||||||
|
// spaces
|
||||||
|
api::space::list_spaces,
|
||||||
|
api::space::list_channels,
|
||||||
|
api::space::get_accessible_channels
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.register(
|
||||||
|
"/",
|
||||||
|
catchers![
|
||||||
|
error::handle_401,
|
||||||
|
error::handle_404,
|
||||||
|
error::handle_default,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/favicon.ico")]
|
||||||
|
pub async fn favicon() -> NamedFile {
|
||||||
|
NamedFile::open("static/favicon.ico").await.unwrap()
|
||||||
|
}
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
// src/llm.rs
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::messenger::ChatMsg;
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct LlmRequest {
|
|
||||||
model: String,
|
|
||||||
messages: Vec<Message>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
struct Message {
|
|
||||||
role: String, // "user" or "assistant"
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct LlmWorker {
|
|
||||||
uri: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LlmWorker {
|
|
||||||
pub fn new(uri: String) -> Self {
|
|
||||||
Self { uri }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn query(&self, message: &ChatMsg) -> Result<ChatMsg, String> {
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
// Build the request body
|
|
||||||
let payload = LlmRequest {
|
|
||||||
model: "gpt-oss-20b".into(), // whatever model you run locally
|
|
||||||
messages: vec![Message {
|
|
||||||
role: "user".into(),
|
|
||||||
content: message.text.clone(),
|
|
||||||
}],
|
|
||||||
};
|
|
||||||
|
|
||||||
// POST to lm‑studio (default 127.0.0.1:1234)
|
|
||||||
let resp = client
|
|
||||||
.post(self.uri.clone())
|
|
||||||
.json(&payload)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|_| String::from("Failed to make request to LLM server"))?;
|
|
||||||
|
|
||||||
// The API returns a JSON with `choices[].message.content`
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct LlmResponse {
|
|
||||||
choices: Vec<Choice>,
|
|
||||||
}
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct Choice {
|
|
||||||
message: Message,
|
|
||||||
}
|
|
||||||
|
|
||||||
let llm_resp: LlmResponse = resp
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|_| String::from("Failed to make request to LLM server"))?;
|
|
||||||
|
|
||||||
Ok(ChatMsg {
|
|
||||||
display_name: Some(String::from("lmstudio")),
|
|
||||||
user_id: 0,
|
|
||||||
text: llm_resp.choices[0].message.content.clone(),
|
|
||||||
timestamp: chrono::Utc::now().timestamp_millis() as usize,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+9
-96
@@ -1,98 +1,11 @@
|
|||||||
// src/main.rs
|
use backend::rocket;
|
||||||
#[macro_use]
|
use backend::cli::handle_cli;
|
||||||
extern crate rocket;
|
|
||||||
|
|
||||||
use rocket::fs::{FileServer, NamedFile};
|
#[rocket::main]
|
||||||
use rocket::http::Method;
|
async fn main() -> Result<(), rocket::Error> {
|
||||||
use rocket::{Build, Rocket};
|
if handle_cli().await {
|
||||||
use rocket_cors::{AllowedOrigins, CorsOptions};
|
return Ok(());
|
||||||
use rocket_db_pools::Database;
|
}
|
||||||
use rocket_dyn_templates::Template;
|
rocket().launch().await?;
|
||||||
use std::env;
|
Ok(())
|
||||||
use std::sync::{Arc, LazyLock};
|
|
||||||
|
|
||||||
use crate::db::{Postgres, Redis};
|
|
||||||
|
|
||||||
pub mod auth;
|
|
||||||
pub mod cdn;
|
|
||||||
pub mod db;
|
|
||||||
pub mod handlers;
|
|
||||||
pub mod llm;
|
|
||||||
pub mod messenger;
|
|
||||||
pub mod user;
|
|
||||||
|
|
||||||
static LMSTUDIO_URL: LazyLock<String> =
|
|
||||||
LazyLock::new(|| env::var("LMSTUDIO_URL").expect("Ensure LMSTUDIO_URL is set!"));
|
|
||||||
|
|
||||||
#[launch]
|
|
||||||
fn rocket() -> Rocket<Build> {
|
|
||||||
// make sure the env is loaded
|
|
||||||
dotenv::dotenv().expect("Failed to load env! aborting launch!");
|
|
||||||
|
|
||||||
let chat = Arc::new(crate::messenger::ChatBroadcaster::new(32));
|
|
||||||
|
|
||||||
let cors = CorsOptions::default()
|
|
||||||
.allowed_origins(AllowedOrigins::all())
|
|
||||||
.allowed_methods(
|
|
||||||
vec![Method::Get, Method::Post, Method::Patch]
|
|
||||||
.into_iter()
|
|
||||||
.map(From::from)
|
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
.allow_credentials(true);
|
|
||||||
|
|
||||||
rocket::build()
|
|
||||||
.manage(chat)
|
|
||||||
.attach(cors.to_cors().unwrap())
|
|
||||||
.attach(Postgres::init())
|
|
||||||
.attach(Redis::init())
|
|
||||||
.attach(Template::fairing())
|
|
||||||
.mount("/static", FileServer::from("static"))
|
|
||||||
.mount("/cdn", cdn::routes())
|
|
||||||
.mount(
|
|
||||||
"/",
|
|
||||||
routes![
|
|
||||||
favicon,
|
|
||||||
messenger::chat_page,
|
|
||||||
auth::signup_page,
|
|
||||||
auth::login_page,
|
|
||||||
auth::mfa_page,
|
|
||||||
auth::invite_page,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
.mount(
|
|
||||||
"/api",
|
|
||||||
routes![
|
|
||||||
cdn::upload_profile_pic,
|
|
||||||
messenger::post_message,
|
|
||||||
messenger::event_stream,
|
|
||||||
user::users,
|
|
||||||
user::display_name,
|
|
||||||
auth::signup,
|
|
||||||
auth::login,
|
|
||||||
auth::get_totp,
|
|
||||||
auth::confirm_totp,
|
|
||||||
auth::generate_invite,
|
|
||||||
auth::verify_totp,
|
|
||||||
auth::disable_totp,
|
|
||||||
auth::get_totp_status,
|
|
||||||
auth::change_password,
|
|
||||||
auth::change_display_name,
|
|
||||||
auth::change_username,
|
|
||||||
auth::delete_account,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
.register(
|
|
||||||
"/",
|
|
||||||
catchers![
|
|
||||||
handlers::handle_404,
|
|
||||||
handlers::handle_401,
|
|
||||||
handlers::handle_default
|
|
||||||
],
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/favicon.ico")]
|
|
||||||
async fn favicon() -> NamedFile {
|
|
||||||
NamedFile::open("static/favicon.ico").await.unwrap()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
use redis::AsyncCommands;
|
use redis::AsyncCommands;
|
||||||
use rocket_db_pools::Connection;
|
use rocket_db_pools::Connection;
|
||||||
|
|
||||||
use crate::{
|
use crate::api::chat::ChatMsg;
|
||||||
db::{Postgres, Redis},
|
use crate::db::{Postgres, Redis};
|
||||||
messenger::ChatMsg,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Helper function to cache message in Redis
|
// Helper function to cache message in Redis
|
||||||
pub async fn insert(
|
pub async fn insert(
|
||||||
|
|||||||
@@ -1,220 +0,0 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use rocket::{
|
|
||||||
Shutdown,
|
|
||||||
response::stream::{Event, EventStream},
|
|
||||||
serde::json::Json,
|
|
||||||
time::OffsetDateTime,
|
|
||||||
};
|
|
||||||
use rocket_db_pools::Connection;
|
|
||||||
use rocket_dyn_templates::{Template, context};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use sqlx::prelude::FromRow;
|
|
||||||
use tokio::{select, sync::broadcast};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
auth::Session,
|
|
||||||
db::{Postgres, Redis},
|
|
||||||
llm::LlmWorker,
|
|
||||||
messenger,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// ---------- shared broadcaster ----------
|
|
||||||
pub struct ChatBroadcaster {
|
|
||||||
buffer_size: usize,
|
|
||||||
senders: std::sync::Mutex<std::collections::HashMap<i32, broadcast::Sender<ChatMsg>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatBroadcaster {
|
|
||||||
pub fn new(buffer_size: usize) -> Self {
|
|
||||||
Self {
|
|
||||||
buffer_size,
|
|
||||||
senders: std::sync::Mutex::new(std::collections::HashMap::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Publish a message to the specified channel.
|
|
||||||
pub async fn publish(&self, channel_id: i32, msg: ChatMsg) {
|
|
||||||
let mut map = self.senders.lock().unwrap();
|
|
||||||
let sender = map
|
|
||||||
.entry(channel_id)
|
|
||||||
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
|
|
||||||
let _ = sender.send(msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Subscribe to the specified channel.
|
|
||||||
pub fn subscribe(&self, channel_id: i32) -> broadcast::Receiver<ChatMsg> {
|
|
||||||
let mut map = self.senders.lock().unwrap();
|
|
||||||
let sender = map
|
|
||||||
.entry(channel_id)
|
|
||||||
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
|
|
||||||
sender.subscribe()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// ---------- Rocket routes ----------
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
|
|
||||||
pub struct ChatMsg {
|
|
||||||
pub display_name: Option<String>,
|
|
||||||
pub user_id: usize,
|
|
||||||
pub text: String,
|
|
||||||
pub timestamp: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
|
|
||||||
pub async fn post_message(
|
|
||||||
mut msg: Json<ChatMsg>,
|
|
||||||
chat: &rocket::State<Arc<ChatBroadcaster>>,
|
|
||||||
mut postgres: Connection<Postgres>,
|
|
||||||
mut cache: Option<Connection<Redis>>,
|
|
||||||
session: Session,
|
|
||||||
channel_id: i32,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
let chat = chat.inner().clone();
|
|
||||||
|
|
||||||
let display_name = sqlx::query!(
|
|
||||||
"SELECT display_name, username FROM users WHERE id = $1",
|
|
||||||
session.user_id as i32
|
|
||||||
)
|
|
||||||
.fetch_one(&mut **postgres)
|
|
||||||
.await
|
|
||||||
.map(|row| row.display_name.unwrap_or(row.username))
|
|
||||||
.unwrap_or_else(|_| "Unknown".to_string());
|
|
||||||
|
|
||||||
msg.user_id = session.user_id;
|
|
||||||
msg.display_name = Some(display_name);
|
|
||||||
chat.publish(channel_id, msg.clone().into_inner()).await;
|
|
||||||
|
|
||||||
sqlx::query!(
|
|
||||||
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
|
|
||||||
channel_id,
|
|
||||||
msg.user_id as i32,
|
|
||||||
msg.text,
|
|
||||||
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
|
|
||||||
)
|
|
||||||
.execute(&mut **postgres)
|
|
||||||
.await
|
|
||||||
.map_err(|_| "Failed".to_string())?;
|
|
||||||
|
|
||||||
println!("gisfujdeghnjuisdfjngiosdfgjkosdf gnojdfsg nmodfsg");
|
|
||||||
|
|
||||||
if let Some(ref mut cache) = cache {
|
|
||||||
messenger::cache::insert(cache, channel_id, &msg)
|
|
||||||
.await
|
|
||||||
.map_err(|_| "Redis cache failed".to_string())?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get response
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let response = LlmWorker::new(crate::LMSTUDIO_URL.to_string())
|
|
||||||
.query(&msg)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if let Ok(reply) = response {
|
|
||||||
chat.publish(channel_id, reply.clone()).await;
|
|
||||||
|
|
||||||
if let Some(ref mut cache) = cache {
|
|
||||||
messenger::cache::insert(cache, channel_id, &reply)
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlx::query!(
|
|
||||||
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
|
|
||||||
channel_id,
|
|
||||||
reply.user_id as i32,
|
|
||||||
reply.text,
|
|
||||||
OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
|
|
||||||
)
|
|
||||||
.execute(&mut **postgres)
|
|
||||||
.await
|
|
||||||
.map_err(|_| "Failed".to_string())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_messages(
|
|
||||||
mut db: Connection<Postgres>,
|
|
||||||
mut redis: Connection<Redis>,
|
|
||||||
channel_id: i32,
|
|
||||||
) -> Json<Vec<ChatMsg>> {
|
|
||||||
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
|
|
||||||
&& !messages.is_empty()
|
|
||||||
{
|
|
||||||
return Json(messages);
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(x) = messenger::cache::initialise(&mut redis, &mut db, channel_id).await {
|
|
||||||
eprintln!("WARN: {x:?}");
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
|
|
||||||
&& !messages.is_empty()
|
|
||||||
{
|
|
||||||
return Json(messages);
|
|
||||||
};
|
|
||||||
|
|
||||||
let res = sqlx::query!(
|
|
||||||
"SELECT u.username, u.display_name, u.id, m.content, m.created_at
|
|
||||||
FROM messages m
|
|
||||||
JOIN users u ON m.user_id = u.id
|
|
||||||
WHERE m.channel_id = $1
|
|
||||||
ORDER BY m.created_at DESC LIMIT 100",
|
|
||||||
channel_id
|
|
||||||
)
|
|
||||||
.fetch_all(&mut **db)
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|_| Vec::new())
|
|
||||||
.into_iter()
|
|
||||||
.rev()
|
|
||||||
.map(|msg| ChatMsg {
|
|
||||||
display_name: Some(msg.display_name.unwrap_or(msg.username)),
|
|
||||||
user_id: msg.id as usize,
|
|
||||||
text: msg.content,
|
|
||||||
timestamp: (msg.created_at.unwrap().unix_timestamp_nanos() / 1_000_000) as usize,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Json(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/events/<channel_id>")]
|
|
||||||
pub async fn event_stream(
|
|
||||||
chat: &rocket::State<Arc<ChatBroadcaster>>,
|
|
||||||
postgres: Connection<Postgres>,
|
|
||||||
cache: Connection<Redis>,
|
|
||||||
_session: Session,
|
|
||||||
mut shutdown: Shutdown,
|
|
||||||
channel_id: i32,
|
|
||||||
) -> EventStream![] {
|
|
||||||
let mut rx = chat.subscribe(channel_id);
|
|
||||||
|
|
||||||
EventStream! {
|
|
||||||
// Initialize the stream with the last 100 messages
|
|
||||||
for msg in get_messages(postgres, cache, channel_id).await.0 {
|
|
||||||
yield Event::json(&msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
loop {
|
|
||||||
select!{
|
|
||||||
// exit early on shutdown
|
|
||||||
_ = &mut shutdown => break,
|
|
||||||
|
|
||||||
msg = rx.recv() => match msg {
|
|
||||||
Ok(msg) => yield Event::json(&msg),
|
|
||||||
Err(broadcast::error::RecvError::Lagged(_)) => {
|
|
||||||
yield Event::comment("RecvError::Lagged");
|
|
||||||
}
|
|
||||||
Err(broadcast::error::RecvError::Closed) => break,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/chat")]
|
|
||||||
pub async fn chat_page(session: Session) -> Template {
|
|
||||||
Template::render("chat", context!(user_id: session.user_id))
|
|
||||||
}
|
|
||||||
@@ -1,4 +1 @@
|
|||||||
mod cache;
|
// mod cache;
|
||||||
mod messages;
|
|
||||||
|
|
||||||
pub use messages::{ChatBroadcaster, ChatMsg, chat_page, event_stream, get_messages, post_message};
|
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
use rocket::serde::{Deserialize, Serialize};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct SignupCredentials {
|
||||||
|
pub email: String,
|
||||||
|
pub username: String,
|
||||||
|
pub password: String,
|
||||||
|
pub access_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct LoginCredentials {
|
||||||
|
pub username: String,
|
||||||
|
pub password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct AuthResponse {
|
||||||
|
pub token: String,
|
||||||
|
pub totp_required: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AccessTokenForm {
|
||||||
|
pub name: String,
|
||||||
|
pub max_uses: i32,
|
||||||
|
pub expiry_date: DateTime<Utc>,
|
||||||
|
pub start_date: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AccessToken {
|
||||||
|
pub id: i64,
|
||||||
|
pub code: String,
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod auth;
|
||||||
|
pub mod user;
|
||||||
|
pub mod space;
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sqlx::FromRow;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||||
|
pub struct Space {
|
||||||
|
pub id: i64,
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub owner_id: i64,
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
pub updated_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||||
|
pub struct Channel {
|
||||||
|
pub id: i64,
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub space_id: i64,
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
pub updated_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct SpaceDto {
|
||||||
|
pub channels: Vec<Channel>,
|
||||||
|
pub id: i64,
|
||||||
|
pub owner_id: i64,
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
pub updated_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
use crate::api::auth::Session;
|
||||||
|
use crate::error::ApiResult;
|
||||||
|
use crate::svc::user_svc::UserService;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use rocket::State;
|
||||||
|
use sqlx::FromRow;
|
||||||
|
use crate::api::totp::TotpStatus;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
#[derive(FromRow)]
|
||||||
|
pub struct User {
|
||||||
|
pub id: i64,
|
||||||
|
pub email: Option<String>,
|
||||||
|
pub username: String,
|
||||||
|
pub nickname: Option<String>,
|
||||||
|
pub passhash: String,
|
||||||
|
pub totp_status: TotpStatus,
|
||||||
|
pub totp_secret: Option<String>,
|
||||||
|
pub created_at: Option<DateTime<Utc>>,
|
||||||
|
pub updated_at: Option<DateTime<Utc>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub struct UserCache {}
|
||||||
|
//
|
||||||
|
// impl UserCache {
|
||||||
|
// pub async fn username(
|
||||||
|
// id: usize,
|
||||||
|
// redis_conn: &mut Connection<Redis>,
|
||||||
|
// pgsql_conn: &mut Connection<Postgres>,
|
||||||
|
// ) -> String {
|
||||||
|
// if let Ok(val) = redis_conn.get(format!("users:{id}")).await {
|
||||||
|
// return val;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
|
||||||
|
// .fetch_one(&mut ***pgsql_conn)
|
||||||
|
// .await
|
||||||
|
// {
|
||||||
|
// let username = v.username;
|
||||||
|
// Self::insert(id, &username, redis_conn).await;
|
||||||
|
// username
|
||||||
|
// } else {
|
||||||
|
// unimplemented!()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
|
||||||
|
// conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800)
|
||||||
|
// .await
|
||||||
|
// .expect("failed to insert key");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
use crate::repo::{Repo, AccessTokenRepoTrait};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use crate::model::auth::AccessToken;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AccessTokenRepo {
|
||||||
|
pool: PgPool
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Repo for AccessTokenRepo {
|
||||||
|
type Target = AccessToken;
|
||||||
|
|
||||||
|
fn new(pool: PgPool) -> Self {
|
||||||
|
Self { pool }
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
|
||||||
|
sqlx::query_as!(AccessToken, "SELECT id, code FROM access_tokens WHERE id = $1", id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
.unwrap_or(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AccessTokenRepoTrait for AccessTokenRepo {
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
|
||||||
|
Repo::get_by_id(self, id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_new(&self,
|
||||||
|
uid: i64, name: &str, code: &str, max_uses: i32,
|
||||||
|
start_date: DateTime<Utc>, expiry_date: DateTime<Utc>
|
||||||
|
) -> Result<i64, sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"INSERT INTO access_tokens (name, code, creator_id, max_uses, created_at, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
|
||||||
|
name,
|
||||||
|
code,
|
||||||
|
uid,
|
||||||
|
max_uses,
|
||||||
|
start_date,
|
||||||
|
expiry_date
|
||||||
|
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn use_token(&self, id: i64) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!("UPDATE access_tokens SET uses = uses + 1 WHERE id = $1", id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error> {
|
||||||
|
sqlx::query_as!(AccessToken,
|
||||||
|
"SELECT id, code FROM access_tokens
|
||||||
|
WHERE code = $1
|
||||||
|
AND created_at < NOW()
|
||||||
|
AND expires_at > NOW()
|
||||||
|
AND uses < max_uses",
|
||||||
|
code
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
use crate::repo::ChannelRepo;
|
||||||
|
use crate::model::space::Channel;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ChannelRepository {
|
||||||
|
pool: PgPool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelRepository {
|
||||||
|
pub fn new(pool: PgPool) -> Self {
|
||||||
|
Self { pool }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl ChannelRepo for ChannelRepository {
|
||||||
|
async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result<i64, sqlx::Error> {
|
||||||
|
let row = sqlx::query!(
|
||||||
|
"INSERT INTO channels (name, description, space_id) VALUES ($1, $2, $3) RETURNING id",
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
space_id
|
||||||
|
)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(row.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_by_space_id(&self, space_id: i64) -> Result<Vec<Channel>, sqlx::Error> {
|
||||||
|
sqlx::query_as!(
|
||||||
|
Channel,
|
||||||
|
"SELECT id, name, description, space_id, created_at as \"created_at!\", updated_at as \"updated_at!\" FROM channels WHERE space_id = $1",
|
||||||
|
space_id
|
||||||
|
)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
use crate::api::chat::ChatMsg;
|
||||||
|
use crate::repo::Repo;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MessageRepository {
|
||||||
|
pool: PgPool
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Repo for MessageRepository {
|
||||||
|
type Target = ChatMsg;
|
||||||
|
|
||||||
|
fn new(pool: PgPool) -> Self {
|
||||||
|
Self { pool }
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: caching with redis
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
|
||||||
|
sqlx::query!(
|
||||||
|
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
|
||||||
|
FROM messages m
|
||||||
|
JOIN users u ON m.user_id = u.id
|
||||||
|
WHERE m.id = $1",
|
||||||
|
id
|
||||||
|
).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg {
|
||||||
|
display_name: Some(row.nickname.unwrap_or(row.username)),
|
||||||
|
user_id: row.user_id,
|
||||||
|
text: row.content,
|
||||||
|
timestamp: row.created_at,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MessageRepository {
|
||||||
|
|
||||||
|
// TODO! caching with redis
|
||||||
|
pub async fn create_new(
|
||||||
|
&self, uid: i64, channel_id: i64,
|
||||||
|
text: &str, created_at: DateTime<Utc>
|
||||||
|
) -> Result<i64, sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"INSERT INTO messages (channel_id, user_id, content, created_at)
|
||||||
|
VALUES ($1, $2, $3, $4) RETURNING id",
|
||||||
|
channel_id,
|
||||||
|
uid,
|
||||||
|
text,
|
||||||
|
created_at
|
||||||
|
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO: caching with redis
|
||||||
|
pub async fn get_by_channel(&self, channel_id: i64, limit: usize)
|
||||||
|
-> Result<Vec<ChatMsg>, sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
|
||||||
|
FROM messages m
|
||||||
|
JOIN users u ON m.user_id = u.id
|
||||||
|
WHERE m.channel_id = $1
|
||||||
|
ORDER BY m.created_at DESC LIMIT $2",
|
||||||
|
channel_id,
|
||||||
|
limit as i64
|
||||||
|
).fetch_all(&self.pool).await.map(|messages| {
|
||||||
|
messages.into_iter().rev().map(|msg| {
|
||||||
|
ChatMsg {
|
||||||
|
display_name: Some(msg.nickname.unwrap_or(msg.username)),
|
||||||
|
user_id: msg.user_id,
|
||||||
|
text: msg.content,
|
||||||
|
timestamp: msg.created_at,
|
||||||
|
}
|
||||||
|
}).collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
use crate::repo::{UserRepo, AccessTokenRepoTrait};
|
||||||
|
use crate::model::user::User;
|
||||||
|
use crate::model::auth::AccessToken;
|
||||||
|
use rocket::async_trait;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use chrono::Utc;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use sqlx::Error;
|
||||||
|
use crate::api::totp::TotpStatus;
|
||||||
|
use crate::api::totp::TotpStatus::Disabled;
|
||||||
|
|
||||||
|
pub struct MockAccessTokenRepo {
|
||||||
|
pub tokens: Mutex<Vec<AccessToken>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AccessTokenRepoTrait for MockAccessTokenRepo {
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
|
||||||
|
self.tokens.lock().unwrap().iter().find(|t| t.id == id).map(|t| AccessToken { id: t.id, code: t.code.clone() })
|
||||||
|
}
|
||||||
|
async fn create_new(&self, _uid: i64, _name: &str, code: &str, _max_uses: i32, _start_date: chrono::DateTime<Utc>, _expiry_date: chrono::DateTime<Utc>) -> Result<i64, sqlx::Error> {
|
||||||
|
let mut tokens = self.tokens.lock().unwrap();
|
||||||
|
let id = tokens.len() as i64 + 1;
|
||||||
|
tokens.push(AccessToken { id, code: code.to_string() });
|
||||||
|
Ok(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn use_token(&self, id: i64) -> Result<(), Error> {
|
||||||
|
// let mut tokens = self.tokens.lock().unwrap();
|
||||||
|
// if let Some(pos) = tokens.iter().position(|t| t.id == id) {
|
||||||
|
// tokens.get_mut(pos).uses =
|
||||||
|
// }
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, Error> {
|
||||||
|
Ok(self.tokens.lock().unwrap()
|
||||||
|
.iter().find(|t| t.code == code)
|
||||||
|
.map(|t| AccessToken { id: t.id, code: t.code.clone() }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MockUserRepo {
|
||||||
|
pub users: Mutex<Vec<User>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl UserRepo for MockUserRepo {
|
||||||
|
fn pool(&self) -> &sqlx::PgPool {
|
||||||
|
unimplemented!("MockUserRepo does not have a real pool")
|
||||||
|
}
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<User> {
|
||||||
|
self.users.lock().unwrap().iter().find(|u| u.id == id).cloned()
|
||||||
|
}
|
||||||
|
async fn save(&self, user: &User) -> Result<(), sqlx::Error> {
|
||||||
|
let mut users = self.users.lock().unwrap();
|
||||||
|
if let Some(pos) = users.iter().position(|u| u.id == user.id) {
|
||||||
|
users[pos] = user.clone();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn new_user(&self, email: &str, username: &str, pass_hash: &str) -> Result<i64, sqlx::Error> {
|
||||||
|
let mut users = self.users.lock().unwrap();
|
||||||
|
let id = users.len() as i64 + 1;
|
||||||
|
users.push(User {
|
||||||
|
id,
|
||||||
|
email: Some(email.to_string()),
|
||||||
|
username: username.to_string(),
|
||||||
|
nickname: None,
|
||||||
|
passhash: pass_hash.to_string(),
|
||||||
|
totp_status: Disabled,
|
||||||
|
totp_secret: None,
|
||||||
|
created_at: Some(Utc::now()),
|
||||||
|
updated_at: Some(Utc::now()),
|
||||||
|
});
|
||||||
|
Ok(id)
|
||||||
|
}
|
||||||
|
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error> {
|
||||||
|
Ok(self.users.lock().unwrap().iter().find(|u| u.username == username).cloned())
|
||||||
|
}
|
||||||
|
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error> {
|
||||||
|
self.users.lock().unwrap().retain(|u| u.id != id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error> {
|
||||||
|
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
|
||||||
|
u.nickname = display_name;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error> {
|
||||||
|
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
|
||||||
|
u.username = username.to_string();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error> {
|
||||||
|
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
|
||||||
|
u.totp_status = *enabled;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error> {
|
||||||
|
Ok(self.users.lock().unwrap().iter().find(|u| u.id == id).and_then(|u| u.totp_secret.clone()))
|
||||||
|
}
|
||||||
|
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error> {
|
||||||
|
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
|
||||||
|
u.totp_secret = secret;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error> {
|
||||||
|
Ok(self.users.lock().unwrap().iter().find(|u| u.id == id).map(|u| u.passhash.clone()).unwrap())
|
||||||
|
}
|
||||||
|
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error> {
|
||||||
|
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
|
||||||
|
u.passhash = pass_hash.to_string();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error> {
|
||||||
|
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
|
||||||
|
u.email = Some(email.to_string());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn set_role(&self, _id: i64, _role: &str) -> Result<(), sqlx::Error> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MockTokenRepo {
|
||||||
|
pub tokens: Mutex<Vec<AccessToken>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AccessTokenRepoTrait for MockTokenRepo {
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
|
||||||
|
self.tokens.lock().unwrap().iter().find(|t| t.id == id).map(|t| AccessToken { id: t.id, code: t.code.clone() })
|
||||||
|
}
|
||||||
|
async fn create_new(&self, _uid: i64, _name: &str, code: &str, _max_uses: i32, _start_date: chrono::DateTime<Utc>, _expiry_date: chrono::DateTime<Utc>) -> Result<i64, sqlx::Error> {
|
||||||
|
let mut tokens = self.tokens.lock().unwrap();
|
||||||
|
let id = tokens.len() as i64 + 1;
|
||||||
|
tokens.push(AccessToken { id, code: code.to_string() });
|
||||||
|
Ok(id)
|
||||||
|
}
|
||||||
|
async fn use_token(&self, _id: i64) -> Result<(), sqlx::Error> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error> {
|
||||||
|
Ok(self.tokens.lock().unwrap().iter().find(|t| t.code == code).map(|t| AccessToken { id: t.id, code: t.code.clone() }))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
use crate::model::auth::AccessToken;
|
||||||
|
use crate::model::user::User;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use crate::api::totp::TotpStatus;
|
||||||
|
use crate::model::space::Space;
|
||||||
|
|
||||||
|
pub mod user_repo;
|
||||||
|
pub mod message_repo;
|
||||||
|
pub mod access_token_repo;
|
||||||
|
pub mod space_repo;
|
||||||
|
pub mod channel_repo;
|
||||||
|
pub mod mock;
|
||||||
|
|
||||||
|
pub trait Repo: Clone + Send + Sync {
|
||||||
|
type Target;
|
||||||
|
|
||||||
|
fn new(pool: sqlx::PgPool) -> Self;
|
||||||
|
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<Self::Target>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
pub trait UserRepo: Send + Sync {
|
||||||
|
fn pool(&self) -> &sqlx::PgPool;
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<User>;
|
||||||
|
async fn save(&self, user: &User) -> Result<(), sqlx::Error>;
|
||||||
|
async fn new_user(&self, email: &str, username: &str, pass_hash: &str) -> Result<i64, sqlx::Error>;
|
||||||
|
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error>;
|
||||||
|
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error>;
|
||||||
|
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error>;
|
||||||
|
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error>;
|
||||||
|
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error>;
|
||||||
|
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error>;
|
||||||
|
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error>;
|
||||||
|
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error>;
|
||||||
|
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error>;
|
||||||
|
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error>;
|
||||||
|
async fn set_role(&self, id: i64, role: &str) -> Result<(), sqlx::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
pub trait SpaceRepo: Send + Sync {
|
||||||
|
async fn create(&self, name: &str, description: Option<&str>, owner_id: i64) -> Result<i64, sqlx::Error>;
|
||||||
|
async fn get_all(&self) -> Result<Vec<crate::model::space::Space>, sqlx::Error>;
|
||||||
|
async fn get_by_member(&self, uid: i64) -> Result<Vec<Space>, sqlx::Error>;
|
||||||
|
async fn get_by_id(&self, id: i64) -> Result<Option<crate::model::space::Space>, sqlx::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
pub trait ChannelRepo: Send + Sync {
|
||||||
|
async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result<i64, sqlx::Error>;
|
||||||
|
async fn get_by_space_id(&self, space_id: i64) -> Result<Vec<crate::model::space::Channel>, sqlx::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait AccessTokenRepoTrait: Send + Sync {
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<AccessToken>;
|
||||||
|
async fn create_new(&self,
|
||||||
|
uid: i64, name: &str, code: &str, max_uses: i32,
|
||||||
|
start_date: DateTime<Utc>, expiry_date: DateTime<Utc>
|
||||||
|
) -> Result<i64, sqlx::Error>;
|
||||||
|
async fn use_token(&self, id: i64) -> Result<(), sqlx::Error>;
|
||||||
|
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error>;
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
use crate::repo::SpaceRepo;
|
||||||
|
use crate::model::space::Space;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SpaceRepository {
|
||||||
|
pool: PgPool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SpaceRepository {
|
||||||
|
pub fn new(pool: PgPool) -> Self {
|
||||||
|
Self { pool }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
|
impl SpaceRepo for SpaceRepository {
|
||||||
|
async fn create(&self, name: &str, description: Option<&str>, owner_id: i64) -> Result<i64, sqlx::Error> {
|
||||||
|
let row = sqlx::query!(
|
||||||
|
"INSERT INTO spaces (name, description, owner_id) VALUES ($1, $2, $3) RETURNING id",
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
owner_id
|
||||||
|
)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(row.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_all(&self) -> Result<Vec<Space>, sqlx::Error> {
|
||||||
|
sqlx::query_as!(Space,
|
||||||
|
"SELECT id, name, description, owner_id, created_at, updated_at FROM spaces"
|
||||||
|
)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_by_member(&self, uid: i64) -> Result<Vec<Space>, sqlx::Error> {
|
||||||
|
sqlx::query_as!(Space,
|
||||||
|
"SELECT s.id, s.name, s.description, s.created_at, s.updated_at, s.owner_id
|
||||||
|
FROM spaces s JOIN space_members sm ON s.id = sm.space_id
|
||||||
|
WHERE sm.user_id = $1",
|
||||||
|
uid
|
||||||
|
).fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_by_id(&self, id: i64) -> Result<Option<Space>, sqlx::Error> {
|
||||||
|
sqlx::query_as!(Space,
|
||||||
|
"SELECT id, name, description, owner_id, created_at, updated_at FROM spaces WHERE id = $1",
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,212 @@
|
|||||||
|
use crate::repo::{Repo, UserRepo};
|
||||||
|
use crate::model::user::User;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use crate::api::totp::TotpStatus;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct UserRepository {
|
||||||
|
pool: PgPool
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserRepository {
|
||||||
|
pub fn new(pool: PgPool) -> Self {
|
||||||
|
Self { pool }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pool(&self) -> &PgPool {
|
||||||
|
&self.pool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Repo for UserRepository {
|
||||||
|
type Target = User;
|
||||||
|
fn new(pool: PgPool) -> Self {
|
||||||
|
Self::new(pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
|
||||||
|
sqlx::query_as!(
|
||||||
|
User,
|
||||||
|
"SELECT id, email, username, nickname, passhash, totp_status as \"totp_status!: TotpStatus\", totp_secret, created_at, updated_at FROM users WHERE id = $1",
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::error!("Database error in get_by_id: {}", e);
|
||||||
|
e
|
||||||
|
})
|
||||||
|
.ok()?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl UserRepo for UserRepository {
|
||||||
|
fn pool(&self) -> &sqlx::PgPool {
|
||||||
|
&self.pool
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_by_id(&self, id: i64) -> Option<User> {
|
||||||
|
Repo::get_by_id(self, id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn save(&self, user: &User) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE users SET email = $1, username = $2, nickname = $3, passhash = $4, totp_status = $5, totp_secret = $6, created_at = $7, updated_at = $8 WHERE id = $9",
|
||||||
|
user.email,
|
||||||
|
user.username,
|
||||||
|
user.nickname,
|
||||||
|
user.passhash,
|
||||||
|
user.totp_status as TotpStatus,
|
||||||
|
user.totp_secret,
|
||||||
|
user.created_at,
|
||||||
|
user.updated_at,
|
||||||
|
user.id
|
||||||
|
).execute(&self.pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn new_user(&self, email: &str, username: &str, passhash: &str) -> Result<i64, sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"INSERT INTO users (email, username, passhash) VALUES ($1, $2, $3) RETURNING id",
|
||||||
|
email,
|
||||||
|
username,
|
||||||
|
passhash
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error> {
|
||||||
|
sqlx::query_as!(
|
||||||
|
User,
|
||||||
|
"SELECT id, email, username, nickname, passhash, totp_status as \"totp_status!: TotpStatus\", totp_secret, created_at, updated_at FROM users WHERE username = $1",
|
||||||
|
username
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!("DELETE FROM users WHERE id = $1", id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE users SET nickname = $1 WHERE id = $2",
|
||||||
|
display_name,
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.execute(&self.pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE users SET username = $1 WHERE id = $2",
|
||||||
|
username,
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.execute(&self.pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE users SET totp_status = $1 WHERE id = $2",
|
||||||
|
enabled as &TotpStatus,
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.execute(&self.pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"SELECT totp_secret FROM users WHERE id = $1",
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
.map(|opt| opt.and_then(|row| row.totp_secret))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE users SET totp_secret = $1 WHERE id = $2",
|
||||||
|
secret,
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.execute(&self.pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"SELECT passhash FROM users WHERE id = $1",
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
.and_then(|row| row.map(|r| r.passhash).ok_or(sqlx::Error::RowNotFound))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE users SET passhash = $1 WHERE id = $2",
|
||||||
|
pass_hash,
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.execute(&self.pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query!(
|
||||||
|
"UPDATE users SET email = $1 WHERE id = $2",
|
||||||
|
email,
|
||||||
|
id
|
||||||
|
)
|
||||||
|
.execute(&self.pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_role(&self, id: i64, role: &str) -> Result<(), sqlx::Error> {
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE users SET role = $1::user_role WHERE id = $2"
|
||||||
|
)
|
||||||
|
.bind(role)
|
||||||
|
.bind(id)
|
||||||
|
.execute(&self.pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use uuid::Uuid;
|
||||||
|
use crate::error::{ApiResult, AppError};
|
||||||
|
use crate::model::auth::AccessToken;
|
||||||
|
use crate::repo::access_token_repo::AccessTokenRepo;
|
||||||
|
use crate::repo::AccessTokenRepoTrait;
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AccessTokenService {
|
||||||
|
repo: Arc<dyn AccessTokenRepoTrait>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AccessTokenService {
|
||||||
|
pub fn new(repo: Arc<dyn AccessTokenRepoTrait>) -> Self {
|
||||||
|
Self { repo }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create(&self,
|
||||||
|
uid: i64, name: &str, max_uses: i32,
|
||||||
|
valid_from: DateTime<Utc>, valid_until: DateTime<Utc>
|
||||||
|
) -> ApiResult<String> {
|
||||||
|
if valid_from > valid_until {
|
||||||
|
return Err(AppError::bad_request("start date must be before end date"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid_until < Utc::now() {
|
||||||
|
return Err(AppError::bad_request("expiry date must be after current date"))
|
||||||
|
}
|
||||||
|
|
||||||
|
let code = Uuid::new_v4().to_string();
|
||||||
|
self.repo.create_new(uid, name, &code, max_uses, valid_from, valid_until).await?;
|
||||||
|
|
||||||
|
Ok(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_valid_token(&self, token: &str) -> ApiResult<AccessToken> {
|
||||||
|
self.repo.get_code_not_expired(token).await?
|
||||||
|
.ok_or(AppError::unauthorised("invalid access token"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn use_token(&self, id: i64) -> ApiResult<()> {
|
||||||
|
self.repo.use_token(id).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,259 @@
|
|||||||
|
use crate::api::auth::{Claims, TokenScope};
|
||||||
|
use crate::api::totp::totp_gen;
|
||||||
|
use crate::error::{ApiResult, AppError};
|
||||||
|
use crate::model::auth::AuthResponse;
|
||||||
|
use crate::repo::{UserRepo, AccessTokenRepoTrait};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use argon2::password_hash::rand_core::OsRng;
|
||||||
|
use argon2::password_hash::SaltString;
|
||||||
|
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use uuid::Uuid;
|
||||||
|
use crate::api::totp::TotpStatus::{Disabled, Enabled};
|
||||||
|
use crate::svc::access_token_svc::AccessTokenService;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AuthService {
|
||||||
|
users: Arc<dyn UserRepo>,
|
||||||
|
tokens: AccessTokenService,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthService {
|
||||||
|
pub fn new(users: Arc<dyn UserRepo>, tokens: AccessTokenService) -> Self {
|
||||||
|
Self { users, tokens }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn signup(&self,
|
||||||
|
email: &str, username: &str,
|
||||||
|
password: &str, access_token: &str
|
||||||
|
) -> ApiResult<AuthResponse> {
|
||||||
|
let tok_id = self.tokens.get_valid_token(access_token).await?.id;
|
||||||
|
|
||||||
|
let pass = password.to_string();
|
||||||
|
let svc = self.clone();
|
||||||
|
let hashed = tokio::task::spawn_blocking(move || svc.hash_password(&pass))
|
||||||
|
.await
|
||||||
|
.map_err(|_| AppError::internal("blocking task panicked"))??;
|
||||||
|
|
||||||
|
let uid = self.users
|
||||||
|
.new_user(email, username, &hashed).await?;
|
||||||
|
|
||||||
|
self.tokens.use_token(tok_id).await?;
|
||||||
|
|
||||||
|
let jwt = Claims::new(uid as usize, TokenScope::Full).encode();
|
||||||
|
Ok(AuthResponse {
|
||||||
|
token: jwt,
|
||||||
|
totp_required: false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn login(&self, username: &str, password: &str) -> ApiResult<AuthResponse> {
|
||||||
|
let user = self.users
|
||||||
|
.get_by_username(username).await?
|
||||||
|
.ok_or(AppError::unauthorised("invalid username"))?;
|
||||||
|
|
||||||
|
let pass = password.to_string();
|
||||||
|
let user_hash = user.passhash.clone();
|
||||||
|
let svc = self.clone();
|
||||||
|
tokio::task::spawn_blocking(move || svc.verify_password(&user_hash, &pass))
|
||||||
|
.await
|
||||||
|
.map_err(|_| AppError::internal("blocking task panicked"))??;
|
||||||
|
|
||||||
|
let scope = if user.totp_status == Enabled { TokenScope::TotpPending } else { TokenScope::Full };
|
||||||
|
let jwt = Claims::new(user.id as usize, scope).encode();
|
||||||
|
|
||||||
|
Ok(AuthResponse {
|
||||||
|
token: jwt,
|
||||||
|
totp_required: user.totp_status == Enabled
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn login_totp(&self, uid: i64, code: &str) -> ApiResult<AuthResponse> {
|
||||||
|
let secret = self.users.get_totp_secret(uid).await?
|
||||||
|
.ok_or(AppError::unauthorised("2fa not enabled"))?;
|
||||||
|
|
||||||
|
self.verify_2fa(uid, &secret, code)?;
|
||||||
|
|
||||||
|
let jwt = Claims::new(uid as usize, TokenScope::Full).encode();
|
||||||
|
|
||||||
|
Ok(AuthResponse {
|
||||||
|
token: jwt,
|
||||||
|
totp_required: false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn disable_totp(&self, uid: i64, password: &str, totp_code: &str) -> ApiResult<AuthResponse> {
|
||||||
|
let mut user = self.users.get_by_id(uid).await
|
||||||
|
.ok_or(AppError::internal("user not found"))?;
|
||||||
|
|
||||||
|
let Some(secret) = user.totp_secret else {
|
||||||
|
return Err(AppError::bad_request("2fa not enabled"));
|
||||||
|
};
|
||||||
|
|
||||||
|
self.verify_password(&user.passhash, password)?;
|
||||||
|
self.verify_2fa(uid, &secret, totp_code)?;
|
||||||
|
|
||||||
|
user.totp_secret = None;
|
||||||
|
user.totp_status = Disabled;
|
||||||
|
self.users.save(&user).await?;
|
||||||
|
|
||||||
|
Ok(AuthResponse {
|
||||||
|
token: Claims::new(uid as usize, TokenScope::Full).encode(),
|
||||||
|
totp_required: false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_totp_status(&self, uid: i64) -> ApiResult<bool> {
|
||||||
|
Ok(
|
||||||
|
self.users.get_totp_secret(uid).await?.is_some()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn confirm_totp(&self, uid: i64, totp_code: &str) -> ApiResult<()> {
|
||||||
|
let secret = self.users.get_totp_secret(uid).await?
|
||||||
|
.ok_or(AppError::bad_request("2fa setup not initialised"))?;
|
||||||
|
|
||||||
|
self.verify_2fa(uid, &secret, totp_code)?;
|
||||||
|
|
||||||
|
self.users.set_twofa_enabled(uid, &Enabled).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_or_create_totp_secret(
|
||||||
|
&self, uid: i64, password: &str,
|
||||||
|
) -> ApiResult<String> {
|
||||||
|
let user = self.users.get_by_id(uid).await
|
||||||
|
.ok_or(AppError::internal("user not found"))?;
|
||||||
|
|
||||||
|
let pass = password.to_string();
|
||||||
|
let user_hash = user.passhash.clone();
|
||||||
|
let svc = self.clone();
|
||||||
|
tokio::task::spawn_blocking(move || svc.verify_password(&user_hash, &pass))
|
||||||
|
.await
|
||||||
|
.map_err(|_| AppError::internal("blocking task panicked"))??;
|
||||||
|
|
||||||
|
if let Some(secret) = user.totp_secret {
|
||||||
|
return Ok(secret);
|
||||||
|
}
|
||||||
|
|
||||||
|
let new_secret = totp_rs::Secret::generate_secret()
|
||||||
|
.to_encoded()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
self.users.set_totp_secret(uid, Some(new_secret.clone())).await?;
|
||||||
|
|
||||||
|
Ok(new_secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn verify_user_password(&self, uid: i64, password: &str) -> ApiResult<()> {
|
||||||
|
let hash = self.users.get_pass_hash(uid).await
|
||||||
|
.map_err(|_| AppError::internal("user not found"))?;
|
||||||
|
|
||||||
|
let pass = password.to_string();
|
||||||
|
let svc = self.clone();
|
||||||
|
tokio::task::spawn_blocking(move || svc.verify_password(&hash, &pass))
|
||||||
|
.await
|
||||||
|
.map_err(|_| AppError::internal("blocking task panicked"))??;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn verify_user_totp(&self, uid: i64, totp_code: &str) -> ApiResult<()> {
|
||||||
|
let secret = self.users.get_totp_secret(uid).await?
|
||||||
|
.ok_or(AppError::internal("user not found"))?;
|
||||||
|
|
||||||
|
self.verify_2fa(uid, &secret, totp_code)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub fn hash_password(&self, password: &str) -> ApiResult<String> {
|
||||||
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
|
Argon2::default()
|
||||||
|
.hash_password(password.as_bytes(), &salt)
|
||||||
|
.map_err(|_| AppError::internal("failed to hash password"))
|
||||||
|
.map(|hash| hash.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Private helpers
|
||||||
|
fn verify_password(&self, pass_hash: &str, password: &str) -> ApiResult<()> {
|
||||||
|
let parsed_hash = PasswordHash::new(&pass_hash)
|
||||||
|
.map_err(|_| AppError::internal("invalid password hash"))?;
|
||||||
|
|
||||||
|
Argon2::default()
|
||||||
|
.verify_password(password.as_bytes(), &parsed_hash)
|
||||||
|
.map_err(|_| AppError::unauthorised("incorrect password"))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn verify_2fa(&self, uid: i64, totp_secret: &str, totp_code: &str) -> ApiResult<()> {
|
||||||
|
if totp_gen(uid, totp_secret.as_bytes())
|
||||||
|
.map_err(|_| AppError::internal("invalid totp secret"))?
|
||||||
|
.check_current(totp_code)
|
||||||
|
.map_err(|_| AppError::internal("invalid totp code"))? {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(AppError::unauthorised("incorrect totp code"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::repo::mock::{MockUserRepo, MockTokenRepo};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
fn setup() -> AuthService {
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("JWT_SECRET", "test_secret");
|
||||||
|
}
|
||||||
|
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
|
||||||
|
let tok_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
|
||||||
|
let tokens = AccessTokenService::new(tok_repo);
|
||||||
|
AuthService::new(users, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_signup_and_login() {
|
||||||
|
let auth = setup();
|
||||||
|
let code = auth.tokens.create(1, "test", 1, Utc::now(), Utc::now()).await.unwrap();
|
||||||
|
|
||||||
|
let signup_res = auth.signup("test@example.com", "tester", "password123", &code).await;
|
||||||
|
assert!(signup_res.is_ok());
|
||||||
|
|
||||||
|
let login_res = auth.login("tester", "password123").await;
|
||||||
|
assert!(login_res.is_ok());
|
||||||
|
let login_data = login_res.unwrap();
|
||||||
|
assert!(!login_data.totp_required);
|
||||||
|
assert!(!login_data.token.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_login_invalid_password() {
|
||||||
|
let auth = setup();
|
||||||
|
let token_code = auth.tokens.create(1, "test", 1, Utc::now(), Utc::now()).await.unwrap();
|
||||||
|
auth.signup("test@example.com", "tester", "password123", &token_code).await.unwrap();
|
||||||
|
|
||||||
|
let login_res = auth.login("tester", "wrong_password").await;
|
||||||
|
assert!(login_res.is_err());
|
||||||
|
if let Err(AppError::Unauthorised(msg)) = login_res {
|
||||||
|
assert_eq!(msg, "incorrect password");
|
||||||
|
} else {
|
||||||
|
panic!("Expected Unauthorised error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invite() {
|
||||||
|
let auth = setup();
|
||||||
|
let res = auth.tokens.create(1, "invite", 1, Utc::now(), Utc::now() + chrono::Duration::days(1)).await;
|
||||||
|
assert!(res.is_ok());
|
||||||
|
let code = res.unwrap();
|
||||||
|
assert!(!code.is_empty());
|
||||||
|
|
||||||
|
let token = auth.tokens.get_valid_token(&code).await;
|
||||||
|
assert!(token.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,187 @@
|
|||||||
|
use crate::api::chat::ChatMsg;
|
||||||
|
use crate::error::{ApiResult, AppError};
|
||||||
|
use crate::repo::message_repo::MessageRepository;
|
||||||
|
use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::broadcast::Sender;
|
||||||
|
use tokio::sync::{broadcast, Mutex};
|
||||||
|
use crate::model::space::SpaceDto;
|
||||||
|
use crate::svc::llm_service::LlmService;
|
||||||
|
|
||||||
|
/// ---------- shared broadcaster ----------
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ChatService {
|
||||||
|
users: Arc<dyn UserRepo>,
|
||||||
|
channels: Arc<dyn ChannelRepo>,
|
||||||
|
spaces: Arc<dyn SpaceRepo>,
|
||||||
|
messages: MessageRepository,
|
||||||
|
|
||||||
|
llm: LlmService,
|
||||||
|
buffer_size: usize,
|
||||||
|
senders: Arc<Mutex<HashMap<i64, Sender<ChatMsg>>>>,
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatService {
|
||||||
|
pub fn new(
|
||||||
|
buffer_size: usize, llm: LlmService,
|
||||||
|
messages: MessageRepository, users: Arc<dyn UserRepo>,
|
||||||
|
channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
channels,
|
||||||
|
spaces,
|
||||||
|
llm,
|
||||||
|
users,
|
||||||
|
messages,
|
||||||
|
buffer_size,
|
||||||
|
senders: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_accessible_channels(&self, uid: i64) -> ApiResult<Vec<SpaceDto>> {
|
||||||
|
// let spaces = self.spaces.get_by_member(uid).await?;
|
||||||
|
// TODO! UNCOMMENT THIS ^^^^^^
|
||||||
|
let spaces = self.spaces.get_all().await?;
|
||||||
|
|
||||||
|
let mut result = Vec::new();
|
||||||
|
for space in spaces {
|
||||||
|
let channels = self.channels.get_by_space_id(space.id).await?;
|
||||||
|
result.push(SpaceDto {
|
||||||
|
channels,
|
||||||
|
id: space.id,
|
||||||
|
owner_id: space.owner_id,
|
||||||
|
name: space.name,
|
||||||
|
description: space.description,
|
||||||
|
created_at: space.created_at,
|
||||||
|
updated_at: space.updated_at,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
|
||||||
|
let messages = self.messages.get_by_channel(channel_id, limit).await?;
|
||||||
|
Ok(messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends a chat message to the specified channel, persists it to the database,
|
||||||
|
/// and handles potential AI-generated replies asynchronously.
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
/// - `channel_id`: The ID of the channel to which the message will be sent.
|
||||||
|
/// - `uid`: The user ID of the sender.
|
||||||
|
/// - `text`: The content of the message to be sent.
|
||||||
|
/// - `created_at`: The timestamp at which the message was created.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// - `ApiResult<()>`: Indicates success or failure of the operation.
|
||||||
|
///
|
||||||
|
/// # Behavior
|
||||||
|
/// 1. Fetches the user by their `uid`. Returns an error if the user is not found.
|
||||||
|
/// 2. Constructs a `ChatMsg` object with the sender's `display_name` or `username`,
|
||||||
|
/// and the specified message content and timestamp.
|
||||||
|
/// 3. Publishes the constructed message to the given channel.
|
||||||
|
/// 4. Persists the message in the database.
|
||||||
|
/// 5. Spawns an asynchronous task to generate an LLM-powered (language model) reply:
|
||||||
|
/// - Sends the original message to the LLM worker for a potential reply.
|
||||||
|
/// - Publishes the LLM's reply to the same channel if successful.
|
||||||
|
/// - Persists the LLM's reply to the database.
|
||||||
|
///
|
||||||
|
/// # Notes
|
||||||
|
/// - Caching with Redis is planned for both message persistence and AI replies, but
|
||||||
|
/// is not implemented in the current version.
|
||||||
|
/// - The spawned asynchronous task does not block the main execution flow.
|
||||||
|
///
|
||||||
|
/// # Potential Errors
|
||||||
|
/// - Returns `AppError::NotFound` if the `uid` does not map to an existing user.
|
||||||
|
/// - Returns an error wrapped in `ApiResult` if the database operations fail.
|
||||||
|
///
|
||||||
|
/// # TODO
|
||||||
|
/// - Implement caching for both user-supplied messages and LLM-generated replies
|
||||||
|
/// using Redis at the repository or service layer.
|
||||||
|
pub async fn send(&self,
|
||||||
|
channel_id: i64, uid: i64,
|
||||||
|
text: &str, created_at: DateTime<Utc>
|
||||||
|
) -> ApiResult<()> {
|
||||||
|
let user = self.users.get_by_id(uid).await
|
||||||
|
.ok_or(AppError::NotFound)?;
|
||||||
|
|
||||||
|
let message = ChatMsg {
|
||||||
|
display_name: Some(user
|
||||||
|
.nickname.clone()
|
||||||
|
.unwrap_or_else(|| user.username.clone())),
|
||||||
|
user_id: uid,
|
||||||
|
text: text.to_string(),
|
||||||
|
timestamp: created_at,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.publish(channel_id, message.clone()).await;
|
||||||
|
|
||||||
|
let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?;
|
||||||
|
// TODO: caching w redis at repository layer
|
||||||
|
|
||||||
|
let svc_instance = self.clone();
|
||||||
|
|
||||||
|
let Some(text) = text.strip_prefix("/ask ") else {
|
||||||
|
return Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
if !svc_instance.llm.enabled() {
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let response = svc_instance.llm
|
||||||
|
.query(&message)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Ok(reply) = response {
|
||||||
|
|
||||||
|
tracing::info!("LLM reply: {}", reply.text);
|
||||||
|
|
||||||
|
svc_instance.publish(channel_id, reply.clone()).await;
|
||||||
|
// TODO: cache response (or do with redis!)
|
||||||
|
if let Err(e) = svc_instance.messages
|
||||||
|
.create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await {
|
||||||
|
tracing::error!("Failed to persist LLM reply: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("LLM reply persisted");
|
||||||
|
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Error contacting LLM: {:?}", response);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Subscribe to the specified channel.
|
||||||
|
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatMsg> {
|
||||||
|
let mut map = self.senders.lock().await;
|
||||||
|
let sender = map
|
||||||
|
.entry(channel_id)
|
||||||
|
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
|
||||||
|
sender.subscribe()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Private helper methods
|
||||||
|
|
||||||
|
/// Publish a message to the specified channel.
|
||||||
|
async fn publish(&self, channel_id: i64, msg: ChatMsg) {
|
||||||
|
let mut map = self.senders.lock().await;
|
||||||
|
let sender = map
|
||||||
|
.entry(channel_id)
|
||||||
|
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
|
||||||
|
let _ = sender.send(msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
#[derive(Clone)]
|
||||||
|
pub struct LlmService;
|
||||||
|
|
||||||
|
|
||||||
|
static LMSTUDIO_URL: LazyLock<Option<String>> = LazyLock::new(|| env::var("LMSTUDIO_URL").ok());
|
||||||
|
static LMSTUDIO_MODEL: LazyLock<Option<String>> = LazyLock::new(|| env::var("LMSTUDIO_MODEL").ok());
|
||||||
|
|
||||||
|
impl LlmService {
|
||||||
|
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn enabled(&self) -> bool {
|
||||||
|
LMSTUDIO_URL.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn query(&self, message: &ChatMsg) -> ApiResult<ChatMsg> {
|
||||||
|
let Some(url) = LMSTUDIO_URL.clone() else {
|
||||||
|
return Err(AppError::internal("AI not enabled!"))
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into());
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
// Build the request body
|
||||||
|
let payload = LlmRequest {
|
||||||
|
model, // whatever model you run locally
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".into(),
|
||||||
|
content: message.text.clone(),
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
|
// POST to lm‑studio (default 127.0.0.1:1234)
|
||||||
|
let resp = client
|
||||||
|
.post(url)
|
||||||
|
.json(&payload)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|_| AppError::internal("Failed to make request to LLM server"))?;
|
||||||
|
|
||||||
|
// The API returns a JSON with `choices[].message.content`
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct LlmResponse {
|
||||||
|
choices: Vec<Choice>,
|
||||||
|
}
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Choice {
|
||||||
|
message: Message,
|
||||||
|
}
|
||||||
|
|
||||||
|
let llm_resp: LlmResponse = resp
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|_| AppError::internal("Failed to parse LLM response"))?;
|
||||||
|
|
||||||
|
Ok(ChatMsg {
|
||||||
|
display_name: Some(String::from("llm")),
|
||||||
|
user_id: 0,
|
||||||
|
text: llm_resp.choices[0].message.content.clone(),
|
||||||
|
timestamp: chrono::Utc::now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use std::env;
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
// src/llm.rs
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::api::chat::ChatMsg;
|
||||||
|
use crate::error::{ApiResult, AppError};
|
||||||
|
use crate::svc::chat_svc::ChatService;
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct LlmRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<Message>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct Message {
|
||||||
|
role: String, // "user" or "assistant"
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
pub mod auth_svc;
|
||||||
|
pub mod chat_svc;
|
||||||
|
pub mod settings_svc;
|
||||||
|
pub mod user_svc;
|
||||||
|
pub mod access_token_svc;
|
||||||
|
pub mod llm_service;
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
//! The `SettingsService` is responsible for managing user account settings, allowing users to
|
||||||
|
//! update their username, password, display name, email, and delete their account.
|
||||||
|
//! It interacts with the `AuthService` to handle authentication and password-related functionality
|
||||||
|
//! and the `UserRepository` to perform updates to user accounts in the data store.
|
||||||
|
|
||||||
|
use crate::error::{ApiResult, AppError};
|
||||||
|
use crate::repo::UserRepo;
|
||||||
|
use crate::svc::auth_svc::AuthService;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SettingsService {
|
||||||
|
auth: AuthService,
|
||||||
|
users: Arc<dyn UserRepo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SettingsService {
|
||||||
|
pub fn new(auth: AuthService, users: Arc<dyn UserRepo>) -> Self {
|
||||||
|
Self { auth, users }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn change_username(&self, uid: i64, new: &str) -> ApiResult<()> {
|
||||||
|
self.users.set_username(uid, new).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn change_password(&self, uid: i64, old: &str, new: &str) -> ApiResult<()> {
|
||||||
|
self.auth.verify_user_password(uid, old).await?;
|
||||||
|
let hashed = self.auth.hash_password(new)?;
|
||||||
|
self.users.set_pass_hash(uid, &hashed).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn change_display_name(&self, uid: i64, new: Option<String>) -> ApiResult<()> {
|
||||||
|
self.users.set_display_name(uid, new).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn change_email(&self, uid: i64, new: &str) -> ApiResult<()> {
|
||||||
|
self.users.set_email(uid, new).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_account(&self, uid: i64, password: &str, totp_code: &Option<String>) -> ApiResult<()> {
|
||||||
|
self.auth.verify_user_password(uid, password).await?;
|
||||||
|
|
||||||
|
// check 2fa code is correct if enabled
|
||||||
|
if self.auth.get_totp_status(uid).await? {
|
||||||
|
|
||||||
|
let Some(totp_code) = totp_code else {
|
||||||
|
return Err(AppError::unauthorised("2fa code is required"))
|
||||||
|
};
|
||||||
|
|
||||||
|
self.auth.verify_user_totp(uid, totp_code).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.users.delete_by_id(uid).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::repo::mock::{MockUserRepo, MockTokenRepo};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use chrono::Utc;
|
||||||
|
use crate::svc::access_token_svc::AccessTokenService;
|
||||||
|
|
||||||
|
fn setup() -> SettingsService {
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("JWT_SECRET", "test_secret");
|
||||||
|
}
|
||||||
|
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
|
||||||
|
let token_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
|
||||||
|
let tokens_svc = AccessTokenService::new(token_repo);
|
||||||
|
let auth = AuthService::new(users.clone(), tokens_svc.clone());
|
||||||
|
SettingsService::new(auth, users)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_change_username() {
|
||||||
|
let settings = setup();
|
||||||
|
let uid = settings.users.new_user("test@example.com", "old", "pass").await.unwrap();
|
||||||
|
|
||||||
|
settings.change_username(uid, "new").await.unwrap();
|
||||||
|
let user = settings.users.get_by_id(uid).await.unwrap();
|
||||||
|
assert_eq!(user.username, "new");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_change_password() {
|
||||||
|
let settings = setup();
|
||||||
|
let pass = "old_pass";
|
||||||
|
let hashed = settings.auth.hash_password(pass).unwrap();
|
||||||
|
let uid = settings.users.new_user("test@example.com", "user", &hashed).await.unwrap();
|
||||||
|
|
||||||
|
settings.change_password(uid, pass, "new_pass").await.unwrap();
|
||||||
|
let _user = settings.users.get_by_id(uid).await.unwrap();
|
||||||
|
assert!(settings.auth.verify_user_password(uid, "new_pass").await.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_account() {
|
||||||
|
let settings = setup();
|
||||||
|
let pass = "password";
|
||||||
|
let hashed = settings.auth.hash_password(pass).unwrap();
|
||||||
|
let uid = settings.users.new_user("test@example.com", "user", &hashed).await.unwrap();
|
||||||
|
|
||||||
|
let res = settings.delete_account(uid, pass, &None).await;
|
||||||
|
assert!(res.is_ok());
|
||||||
|
|
||||||
|
let user = settings.users.get_by_id(uid).await;
|
||||||
|
assert!(user.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
use crate::error::ApiResult;
|
||||||
|
use crate::repo::UserRepo;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub struct UserService {
|
||||||
|
repo: Arc<dyn UserRepo>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserService {
|
||||||
|
pub fn new(repo: Arc<dyn UserRepo>) -> Self {
|
||||||
|
Self { repo }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_display_name(&self, uid: i64) -> ApiResult<String> {
|
||||||
|
// TODO: redis caching for display names
|
||||||
|
|
||||||
|
let user = self.repo.get_by_id(uid)
|
||||||
|
.await.ok_or(crate::error::AppError::NotFound)?;
|
||||||
|
|
||||||
|
Ok(user.nickname.unwrap_or_else(|| user.username))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_username(&self, uid: i64) -> ApiResult<String> {
|
||||||
|
self.repo.get_by_id(uid)
|
||||||
|
.await.ok_or(crate::error::AppError::NotFound)
|
||||||
|
.map(|u| u.username)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,188 +0,0 @@
|
|||||||
use argon2::{Argon2, PasswordHash, PasswordVerifier};
|
|
||||||
use redis::AsyncCommands;
|
|
||||||
use rocket::{http::Status, serde::json::Json, time::OffsetDateTime};
|
|
||||||
use rocket_db_pools::Connection;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
auth::{Session, two_factor::totp_gen},
|
|
||||||
db::{Postgres, Redis},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub struct User {
|
|
||||||
pub id: i32,
|
|
||||||
pub email: Option<String>,
|
|
||||||
pub username: String,
|
|
||||||
pub display_name: Option<String>,
|
|
||||||
pub pass_hash: String,
|
|
||||||
pub twofa_enabled: bool,
|
|
||||||
pub totp_secret: Option<String>,
|
|
||||||
pub created_at: Option<OffsetDateTime>,
|
|
||||||
pub updated_at: Option<OffsetDateTime>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl User {
|
|
||||||
pub async fn get_by_id(id: usize, db: &mut Connection<Postgres>) -> Option<Self> {
|
|
||||||
sqlx::query_as!(
|
|
||||||
Self,
|
|
||||||
"SELECT id, email, username, display_name, pass_hash, twofa_enabled, totp_secret, created_at, updated_at FROM users WHERE id = $1",
|
|
||||||
id as i32
|
|
||||||
)
|
|
||||||
.fetch_optional(&mut ***db)
|
|
||||||
.await
|
|
||||||
.unwrap_or(None)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete(&mut self, db: &mut Connection<Postgres>) -> Result<(), sqlx::Error> {
|
|
||||||
sqlx::query!("DELETE FROM users WHERE id = $1", self.id)
|
|
||||||
.execute(&mut ***db)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn verify_2fa(&self, code: &str) -> Result<(), Status> {
|
|
||||||
if totp_gen(
|
|
||||||
self.id as usize,
|
|
||||||
self.totp_secret
|
|
||||||
.clone()
|
|
||||||
.expect("user with 2fa enabled has no totp secret")
|
|
||||||
.as_bytes(),
|
|
||||||
)
|
|
||||||
.map_err(|_| Status::InternalServerError)?
|
|
||||||
.check_current(code)
|
|
||||||
.map_err(|_| Status::InternalServerError)?
|
|
||||||
{
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(Status::Unauthorized)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn verify_password(&self, password: &str) -> Result<(), Status> {
|
|
||||||
let parsed_hash = PasswordHash::new(&self.pass_hash)
|
|
||||||
.inspect_err(|e| {
|
|
||||||
tracing::error!("Failed to parse hash for password! uid:{} {e}", self.id)
|
|
||||||
})
|
|
||||||
.map_err(|_| Status::InternalServerError)?;
|
|
||||||
|
|
||||||
Argon2::default()
|
|
||||||
.verify_password(password.as_bytes(), &parsed_hash)
|
|
||||||
.map_err(|_| Status::Unauthorized)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_display_name(
|
|
||||||
&mut self,
|
|
||||||
display_name: Option<String>,
|
|
||||||
db: &mut Connection<Postgres>,
|
|
||||||
) -> Result<(), sqlx::Error> {
|
|
||||||
self.display_name = display_name;
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE users SET display_name = $1 WHERE id = $2",
|
|
||||||
self.display_name,
|
|
||||||
self.id
|
|
||||||
)
|
|
||||||
.execute(&mut ***db)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_username(
|
|
||||||
&mut self,
|
|
||||||
username: String,
|
|
||||||
db: &mut Connection<Postgres>,
|
|
||||||
) -> Result<(), sqlx::Error> {
|
|
||||||
self.username = username;
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE users SET username = $1 WHERE id = $2",
|
|
||||||
self.username,
|
|
||||||
self.id
|
|
||||||
)
|
|
||||||
.execute(&mut ***db)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_twofa_enabled(
|
|
||||||
&mut self,
|
|
||||||
enabled: bool,
|
|
||||||
db: &mut Connection<Postgres>,
|
|
||||||
) -> Result<(), sqlx::Error> {
|
|
||||||
self.twofa_enabled = enabled;
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE users SET twofa_enabled = $1 WHERE id = $2",
|
|
||||||
self.twofa_enabled,
|
|
||||||
self.id
|
|
||||||
)
|
|
||||||
.execute(&mut ***db)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_pass_hash(
|
|
||||||
&mut self,
|
|
||||||
pass_hash: String,
|
|
||||||
db: &mut Connection<Postgres>,
|
|
||||||
) -> Result<(), sqlx::Error> {
|
|
||||||
self.pass_hash = pass_hash;
|
|
||||||
sqlx::query!(
|
|
||||||
"UPDATE users SET pass_hash = $1 WHERE id = $2",
|
|
||||||
self.pass_hash,
|
|
||||||
self.id
|
|
||||||
)
|
|
||||||
.execute(&mut ***db)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/users", rank = 2)]
|
|
||||||
pub async fn users(_ag: Session, mut db: Connection<Postgres>) -> Json<Vec<i32>> {
|
|
||||||
sqlx::query!("SELECT id FROM users")
|
|
||||||
.fetch_all(&mut **db)
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|_| Vec::new())
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| row.id)
|
|
||||||
.collect::<Vec<i32>>()
|
|
||||||
.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/users/<id>", rank = 1)]
|
|
||||||
pub async fn display_name(
|
|
||||||
id: usize,
|
|
||||||
_ag: Session,
|
|
||||||
mut pgsql_conn: Connection<Postgres>,
|
|
||||||
mut redis_conn: Connection<Redis>,
|
|
||||||
) -> String {
|
|
||||||
UserCache::username(id, &mut redis_conn, &mut pgsql_conn).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct UserCache {}
|
|
||||||
|
|
||||||
impl UserCache {
|
|
||||||
pub async fn username(
|
|
||||||
id: usize,
|
|
||||||
redis_conn: &mut Connection<Redis>,
|
|
||||||
pgsql_conn: &mut Connection<Postgres>,
|
|
||||||
) -> String {
|
|
||||||
if let Ok(val) = redis_conn.get(format!("users:{id}")).await {
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
|
|
||||||
.fetch_one(&mut ***pgsql_conn)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
let username = v.username;
|
|
||||||
Self::insert(id, &username, redis_conn).await;
|
|
||||||
username
|
|
||||||
} else {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
|
|
||||||
conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800)
|
|
||||||
.await
|
|
||||||
.expect("failed to insert key");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
use backend::rocket_builder;
|
||||||
|
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
|
||||||
|
use backend::repo::message_repo::MessageRepository;
|
||||||
|
use backend::svc::chat_svc::ChatService;
|
||||||
|
use backend::repo::user_repo::UserRepository;
|
||||||
|
use backend::repo::{Repo, AccessTokenRepoTrait};
|
||||||
|
use rocket::local::asynchronous::Client;
|
||||||
|
use rocket::http::{Status, ContentType};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use chrono::Utc;
|
||||||
|
use backend::svc::llm_service::LlmService;
|
||||||
|
|
||||||
|
async fn test_rocket() -> rocket::Rocket<rocket::Build> {
|
||||||
|
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
|
||||||
|
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
|
||||||
|
|
||||||
|
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
|
||||||
|
let messages = MessageRepository::new(pool.clone());
|
||||||
|
let user_repo = Arc::new(UserRepository::new(pool));
|
||||||
|
let llm_service = LlmService::new();
|
||||||
|
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
|
||||||
|
|
||||||
|
rocket_builder(users, tokens, chat_service)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_unauthorized_access() {
|
||||||
|
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
|
||||||
|
|
||||||
|
// Attempt to access a protected endpoint without authentication
|
||||||
|
let response = client.patch("/api/settings/display_name").dispatch().await;
|
||||||
|
assert_eq!(response.status(), Status::Unauthorized);
|
||||||
|
|
||||||
|
let response = client.post("/api/settings/password").dispatch().await;
|
||||||
|
assert_eq!(response.status(), Status::Unauthorized);
|
||||||
|
|
||||||
|
let response = client.delete("/api/settings").dispatch().await;
|
||||||
|
assert_eq!(response.status(), Status::Unauthorized);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_signup_invalid_token() {
|
||||||
|
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
|
||||||
|
|
||||||
|
let signup_data = json!({
|
||||||
|
"email": "test@example.com",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123",
|
||||||
|
"access_token": "invalid-token"
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = client.post("/api/signup")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(signup_data.to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Unauthorized);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_login_invalid_credentials() {
|
||||||
|
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
|
||||||
|
|
||||||
|
let login_data = json!({
|
||||||
|
"username": "nonexistent",
|
||||||
|
"password": "wrongpassword"
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = client.post("/api/login")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(login_data.to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Unauthorized);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_full_auth_flow() {
|
||||||
|
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
|
||||||
|
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
|
||||||
|
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
|
||||||
|
let messages = MessageRepository::new(pool.clone());
|
||||||
|
let user_repo = Arc::new(UserRepository::new(pool));
|
||||||
|
let llm_service = LlmService::new();
|
||||||
|
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
|
||||||
|
|
||||||
|
let token_code = "valid-token";
|
||||||
|
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
|
||||||
|
|
||||||
|
let client = Client::tracked(rocket_builder(users, tokens, chat_service)).await.expect("valid rocket instance");
|
||||||
|
|
||||||
|
// 1. Signup
|
||||||
|
let signup_data = json!({
|
||||||
|
"email": "test@example.com",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123",
|
||||||
|
"access_token": token_code
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = client.post("/api/signup")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(signup_data.to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Ok);
|
||||||
|
let body = response.into_string().await.unwrap();
|
||||||
|
assert!(body.contains("token"));
|
||||||
|
|
||||||
|
// 2. Login
|
||||||
|
let login_data = json!({
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123"
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = client.post("/api/login")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(login_data.to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Ok);
|
||||||
|
let body = response.into_string().await.unwrap();
|
||||||
|
assert!(body.contains("token"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_delete_account_security() {
|
||||||
|
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
|
||||||
|
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
|
||||||
|
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
|
||||||
|
let messages = MessageRepository::new(pool.clone());
|
||||||
|
let user_repo = Arc::new(UserRepository::new(pool));
|
||||||
|
let llm_service = LlmService::new();
|
||||||
|
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
|
||||||
|
|
||||||
|
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
|
||||||
|
|
||||||
|
let token_code = "valid-token";
|
||||||
|
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
|
||||||
|
|
||||||
|
client.post("/api/signup")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"email": "test@example.com",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123",
|
||||||
|
"access_token": token_code
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Login to get JWT
|
||||||
|
let login_res = client.post("/api/login")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123"
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let auth_resp: serde_json::Value = serde_json::from_str(&login_res.into_string().await.unwrap()).unwrap();
|
||||||
|
let jwt = auth_resp["token"].as_str().unwrap();
|
||||||
|
|
||||||
|
// 1. Delete with WRONG password
|
||||||
|
let response = client.delete("/api/settings")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.header(rocket::http::Header::new("Authorization", format!("Bearer {}", jwt)))
|
||||||
|
.body(json!({
|
||||||
|
"password": "wrongpassword",
|
||||||
|
"totp_code": null
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Unauthorized);
|
||||||
|
|
||||||
|
// 2. Delete with CORRECT password
|
||||||
|
let response = client.delete("/api/settings")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.header(rocket::http::Header::new("Authorization", format!("Bearer {}", jwt)))
|
||||||
|
.body(json!({
|
||||||
|
"password": "password123",
|
||||||
|
"totp_code": null
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Ok);
|
||||||
|
|
||||||
|
// Verify user is gone
|
||||||
|
assert!(users.users.lock().unwrap().is_empty());
|
||||||
|
}
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
use backend::rocket_builder;
|
||||||
|
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
|
||||||
|
use backend::repo::message_repo::MessageRepository;
|
||||||
|
use backend::svc::chat_svc::ChatService;
|
||||||
|
use backend::repo::{Repo, AccessTokenRepoTrait};
|
||||||
|
use rocket::local::asynchronous::Client;
|
||||||
|
use rocket::http::{Status, ContentType, Header};
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use chrono::Utc;
|
||||||
|
use backend::svc::llm_service::LlmService;
|
||||||
|
|
||||||
|
async fn setup_client_with_svc(chat_service: ChatService, users: Arc<MockUserRepo>, tokens: Arc<MockTokenRepo>) -> (Client, String) {
|
||||||
|
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
|
||||||
|
|
||||||
|
// Create a user and get JWT
|
||||||
|
let token_code = "valid-token";
|
||||||
|
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
|
||||||
|
|
||||||
|
let jwt = {
|
||||||
|
let signup_res = client.post("/api/signup")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"email": "test@example.com",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123",
|
||||||
|
"access_token": token_code
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
assert_eq!(signup_res.status(), Status::Ok);
|
||||||
|
|
||||||
|
let login_res = client.post("/api/login")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123"
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(login_res.status(), Status::Ok, "Login failed");
|
||||||
|
|
||||||
|
let body = login_res.into_string().await.expect("login body");
|
||||||
|
let auth_resp: serde_json::Value = serde_json::from_str(&body).unwrap();
|
||||||
|
auth_resp["token"].as_str().unwrap().to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
(client, jwt)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_chat_event_stream_consistency() {
|
||||||
|
unsafe { std::env::set_var("JWT_SECRET", "test_secret"); }
|
||||||
|
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
|
||||||
|
let messages = <MessageRepository as Repo>::new(pool.clone());
|
||||||
|
let users_repo = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
|
||||||
|
let tokens_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
|
||||||
|
let llm_service = LlmService::new();
|
||||||
|
let chat_service = ChatService::new(1024, messages, users_repo.clone(), llm_service);
|
||||||
|
|
||||||
|
let (client, jwt) = setup_client_with_svc(chat_service.clone(), users_repo.clone(), tokens_repo.clone()).await;
|
||||||
|
|
||||||
|
// Use the same client for sender but with a different user (or the same, doesn't matter for broadcast)
|
||||||
|
// Actually, to simulate another user, we should sign up another user.
|
||||||
|
let jwt_sender = {
|
||||||
|
let token_code = "valid-token-2";
|
||||||
|
tokens_repo.create_new(1, "test2", token_code, 1, Utc::now(), Utc::now() + chrono::Duration::days(1)).await.unwrap();
|
||||||
|
let signup_res = client.post("/api/signup")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"email": "test2@example.com",
|
||||||
|
"username": "testuser2",
|
||||||
|
"password": "password123",
|
||||||
|
"access_token": token_code
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
assert_eq!(signup_res.status(), Status::Ok);
|
||||||
|
let login_res = client.post("/api/login")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"username": "testuser2",
|
||||||
|
"password": "password123"
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
let body = login_res.into_string().await.unwrap();
|
||||||
|
let auth_resp: serde_json::Value = serde_json::from_str(&body).unwrap();
|
||||||
|
auth_resp["token"].as_str().unwrap().to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let channel_id = 1;
|
||||||
|
|
||||||
|
// Start listening to the event stream
|
||||||
|
let mut response = client.get(format!("/api/events/{}", channel_id))
|
||||||
|
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Ok);
|
||||||
|
|
||||||
|
let num_messages = 5; // Reduced for faster debugging
|
||||||
|
let mut received_count = 0;
|
||||||
|
|
||||||
|
let jwt_clone = jwt.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
for i in 0..num_messages {
|
||||||
|
let msg = format!("Message {}", i);
|
||||||
|
let res = sender_client.post(format!("/api/chat/{}", channel_id))
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.header(Header::new("Authorization", format!("Bearer {}", jwt_clone)))
|
||||||
|
.body(json!({
|
||||||
|
"display_name": "testuser",
|
||||||
|
"user_id": 1,
|
||||||
|
"text": msg,
|
||||||
|
"timestamp": Utc::now()
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
assert_eq!(res.status(), Status::Ok);
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait a bit for messages to be posted
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||||
|
|
||||||
|
// Consume the stream
|
||||||
|
let text = response.into_string().await.unwrap();
|
||||||
|
println!("Received chunk: {}", text);
|
||||||
|
let mut received_count = 0;
|
||||||
|
for line in text.lines() {
|
||||||
|
if line.starts_with("data:") {
|
||||||
|
received_count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(received_count, num_messages, "Should receive all posted messages. Received: {}. Full text: {}", received_count, text);
|
||||||
|
}
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
use backend::rocket_builder;
|
||||||
|
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
|
||||||
|
use backend::repo::message_repo::MessageRepository;
|
||||||
|
use backend::svc::chat_svc::ChatService;
|
||||||
|
use backend::repo::user_repo::UserRepository;
|
||||||
|
use backend::repo::{Repo, AccessTokenRepoTrait};
|
||||||
|
use rocket::local::asynchronous::Client;
|
||||||
|
use rocket::http::{Status, ContentType, Header};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use chrono::Utc;
|
||||||
|
use backend::svc::llm_service::LlmService;
|
||||||
|
|
||||||
|
async fn setup_client() -> (Client, Arc<MockUserRepo>, String) {
|
||||||
|
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
|
||||||
|
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
|
||||||
|
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
|
||||||
|
let messages = MessageRepository::new(pool.clone());
|
||||||
|
let user_repo = Arc::new(UserRepository::new(pool));
|
||||||
|
let llm_service = LlmService::new();
|
||||||
|
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
|
||||||
|
|
||||||
|
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
|
||||||
|
|
||||||
|
// Create a user and get JWT
|
||||||
|
let token_code = "valid-token";
|
||||||
|
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
|
||||||
|
|
||||||
|
client.post("/api/signup")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"email": "test@example.com",
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123",
|
||||||
|
"access_token": token_code
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let login_res = client.post("/api/login")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "password123"
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let auth_resp: serde_json::Value = serde_json::from_str(&login_res.into_string().await.unwrap()).unwrap();
|
||||||
|
let jwt = auth_resp["token"].as_str().unwrap().to_string();
|
||||||
|
|
||||||
|
(client, users, jwt)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_change_display_name() {
|
||||||
|
let (client, users, jwt) = setup_client().await;
|
||||||
|
|
||||||
|
let response = client.patch("/api/settings/display_name")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
|
||||||
|
.body(json!({
|
||||||
|
"display_name": "New Display Name"
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Ok);
|
||||||
|
|
||||||
|
let user = users.users.lock().unwrap()[0].clone();
|
||||||
|
assert_eq!(user.nickname, Some("New Display Name".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_change_username() {
|
||||||
|
let (client, users, jwt) = setup_client().await;
|
||||||
|
|
||||||
|
let response = client.patch("/api/settings/username")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
|
||||||
|
.body(json!({
|
||||||
|
"username": "newusername"
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Ok);
|
||||||
|
|
||||||
|
let user = users.users.lock().unwrap()[0].clone();
|
||||||
|
assert_eq!(user.username, "newusername");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rocket::async_test]
|
||||||
|
async fn test_change_password() {
|
||||||
|
let (client, _, jwt) = setup_client().await;
|
||||||
|
|
||||||
|
let response = client.post("/api/settings/password")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
|
||||||
|
.body(json!({
|
||||||
|
"old_password": "password123",
|
||||||
|
"new_password": "newpassword456"
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(response.status(), Status::Ok);
|
||||||
|
|
||||||
|
// Verify login with new password
|
||||||
|
let login_res = client.post("/api/login")
|
||||||
|
.header(ContentType::JSON)
|
||||||
|
.body(json!({
|
||||||
|
"username": "testuser",
|
||||||
|
"password": "newpassword456"
|
||||||
|
}).to_string())
|
||||||
|
.dispatch()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(login_res.status(), Status::Ok);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user