// src/main.rs #[macro_use] extern crate rocket; use rocket::fairing::Fairing; use rocket::http::Method; use rocket::response::stream::{Event, EventStream}; use rocket::serde::json::Json; use rocket::{Build, Rocket}; use rocket_cors::{AllowedOrigins, CorsOptions}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::broadcast; use crate::llm::query_llm; pub mod auth; pub mod llm; /// ---------- shared broadcaster ---------- struct ChatBroadcaster { sender: broadcast::Sender, } impl ChatBroadcaster { fn new(buffer_size: usize) -> Self { let (sender, _rx) = broadcast::channel::(buffer_size); Self { sender } } async fn publish(&self, msg: ChatMsg) { let _ = self.sender.send(msg); } fn subscribe(&self) -> broadcast::Receiver { self.sender.subscribe() } } /// ---------- Rocket routes ---------- #[derive(Debug, Serialize, Deserialize, Clone)] struct ChatMsg { userid: usize, text: String, timestamp: usize, } #[post("/chat", format = "json", data = "")] async fn post_message( msg: Json, chat: &rocket::State>, ) -> &'static str { chat.publish(msg.into_inner()).await; "Message sent" } #[get("/events")] async fn event_stream(chat: &rocket::State>) -> EventStream![] { let mut rx = chat.subscribe(); EventStream! { loop { match rx.recv().await { Ok(msg) => yield Event::json(&msg), Err(broadcast::error::RecvError::Lagged(_)) => { yield Event::comment("lagged"); } Err(broadcast::error::RecvError::Closed) => break, } } } } // ---------- LLM worker ---------- async fn start_llm_worker(chat: Arc) { let mut rx = chat.subscribe(); loop { match rx.recv().await { Ok(msg) => { if msg.userid == 0 { // ignore bot messages continue; } let user_text = msg.text.clone(); let chat_clone = chat.clone(); rocket::tokio::spawn(async move { match query_llm(&user_text).await { Ok(reply) => { let bot_msg = ChatMsg { userid: 0, text: reply, timestamp: chrono::Local::now().timestamp() as usize, }; chat_clone.publish(bot_msg).await; } Err(e) => eprintln!("LLM error: {}", e), } }); } Err(_) => break, // channel closed } } } pub struct LlmWorkerFairing; #[rocket::async_trait] impl Fairing for LlmWorkerFairing { fn info(&self) -> rocket::fairing::Info { rocket::fairing::Info { name: "LLM background worker", kind: rocket::fairing::Kind::Ignite, } } async fn on_ignite(&self, rocket: Rocket) -> rocket::fairing::Result { // Grab the shared broadcaster from state let chat = rocket .state::>() .expect("ChatBroadcaster not managed"); // Clone it so we can move into async block let chat_clone = Arc::clone(chat); // Spawn the background worker **inside** on_ignite tokio::spawn(async move { start_llm_worker(chat_clone).await; }); Ok(rocket) } } /// ---------- launch ---------- #[launch] fn rocket() -> Rocket { let chat = Arc::new(ChatBroadcaster::new(32)); let cors = CorsOptions::default() .allowed_origins(AllowedOrigins::all()) .allowed_methods( vec![Method::Get, Method::Post, Method::Patch] .into_iter() .map(From::from) .collect(), ) .allow_credentials(true); rocket::build() .manage(chat) .attach(cors.to_cors().unwrap()) .attach(LlmWorkerFairing {}) .mount( "/", routes![post_message, event_stream, auth::signup, auth::login], ) }