diff --git a/lumni/src/apps/builtin/llm/prompt/src/app.rs b/lumni/src/apps/builtin/llm/prompt/src/app.rs index ea3b918f..ef127f8a 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/app.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/app.rs @@ -95,7 +95,7 @@ async fn prompt_app( Some(WindowEvent::Prompt(prompt_action)) => { match prompt_action { PromptAction::Write(prompt) => { - send_prompt(tab, &db_conn, &prompt, &color_scheme, tx.clone()).await?; + send_prompt(tab, &prompt, &color_scheme, tx.clone()).await?; } PromptAction::Clear => { tab.ui.response.text_empty(); @@ -188,7 +188,7 @@ async fn prompt_app( let display_content = format!("{}{}", trim_buffer.unwrap_or("".to_string()), trimmed_response); if !display_content.is_empty() { - chat.update_last_exchange(&db_conn, &display_content); + chat.update_last_exchange(&display_content); tab_ui.response.text_append_with_insert(&display_content, Some(color_scheme.get_secondary_style())); } @@ -375,7 +375,7 @@ async fn interactive_mode( async fn process_non_interactive_input( chat: ChatSession, - db_conn: ConversationDatabase, + _db_conn: ConversationDatabase, ) -> Result<(), ApplicationError> { let chat = Arc::new(Mutex::new(chat)); let stdin = tokio::io::stdin(); @@ -409,7 +409,7 @@ async fn process_non_interactive_input( // Process the prompt let process_handle = tokio::spawn(async move { let mut chat = chat_clone.lock().await; - chat.process_prompt(&db_conn, input, running.clone()).await + chat.process_prompt(input, running.clone()).await }); // Wait for the process to complete or for a shutdown signal @@ -491,17 +491,13 @@ async fn handle_ctrl_c(r: Arc>, s: Arc>) { async fn send_prompt<'a>( tab: &mut TabSession<'a>, - db_conn: &ConversationDatabase, prompt: &str, color_scheme: &ColorScheme, tx: mpsc::Sender, ) -> Result<(), ApplicationError> { // prompt should end with single newline let formatted_prompt = format!("{}\n", prompt.trim_end()); - let result = tab - .chat - .message(tx.clone(), db_conn, &formatted_prompt) - .await; + let result = tab.chat.message(tx.clone(), &formatted_prompt).await; match result { Ok(_) => { diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs index 55bd8d1c..013aac01 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs @@ -8,8 +8,8 @@ mod schema; mod store; pub use schema::{ - Attachment, Conversation, ConversationId, Exchange, - ExchangeId, ConversationCache, Message, ModelId, + Attachment, Conversation, ConversationCache, ConversationId, Exchange, + ExchangeId, Message, ModelId, }; pub use store::ConversationDatabaseStore; @@ -17,7 +17,6 @@ pub use super::PromptRole; pub struct ConversationDatabase { pub store: Arc>, - pub cache: Arc>, } impl ConversationDatabase { @@ -26,7 +25,6 @@ impl ConversationDatabase { store: Arc::new(Mutex::new(ConversationDatabaseStore::new( sqlite_file, )?)), - cache: Arc::new(Mutex::new(ConversationCache::new())), }) } @@ -54,8 +52,8 @@ impl ConversationDatabase { pub fn finalize_exchange( &self, exchange: &Exchange, + cache: &ConversationCache, ) -> Result<(), SqliteError> { - let cache = self.cache.lock().unwrap(); let messages = cache.get_exchange_messages(exchange.id); let attachments = messages .iter() diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs index 41cd67db..c4ad6efd 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs @@ -89,6 +89,7 @@ pub struct Attachment { #[derive(Debug)] pub struct ConversationCache { + conversation_id: ConversationId, models: HashMap, exchanges: Vec, messages: HashMap, @@ -100,6 +101,7 @@ pub struct ConversationCache { impl ConversationCache { pub fn new() -> Self { ConversationCache { + conversation_id: ConversationId(0), models: HashMap::new(), exchanges: Vec::new(), messages: HashMap::new(), @@ -109,6 +111,14 @@ impl ConversationCache { } } + pub fn get_conversation_id(&self) -> ConversationId { + self.conversation_id + } + + pub fn set_conversation_id(&mut self, conversation_id: ConversationId) { + self.conversation_id = conversation_id; + } + pub fn new_exchange_id(&self) -> ExchangeId { ExchangeId(self.exchanges.len() as i64) } diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs index b141cd9f..4cae6ced 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs @@ -4,8 +4,7 @@ use rusqlite::Error as SqliteError; use super::connector::DatabaseConnector; use super::schema::{ - Attachment, AttachmentData, Conversation, ConversationId, - Exchange, Message, + Attachment, AttachmentData, Conversation, ConversationId, Exchange, Message, }; pub struct ConversationDatabaseStore { diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs index 8307864a..4718d19f 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs @@ -1,10 +1,8 @@ -use std::sync::MutexGuard; - use lumni::api::error::ApplicationError; use super::db::{ - ConversationDatabase, ConversationId, Exchange, ExchangeId, - ConversationCache, Message, ModelId, + ConversationCache, ConversationDatabase, ConversationId, Exchange, + ExchangeId, Message, ModelId, }; use super::prompt::Prompt; use super::{ @@ -14,11 +12,11 @@ use super::{ pub use crate::external as lumni; pub struct PromptInstruction { + cache: ConversationCache, completion_options: ChatCompletionOptions, prompt_options: PromptOptions, // TODO: get from db system_prompt: SystemPrompt, prompt_template: Option, - current_conversation_id: ConversationId, } impl Default for PromptInstruction { @@ -30,11 +28,11 @@ impl Default for PromptInstruction { .set_stream(true); PromptInstruction { + cache: ConversationCache::new(), completion_options, prompt_options: PromptOptions::default(), system_prompt: SystemPrompt::default(), prompt_template: None, - current_conversation_id: ConversationId(0), } } } @@ -69,17 +67,15 @@ impl PromptInstruction { prompt_instruction.preload_from_assistant( assistant, instruction, // add user-instruction with assistant - db_conn, )?; } else if let Some(instruction) = instruction { prompt_instruction.set_system_prompt(instruction); }; // Create a new Conversation in the database - let conversation_id = { - db_conn.new_conversation("New Conversation", None)? - }; - prompt_instruction.current_conversation_id = conversation_id; + let conversation_id = + { db_conn.new_conversation("New Conversation", None)? }; + prompt_instruction.cache.set_conversation_id(conversation_id); Ok(prompt_instruction) } @@ -89,28 +85,21 @@ impl PromptInstruction { db_conn: &ConversationDatabase, ) -> Result<(), ApplicationError> { // reset by creating a new conversation - self.current_conversation_id = + let current_conversation_id = db_conn.new_conversation("New Conversation", None)?; + self.cache.set_conversation_id(current_conversation_id); Ok(()) } - pub fn append_last_response( - &mut self, - answer: &str, - db_conn: &ConversationDatabase, - ) { + pub fn append_last_response(&mut self, answer: &str) { ExchangeHandler::append_response( - &mut db_conn.cache.lock().unwrap(), - self.current_conversation_id, + &mut self.cache, answer, ); } - pub fn get_last_response( - &self, - db_conn: &ConversationDatabase, - ) -> Option { - ExchangeHandler::get_last_response(&db_conn.cache.lock().unwrap()) + pub fn get_last_response(&mut self) -> Option { + ExchangeHandler::get_last_response(&mut self.cache) } pub fn put_last_response( @@ -120,19 +109,19 @@ impl PromptInstruction { db_conn: &ConversationDatabase, ) { let exchange = ExchangeHandler::put_last_response( - db_conn, + &mut self.cache, answer, tokens_predicted, ); if let Some(exchange) = exchange { - let _result = db_conn.finalize_exchange(&exchange); + let _result = db_conn.finalize_exchange(&exchange, &self.cache); } } fn first_exchange(&self) -> Exchange { Exchange { id: ExchangeId(0), - conversation_id: self.current_conversation_id, + conversation_id: self.cache.get_conversation_id(), model_id: ModelId(0), system_prompt: Some(self.system_prompt.instruction.clone()), completion_options: serde_json::to_value(&self.completion_options) @@ -146,14 +135,11 @@ impl PromptInstruction { } } - pub fn subsequent_exchange( - &mut self, - cache: &mut MutexGuard, - ) -> Exchange { - if let Some(last) = cache.get_last_exchange() { + pub fn subsequent_exchange(&mut self) -> Exchange { + if let Some(last) = self.cache.get_last_exchange() { Exchange { - id: cache.new_exchange_id(), - conversation_id: self.current_conversation_id, + id: self.cache.new_exchange_id(), + conversation_id: self.cache.get_conversation_id(), model_id: last.model_id, system_prompt: last.system_prompt.clone(), completion_options: last.completion_options.clone(), @@ -169,8 +155,8 @@ impl PromptInstruction { let exchange = self.first_exchange(); // add system prompt let system_message = Message { - id: cache.new_message_id(), - conversation_id: self.current_conversation_id, + id: self.cache.new_message_id(), + conversation_id: self.cache.get_conversation_id(), exchange_id: exchange.id, role: PromptRole::System, message_type: "text".to_string(), @@ -184,12 +170,12 @@ impl PromptInstruction { is_deleted: false, }; // add first exchange including system prompt message - cache.add_message(system_message); - cache.add_exchange(exchange.clone()); + self.cache.add_message(system_message); + self.cache.add_exchange(exchange.clone()); // return subsequent exchange Exchange { - id: cache.new_exchange_id(), + id: self.cache.new_exchange_id(), prompt_tokens: None, completion_tokens: None, previous_exchange_id: Some(exchange.id), @@ -203,18 +189,16 @@ impl PromptInstruction { question: &str, token_length: Option, max_token_length: usize, - db_conn: &ConversationDatabase, ) -> Vec { // token budget for the system prompt let system_prompt_token_length = self.get_n_keep().unwrap_or(0); // add the partial exchange (question) to the conversation - let mut cache = db_conn.cache.lock().unwrap(); - let exchange = self.subsequent_exchange(&mut cache); + let exchange = self.subsequent_exchange(); let user_message = Message { - id: cache.new_message_id(), - conversation_id: self.current_conversation_id, + id: self.cache.new_message_id(), + conversation_id: self.cache.get_conversation_id(), exchange_id: exchange.id, role: PromptRole::User, message_type: "text".to_string(), @@ -224,11 +208,11 @@ impl PromptInstruction { created_at: 0, is_deleted: false, }; - cache.add_exchange(exchange); + self.cache.add_exchange(exchange); // new prompt only has user message, answer is not yet generated - cache.add_message(user_message); + self.cache.add_message(user_message); - let current_exchanges = cache.get_exchanges(); + let current_exchanges = self.cache.get_exchanges(); // Collect messages while respecting token limits let mut messages: Vec = Vec::new(); @@ -236,8 +220,11 @@ impl PromptInstruction { // Add messages from most recent to oldest, respecting token limit for exchange in current_exchanges.into_iter().rev() { - for msg in - cache.get_exchange_messages(exchange.id).into_iter().rev() + for msg in self + .cache + .get_exchange_messages(exchange.id) + .into_iter() + .rev() { let msg_token_length = msg.token_length.map(|len| len as usize).unwrap_or(0); @@ -322,7 +309,6 @@ impl PromptInstruction { &mut self, assistant: String, user_instruction: Option, - db_conn: &ConversationDatabase, ) -> Result<(), ApplicationError> { let assistant_prompts: Vec = serde_yaml::from_str(PERSONAS) .map_err(|e| { @@ -346,13 +332,11 @@ impl PromptInstruction { self.set_system_prompt(system_prompt.clone()); if let Some(exchanges) = prompt.exchanges() { - let mut cache = db_conn.cache.lock().unwrap(); - // Create a new exchange with the system prompt let exchange = self.first_exchange(); let system_message = Message { - id: cache.new_message_id(), - conversation_id: self.current_conversation_id, + id: self.cache.new_message_id(), + conversation_id: self.cache.get_conversation_id(), exchange_id: exchange.id, role: PromptRole::System, message_type: "text".to_string(), @@ -362,15 +346,15 @@ impl PromptInstruction { created_at: 0, is_deleted: false, }; - cache.add_message(system_message); - cache.add_exchange(exchange); + self.cache.add_message(system_message); + self.cache.add_exchange(exchange); for loaded_exchange in exchanges.iter() { - let exchange = self.subsequent_exchange(&mut cache); + let exchange = self.subsequent_exchange(); let user_message = Message { - id: cache.new_message_id(), - conversation_id: self.current_conversation_id, + id: self.cache.new_message_id(), + conversation_id: self.cache.get_conversation_id(), exchange_id: exchange.id, role: PromptRole::User, message_type: "text".to_string(), @@ -381,8 +365,8 @@ impl PromptInstruction { is_deleted: false, }; let assistant_message = Message { - id: cache.new_message_id(), - conversation_id: self.current_conversation_id, + id: self.cache.new_message_id(), + conversation_id: self.cache.get_conversation_id(), exchange_id: exchange.id, role: PromptRole::Assistant, message_type: "text".to_string(), @@ -392,9 +376,9 @@ impl PromptInstruction { created_at: 0, // Use proper timestamp is_deleted: false, }; - cache.add_message(user_message); - cache.add_message(assistant_message); - cache.add_exchange(exchange); + self.cache.add_message(user_message); + self.cache.add_message(assistant_message); + self.cache.add_exchange(exchange); } } @@ -426,8 +410,7 @@ pub struct ExchangeHandler; impl ExchangeHandler { pub fn append_response( - cache: &mut MutexGuard, - current_conversation_id: ConversationId, + cache: &mut ConversationCache, answer: &str, ) { let last_exchange = cache.get_last_exchange(); @@ -447,7 +430,7 @@ impl ExchangeHandler { // If the last message is from User or there's no message, create a new Assistant message let new_message = Message { id: cache.new_message_id(), - conversation_id: current_conversation_id, + conversation_id: cache.get_conversation_id(), exchange_id: exchange.id, role: PromptRole::Assistant, message_type: "text".to_string(), @@ -468,13 +451,11 @@ impl ExchangeHandler { } } - pub fn get_last_response( - db_lock: &MutexGuard, - ) -> Option { - db_lock + pub fn get_last_response(cache: &mut ConversationCache) -> Option { + cache .get_last_exchange() .and_then(|last_exchange| { - db_lock.get_last_message_of_exchange(last_exchange.id) + cache.get_last_message_of_exchange(last_exchange.id) }) .and_then(|last_message| { if last_message.role == PromptRole::Assistant { @@ -486,13 +467,13 @@ impl ExchangeHandler { } pub fn put_last_response( - db_conn: &ConversationDatabase, + cache: &mut ConversationCache, answer: &str, tokens_predicted: Option, ) -> Option { // Capture the necessary message and exchange IDs in a separate scope let (message_id, exchange) = { - let cache = db_conn.cache.lock().unwrap(); + //let cache = db_conn.cache.lock().unwrap(); let last_exchange = cache.get_last_exchange()?; let last_message = cache.get_last_message_of_exchange(last_exchange.id)?; @@ -509,7 +490,7 @@ impl ExchangeHandler { if let (Some(message_id), Some(exchange)) = (message_id, exchange) { // Perform the update in a separate cache lock scope let token_length = tokens_predicted.map(|t| t as i32); - let mut cache = db_conn.cache.lock().unwrap(); + //let mut cache = db_conn.cache.lock().unwrap(); cache.update_message_by_id(message_id, answer, token_length); Some(exchange) } else { @@ -524,7 +505,7 @@ struct SystemPrompt { } impl SystemPrompt { - pub fn default() -> Self { + fn default() -> Self { SystemPrompt { instruction: "".to_string(), token_length: Some(0), diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs index 35fc9571..96ad6094 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs @@ -67,15 +67,11 @@ impl ChatSession { pub fn reset(&mut self, db: &ConversationDatabase) { self.stop(); - self.prompt_instruction.reset_history(db); + _ = self.prompt_instruction.reset_history(db); } - pub fn update_last_exchange( - &mut self, - db: &ConversationDatabase, - answer: &str, - ) { - self.prompt_instruction.append_last_response(answer, db); + pub fn update_last_exchange(&mut self, answer: &str) { + self.prompt_instruction.append_last_response(answer); } pub async fn finalize_last_exchange( @@ -83,7 +79,7 @@ impl ChatSession { db: &ConversationDatabase, _tokens_predicted: Option, ) -> Result<(), ApplicationError> { - let last_answer = self.prompt_instruction.get_last_response(db); + let last_answer = self.prompt_instruction.get_last_response(); if let Some(last_answer) = last_answer { let trimmed_answer = last_answer.trim(); @@ -106,7 +102,6 @@ impl ChatSession { pub async fn message( &mut self, tx: mpsc::Sender, - db: &ConversationDatabase, question: &str, ) -> Result<(), ApplicationError> { let max_token_length = self @@ -119,7 +114,6 @@ impl ChatSession { &user_question, token_length, max_token_length, - db, ); let (cancel_tx, cancel_rx) = oneshot::channel(); @@ -178,12 +172,11 @@ impl ChatSession { // used in non-interactive mode pub async fn process_prompt( &mut self, - db: &ConversationDatabase, question: String, stop_signal: Arc>, ) -> Result<(), ApplicationError> { let (tx, rx) = mpsc::channel(32); - let _ = self.message(tx, db, &question).await; + let _ = self.message(tx, &question).await; self.handle_response(rx, stop_signal).await?; self.stop(); Ok(())