This commit is contained in:
2025-10-04 14:32:13 +01:00
parent 1cfc5774ad
commit 7efac1ae33
10 changed files with 1114 additions and 94 deletions
+851 -61
View File
File diff suppressed because it is too large Load Diff
+3 -1
View File
@@ -4,9 +4,11 @@ version = "0.1.0"
edition = "2024"
[dependencies]
chrono = { version = "0.4.42", features = ["serde"] }
futures-util = "0.3.31"
reqwest = { version = "0.12.23", features = ["json"] }
rocket = { version = "0.5.1", features = ["json"] }
rocket_cors = "0.6.0"
rocket_db_pools = { version = "0.2.0", features = ["sqlx_sqlite"] }
rocket_db_pools = { version = "0.2.0", features = ["sqlx_postgres"] }
serde = { version = "1.0.228", features = ["derive"] }
tokio = { version = "1.47.1", features = ["full"] }
+30
View File
@@ -0,0 +1,30 @@
use rocket::{post, serde::json::Json};
use rocket_db_pools::{Connection, Database, sqlx};
use serde::{Deserialize, Serialize};
#[derive(Database)]
#[database("postgres_db")]
pub struct DbConn(sqlx::PgPool);
#[derive(Serialize, Deserialize)]
pub struct UserCredentials {
pub username: String,
pub password: String,
}
#[post("/signup", data = "<cred>")]
pub async fn signup(
conn: Connection<DbConn>,
cred: Json<UserCredentials>,
) -> Result<Json<String>, String> {
Ok(Json("Signup successful".to_string()))
}
#[post("/login", data = "<cred>")]
pub async fn login(
conn: Connection<DbConn>,
cred: Json<UserCredentials>,
) -> Result<Json<String>, String> {
// TODO: implement actual login logic, e.g. verify password and generate token
Ok(Json("Login successful".to_string()))
}
+48
View File
@@ -0,0 +1,48 @@
// src/llm.rs
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
struct LlmRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String, // "user" or "assistant"
content: String,
}
pub async fn query_llm(text: &str) -> Result<String, String> {
let client = reqwest::Client::new();
// Build the request body
let payload = LlmRequest {
model: "gemma2-9b-it".into(), // whatever model you run locally
messages: vec![Message {
role: "user".into(),
content: text.into(),
}],
};
// POST to lmstudio (default 127.0.0.1:1234)
let resp = client
.post("http://127.0.0.1:1234/v1/chat/completions")
.json(&payload)
.send()
.await
.unwrap();
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)]
struct LlmResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: Message,
}
let llm_resp: LlmResponse = resp.json().await.unwrap();
Ok(llm_resp.choices[0].message.content.clone())
}
+83 -10
View File
@@ -2,39 +2,47 @@
#[macro_use]
extern crate rocket;
use rocket::fairing::Fairing;
use rocket::http::Method;
use rocket::response::stream::{Event, EventStream};
use rocket::serde::json::Json;
use rocket::{Build, Rocket};
use rocket_cors::{AllowedOrigins, CorsOptions};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::broadcast;
use crate::llm::query_llm;
pub mod auth;
pub mod llm;
/// ---------- shared broadcaster ----------
struct ChatBroadcaster {
sender: broadcast::Sender<String>,
sender: broadcast::Sender<ChatMsg>,
}
impl ChatBroadcaster {
fn new(buffer_size: usize) -> Self {
let (sender, _rx) = broadcast::channel::<String>(buffer_size);
let (sender, _rx) = broadcast::channel::<ChatMsg>(buffer_size);
Self { sender }
}
async fn publish(&self, msg: String) {
async fn publish(&self, msg: ChatMsg) {
let _ = self.sender.send(msg);
}
fn subscribe(&self) -> broadcast::Receiver<String> {
fn subscribe(&self) -> broadcast::Receiver<ChatMsg> {
self.sender.subscribe()
}
}
/// ---------- Rocket routes ----------
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
struct ChatMsg {
userid: usize,
text: String,
timestamp: usize,
}
#[post("/chat", format = "json", data = "<msg>")]
@@ -42,8 +50,7 @@ async fn post_message(
msg: Json<ChatMsg>,
chat: &rocket::State<Arc<ChatBroadcaster>>,
) -> &'static str {
let text = msg.text.clone();
chat.publish(text).await;
chat.publish(msg.into_inner()).await;
"Message sent"
}
@@ -54,7 +61,7 @@ async fn event_stream(chat: &rocket::State<Arc<ChatBroadcaster>>) -> EventStream
EventStream! {
loop {
match rx.recv().await {
Ok(msg) => yield Event::data(msg),
Ok(msg) => yield Event::json(&msg),
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Event::comment("lagged");
}
@@ -64,6 +71,68 @@ async fn event_stream(chat: &rocket::State<Arc<ChatBroadcaster>>) -> EventStream
}
}
// ---------- LLM worker ----------
async fn start_llm_worker(chat: Arc<ChatBroadcaster>) {
let mut rx = chat.subscribe();
loop {
match rx.recv().await {
Ok(msg) => {
if msg.userid == 0 {
// ignore bot messages
continue;
}
let user_text = msg.text.clone();
let chat_clone = chat.clone();
rocket::tokio::spawn(async move {
match query_llm(&user_text).await {
Ok(reply) => {
let bot_msg = ChatMsg {
userid: 0,
text: reply,
timestamp: chrono::Local::now().timestamp() as usize,
};
chat_clone.publish(bot_msg).await;
}
Err(e) => eprintln!("LLM error: {}", e),
}
});
}
Err(_) => break, // channel closed
}
}
}
pub struct LlmWorkerFairing;
#[rocket::async_trait]
impl Fairing for LlmWorkerFairing {
fn info(&self) -> rocket::fairing::Info {
rocket::fairing::Info {
name: "LLM background worker",
kind: rocket::fairing::Kind::Ignite,
}
}
async fn on_ignite(&self, rocket: Rocket<Build>) -> rocket::fairing::Result {
// Grab the shared broadcaster from state
let chat = rocket
.state::<Arc<ChatBroadcaster>>()
.expect("ChatBroadcaster not managed");
// Clone it so we can move into async block
let chat_clone = Arc::clone(chat);
// Spawn the background worker **inside** on_ignite
tokio::spawn(async move {
start_llm_worker(chat_clone).await;
});
Ok(rocket)
}
}
/// ---------- launch ----------
#[launch]
fn rocket() -> Rocket<Build> {
@@ -82,5 +151,9 @@ fn rocket() -> Rocket<Build> {
rocket::build()
.manage(chat)
.attach(cors.to_cors().unwrap())
.mount("/", routes![post_message, event_stream])
.attach(LlmWorkerFairing {})
.mount(
"/",
routes![post_message, event_stream, auth::signup, auth::login],
)
}