Files
chatapp/backend/src/main.rs
T
2025-10-04 14:32:13 +01:00

160 lines
4.3 KiB
Rust

// src/main.rs
#[macro_use]
extern crate rocket;
use rocket::fairing::Fairing;
use rocket::http::Method;
use rocket::response::stream::{Event, EventStream};
use rocket::serde::json::Json;
use rocket::{Build, Rocket};
use rocket_cors::{AllowedOrigins, CorsOptions};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::broadcast;
use crate::llm::query_llm;
pub mod auth;
pub mod llm;
/// ---------- shared broadcaster ----------
struct ChatBroadcaster {
sender: broadcast::Sender<ChatMsg>,
}
impl ChatBroadcaster {
fn new(buffer_size: usize) -> Self {
let (sender, _rx) = broadcast::channel::<ChatMsg>(buffer_size);
Self { sender }
}
async fn publish(&self, msg: ChatMsg) {
let _ = self.sender.send(msg);
}
fn subscribe(&self) -> broadcast::Receiver<ChatMsg> {
self.sender.subscribe()
}
}
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone)]
struct ChatMsg {
userid: usize,
text: String,
timestamp: usize,
}
#[post("/chat", format = "json", data = "<msg>")]
async fn post_message(
msg: Json<ChatMsg>,
chat: &rocket::State<Arc<ChatBroadcaster>>,
) -> &'static str {
chat.publish(msg.into_inner()).await;
"Message sent"
}
#[get("/events")]
async fn event_stream(chat: &rocket::State<Arc<ChatBroadcaster>>) -> EventStream![] {
let mut rx = chat.subscribe();
EventStream! {
loop {
match rx.recv().await {
Ok(msg) => yield Event::json(&msg),
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Event::comment("lagged");
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}
}
// ---------- LLM worker ----------
async fn start_llm_worker(chat: Arc<ChatBroadcaster>) {
let mut rx = chat.subscribe();
loop {
match rx.recv().await {
Ok(msg) => {
if msg.userid == 0 {
// ignore bot messages
continue;
}
let user_text = msg.text.clone();
let chat_clone = chat.clone();
rocket::tokio::spawn(async move {
match query_llm(&user_text).await {
Ok(reply) => {
let bot_msg = ChatMsg {
userid: 0,
text: reply,
timestamp: chrono::Local::now().timestamp() as usize,
};
chat_clone.publish(bot_msg).await;
}
Err(e) => eprintln!("LLM error: {}", e),
}
});
}
Err(_) => break, // channel closed
}
}
}
pub struct LlmWorkerFairing;
#[rocket::async_trait]
impl Fairing for LlmWorkerFairing {
fn info(&self) -> rocket::fairing::Info {
rocket::fairing::Info {
name: "LLM background worker",
kind: rocket::fairing::Kind::Ignite,
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> rocket::fairing::Result {
// Grab the shared broadcaster from state
let chat = rocket
.state::<Arc<ChatBroadcaster>>()
.expect("ChatBroadcaster not managed");
// Clone it so we can move into async block
let chat_clone = Arc::clone(chat);
// Spawn the background worker **inside** on_ignite
tokio::spawn(async move {
start_llm_worker(chat_clone).await;
});
Ok(rocket)
}
}
/// ---------- launch ----------
#[launch]
fn rocket() -> Rocket<Build> {
let chat = Arc::new(ChatBroadcaster::new(32));
let cors = CorsOptions::default()
.allowed_origins(AllowedOrigins::all())
.allowed_methods(
vec![Method::Get, Method::Post, Method::Patch]
.into_iter()
.map(From::from)
.collect(),
)
.allow_credentials(true);
rocket::build()
.manage(chat)
.attach(cors.to_cors().unwrap())
.attach(LlmWorkerFairing {})
.mount(
"/",
routes![post_message, event_stream, auth::signup, auth::login],
)
}