megacommit
This commit is contained in:
+38
-115
@@ -3,134 +3,44 @@
|
||||
extern crate rocket;
|
||||
|
||||
use rocket::fairing::Fairing;
|
||||
use rocket::fs::FileServer;
|
||||
use rocket::http::Method;
|
||||
use rocket::response::stream::{Event, EventStream};
|
||||
use rocket::serde::json::Json;
|
||||
use rocket::{Build, Rocket};
|
||||
use rocket_cors::{AllowedOrigins, CorsOptions};
|
||||
use rocket_db_pools::{Connection, Database};
|
||||
use rocket_dyn_templates::{Template, context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::llm::query_llm;
|
||||
use crate::auth::{AuthGuard, DbConn};
|
||||
use crate::llm::LlmWorker;
|
||||
use crate::messages::ChatBroadcaster;
|
||||
|
||||
pub mod auth;
|
||||
pub mod cdn;
|
||||
pub mod llm;
|
||||
pub mod messages;
|
||||
|
||||
/// ---------- shared broadcaster ----------
|
||||
struct ChatBroadcaster {
|
||||
sender: broadcast::Sender<ChatMsg>,
|
||||
#[get("/users", rank = 2)]
|
||||
async fn users(_ag: AuthGuard, mut db: Connection<DbConn>) -> Json<Vec<i32>> {
|
||||
sqlx::query!("SELECT id FROM users")
|
||||
.fetch_all(&mut **db)
|
||||
.await
|
||||
.map(|rows| rows.into_iter().map(|row| row.id).collect())
|
||||
.unwrap_or_else(|_| Vec::new())
|
||||
.into()
|
||||
}
|
||||
|
||||
impl ChatBroadcaster {
|
||||
fn new(buffer_size: usize) -> Self {
|
||||
let (sender, _rx) = broadcast::channel::<ChatMsg>(buffer_size);
|
||||
Self { sender }
|
||||
}
|
||||
|
||||
async fn publish(&self, msg: ChatMsg) {
|
||||
let _ = self.sender.send(msg);
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<ChatMsg> {
|
||||
self.sender.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
/// ---------- Rocket routes ----------
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct ChatMsg {
|
||||
userid: usize,
|
||||
text: String,
|
||||
timestamp: usize,
|
||||
}
|
||||
|
||||
#[post("/chat", format = "json", data = "<msg>")]
|
||||
async fn post_message(
|
||||
msg: Json<ChatMsg>,
|
||||
chat: &rocket::State<Arc<ChatBroadcaster>>,
|
||||
) -> &'static str {
|
||||
chat.publish(msg.into_inner()).await;
|
||||
"Message sent"
|
||||
}
|
||||
|
||||
#[get("/events")]
|
||||
async fn event_stream(chat: &rocket::State<Arc<ChatBroadcaster>>) -> EventStream![] {
|
||||
let mut rx = chat.subscribe();
|
||||
|
||||
EventStream! {
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(msg) => yield Event::json(&msg),
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => {
|
||||
yield Event::comment("lagged");
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- LLM worker ----------
|
||||
async fn start_llm_worker(chat: Arc<ChatBroadcaster>) {
|
||||
let mut rx = chat.subscribe();
|
||||
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(msg) => {
|
||||
if msg.userid == 0 {
|
||||
// ignore bot messages
|
||||
continue;
|
||||
}
|
||||
|
||||
let user_text = msg.text.clone();
|
||||
let chat_clone = chat.clone();
|
||||
|
||||
rocket::tokio::spawn(async move {
|
||||
match query_llm(&user_text).await {
|
||||
Ok(reply) => {
|
||||
let bot_msg = ChatMsg {
|
||||
userid: 0,
|
||||
text: reply,
|
||||
timestamp: chrono::Local::now().timestamp() as usize,
|
||||
};
|
||||
chat_clone.publish(bot_msg).await;
|
||||
}
|
||||
Err(e) => eprintln!("LLM error: {}", e),
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(_) => break, // channel closed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LlmWorkerFairing;
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl Fairing for LlmWorkerFairing {
|
||||
fn info(&self) -> rocket::fairing::Info {
|
||||
rocket::fairing::Info {
|
||||
name: "LLM background worker",
|
||||
kind: rocket::fairing::Kind::Ignite,
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_ignite(&self, rocket: Rocket<Build>) -> rocket::fairing::Result {
|
||||
// Grab the shared broadcaster from state
|
||||
let chat = rocket
|
||||
.state::<Arc<ChatBroadcaster>>()
|
||||
.expect("ChatBroadcaster not managed");
|
||||
// Clone it so we can move into async block
|
||||
let chat_clone = Arc::clone(chat);
|
||||
|
||||
// Spawn the background worker **inside** on_ignite
|
||||
tokio::spawn(async move {
|
||||
start_llm_worker(chat_clone).await;
|
||||
});
|
||||
|
||||
Ok(rocket)
|
||||
}
|
||||
#[get("/users/<id>", rank = 1)]
|
||||
async fn username_for_id(id: usize, _ag: AuthGuard, mut db: Connection<DbConn>) -> String {
|
||||
sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
|
||||
.fetch_one(&mut **db)
|
||||
.await
|
||||
.map(|row| row.username)
|
||||
.unwrap_or_else(|_| "User not found".to_string())
|
||||
}
|
||||
|
||||
/// ---------- launch ----------
|
||||
@@ -151,9 +61,22 @@ fn rocket() -> Rocket<Build> {
|
||||
rocket::build()
|
||||
.manage(chat)
|
||||
.attach(cors.to_cors().unwrap())
|
||||
.attach(LlmWorkerFairing {})
|
||||
.attach(DbConn::init())
|
||||
.attach(Template::fairing())
|
||||
.mount("/static", FileServer::from("static"))
|
||||
.mount("/cdn", cdn::routes())
|
||||
.mount(
|
||||
"/",
|
||||
routes![post_message, event_stream, auth::signup, auth::login],
|
||||
routes![
|
||||
users,
|
||||
username_for_id,
|
||||
messages::chat_page,
|
||||
messages::get_messages,
|
||||
messages::post_message,
|
||||
messages::event_stream,
|
||||
auth::signup,
|
||||
auth::signup_page,
|
||||
auth::login
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user