536 lines
18 KiB
Rust
536 lines
18 KiB
Rust
use std::{
|
|
io::{BufRead, BufReader},
|
|
sync::{Arc, Mutex},
|
|
};
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::editors::settings_editor::ProjectSettings;
|
|
|
|
#[derive(Clone)]
|
|
pub struct ContentAI {
|
|
pub open: bool,
|
|
|
|
// model input
|
|
pub content: String,
|
|
pub instruction: String,
|
|
pub context_override: String,
|
|
pub system_prompt: String,
|
|
|
|
// model settings
|
|
pub max_tokens: usize,
|
|
pub temperature: f32,
|
|
pub reasoning_effort: ReasoningEffort,
|
|
pub model_override: String,
|
|
|
|
// model output
|
|
pub reasoning: Arc<Mutex<String>>,
|
|
pub result: Arc<Mutex<String>>,
|
|
pub ready: Arc<Mutex<ReadyState>>,
|
|
}
|
|
|
|
impl ContentAI {
|
|
pub fn new(content: String) -> Self {
|
|
Self {
|
|
// model input
|
|
content,
|
|
instruction: String::new(),
|
|
context_override: String::new(),
|
|
system_prompt: String::new(),
|
|
|
|
// model settings
|
|
max_tokens: 2048,
|
|
reasoning_effort: ReasoningEffort::default(),
|
|
temperature: 0.7,
|
|
model_override: String::new(),
|
|
reasoning: Arc::new(Mutex::new(String::new())),
|
|
|
|
// output
|
|
result: Arc::new(Mutex::new(String::new())),
|
|
ready: Arc::new(Mutex::new(ReadyState::Idle)),
|
|
|
|
// ui
|
|
open: true,
|
|
}
|
|
}
|
|
|
|
pub fn ui(&mut self, ui: &mut egui::Ui, project: &mut ProjectSettings) {
|
|
let is_open = self.open;
|
|
|
|
if is_open {
|
|
egui::SidePanel::right("ai_assistant").show_inside(ui, |ui| {
|
|
Self::ui_main(self, ui, project);
|
|
});
|
|
}
|
|
self.open = is_open;
|
|
}
|
|
|
|
fn ui_output_box(&mut self, ui: &mut egui::Ui, project: &mut ProjectSettings) {
|
|
let mut ready_lock = self.ready.lock().unwrap();
|
|
|
|
ui.horizontal(|ui| {
|
|
if *ready_lock == ReadyState::Generating {
|
|
if ui.button("Cancel").clicked() {
|
|
*ready_lock = ReadyState::Halted;
|
|
}
|
|
if ui.button("Stop").clicked() {
|
|
*ready_lock = ReadyState::Idle;
|
|
}
|
|
ui.spinner();
|
|
ui.label("Generating...");
|
|
}
|
|
|
|
if *ready_lock == ReadyState::Idle {
|
|
let continue_content = || {
|
|
let content = self.content.clone();
|
|
let project = project.clone();
|
|
let result = self.result.clone();
|
|
let reasoning = self.reasoning.clone();
|
|
let ready = self.ready.clone();
|
|
|
|
let options = AIOptions {
|
|
max_completion_tokens: self.max_tokens,
|
|
reasoning_effort: self.reasoning_effort,
|
|
temperature: self.temperature,
|
|
model_override: if !self.model_override.is_empty() {
|
|
Some(self.model_override.clone())
|
|
} else {
|
|
None
|
|
},
|
|
};
|
|
|
|
let ai_input = AIInput {
|
|
system_prompt: self.system_prompt.clone(),
|
|
user_prompt: format!(
|
|
"{}\n\n{} {}",
|
|
self.instruction.clone(),
|
|
project.ai_context.clone(),
|
|
self.context_override.clone()
|
|
),
|
|
previous_content: content.clone(),
|
|
structure: None,
|
|
};
|
|
|
|
result.lock().unwrap().clear();
|
|
|
|
std::thread::spawn(move || {
|
|
let result = crate::llm_integration::content_llm::continue_content(
|
|
ai_input,
|
|
options,
|
|
project,
|
|
result,
|
|
reasoning,
|
|
ready.clone(),
|
|
);
|
|
if let Err(e) = result {
|
|
eprintln!("Error in content generation: {e}");
|
|
}
|
|
});
|
|
};
|
|
|
|
if ui.button("Generate ").clicked() {
|
|
continue_content();
|
|
}
|
|
|
|
ui.label("Idle");
|
|
}
|
|
|
|
// show regardless of state
|
|
if ui.button("Insert").clicked() {
|
|
*ready_lock = ReadyState::Ready;
|
|
}
|
|
|
|
if ui.button("Clear").clicked() {
|
|
self.result.lock().unwrap().clear();
|
|
self.reasoning.lock().unwrap().clear();
|
|
}
|
|
});
|
|
|
|
ui.spacing();
|
|
|
|
ui.vertical(|ui| {
|
|
egui::TopBottomPanel::top("reasoning_output")
|
|
.resizable(true)
|
|
.show_inside(ui, |ui| {
|
|
egui::ScrollArea::both()
|
|
.auto_shrink([false, true])
|
|
.id_salt("reasoning_output")
|
|
.max_width(ui.available_width())
|
|
// .max_height(ui.available_height() / 3.0)
|
|
.show(ui, |ui| {
|
|
ui.add(
|
|
egui::TextEdit::multiline(&mut *self.reasoning.lock().unwrap())
|
|
.font(egui::TextStyle::Monospace)
|
|
.interactive(false)
|
|
.desired_rows(5)
|
|
.frame(false)
|
|
.desired_width(ui.available_width())
|
|
.lock_focus(true)
|
|
.hint_text("Reasoning will appear here..."),
|
|
);
|
|
});
|
|
});
|
|
|
|
egui::ScrollArea::both()
|
|
.auto_shrink([false, false])
|
|
.id_salt("llm_output")
|
|
.max_width(ui.available_width())
|
|
.show(ui, |ui| {
|
|
ui.add(
|
|
egui::TextEdit::multiline(&mut *self.result.lock().unwrap())
|
|
.font(egui::TextStyle::Monospace)
|
|
.interactive(false)
|
|
.desired_rows(0)
|
|
.frame(false)
|
|
.desired_width(ui.available_width())
|
|
.lock_focus(true)
|
|
.hint_text("Content will appear here..."),
|
|
);
|
|
});
|
|
});
|
|
}
|
|
|
|
fn ui_main(&mut self, ui: &mut egui::Ui, project: &mut ProjectSettings) {
|
|
{
|
|
ui.weak("(The model will see current file content)");
|
|
|
|
egui::CollapsingHeader::new("Settings")
|
|
.default_open(true)
|
|
.show(ui, |ui| {
|
|
egui::Grid::new("continue_grid")
|
|
.num_columns(2)
|
|
.striped(true)
|
|
.show(ui, |ui| {
|
|
ui.label("Max Tokens");
|
|
ui.add(
|
|
egui::DragValue::new(&mut self.max_tokens)
|
|
.range(128..=u32::MAX)
|
|
.speed(128),
|
|
);
|
|
ui.end_row();
|
|
|
|
ui.label("Temperature");
|
|
ui.add(
|
|
egui::DragValue::new(&mut self.temperature)
|
|
.range(0.0..=2.0)
|
|
.speed(0.1),
|
|
);
|
|
|
|
ui.label("Reasoning effort");
|
|
|
|
egui::ComboBox::from_id_salt("reasoning_effort")
|
|
.selected_text(self.reasoning_effort.to_string())
|
|
.show_ui(ui, |ui| {
|
|
ui.selectable_value(
|
|
&mut self.reasoning_effort,
|
|
ReasoningEffort::Minimal,
|
|
"Minimal",
|
|
);
|
|
ui.selectable_value(
|
|
&mut self.reasoning_effort,
|
|
ReasoningEffort::Low,
|
|
"Low",
|
|
);
|
|
ui.selectable_value(
|
|
&mut self.reasoning_effort,
|
|
ReasoningEffort::Medium,
|
|
"Medium",
|
|
);
|
|
ui.selectable_value(
|
|
&mut self.reasoning_effort,
|
|
ReasoningEffort::High,
|
|
"High",
|
|
);
|
|
});
|
|
|
|
ui.end_row();
|
|
|
|
ui.label("Model override");
|
|
ui.add(egui::TextEdit::singleline(&mut self.model_override));
|
|
ui.end_row();
|
|
});
|
|
});
|
|
|
|
egui::TopBottomPanel::top("continue_instruction")
|
|
.resizable(true)
|
|
.show_inside(ui, |ui| {
|
|
egui::CollapsingHeader::new("Instructions")
|
|
.default_open(true)
|
|
.show(ui, |ui| {
|
|
egui::ScrollArea::vertical()
|
|
.auto_shrink([false, false])
|
|
.max_height(ui.available_height())
|
|
.show(ui, |ui| {
|
|
ui.add(
|
|
egui::TextEdit::multiline(&mut self.instruction)
|
|
.frame(false)
|
|
.desired_width(ui.available_width())
|
|
.hint_text("Writing Instructions"),
|
|
);
|
|
});
|
|
});
|
|
});
|
|
|
|
egui::TopBottomPanel::top("continue_context")
|
|
.resizable(true)
|
|
.show_inside(ui, |ui| {
|
|
egui::CollapsingHeader::new("Context")
|
|
.default_open(true)
|
|
.show(ui, |ui| {
|
|
egui::ScrollArea::vertical()
|
|
.auto_shrink([false, false])
|
|
.max_height(ui.available_height())
|
|
.show(ui, |ui| {
|
|
ui.add(
|
|
egui::TextEdit::multiline(&mut self.context_override)
|
|
.frame(false)
|
|
.desired_width(ui.available_width())
|
|
.hint_text("Any additional context?"),
|
|
);
|
|
});
|
|
});
|
|
});
|
|
|
|
egui::TopBottomPanel::top("continue_system_prompt")
|
|
.resizable(true)
|
|
.show_inside(ui, |ui| {
|
|
egui::CollapsingHeader::new("System prompt")
|
|
.default_open(true)
|
|
.show(ui, |ui| {
|
|
egui::ScrollArea::vertical()
|
|
.auto_shrink([false, false])
|
|
.max_height(ui.available_height())
|
|
.show(ui, |ui| {
|
|
ui.add(
|
|
egui::TextEdit::multiline(&mut self.system_prompt)
|
|
.frame(false)
|
|
.desired_width(ui.available_width())
|
|
.hint_text("System prompt"),
|
|
);
|
|
});
|
|
});
|
|
});
|
|
|
|
self.ui_output_box(ui, project);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn continue_content(
|
|
ai_input: AIInput,
|
|
// context: String,
|
|
// previous_content: String,
|
|
// instruction: String,
|
|
options: AIOptions,
|
|
project: ProjectSettings,
|
|
result: Arc<Mutex<String>>,
|
|
reasoning: Arc<Mutex<String>>,
|
|
ready: Arc<Mutex<ReadyState>>,
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|
*ready.lock().unwrap() = ReadyState::Generating;
|
|
|
|
let client = reqwest::blocking::Client::new();
|
|
|
|
let messages = vec![
|
|
Message {
|
|
role: "system".to_string(),
|
|
content: ai_input.system_prompt,
|
|
},
|
|
Message {
|
|
role: "user".to_string(),
|
|
content: format!(
|
|
"<Instructions> {}\n\n<Previous content> {}\n\n",
|
|
ai_input.user_prompt, ai_input.previous_content
|
|
),
|
|
},
|
|
];
|
|
|
|
let request = ChatRequest {
|
|
messages,
|
|
temperature: options.temperature,
|
|
max_tokens: options.max_completion_tokens,
|
|
model: options.model_override,
|
|
reasoning_effort: options.reasoning_effort,
|
|
stream: true,
|
|
};
|
|
|
|
let llm_api_uri = if let Some(uri) = project.local_overrides.llm_api_uri {
|
|
uri
|
|
} else {
|
|
project.global_settings.llm_api_uri.unwrap()
|
|
};
|
|
|
|
let api_key = if let Some(key) = project.local_overrides.llm_api_key {
|
|
if key.is_empty() { None } else { Some(key) }
|
|
} else if let Some(key) = project.global_settings.llm_api_key {
|
|
if key.is_empty() { None } else { Some(key) }
|
|
} else {
|
|
return Err("No API key found".into());
|
|
};
|
|
|
|
let response = if let Some(k) = api_key {
|
|
client
|
|
.post(llm_api_uri + "/api/v0/chat/completions")
|
|
.json(&request)
|
|
.bearer_auth(k)
|
|
.send()?
|
|
} else {
|
|
client
|
|
.post(llm_api_uri + "/api/v0/chat/completions")
|
|
.json(&request)
|
|
.send()?
|
|
};
|
|
|
|
println!("success!");
|
|
|
|
// println!("response: {}", response.text().unwrap());
|
|
let reader = BufReader::new(response);
|
|
for line in reader.lines() {
|
|
// initial loop to check if the user has terminated the generation
|
|
{
|
|
let mut ready = ready.lock().unwrap();
|
|
|
|
if *ready == ReadyState::Halted {
|
|
result.lock().unwrap().clear();
|
|
reasoning.lock().unwrap().clear();
|
|
}
|
|
|
|
if *ready != ReadyState::Generating {
|
|
*ready = ReadyState::Idle;
|
|
break;
|
|
}
|
|
}
|
|
|
|
let line = line?;
|
|
if line == "data: [DONE]" {
|
|
break;
|
|
}
|
|
|
|
if let Some(json) = line.strip_prefix("data: ") {
|
|
if let Ok(chunk) = serde_json::from_str::<StreamingChatResponse>(json) {
|
|
println!("chunk: {chunk:?}");
|
|
|
|
if let Some(content) = chunk.choices[0].delta.content.as_ref() {
|
|
println!("content: {content}");
|
|
result.lock().unwrap().push_str(content);
|
|
}
|
|
|
|
if let Some(reasoning_content) = chunk.choices[0].delta.reasoning_content.as_ref() {
|
|
println!("reasoning_content: {reasoning_content}");
|
|
reasoning.lock().unwrap().push_str(reasoning_content);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
*ready.lock().unwrap() = ReadyState::Idle;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub struct AIOptions {
|
|
pub max_completion_tokens: usize,
|
|
pub temperature: f32,
|
|
pub reasoning_effort: ReasoningEffort,
|
|
pub model_override: Option<String>,
|
|
}
|
|
|
|
pub struct AIInput {
|
|
pub system_prompt: String,
|
|
pub user_prompt: String,
|
|
pub previous_content: String,
|
|
pub structure: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Clone, Copy)]
|
|
pub enum ReadyState {
|
|
Idle,
|
|
Generating,
|
|
Ready,
|
|
Halted,
|
|
}
|
|
|
|
#[derive(Serialize, Copy, Clone, PartialEq, Default)]
|
|
pub enum ReasoningEffort {
|
|
#[serde(rename = "minimal")]
|
|
Minimal,
|
|
|
|
#[default]
|
|
#[serde(rename = "low")]
|
|
Low,
|
|
#[serde(rename = "medium")]
|
|
Medium,
|
|
#[serde(rename = "high")]
|
|
High,
|
|
}
|
|
|
|
impl std::fmt::Display for ReasoningEffort {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
ReasoningEffort::Minimal => write!(f, "Minimal"),
|
|
ReasoningEffort::Low => write!(f, "Low"),
|
|
ReasoningEffort::Medium => write!(f, "Medium"),
|
|
ReasoningEffort::High => write!(f, "High"),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Simple request structure
|
|
#[derive(Serialize)]
|
|
struct ChatRequest {
|
|
messages: Vec<Message>,
|
|
temperature: f32,
|
|
max_tokens: usize,
|
|
stream: bool,
|
|
reasoning_effort: ReasoningEffort,
|
|
|
|
// if we give the API model:null it returns 500
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
model: Option<String>,
|
|
}
|
|
|
|
// Streaming response structures
|
|
#[derive(Deserialize, Debug)]
|
|
struct StreamingChatResponse {
|
|
choices: Vec<StreamingChoice>,
|
|
}
|
|
|
|
#[derive(Deserialize, Debug)]
|
|
struct StreamingChoice {
|
|
delta: Delta,
|
|
#[serde(default)]
|
|
#[allow(unused)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize, Debug)]
|
|
struct Delta {
|
|
#[serde(default)]
|
|
#[allow(unused)]
|
|
role: Option<String>,
|
|
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
reasoning_content: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
struct Message {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Deserialize, Debug)]
|
|
struct ChatResponse {
|
|
#[allow(unused)]
|
|
choices: Vec<Choice>,
|
|
}
|
|
|
|
#[derive(Deserialize, Debug)]
|
|
struct Choice {
|
|
#[allow(unused)]
|
|
message: Message,
|
|
}
|