Files
worldcoder/src/llm_integration/content_llm.rs
T
zxq5 c21819e786
Continuous integration / build (push) Failing after 1m48s
fixed cargo warns
2025-08-09 22:42:42 +01:00

499 lines
16 KiB
Rust

use std::{
io::{BufRead, BufReader},
sync::{Arc, Mutex},
};
use serde::{Deserialize, Serialize};
use crate::editors::settings_editor::ProjectSettings;
#[derive(Clone)]
pub enum ContentAI {
Summarise {
open: bool,
content: String,
result: Arc<Mutex<String>>,
ready: Arc<Mutex<ReadyState>>,
},
Continue {
open: bool,
content: String,
instruction: String,
max_tokens: usize,
context_override: String,
result: Arc<Mutex<String>>,
ready: Arc<Mutex<ReadyState>>,
temperature: f32,
reasoning_effort: ReasoningEffort,
model_override: String,
},
}
impl ContentAI {
pub fn ui(&mut self, ui: &mut egui::Ui, project: &mut ProjectSettings) {
let mut is_open = *match self {
ContentAI::Summarise { open, .. } => open,
ContentAI::Continue { open, .. } => open,
};
if is_open {
egui::Window::new("AI Assistant")
.open(&mut is_open)
.show(ui.ctx(), |ui| match self {
ContentAI::Summarise { .. } => {
Self::ui_summarise(self, ui, project);
}
ContentAI::Continue { .. } => {
Self::ui_continue(self, ui, project);
}
});
}
match self {
ContentAI::Summarise { open, .. } => *open = is_open,
ContentAI::Continue { open, .. } => *open = is_open,
};
}
fn ui_summarise(&mut self, ui: &mut egui::Ui, _project: &mut ProjectSettings) {
if let ContentAI::Summarise {
content,
result,
ready,
..
} = self
{
egui::ScrollArea::vertical()
.id_salt("summarise")
.auto_shrink([false, false])
.max_width(ui.available_width())
.show(ui, |ui| {
ui.add(
egui::TextEdit::multiline(content)
.frame(false)
.interactive(false),
);
});
ui.add(
egui::TextEdit::multiline(&mut *result.lock().unwrap())
.font(egui::TextStyle::Monospace)
.interactive(false)
.frame(false)
.lock_focus(true)
.hint_text("Summary will appear here..."),
);
if ui.button("Summarise").clicked() {
// Self::summarise(content, result.clone());
*ready.lock().unwrap() = ReadyState::Generating;
}
}
}
fn ui_continue(&mut self, ui: &mut egui::Ui, project: &mut ProjectSettings) {
if let ContentAI::Continue {
content,
instruction,
max_tokens,
context_override,
result,
ready,
temperature,
model_override,
reasoning_effort,
..
} = self
{
ui.weak("(The model will see current file content)");
ui.separator();
// Instructions
egui::ScrollArea::both()
.id_salt("continue_instruction")
.auto_shrink([true, false])
.max_height(ui.available_height() / 4.0)
.max_width(ui.available_width())
.show(ui, |ui| {
ui.add(
egui::TextEdit::multiline(instruction)
.frame(false)
.desired_width(ui.available_width())
.hint_text("Writing Instructions"),
);
});
ui.separator();
// Context
egui::ScrollArea::both()
.id_salt("continue_context")
.auto_shrink([true, false])
.max_height(ui.available_height() / 4.0)
.max_width(ui.available_width())
.show(ui, |ui| {
ui.add(
egui::TextEdit::multiline(context_override)
.frame(false)
.desired_width(ui.available_width())
.hint_text("Any additional context?"),
);
});
ui.separator();
egui::Grid::new("continue_grid")
.num_columns(2)
.striped(true)
.show(ui, |ui| {
ui.label("Max Tokens");
ui.add(
egui::DragValue::new(max_tokens)
.range(128..=u32::MAX)
.speed(128),
);
ui.end_row();
ui.label("Temperature");
ui.add(
egui::DragValue::new(temperature)
.range(0.0..=2.0)
.speed(0.1),
);
ui.label("Reasoning effort");
egui::ComboBox::from_id_salt("reasoning_effort")
.selected_text(reasoning_effort.to_string())
.show_ui(ui, |ui| {
ui.selectable_value(
reasoning_effort,
ReasoningEffort::Minimal,
"Minimal",
);
ui.selectable_value(reasoning_effort, ReasoningEffort::Low, "Low");
ui.selectable_value(
reasoning_effort,
ReasoningEffort::Medium,
"Medium",
);
ui.selectable_value(reasoning_effort, ReasoningEffort::High, "High");
});
ui.end_row();
ui.label("Model override");
ui.add(egui::TextEdit::singleline(model_override));
ui.end_row();
});
ui.separator();
let mut ready_lock = ready.lock().unwrap();
match *ready_lock {
ReadyState::Idle => {
let continue_content = || {
let context_override = context_override.clone();
let content = content.clone();
let instruction = instruction.clone();
let project = project.clone();
let ai_context = project.ai_context.clone();
let result = result.clone();
let ready = ready.clone();
let reasoning_effort = reasoning_effort;
let options = AIOptions {
max_completion_tokens: *max_tokens,
reasoning_effort: *reasoning_effort,
temperature: *temperature,
model_override: if !model_override.is_empty() {
Some(model_override.clone())
} else {
None
},
};
result.lock().unwrap().clear();
std::thread::spawn(move || {
let result = crate::llm_integration::content_llm::continue_content(
ai_context + "\n" + &context_override,
content,
instruction,
options,
project,
result,
ready.clone(),
);
if let Err(e) = result {
eprintln!("Error in content generation: {e}");
}
});
};
ui.horizontal(|ui| {
if ui.button("Generate ").clicked() {
continue_content();
}
ui.label("Idle");
});
}
ReadyState::Generating => {
ui.horizontal(|ui| {
if ui.button("Cancel").clicked() {
*ready_lock = ReadyState::Halted;
}
if ui.button("Stop").clicked() {
*ready_lock = ReadyState::Idle;
}
ui.spinner();
ui.label("Generating...");
});
}
ReadyState::Halted => {}
ReadyState::Ready => {}
}
egui::ScrollArea::both()
.auto_shrink([true, true])
.id_salt("llm_output")
.max_width(ui.available_width())
.max_height(ui.available_height() / 4.0)
.show(ui, |ui| {
ui.add(
egui::TextEdit::multiline(&mut *result.lock().unwrap())
.font(egui::TextStyle::Monospace)
.interactive(false)
.desired_rows(0)
.frame(false)
.desired_width(ui.available_width())
.lock_focus(true)
.hint_text("Content will appear here..."),
);
});
ui.separator();
ui.horizontal(|ui| {
if ui.button("Insert").clicked() {
*ready_lock = ReadyState::Ready;
}
if ui.button("Clear").clicked() {
result.lock().unwrap().clear();
}
});
}
}
}
pub fn continue_content(
context: String,
previous_content: String,
instruction: String,
options: AIOptions,
project: ProjectSettings,
result: Arc<Mutex<String>>,
ready: Arc<Mutex<ReadyState>>,
) -> Result<(), Box<dyn std::error::Error>> {
*ready.lock().unwrap() = ReadyState::Generating;
let client = reqwest::blocking::Client::new();
let messages = vec![
Message {
role: "system".to_string(),
content: "
Please generate content that is a continuation of the given text.
Your response should be a logical next step in the content and should not repeat any of the text from the instruction or the content.
Do not generate any text that is not a direct continuation of the content.
if extra instructions are provided, follow them exactly, otherwise continue the text in a logical way.
your output should NEVER be a repeat of any previous content
".to_string(),
},
Message {
role: "user".to_string(),
content: format!("Context / General instructions: {context}"),
},
Message {
role: "user".to_string(),
content: format!("Content to continue: {previous_content}"),
},
Message {
role: "user".to_string(),
content: format!("Specific instructions: {instruction}"),
},
];
let request = ChatRequest {
messages,
temperature: options.temperature,
max_tokens: options.max_completion_tokens,
model: options.model_override,
reasoning_effort: options.reasoning_effort,
stream: true,
};
let llm_api_uri = if let Some(uri) = project.local_overrides.llm_api_uri {
uri
} else {
project.global_settings.llm_api_uri.unwrap()
};
let api_key = if let Some(key) = project.local_overrides.llm_api_key {
if key.is_empty() { None } else { Some(key) }
} else if let Some(key) = project.global_settings.llm_api_key {
if key.is_empty() { None } else { Some(key) }
} else {
return Err("No API key found".into());
};
let response = if let Some(k) = api_key {
client
.post(llm_api_uri + "/v1/chat/completions")
.json(&request)
.bearer_auth(k)
.send()?
} else {
client
.post(llm_api_uri + "/v1/chat/completions")
.json(&request)
.send()?
};
println!("success!");
let reader = BufReader::new(response);
for line in reader.lines() {
// initial loop to check if the user has terminated the generation
{
let mut ready = ready.lock().unwrap();
if *ready == ReadyState::Halted {
result.lock().unwrap().clear();
}
if *ready != ReadyState::Generating {
*ready = ReadyState::Idle;
break;
}
}
let line = line?;
if line == "data: [DONE]" {
break;
}
if let Some(json) = line.strip_prefix("data: ") {
if let Ok(chunk) = serde_json::from_str::<StreamingChatResponse>(json) {
if let Some(content) = chunk.choices[0].delta.content.as_ref() {
result.lock().unwrap().push_str(content);
}
}
}
}
*ready.lock().unwrap() = ReadyState::Idle;
Ok(())
}
pub struct AIOptions {
pub max_completion_tokens: usize,
pub temperature: f32,
pub reasoning_effort: ReasoningEffort,
pub model_override: Option<String>,
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum ReadyState {
Idle,
Generating,
Ready,
Halted,
}
#[derive(Serialize, Copy, Clone, PartialEq)]
pub enum ReasoningEffort {
#[serde(rename = "minimal")]
Minimal,
#[serde(rename = "low")]
Low,
#[serde(rename = "medium")]
Medium,
#[serde(rename = "high")]
High,
}
impl Default for ReasoningEffort {
fn default() -> Self {
ReasoningEffort::Low
}
}
impl ToString for ReasoningEffort {
fn to_string(&self) -> String {
match self {
ReasoningEffort::Minimal => "Minimal".to_string(),
ReasoningEffort::Low => "Low".to_string(),
ReasoningEffort::Medium => "Medium".to_string(),
ReasoningEffort::High => "High".to_string(),
}
}
}
// Simple request structure
#[derive(Serialize)]
struct ChatRequest {
messages: Vec<Message>,
temperature: f32,
max_tokens: usize,
stream: bool,
reasoning_effort: ReasoningEffort,
// if we give the API model:null it returns 500
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<String>,
}
// Streaming response structures
#[derive(Deserialize, Debug)]
struct StreamingChatResponse {
choices: Vec<StreamingChoice>,
}
#[derive(Deserialize, Debug)]
struct StreamingChoice {
delta: Delta,
#[serde(default)]
#[allow(unused)]
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct Delta {
#[serde(default)]
#[allow(unused)]
role: Option<String>,
#[serde(default)]
content: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
struct Message {
role: String,
content: String,
}
#[derive(Deserialize, Debug)]
struct ChatResponse {
#[allow(unused)]
choices: Vec<Choice>,
}
#[derive(Deserialize, Debug)]
struct Choice {
#[allow(unused)]
message: Message,
}