added streaming for AI content generation
Continuous integration / build (push) Failing after 6m27s

This commit is contained in:
2025-07-28 01:27:03 +01:00
parent 5294feb5ff
commit 6c40f34122
2 changed files with 127 additions and 71 deletions
+73 -57
View File
@@ -1,3 +1,5 @@
use std::sync::{Arc, Mutex};
use egui::TextEdit; use egui::TextEdit;
use egui_commonmark::{CommonMarkCache, CommonMarkViewer}; use egui_commonmark::{CommonMarkCache, CommonMarkViewer};
use serde::{self, Deserialize, Serialize}; use serde::{self, Deserialize, Serialize};
@@ -5,6 +7,7 @@ use serde::{self, Deserialize, Serialize};
use crate::{ use crate::{
PROJECT_FOLDER, PROJECT_FOLDER,
editors::{context_editor::ProjectContext, tags::Tag}, editors::{context_editor::ProjectContext, tags::Tag},
llm_integration::content_llm::ReadyState,
util, util,
}; };
@@ -59,8 +62,8 @@ pub enum ContentAI {
Summarise { Summarise {
open: bool, open: bool,
content: String, content: String,
result: String, result: Arc<Mutex<String>>,
ready: bool, ready: Arc<Mutex<ReadyState>>,
}, },
Continue { Continue {
open: bool, open: bool,
@@ -68,8 +71,8 @@ pub enum ContentAI {
instruction: String, instruction: String,
max_tokens: usize, max_tokens: usize,
context_override: String, context_override: String,
result: String, result: Arc<Mutex<String>>,
ready: bool, ready: Arc<Mutex<ReadyState>>,
}, },
} }
@@ -101,7 +104,7 @@ impl ContentAI {
); );
}); });
ui.add( ui.add(
egui::TextEdit::multiline(result) egui::TextEdit::multiline(&mut *result.lock().unwrap())
.font(egui::TextStyle::Monospace) .font(egui::TextStyle::Monospace)
.interactive(false) .interactive(false)
.frame(false) .frame(false)
@@ -109,8 +112,8 @@ impl ContentAI {
.hint_text("Summary will appear here..."), .hint_text("Summary will appear here..."),
); );
if ui.button("Summarise").clicked() { if ui.button("Summarise").clicked() {
*result = Self::summarise(content).unwrap(); // Self::summarise(content, result.clone());
*ready = true; *ready.lock().unwrap() = ReadyState::Generating;
} }
} }
ContentAI::Continue { ContentAI::Continue {
@@ -140,22 +143,60 @@ impl ContentAI {
ui.add(egui::Slider::new(max_tokens, 1000..=1000000)); ui.add(egui::Slider::new(max_tokens, 1000..=1000000));
ui.separator(); ui.separator();
if ui.button("Continue").clicked() { if ui.button("Continue").clicked() {
match Self::continue_content( let context_override = context_override.clone();
instruction, let content = content.clone();
*max_tokens, let instruction = instruction.clone();
context_override, let max_tokens = *max_tokens;
project, let project = project.clone();
content, let ai_context = project.ai_context_prompt.clone();
) { let result = result.clone();
Ok(str) => { let ready = ready.clone();
*result = str;
*ready = true; std::thread::spawn(move || {
*ready.lock().unwrap() = ReadyState::Generating;
let result = crate::llm_integration::content_llm::continue_content(
if context_override.is_empty() {
ai_context
} else {
context_override
},
content,
instruction,
max_tokens,
project,
result,
);
if let Err(e) = result {
eprintln!("Error in content generation: {e}");
} }
Err(err) => {
*result = format!("Error: {err}"); *ready.lock().unwrap() = ReadyState::Ready;
*ready = true; });
} }
}
if *ready.lock().unwrap() == ReadyState::Generating {
ui.horizontal(|ui| {
ui.spinner();
ui.label("Generating...");
});
egui::ScrollArea::both()
.auto_shrink([false, false])
.max_width(ui.available_width())
.show(ui, |ui| {
ui.add(
egui::TextEdit::multiline(&mut *result.lock().unwrap())
.font(egui::TextStyle::Monospace)
.interactive(false)
.frame(false)
.desired_width(ui.available_width())
.lock_focus(true)
.hint_text("Content will appear here..."),
);
});
} else if *ready.lock().unwrap() == ReadyState::Idle {
ui.label("Idle");
} }
} }
}); });
@@ -166,31 +207,6 @@ impl ContentAI {
ContentAI::Continue { open, .. } => *open = is_open, ContentAI::Continue { open, .. } => *open = is_open,
}; };
} }
fn summarise(_content: &str) -> Result<String, Box<dyn std::error::Error>> {
// crate::llm_integration::content_llm::summarise_content(content, result)
Ok(String::new())
}
fn continue_content(
instruction: &str,
max_tokens: usize,
context_override: &str,
project: &mut ProjectContext,
content: &mut str,
) -> Result<String, Box<dyn std::error::Error>> {
crate::llm_integration::content_llm::continue_content(
if context_override.is_empty() {
&project.ai_context_prompt
} else {
context_override
},
content,
instruction,
max_tokens,
project,
)
}
} }
impl ContentSection { impl ContentSection {
@@ -271,17 +287,17 @@ impl MainEditor {
match dialog { match dialog {
ContentAI::Summarise { ready, result, .. } => { ContentAI::Summarise { ready, result, .. } => {
if *ready { if *ready.lock().unwrap() == ReadyState::Ready {
self.content.content.push_str(result.as_str()); self.content.content.push_str(&result.lock().unwrap());
self.content.saved = false; self.content.saved = false;
*ready = false; *ready.lock().unwrap() = ReadyState::Idle;
} }
} }
ContentAI::Continue { ready, result, .. } => { ContentAI::Continue { ready, result, .. } => {
if *ready { if *ready.lock().unwrap() == ReadyState::Ready {
self.content.content.push_str(result.as_str()); self.content.content.push_str(&result.lock().unwrap());
self.content.saved = false; self.content.saved = false;
*ready = false; *ready.lock().unwrap() = ReadyState::Idle;
} }
} }
} }
@@ -472,10 +488,10 @@ impl MainEditor {
ui.add_enabled_ui(project.ai_enabled(), |ui| { ui.add_enabled_ui(project.ai_enabled(), |ui| {
if ui.button("Summarise").clicked() { if ui.button("Summarise").clicked() {
self.dialog = Some(ContentAI::Summarise { self.dialog = Some(ContentAI::Summarise {
result: String::new(), result: Arc::new(Mutex::new(String::new())),
content: self.content.content.clone(), content: self.content.content.clone(),
open: true, open: true,
ready: false, ready: Arc::new(Mutex::new(ReadyState::Idle)),
}); });
} }
@@ -485,9 +501,9 @@ impl MainEditor {
instruction: String::new(), instruction: String::new(),
max_tokens: 1024, max_tokens: 1024,
context_override: "".to_string(), context_override: "".to_string(),
result: String::new(), result: Arc::new(Mutex::new(String::new())),
open: true, open: true,
ready: false, ready: Arc::new(Mutex::new(ReadyState::Idle)),
}); });
} }
}); });
+54 -14
View File
@@ -1,14 +1,20 @@
use std::{
io::{BufRead, BufReader},
sync::{Arc, Mutex},
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::editors::context_editor::ProjectContext; use crate::editors::context_editor::ProjectContext;
pub fn continue_content( pub fn continue_content(
context: &str, context: String,
previous_content: &str, previous_content: String,
instruction: &str, instruction: String,
max_tokens: usize, max_tokens: usize,
project: &ProjectContext, project: ProjectContext,
) -> Result<String, Box<dyn std::error::Error>> { result: Arc<Mutex<String>>,
) -> Result<(), Box<dyn std::error::Error>> {
let client = reqwest::blocking::Client::new(); let client = reqwest::blocking::Client::new();
let messages = vec![ let messages = vec![
@@ -39,7 +45,7 @@ pub fn continue_content(
messages, messages,
temperature: 0.7, temperature: 0.7,
max_tokens, max_tokens,
stream: false, stream: true,
}; };
let response = client let response = client
@@ -47,17 +53,30 @@ pub fn continue_content(
.json(&request) .json(&request)
.send()?; .send()?;
if !response.status().is_success() { let reader = BufReader::new(response);
return Err(format!("Request failed: {}", response.text()?).into());
for line in reader.lines() {
let line = line?;
if line == "data: [DONE]" {
break;
}
if let Some(json) = line.strip_prefix("data: ") {
if let Ok(chunk) = serde_json::from_str::<StreamingChatResponse>(json) {
if let Some(content) = chunk.choices[0].delta.content.as_ref() {
result.lock().unwrap().push_str(content);
}
}
}
} }
let response: ChatResponse = response.json()?; Ok(())
}
if let Some(choice) = response.choices.into_iter().next() { #[derive(Debug, PartialEq, Clone, Copy)]
Ok(choice.message.content) pub enum ReadyState {
} else { Idle,
Err("No response from model".into()) Generating,
} Ready,
} }
// Simple request structure // Simple request structure
@@ -69,6 +88,27 @@ struct ChatRequest {
stream: bool, stream: bool,
} }
// Streaming response structures
#[derive(Deserialize, Debug)]
struct StreamingChatResponse {
choices: Vec<StreamingChoice>,
}
#[derive(Deserialize, Debug)]
struct StreamingChoice {
delta: Delta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct Delta {
#[serde(default)]
role: Option<String>,
#[serde(default)]
content: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
struct Message { struct Message {
role: String, role: String,