From 6c40f3412250e0404d3b136a321257f7bb48ed31 Mon Sep 17 00:00:00 2001 From: zxq5 Date: Mon, 28 Jul 2025 01:27:03 +0100 Subject: [PATCH] added streaming for AI content generation --- src/editors/content_editor.rs | 130 ++++++++++++++++------------- src/llm_integration/content_llm.rs | 68 +++++++++++---- 2 files changed, 127 insertions(+), 71 deletions(-) diff --git a/src/editors/content_editor.rs b/src/editors/content_editor.rs index 958a7bf..ad6929f 100644 --- a/src/editors/content_editor.rs +++ b/src/editors/content_editor.rs @@ -1,3 +1,5 @@ +use std::sync::{Arc, Mutex}; + use egui::TextEdit; use egui_commonmark::{CommonMarkCache, CommonMarkViewer}; use serde::{self, Deserialize, Serialize}; @@ -5,6 +7,7 @@ use serde::{self, Deserialize, Serialize}; use crate::{ PROJECT_FOLDER, editors::{context_editor::ProjectContext, tags::Tag}, + llm_integration::content_llm::ReadyState, util, }; @@ -59,8 +62,8 @@ pub enum ContentAI { Summarise { open: bool, content: String, - result: String, - ready: bool, + result: Arc>, + ready: Arc>, }, Continue { open: bool, @@ -68,8 +71,8 @@ pub enum ContentAI { instruction: String, max_tokens: usize, context_override: String, - result: String, - ready: bool, + result: Arc>, + ready: Arc>, }, } @@ -101,7 +104,7 @@ impl ContentAI { ); }); ui.add( - egui::TextEdit::multiline(result) + egui::TextEdit::multiline(&mut *result.lock().unwrap()) .font(egui::TextStyle::Monospace) .interactive(false) .frame(false) @@ -109,8 +112,8 @@ impl ContentAI { .hint_text("Summary will appear here..."), ); if ui.button("Summarise").clicked() { - *result = Self::summarise(content).unwrap(); - *ready = true; + // Self::summarise(content, result.clone()); + *ready.lock().unwrap() = ReadyState::Generating; } } ContentAI::Continue { @@ -140,22 +143,60 @@ impl ContentAI { ui.add(egui::Slider::new(max_tokens, 1000..=1000000)); ui.separator(); if ui.button("Continue").clicked() { - match Self::continue_content( - instruction, - *max_tokens, - context_override, - project, - content, - ) { - Ok(str) => { - *result = str; - *ready = true; + let context_override = context_override.clone(); + let content = content.clone(); + let instruction = instruction.clone(); + let max_tokens = *max_tokens; + let project = project.clone(); + let ai_context = project.ai_context_prompt.clone(); + let result = result.clone(); + let ready = ready.clone(); + + std::thread::spawn(move || { + *ready.lock().unwrap() = ReadyState::Generating; + + let result = crate::llm_integration::content_llm::continue_content( + if context_override.is_empty() { + ai_context + } else { + context_override + }, + content, + instruction, + max_tokens, + project, + result, + ); + if let Err(e) = result { + eprintln!("Error in content generation: {e}"); } - Err(err) => { - *result = format!("Error: {err}"); - *ready = true; - } - } + + *ready.lock().unwrap() = ReadyState::Ready; + }); + } + + if *ready.lock().unwrap() == ReadyState::Generating { + ui.horizontal(|ui| { + ui.spinner(); + ui.label("Generating..."); + }); + + egui::ScrollArea::both() + .auto_shrink([false, false]) + .max_width(ui.available_width()) + .show(ui, |ui| { + ui.add( + egui::TextEdit::multiline(&mut *result.lock().unwrap()) + .font(egui::TextStyle::Monospace) + .interactive(false) + .frame(false) + .desired_width(ui.available_width()) + .lock_focus(true) + .hint_text("Content will appear here..."), + ); + }); + } else if *ready.lock().unwrap() == ReadyState::Idle { + ui.label("Idle"); } } }); @@ -166,31 +207,6 @@ impl ContentAI { ContentAI::Continue { open, .. } => *open = is_open, }; } - - fn summarise(_content: &str) -> Result> { - // crate::llm_integration::content_llm::summarise_content(content, result) - Ok(String::new()) - } - - fn continue_content( - instruction: &str, - max_tokens: usize, - context_override: &str, - project: &mut ProjectContext, - content: &mut str, - ) -> Result> { - crate::llm_integration::content_llm::continue_content( - if context_override.is_empty() { - &project.ai_context_prompt - } else { - context_override - }, - content, - instruction, - max_tokens, - project, - ) - } } impl ContentSection { @@ -271,17 +287,17 @@ impl MainEditor { match dialog { ContentAI::Summarise { ready, result, .. } => { - if *ready { - self.content.content.push_str(result.as_str()); + if *ready.lock().unwrap() == ReadyState::Ready { + self.content.content.push_str(&result.lock().unwrap()); self.content.saved = false; - *ready = false; + *ready.lock().unwrap() = ReadyState::Idle; } } ContentAI::Continue { ready, result, .. } => { - if *ready { - self.content.content.push_str(result.as_str()); + if *ready.lock().unwrap() == ReadyState::Ready { + self.content.content.push_str(&result.lock().unwrap()); self.content.saved = false; - *ready = false; + *ready.lock().unwrap() = ReadyState::Idle; } } } @@ -472,10 +488,10 @@ impl MainEditor { ui.add_enabled_ui(project.ai_enabled(), |ui| { if ui.button("Summarise").clicked() { self.dialog = Some(ContentAI::Summarise { - result: String::new(), + result: Arc::new(Mutex::new(String::new())), content: self.content.content.clone(), open: true, - ready: false, + ready: Arc::new(Mutex::new(ReadyState::Idle)), }); } @@ -485,9 +501,9 @@ impl MainEditor { instruction: String::new(), max_tokens: 1024, context_override: "".to_string(), - result: String::new(), + result: Arc::new(Mutex::new(String::new())), open: true, - ready: false, + ready: Arc::new(Mutex::new(ReadyState::Idle)), }); } }); diff --git a/src/llm_integration/content_llm.rs b/src/llm_integration/content_llm.rs index a16ff1d..ddf02a7 100644 --- a/src/llm_integration/content_llm.rs +++ b/src/llm_integration/content_llm.rs @@ -1,14 +1,20 @@ +use std::{ + io::{BufRead, BufReader}, + sync::{Arc, Mutex}, +}; + use serde::{Deserialize, Serialize}; use crate::editors::context_editor::ProjectContext; pub fn continue_content( - context: &str, - previous_content: &str, - instruction: &str, + context: String, + previous_content: String, + instruction: String, max_tokens: usize, - project: &ProjectContext, -) -> Result> { + project: ProjectContext, + result: Arc>, +) -> Result<(), Box> { let client = reqwest::blocking::Client::new(); let messages = vec![ @@ -39,7 +45,7 @@ pub fn continue_content( messages, temperature: 0.7, max_tokens, - stream: false, + stream: true, }; let response = client @@ -47,17 +53,30 @@ pub fn continue_content( .json(&request) .send()?; - if !response.status().is_success() { - return Err(format!("Request failed: {}", response.text()?).into()); + let reader = BufReader::new(response); + + for line in reader.lines() { + let line = line?; + if line == "data: [DONE]" { + break; + } + if let Some(json) = line.strip_prefix("data: ") { + if let Ok(chunk) = serde_json::from_str::(json) { + if let Some(content) = chunk.choices[0].delta.content.as_ref() { + result.lock().unwrap().push_str(content); + } + } + } } - let response: ChatResponse = response.json()?; + Ok(()) +} - if let Some(choice) = response.choices.into_iter().next() { - Ok(choice.message.content) - } else { - Err("No response from model".into()) - } +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum ReadyState { + Idle, + Generating, + Ready, } // Simple request structure @@ -69,6 +88,27 @@ struct ChatRequest { stream: bool, } +// Streaming response structures +#[derive(Deserialize, Debug)] +struct StreamingChatResponse { + choices: Vec, +} + +#[derive(Deserialize, Debug)] +struct StreamingChoice { + delta: Delta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +struct Delta { + #[serde(default)] + role: Option, + #[serde(default)] + content: Option, +} + #[derive(Serialize, Deserialize, Debug)] struct Message { role: String,