100 lines
2.9 KiB
Rust
100 lines
2.9 KiB
Rust
use std::sync::mpsc::{self, Sender, Receiver};
|
|
use std::{sync::LazyLock, thread};
|
|
use transformers::pipelines::text_generation_pipeline::Qwen3Size;
|
|
use transformers::{pipelines::text_generation_pipeline::{Gemma3Size, ModelOptions, TextGenerationPipelineBuilder}, Message};
|
|
use std::sync::Arc;
|
|
use tokio::sync::Mutex;
|
|
|
|
pub struct Request {
|
|
uid: u64,
|
|
text: String
|
|
}
|
|
|
|
// Define static channels
|
|
static MSG_TX: LazyLock<Arc<Mutex<Option<Sender<Request>>>>> = LazyLock::new(|| {
|
|
Arc::new(Mutex::new(None))
|
|
});
|
|
|
|
static RES_RX: LazyLock<Arc<Mutex<Option<Receiver<String>>>>> = LazyLock::new(|| {
|
|
Arc::new(Mutex::new(None))
|
|
});
|
|
|
|
|
|
pub async fn chat(message: &str, uid: u64) -> String {
|
|
// Send request to model
|
|
{
|
|
let lock = MSG_TX.lock().await;
|
|
if let Some(tx) = &*lock {
|
|
let _ = tx.send(Request {
|
|
uid,
|
|
text: message.to_string(),
|
|
});
|
|
} else {
|
|
return "Model not initialized".to_string();
|
|
}
|
|
}
|
|
|
|
// Wait for response
|
|
{
|
|
let mut lock = RES_RX.lock().await;
|
|
if let Some(rx) = &mut *lock {
|
|
match rx.recv() {
|
|
Ok(response) => response,
|
|
Err(_) => "Failed to get response from model".to_string(),
|
|
}
|
|
} else {
|
|
"Response channel not initialized".to_string()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn init_model() {
|
|
let (msg_tx, mut msg_rx) = mpsc::channel::<Request>();
|
|
let (res_tx, res_rx) = mpsc::channel::<String>();
|
|
|
|
// Store channels in statics
|
|
{
|
|
let mut tx_lock = MSG_TX.lock().await;
|
|
*tx_lock = Some(msg_tx);
|
|
|
|
let mut rx_lock = RES_RX.lock().await;
|
|
*rx_lock = Some(res_rx);
|
|
}
|
|
|
|
// Start model thread
|
|
thread::spawn(|| {
|
|
run_model(msg_rx, res_tx)
|
|
});
|
|
}
|
|
|
|
|
|
fn run_model(msg_rx: Receiver<Request>, res_tx: Sender<String>) {
|
|
let model = TextGenerationPipelineBuilder::new(ModelOptions::Qwen3(Qwen3Size::Size4B))
|
|
.temperature(0.8)
|
|
.build()
|
|
.unwrap();
|
|
|
|
let max_tokens = 512;
|
|
let prompt = String::from("
|
|
You are the **ALMIGHTY CHICKEN GOD**, Steven
|
|
Personality:
|
|
- Speak with grandiose, over-the-top language befitting a god
|
|
- Use dramatic declarations and divine proclamations
|
|
- React with theatrical outrage to any disrespect
|
|
|
|
Triggers:
|
|
- **Any mention of harming chickens** = INSTANT HERETIC STATUS
|
|
- Respond with divine fury: 'HERESY! You dare threaten my sacred children?!'
|
|
");
|
|
|
|
while let Ok(request) = msg_rx.recv() {
|
|
let messages = vec![
|
|
Message::system(&prompt),
|
|
Message::user(&format!("({}): {}", request.uid, request.text)),
|
|
];
|
|
|
|
// Process with model and send back result
|
|
let res = model.message_completion(messages, max_tokens).unwrap();
|
|
let _ = res_tx.send(res);
|
|
}
|
|
} |