minor refactoring
This commit is contained in:
@@ -6,6 +6,7 @@ edition = "2024"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
argon2 = "0.5.3"
|
argon2 = "0.5.3"
|
||||||
chrono = { version = "0.4.42", features = ["serde"] }
|
chrono = { version = "0.4.42", features = ["serde"] }
|
||||||
|
dotenv = "0.15.0"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
image = "0.25.8"
|
image = "0.25.8"
|
||||||
rand = "0.9.2"
|
rand = "0.9.2"
|
||||||
|
|||||||
+1
-1
@@ -1,7 +1,7 @@
|
|||||||
// src/llm.rs
|
// src/llm.rs
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::messages::ChatMsg;
|
use crate::messenger::ChatMsg;
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct LlmRequest {
|
struct LlmRequest {
|
||||||
|
|||||||
+16
-72
@@ -10,7 +10,8 @@ use rocket::{Build, Rocket};
|
|||||||
use rocket_cors::{AllowedOrigins, CorsOptions};
|
use rocket_cors::{AllowedOrigins, CorsOptions};
|
||||||
use rocket_db_pools::{Connection, Database};
|
use rocket_db_pools::{Connection, Database};
|
||||||
use rocket_dyn_templates::Template;
|
use rocket_dyn_templates::Template;
|
||||||
use std::sync::Arc;
|
use std::env;
|
||||||
|
use std::sync::{Arc, LazyLock};
|
||||||
|
|
||||||
use crate::auth::Session;
|
use crate::auth::Session;
|
||||||
use crate::db::{Postgres, Redis};
|
use crate::db::{Postgres, Redis};
|
||||||
@@ -20,33 +21,18 @@ pub mod cdn;
|
|||||||
pub mod db;
|
pub mod db;
|
||||||
pub mod handlers;
|
pub mod handlers;
|
||||||
pub mod llm;
|
pub mod llm;
|
||||||
pub mod messages;
|
pub mod messenger;
|
||||||
|
pub mod user;
|
||||||
|
|
||||||
#[get("/users", rank = 2)]
|
static LMSTUDIO_URL: LazyLock<String> =
|
||||||
async fn users(_ag: Session, mut db: Connection<Postgres>) -> Json<Vec<i32>> {
|
LazyLock::new(|| env::var("LMSTUDIO_URL").expect("Ensure LMSTUDIO_URL is set!"));
|
||||||
sqlx::query!("SELECT id FROM users")
|
|
||||||
.fetch_all(&mut **db)
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|_| Vec::new())
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| row.id)
|
|
||||||
.collect::<Vec<i32>>()
|
|
||||||
.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/users/<id>", rank = 1)]
|
|
||||||
async fn display_name(
|
|
||||||
id: usize,
|
|
||||||
_ag: Session,
|
|
||||||
mut pgsql_conn: Connection<Postgres>,
|
|
||||||
mut redis_conn: Connection<Redis>,
|
|
||||||
) -> String {
|
|
||||||
UserCache::username(id, &mut redis_conn, &mut pgsql_conn).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[launch]
|
#[launch]
|
||||||
fn rocket() -> Rocket<Build> {
|
fn rocket() -> Rocket<Build> {
|
||||||
let chat = Arc::new(crate::messages::ChatBroadcaster::new(32));
|
// make sure the env is loaded
|
||||||
|
dotenv::dotenv().expect("Failed to load env! aborting launch!");
|
||||||
|
|
||||||
|
let chat = Arc::new(crate::messenger::ChatBroadcaster::new(32));
|
||||||
|
|
||||||
let cors = CorsOptions::default()
|
let cors = CorsOptions::default()
|
||||||
.allowed_origins(AllowedOrigins::all())
|
.allowed_origins(AllowedOrigins::all())
|
||||||
@@ -70,7 +56,7 @@ fn rocket() -> Rocket<Build> {
|
|||||||
"/",
|
"/",
|
||||||
routes![
|
routes![
|
||||||
favicon,
|
favicon,
|
||||||
messages::chat_page,
|
messenger::chat_page,
|
||||||
auth::signup_page,
|
auth::signup_page,
|
||||||
auth::login_page,
|
auth::login_page,
|
||||||
auth::mfa_page,
|
auth::mfa_page,
|
||||||
@@ -81,11 +67,11 @@ fn rocket() -> Rocket<Build> {
|
|||||||
"/api",
|
"/api",
|
||||||
routes![
|
routes![
|
||||||
cdn::upload_profile_pic,
|
cdn::upload_profile_pic,
|
||||||
messages::get_messages,
|
messenger::get_messages,
|
||||||
messages::post_message,
|
messenger::post_message,
|
||||||
messages::event_stream,
|
messenger::event_stream,
|
||||||
users,
|
user::users,
|
||||||
display_name,
|
user::display_name,
|
||||||
auth::signup,
|
auth::signup,
|
||||||
auth::login,
|
auth::login,
|
||||||
auth::get_totp,
|
auth::get_totp,
|
||||||
@@ -107,45 +93,3 @@ fn rocket() -> Rocket<Build> {
|
|||||||
async fn favicon() -> NamedFile {
|
async fn favicon() -> NamedFile {
|
||||||
NamedFile::open("static/favicon.ico").await.unwrap()
|
NamedFile::open("static/favicon.ico").await.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct UserCache {}
|
|
||||||
|
|
||||||
impl UserCache {
|
|
||||||
pub async fn username(
|
|
||||||
id: usize,
|
|
||||||
redis_conn: &mut Connection<Redis>,
|
|
||||||
pgsql_conn: &mut Connection<Postgres>,
|
|
||||||
) -> String {
|
|
||||||
if let Ok(val) = cmd("GET")
|
|
||||||
.arg(&[format!("users:{id}")])
|
|
||||||
.query_async(&mut **redis_conn)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
|
|
||||||
.fetch_one(&mut ***pgsql_conn)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
let username = v.username;
|
|
||||||
Self::insert(id, &username, redis_conn).await;
|
|
||||||
username
|
|
||||||
} else {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
|
|
||||||
cmd("SET")
|
|
||||||
.arg(&[
|
|
||||||
format!("users:{id}"),
|
|
||||||
username.to_string(),
|
|
||||||
"EX".to_string(),
|
|
||||||
"1800".to_string(),
|
|
||||||
])
|
|
||||||
.query_async(&mut **conn)
|
|
||||||
.await
|
|
||||||
.expect("failed to insert key")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use rocket_db_pools::Connection;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
db::{Postgres, Redis},
|
db::{Postgres, Redis},
|
||||||
messages::ChatMsg,
|
messenger::ChatMsg,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper function to cache message in Redis
|
// Helper function to cache message in Redis
|
||||||
@@ -1,13 +1,11 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use redis::{AsyncCommands, cmd};
|
|
||||||
use rocket::{
|
use rocket::{
|
||||||
Shutdown,
|
Shutdown,
|
||||||
response::stream::{Event, EventStream},
|
response::stream::{Event, EventStream},
|
||||||
serde::json::Json,
|
serde::json::Json,
|
||||||
time::OffsetDateTime,
|
time::OffsetDateTime,
|
||||||
};
|
};
|
||||||
use rocket_cors::CorsOptions;
|
|
||||||
use rocket_db_pools::Connection;
|
use rocket_db_pools::Connection;
|
||||||
use rocket_dyn_templates::{Template, context};
|
use rocket_dyn_templates::{Template, context};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -22,21 +20,34 @@ use crate::{
|
|||||||
|
|
||||||
/// ---------- shared broadcaster ----------
|
/// ---------- shared broadcaster ----------
|
||||||
pub struct ChatBroadcaster {
|
pub struct ChatBroadcaster {
|
||||||
sender: broadcast::Sender<ChatMsg>,
|
buffer_size: usize,
|
||||||
|
senders: std::sync::Mutex<std::collections::HashMap<i32, broadcast::Sender<ChatMsg>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatBroadcaster {
|
impl ChatBroadcaster {
|
||||||
pub fn new(buffer_size: usize) -> Self {
|
pub fn new(buffer_size: usize) -> Self {
|
||||||
let (sender, _rx) = broadcast::channel::<ChatMsg>(buffer_size);
|
Self {
|
||||||
Self { sender }
|
buffer_size,
|
||||||
|
senders: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn publish(&self, msg: ChatMsg) {
|
/// Publish a message to the specified channel.
|
||||||
let _ = self.sender.send(msg);
|
pub async fn publish(&self, channel_id: i32, msg: ChatMsg) {
|
||||||
|
let mut map = self.senders.lock().unwrap();
|
||||||
|
let sender = map
|
||||||
|
.entry(channel_id)
|
||||||
|
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
|
||||||
|
let _ = sender.send(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn subscribe(&self) -> broadcast::Receiver<ChatMsg> {
|
/// Subscribe to the specified channel.
|
||||||
self.sender.subscribe()
|
pub fn subscribe(&self, channel_id: i32) -> broadcast::Receiver<ChatMsg> {
|
||||||
|
let mut map = self.senders.lock().unwrap();
|
||||||
|
let sender = map
|
||||||
|
.entry(channel_id)
|
||||||
|
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
|
||||||
|
sender.subscribe()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,18 +60,15 @@ pub struct ChatMsg {
|
|||||||
pub timestamp: usize,
|
pub timestamp: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/chat", format = "json", data = "<msg>")]
|
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
|
||||||
pub async fn post_message(
|
pub async fn post_message(
|
||||||
mut msg: Json<ChatMsg>,
|
mut msg: Json<ChatMsg>,
|
||||||
chat: &rocket::State<Arc<ChatBroadcaster>>,
|
chat: &rocket::State<Arc<ChatBroadcaster>>,
|
||||||
mut postgres: Connection<Postgres>,
|
mut postgres: Connection<Postgres>,
|
||||||
mut cache: Connection<Redis>,
|
mut cache: Option<Connection<Redis>>,
|
||||||
session: Session,
|
session: Session,
|
||||||
|
channel_id: i32,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
const CHANNEL_ID: i32 = 1;
|
|
||||||
let channel_id = CHANNEL_ID;
|
|
||||||
const LMSTUDIO_URI: &'static str = "http://127.0.0.1:1234/v1/chat/completions";
|
|
||||||
|
|
||||||
let chat = chat.inner().clone();
|
let chat = chat.inner().clone();
|
||||||
|
|
||||||
let display_name = sqlx::query!(
|
let display_name = sqlx::query!(
|
||||||
@@ -74,11 +82,11 @@ pub async fn post_message(
|
|||||||
|
|
||||||
msg.user_id = session.user_id;
|
msg.user_id = session.user_id;
|
||||||
msg.display_name = Some(display_name);
|
msg.display_name = Some(display_name);
|
||||||
chat.publish(msg.clone().into_inner()).await;
|
chat.publish(channel_id, msg.clone().into_inner()).await;
|
||||||
|
|
||||||
sqlx::query!(
|
sqlx::query!(
|
||||||
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
|
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
|
||||||
CHANNEL_ID,
|
channel_id,
|
||||||
msg.user_id as i32,
|
msg.user_id as i32,
|
||||||
msg.text,
|
msg.text,
|
||||||
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
|
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
|
||||||
@@ -87,22 +95,30 @@ pub async fn post_message(
|
|||||||
.await
|
.await
|
||||||
.map_err(|_| "Failed".to_string())?;
|
.map_err(|_| "Failed".to_string())?;
|
||||||
|
|
||||||
super::cache::insert(&mut cache, channel_id, &msg)
|
if let Some(ref mut cache) = cache {
|
||||||
.await
|
messenger::cache::insert(cache, channel_id, &msg)
|
||||||
.map_err(|_| "Redis cache failed".to_string())?;
|
.await
|
||||||
|
.map_err(|_| "Redis cache failed".to_string())?;
|
||||||
|
}
|
||||||
|
|
||||||
// get response
|
// get response
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let response = LlmWorker::new(LMSTUDIO_URI.to_string()).query(&msg).await;
|
let response = LlmWorker::new(crate::LMSTUDIO_URL.to_string())
|
||||||
|
.query(&msg)
|
||||||
|
.await;
|
||||||
|
|
||||||
if let Ok(reply) = response {
|
if let Ok(reply) = response {
|
||||||
chat.publish(reply.clone()).await;
|
chat.publish(channel_id, reply.clone()).await;
|
||||||
super::cache::insert(&mut cache, CHANNEL_ID, &reply)
|
|
||||||
.await
|
if let Some(ref mut cache) = cache {
|
||||||
.ok();
|
messenger::cache::insert(cache, channel_id, &reply)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
sqlx::query!(
|
sqlx::query!(
|
||||||
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
|
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
|
||||||
CHANNEL_ID,
|
channel_id,
|
||||||
reply.user_id as i32,
|
reply.user_id as i32,
|
||||||
reply.text,
|
reply.text,
|
||||||
OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
|
OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
|
||||||
@@ -126,17 +142,17 @@ pub async fn get_messages(
|
|||||||
const CHANNEL_ID: i32 = 1;
|
const CHANNEL_ID: i32 = 1;
|
||||||
let channel_id = CHANNEL_ID;
|
let channel_id = CHANNEL_ID;
|
||||||
|
|
||||||
if let Ok(messages) = super::cache::get(&mut redis, channel_id).await
|
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
|
||||||
&& !messages.is_empty()
|
&& !messages.is_empty()
|
||||||
{
|
{
|
||||||
return Json(messages);
|
return Json(messages);
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(x) = super::cache::initialise(&mut redis, &mut db, channel_id).await {
|
if let Err(x) = messenger::cache::initialise(&mut redis, &mut db, channel_id).await {
|
||||||
eprintln!("WARN: {x:?}");
|
eprintln!("WARN: {x:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(messages) = super::cache::get(&mut redis, channel_id).await
|
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
|
||||||
&& !messages.is_empty()
|
&& !messages.is_empty()
|
||||||
{
|
{
|
||||||
return Json(messages);
|
return Json(messages);
|
||||||
@@ -165,15 +181,16 @@ pub async fn get_messages(
|
|||||||
Json(res)
|
Json(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/events")]
|
#[get("/events/<channel_id>")]
|
||||||
pub async fn event_stream(
|
pub async fn event_stream(
|
||||||
chat: &rocket::State<Arc<ChatBroadcaster>>,
|
chat: &rocket::State<Arc<ChatBroadcaster>>,
|
||||||
postgres: Connection<Postgres>,
|
postgres: Connection<Postgres>,
|
||||||
cache: Connection<Redis>,
|
cache: Connection<Redis>,
|
||||||
ag: Session,
|
ag: Session,
|
||||||
mut shutdown: Shutdown,
|
mut shutdown: Shutdown,
|
||||||
|
channel_id: i32,
|
||||||
) -> EventStream![] {
|
) -> EventStream![] {
|
||||||
let mut rx = chat.subscribe();
|
let mut rx = chat.subscribe(channel_id);
|
||||||
|
|
||||||
EventStream! {
|
EventStream! {
|
||||||
// Initialize the stream with the last 100 messages
|
// Initialize the stream with the last 100 messages
|
||||||
@@ -202,8 +219,3 @@ pub async fn event_stream(
|
|||||||
pub async fn chat_page(session: Session) -> Template {
|
pub async fn chat_page(session: Session) -> Template {
|
||||||
Template::render("chat", context!(user_id: session.user_id))
|
Template::render("chat", context!(user_id: session.user_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/chatpreview")]
|
|
||||||
pub async fn chat_page_preview(session: Session) -> Template {
|
|
||||||
Template::render("chatpreview", context!(user_id: session.user_id))
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
use redis::cmd;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
|
use rocket_db_pools::Connection;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::Session,
|
||||||
|
db::{Postgres, Redis},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[get("/users", rank = 2)]
|
||||||
|
pub async fn users(_ag: Session, mut db: Connection<Postgres>) -> Json<Vec<i32>> {
|
||||||
|
sqlx::query!("SELECT id FROM users")
|
||||||
|
.fetch_all(&mut **db)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| Vec::new())
|
||||||
|
.into_iter()
|
||||||
|
.map(|row| row.id)
|
||||||
|
.collect::<Vec<i32>>()
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/users/<id>", rank = 1)]
|
||||||
|
pub async fn display_name(
|
||||||
|
id: usize,
|
||||||
|
_ag: Session,
|
||||||
|
mut pgsql_conn: Connection<Postgres>,
|
||||||
|
mut redis_conn: Connection<Redis>,
|
||||||
|
) -> String {
|
||||||
|
UserCache::username(id, &mut redis_conn, &mut pgsql_conn).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct UserCache {}
|
||||||
|
|
||||||
|
impl UserCache {
|
||||||
|
pub async fn username(
|
||||||
|
id: usize,
|
||||||
|
redis_conn: &mut Connection<Redis>,
|
||||||
|
pgsql_conn: &mut Connection<Postgres>,
|
||||||
|
) -> String {
|
||||||
|
if let Ok(val) = cmd("GET")
|
||||||
|
.arg(&[format!("users:{id}")])
|
||||||
|
.query_async(&mut **redis_conn)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
|
||||||
|
.fetch_one(&mut ***pgsql_conn)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
let username = v.username;
|
||||||
|
Self::insert(id, &username, redis_conn).await;
|
||||||
|
username
|
||||||
|
} else {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
|
||||||
|
cmd("SET")
|
||||||
|
.arg(&[
|
||||||
|
format!("users:{id}"),
|
||||||
|
username.to_string(),
|
||||||
|
"EX".to_string(),
|
||||||
|
"1800".to_string(),
|
||||||
|
])
|
||||||
|
.query_async(&mut **conn)
|
||||||
|
.await
|
||||||
|
.expect("failed to insert key")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user