From ce10c259c521dcd34a97ef392f331fc8e6fcb973 Mon Sep 17 00:00:00 2001 From: Anthony Potappel Date: Tue, 9 Jul 2024 22:50:47 +0200 Subject: [PATCH] wip - remove ChatHistory/ Exchange structs in favor of new sql aligned db struct --- lumni/Cargo.toml | 1 + .../builtin/llm/prompt/src/chat/exchange.rs | 42 -- .../builtin/llm/prompt/src/chat/history.rs | 149 ------- .../llm/prompt/src/chat/instruction.rs | 397 +++++++++++++++--- .../apps/builtin/llm/prompt/src/chat/mod.rs | 12 +- .../builtin/llm/prompt/src/chat/options.rs | 4 +- .../builtin/llm/prompt/src/chat/prompt.rs | 9 +- .../builtin/llm/prompt/src/chat/schema.rs | 134 +++++- .../builtin/llm/prompt/src/chat/schema.sql | 2 + .../builtin/llm/prompt/src/chat/session.rs | 73 ++-- .../builtin/llm/prompt/src/model/formatter.rs | 8 +- .../builtin/llm/prompt/src/model/llama3.rs | 4 +- .../llm/prompt/src/server/bedrock/mod.rs | 43 +- .../llm/prompt/src/server/llama/mod.rs | 20 +- .../apps/builtin/llm/prompt/src/server/mod.rs | 17 +- .../llm/prompt/src/server/ollama/mod.rs | 37 +- .../llm/prompt/src/server/openai/mod.rs | 31 +- .../llm/prompt/src/server/openai/request.rs | 10 +- .../prompt/src/tui/components/text_buffer.rs | 4 +- .../prompt/src/tui/components/text_window.rs | 3 +- .../apps/builtin/llm/prompt/src/tui/draw.rs | 5 +- 21 files changed, 607 insertions(+), 398 deletions(-) delete mode 100644 lumni/src/apps/builtin/llm/prompt/src/chat/exchange.rs delete mode 100644 lumni/src/apps/builtin/llm/prompt/src/chat/history.rs diff --git a/lumni/Cargo.toml b/lumni/Cargo.toml index 8f479faf..4bc3391e 100644 --- a/lumni/Cargo.toml +++ b/lumni/Cargo.toml @@ -48,6 +48,7 @@ libc = "0.2" tiktoken-rs = "0.5.9" syntect = { version = "5.2.0", default-features = false, features = ["parsing", "default-fancy"] } crc32fast = { version = "1.4" } +rusqlite = { version = "0.31" } # CLI env_logger = { version = "0.9", optional = true } diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/exchange.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/exchange.rs deleted file mode 100644 index 994e3941..00000000 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/exchange.rs +++ /dev/null @@ -1,42 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ChatExchange { - question: String, - answer: String, - token_length: Option, -} - -impl ChatExchange { - pub fn new(question: String, answer: String) -> Self { - ChatExchange { - question, - answer, - token_length: None, - } - } - - pub fn get_question(&self) -> &str { - &self.question - } - - pub fn get_answer(&self) -> &str { - &self.answer - } - - pub fn set_answer(&mut self, answer: String) { - self.answer = answer; - } - - pub fn push_to_answer(&mut self, text: &str) { - self.answer.push_str(text); - } - - pub fn get_token_length(&self) -> Option { - self.token_length - } - - pub fn set_token_length(&mut self, token_length: usize) { - self.token_length = Some(token_length); - } -} diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/history.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/history.rs deleted file mode 100644 index 72c3e8eb..00000000 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/history.rs +++ /dev/null @@ -1,149 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use super::exchange::ChatExchange; -use super::{LLMDefinition, PromptRole}; - -#[derive(Debug, Clone)] -pub struct ChatHistory { - exchanges: Vec, - keep_n: Option, // keep n exchanges in history if reset -} - -impl ChatHistory { - pub fn new() -> Self { - ChatHistory { - exchanges: Vec::new(), - keep_n: None, - } - } - - pub fn new_with_exchanges(exchanges: Vec) -> Self { - let keep_n = Some(exchanges.len()); // keep initial exchanges if reset - ChatHistory { exchanges, keep_n } - } - - pub fn reset(&mut self) { - if let Some(keep_n) = self.keep_n { - self.exchanges.truncate(keep_n); - } else { - self.exchanges.clear(); - } - } - - pub fn get_last_exchange_mut(&mut self) -> Option<&mut ChatExchange> { - self.exchanges.last_mut() - } - - pub fn update_last_exchange(&mut self, answer: &str) { - if let Some(last_exchange) = self.exchanges.last_mut() { - last_exchange.push_to_answer(answer); - } - } - - pub fn new_prompt( - &mut self, - new_exchange: ChatExchange, - max_token_length: usize, - system_prompt_length: Option, - ) -> Vec { - let mut result_exchanges = Vec::new(); - - // instruction and new exchange should always be added, - // calculate the remaining tokens to see how much history can be added - let tokens_remaining = { - let tokens_required = new_exchange.get_token_length().unwrap_or(0) - + system_prompt_length.unwrap_or(0); - max_token_length.saturating_sub(tokens_required) - }; - - // cleanup last exchange if second (answer) element is un-answered (empty) - if let Some(last_exchange) = self.exchanges.last() { - if last_exchange.get_answer().is_empty() { - self.exchanges.pop(); - } - } - - let mut history_tokens = 0; - - for exchange in self.exchanges.iter().rev() { - let exchange_tokens = exchange.get_token_length().unwrap_or(0); - if history_tokens + exchange_tokens > tokens_remaining { - break; - } - history_tokens += exchange_tokens; - result_exchanges.insert(0, exchange.clone()); - } - - // add the new exchange to both the result and the history - result_exchanges.push(new_exchange.clone()); - self.exchanges.push(new_exchange); - result_exchanges - } - - pub fn exchanges_to_string<'a, I>( - model: &LLMDefinition, - exchanges: I, - ) -> String - where - I: IntoIterator, - { - let mut prompt = String::new(); - let formatter = model.get_formatter(); - - for exchange in exchanges { - prompt.push_str( - &formatter.fmt_prompt_message( - PromptRole::User, - exchange.get_question(), - ), - ); - prompt.push_str(&formatter.fmt_prompt_message( - PromptRole::Assistant, - exchange.get_answer(), - )); - } - prompt - } - - pub fn exchanges_to_messages<'a, I>( - exchanges: I, - system_prompt: Option<&str>, - fn_role_name: &dyn Fn(PromptRole) -> &'static str, - ) -> Vec - where - I: IntoIterator, - { - let mut messages = Vec::new(); - - if let Some(system_prompt) = system_prompt { - messages.push(ChatMessage { - role: fn_role_name(PromptRole::System).to_string(), - content: system_prompt.to_string(), - }); - } - - for exchange in exchanges { - messages.push(ChatMessage { - role: fn_role_name(PromptRole::User).to_string(), - content: exchange.get_question().to_string(), - }); - - // dont add empty answers - let content = exchange.get_answer().to_string(); - if content.is_empty() { - continue; - } - messages.push(ChatMessage { - role: fn_role_name(PromptRole::Assistant).to_string(), - content, - }); - } - messages - } -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ChatMessage { - pub role: String, - pub content: String, -} diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs index 5c378e59..515bc18a 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs @@ -1,19 +1,24 @@ +use std::sync::{Arc, Mutex, MutexGuard}; + use lumni::api::error::ApplicationError; -use super::history::ChatHistory; use super::prompt::Prompt; +use super::schema::{ + ConversationId, Exchange, InMemoryDatabase, Message, ModelId, +}; use super::{ - ChatCompletionOptions, ChatExchange, PromptOptions, DEFAULT_N_PREDICT, - DEFAULT_TEMPERATURE, PERSONAS, + ChatCompletionOptions, ChatMessage, PromptOptions, PromptRole, + DEFAULT_N_PREDICT, DEFAULT_TEMPERATURE, PERSONAS, }; pub use crate::external as lumni; pub struct PromptInstruction { completion_options: ChatCompletionOptions, - prompt_options: PromptOptions, + prompt_options: PromptOptions, // TODO: get from db system_prompt: SystemPrompt, - history: ChatHistory, prompt_template: Option, + pub db: Arc>, + current_conversation_id: ConversationId, } impl Default for PromptInstruction { @@ -28,8 +33,9 @@ impl Default for PromptInstruction { completion_options, prompt_options: PromptOptions::default(), system_prompt: SystemPrompt::default(), - history: ChatHistory::new(), prompt_template: None, + db: Arc::new(Mutex::new(InMemoryDatabase::new())), + current_conversation_id: ConversationId(0), } } } @@ -67,48 +73,205 @@ impl PromptInstruction { } else if let Some(instruction) = instruction { prompt_instruction.set_system_prompt(instruction); }; + + // Create a new Conversation in the database + let conversation_id = { + let mut db_lock = prompt_instruction.db.lock().unwrap(); + db_lock.new_conversation("New Conversation", None) + }; + prompt_instruction.current_conversation_id = conversation_id; + Ok(prompt_instruction) } pub fn reset_history(&mut self) { - self.history.reset(); + // Create a new Conversation in the database + let new_conversation_id = { + let mut db_lock = self.db.lock().unwrap(); + db_lock.new_conversation( + "New Conversation", + Some(self.current_conversation_id), + ) + }; + self.current_conversation_id = new_conversation_id; + } + + pub fn append_last_response(&mut self, answer: &str) { + ExchangeHandler::append_response( + &mut self.db.lock().unwrap(), + self.current_conversation_id, + answer, + ); } - pub fn update_last_exchange(&mut self, answer: &str) { - self.history.update_last_exchange(answer); + pub fn get_last_response(&self) -> Option { + ExchangeHandler::get_last_response( + &self.db.lock().unwrap(), + self.current_conversation_id, + ) } - pub fn get_last_exchange_mut(&mut self) -> Option<&mut ChatExchange> { - self.history.get_last_exchange_mut() + pub fn put_last_response( + &mut self, + answer: &str, + tokens_predicted: Option, + ) { + ExchangeHandler::put_last_response( + &mut self.db.lock().unwrap(), + self.current_conversation_id, + answer, + tokens_predicted, + ); + } + + fn first_exchange( + &self, + db_lock: &mut MutexGuard<'_, InMemoryDatabase>, + ) -> Exchange { + Exchange { + id: db_lock.new_exchange_id(), + conversation_id: self.current_conversation_id, + model_id: ModelId(0), + system_prompt: self.system_prompt.instruction.clone(), + completion_options: serde_json::to_value(&self.completion_options) + .unwrap_or_default(), + prompt_options: serde_json::to_value(&self.prompt_options) + .unwrap_or_default(), + completion_tokens: 0, + prompt_tokens: 0, + created_at: 0, + previous_exchange_id: None, + } } - pub fn new_prompt( + pub fn subsequent_exchange( &mut self, - new_exchange: ChatExchange, + question: &str, + token_length: Option, max_token_length: usize, - n_keep: Option, - ) -> Vec { - self.history - .new_prompt(new_exchange, max_token_length, n_keep) + ) -> Vec { + let mut db_lock = self.db.lock().unwrap(); + + // token budget for the system prompt + let system_prompt_token_length = self.get_n_keep().unwrap_or(0); + + // add the partial exchange (question) to the conversation + let last_exchange = + db_lock.get_last_exchange(self.current_conversation_id); + + let exchange = if let Some(last) = last_exchange { + // add exchange based on the last one + Exchange { + id: db_lock.new_exchange_id(), + conversation_id: self.current_conversation_id, + model_id: last.model_id, + system_prompt: last.system_prompt.clone(), // copy from previous exchange + completion_options: last.completion_options.clone(), // copy from previous exchange + prompt_options: last.prompt_options.clone(), // copy from previous exchange + completion_tokens: 0, + prompt_tokens: 0, + created_at: 0, + previous_exchange_id: Some(last.id), + } + } else { + // create first exchange + let exchange = self.first_exchange(&mut db_lock); + let system_message = Message { + id: db_lock.new_message_id(), + conversation_id: self.current_conversation_id, + exchange_id: exchange.id, + role: PromptRole::System, + message_type: "text".to_string(), + content: self.system_prompt.get_instruction().to_string(), + has_attachments: false, + token_length: self.system_prompt.get_token_length().unwrap_or(0) as i32, + created_at: 0, + }; + db_lock.add_message(system_message); + exchange + }; + + let user_message = Message { + id: db_lock.new_message_id(), + conversation_id: self.current_conversation_id, + exchange_id: exchange.id, + role: PromptRole::User, + message_type: "text".to_string(), + content: question.to_string(), + has_attachments: false, + token_length: token_length.unwrap_or(0) as i32, + created_at: 0, + }; + + db_lock.add_exchange(exchange); + // new_prompt only has user question, answer is added later + db_lock.add_message(user_message); + + let current_exchanges = + db_lock.get_conversation_exchanges(self.current_conversation_id); + + // Collect messages while respecting token limits + let mut messages: Vec = Vec::new(); + let mut total_tokens = system_prompt_token_length; + + // Add messages from most recent to oldest, respecting token limit + for exchange in current_exchanges.into_iter().rev() { + for msg in + db_lock.get_exchange_messages(exchange.id).into_iter().rev() + { + if msg.role == PromptRole::System { + continue; // system prompt is included separately + } + if total_tokens + msg.token_length as usize <= max_token_length + { + total_tokens += msg.token_length as usize; + messages.push(ChatMessage { + role: msg.role, + content: msg.content.clone(), + }); + } else { + // reached token limit + break; + } + } + + if total_tokens >= max_token_length { + break; + } + } + // ensure the system prompt is always included + // after reverse, the system prompt will be at the beginning + messages.push(ChatMessage { + role: PromptRole::System, + content: self.system_prompt.get_instruction().to_string(), + }); + // Reverse the messages to maintain chronological order + messages.reverse(); + messages } pub fn get_completion_options(&self) -> &ChatCompletionOptions { + // no need to change this yet &self.completion_options } pub fn get_completion_options_mut(&mut self) -> &mut ChatCompletionOptions { + // no need to change this yet &mut self.completion_options } pub fn get_prompt_options(&self) -> &PromptOptions { + // no need to change this yet &self.prompt_options } pub fn get_prompt_options_mut(&mut self) -> &mut PromptOptions { + // no need to change this yet &mut self.prompt_options } pub fn get_n_keep(&self) -> Option { + // no need to change this yet self.completion_options.get_n_keep() } @@ -137,54 +300,186 @@ impl PromptInstruction { assistant: String, user_instruction: Option, ) -> Result<(), ApplicationError> { - // Find the selected persona by name let assistant_prompts: Vec = serde_yaml::from_str(PERSONAS) .map_err(|e| { ApplicationError::Unexpected(format!( "Failed to parse persona data: {}", - e.to_string() + e )) })?; - if let Some(prompt) = assistant_prompts + + let prompt = assistant_prompts .into_iter() .find(|p| p.name() == assistant) - { - // system prompt is the assistant instruction + user instruction - // default to empty string if either is not available - let system_prompt = - if let Some(assistant_instruction) = prompt.system_prompt() { - let system_prompt = - if let Some(user_instruction) = user_instruction { - // strip trailing whitespace from assistant instruction - format!( - "{} {}", - assistant_instruction.trim_end(), - user_instruction - ) - } else { - assistant_instruction.to_string() - }; - system_prompt - } else { - user_instruction.unwrap_or_default() + .ok_or_else(|| { + ApplicationError::Unexpected(format!( + "Assistant '{}' not found in the dataset", + assistant + )) + })?; + + let system_prompt = build_system_prompt(&prompt, &user_instruction); + self.set_system_prompt(system_prompt.clone()); + + if let Some(exchanges) = prompt.exchanges() { + let mut db_lock = self.db.lock().unwrap(); + + // Create a new exchange with the system prompt + let exchange = self.first_exchange(&mut db_lock); + let system_message = Message { + id: db_lock.new_message_id(), + conversation_id: self.current_conversation_id, + exchange_id: exchange.id, + role: PromptRole::System, + message_type: "text".to_string(), + content: system_prompt, + has_attachments: false, + token_length: 0, // TODO: compute token length + created_at: 0, + }; + db_lock.add_message(system_message); + + for loaded_exchange in exchanges.iter() { + let user_message = Message { + id: db_lock.new_message_id(), + conversation_id: self.current_conversation_id, + exchange_id: exchange.id, + role: PromptRole::User, + message_type: "text".to_string(), + content: loaded_exchange.question.clone(), + has_attachments: false, + token_length: 0, // Implement proper token counting + created_at: 0, // Use proper timestamp + }; + let assistant_message = Message { + id: db_lock.new_message_id(), + conversation_id: self.current_conversation_id, + exchange_id: exchange.id, + role: PromptRole::Assistant, + message_type: "text".to_string(), + content: loaded_exchange.answer.clone(), + has_attachments: false, + token_length: 0, // Implement proper token counting + created_at: 0, // Use proper timestamp }; - self.set_system_prompt(system_prompt); + db_lock.add_message(user_message); + db_lock.add_message(assistant_message); + } + } + + if let Some(prompt_template) = prompt.prompt_template() { + self.prompt_template = Some(prompt_template.to_string()); + } + + Ok(()) + } +} + +fn build_system_prompt( + prompt: &Prompt, + user_instruction: &Option, +) -> String { + match (prompt.system_prompt(), user_instruction) { + (Some(assistant_instruction), Some(user_instr)) => { + format!("{} {}", assistant_instruction.trim_end(), user_instr) + } + (Some(assistant_instruction), None) => { + assistant_instruction.to_string() + } + (None, Some(user_instr)) => user_instr.to_string(), + (None, None) => String::new(), + } +} - // Load predefined exchanges from persona if available - if let Some(exchanges) = prompt.exchanges() { - self.history = - ChatHistory::new_with_exchanges(exchanges.clone()); +pub struct ExchangeHandler; + +impl ExchangeHandler { + pub fn append_response( + db_lock: &mut MutexGuard, + current_conversation_id: ConversationId, + answer: &str, + ) { + let last_exchange = db_lock.get_last_exchange(current_conversation_id); + + if let Some(exchange) = last_exchange { + let last_message = + db_lock.get_last_message_of_exchange(exchange.id).cloned(); + + match last_message { + Some(msg) if msg.role == PromptRole::Assistant => { + // If the last message is from Assistant, append to it + let new_content = + format!("{}{}", msg.content, answer).to_string(); + db_lock.update_message_by_id(msg.id, &new_content, None); + } + _ => { + // If the last message is from User or there's no message, create a new Assistant message + let new_message = Message { + id: db_lock.new_message_id(), + conversation_id: current_conversation_id, + exchange_id: exchange.id, + role: PromptRole::Assistant, + message_type: "text".to_string(), + content: answer.to_string(), + has_attachments: false, + token_length: answer.len() as i32, // Simplified token count + created_at: 0, // You might want to use a proper timestamp here + }; + db_lock.add_message(new_message); + } } + } else { + // If there's no exchange, something went wrong + eprintln!("Error: No exchange found when trying to append answer"); + } + } + + pub fn get_last_response( + db_lock: &MutexGuard, + current_conversation_id: ConversationId, + ) -> Option { + db_lock + .get_last_exchange(current_conversation_id) + .and_then(|last_exchange| { + db_lock.get_last_message_of_exchange(last_exchange.id) + }) + .and_then(|last_message| { + if last_message.role == PromptRole::Assistant { + Some(last_message.content.clone()) + } else { + None + } + }) + } - if let Some(prompt_template) = prompt.prompt_template() { - self.prompt_template = Some(prompt_template.to_string()); + pub fn put_last_response( + db_lock: &mut MutexGuard, + current_conversation_id: ConversationId, + answer: &str, + tokens_predicted: Option, + ) { + let (message_id, is_assistant) = if let Some(last_exchange) = + db_lock.get_last_exchange(current_conversation_id) + { + if let Some(last_message) = + db_lock.get_last_message_of_exchange(last_exchange.id) + { + // Check the role directly here and only pass on the ID if it's an assistant's message + ( + Some(last_message.id), + last_message.role == PromptRole::Assistant, + ) + } else { + (None, false) } - Ok(()) } else { - Err(ApplicationError::Unexpected(format!( - "Assistant '{}' not found in the dataset", - assistant - ))) + (None, false) + }; + + // Perform the update if the message is from an assistant + if let (Some(id), true) = (message_id, is_assistant) { + let token_length = tokens_predicted.map(|t| t as i32); + db_lock.update_message_by_id(id, answer, token_length); } } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs index a85b0f00..48b40013 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs @@ -1,15 +1,11 @@ use std::error::Error; - -mod exchange; -mod history; mod instruction; mod options; mod prompt; +mod schema; mod send; mod session; -pub use exchange::ChatExchange; -pub use history::{ChatHistory, ChatMessage}; pub use instruction::PromptInstruction; pub use options::{ChatCompletionOptions, PromptOptions}; use prompt::Prompt; @@ -43,3 +39,9 @@ impl TokenResponse { &self.tokens } } + +#[derive(Debug, Clone)] +pub struct ChatMessage { + pub role: PromptRole, + pub content: String, +} diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs index ad3d7766..d252b6f9 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs @@ -94,7 +94,7 @@ impl ChatCompletionOptions { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct RolePrefix { user: String, assistant: String, @@ -121,7 +121,7 @@ impl RolePrefix { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct PromptOptions { n_ctx: Option, #[serde(default)] diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/prompt.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/prompt.rs index 4304482b..9648c504 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/prompt.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/prompt.rs @@ -1,7 +1,5 @@ use serde::{Deserialize, Serialize}; -use super::exchange::ChatExchange; - #[derive(Debug, Serialize, Deserialize)] pub struct Prompt { name: String, @@ -27,3 +25,10 @@ impl Prompt { self.exchanges.as_ref() } } + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ChatExchange { + pub question: String, + pub answer: String, + pub token_length: Option, +} diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/schema.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/schema.rs index 88c4b67e..0ad11eb7 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/schema.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/schema.rs @@ -1,8 +1,11 @@ -use serde::{Serialize, Deserialize}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::thread; + use rusqlite; +use serde::{Deserialize, Serialize}; + +use super::PromptRole; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct ModelId(pub i64); @@ -49,6 +52,7 @@ pub struct Exchange { pub completion_tokens: i32, pub prompt_tokens: i32, pub created_at: i64, + pub previous_exchange_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -56,7 +60,7 @@ pub struct Message { pub id: MessageId, pub conversation_id: ConversationId, pub exchange_id: ExchangeId, - pub role: Role, + pub role: PromptRole, pub message_type: String, pub content: String, pub has_attachments: bool, @@ -64,13 +68,6 @@ pub struct Message { pub created_at: i64, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Role { - User, - Assistant, - System, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub enum AttachmentData { Uri(String), @@ -89,14 +86,30 @@ pub struct Attachment { pub created_at: i64, } +impl Conversation { + pub fn new(name: &str) -> Self { + Conversation { + id: ConversationId(0), // You might want to generate a unique ID here + name: name.to_string(), + metadata: serde_json::Value::Null, + parent_conversation_id: ConversationId(0), + fork_exchange_id: ExchangeId(0), + schema_version: 1, + created_at: 0, // not using timestamps for now, stick with 0 for now + updated_at: 0, // not using timestamps for now, stick with 0 for now + } + } +} + +#[derive(Debug)] pub struct InMemoryDatabase { models: HashMap, conversations: HashMap, exchanges: HashMap, messages: HashMap, attachments: HashMap, - - conversation_exchanges: HashMap>, + + conversation_exchanges: HashMap>, exchange_messages: HashMap>, message_attachments: HashMap>, } @@ -115,6 +128,37 @@ impl InMemoryDatabase { } } + pub fn new_conversation_id(&self) -> ConversationId { + ConversationId(self.conversations.len() as i64) + } + + pub fn new_exchange_id(&self) -> ExchangeId { + ExchangeId(self.exchanges.len() as i64) + } + + pub fn new_message_id(&self) -> MessageId { + MessageId(self.messages.len() as i64) + } + + pub fn new_attachment_id(&self) -> AttachmentId { + AttachmentId(self.attachments.len() as i64) + } + + pub fn new_conversation( + &mut self, + name: &str, + parent_id: Option, + ) -> ConversationId { + let new_id = self.new_conversation_id(); + let mut conversation = Conversation::new(name); + conversation.id = new_id; + if let Some(parent) = parent_id { + conversation.parent_conversation_id = parent; + } + self.add_conversation(conversation); + new_id + } + pub fn add_model(&mut self, model: Model) { self.models.insert(model.model_id, model); } @@ -127,7 +171,7 @@ impl InMemoryDatabase { self.conversation_exchanges .entry(exchange.conversation_id) .or_default() - .insert(exchange.id); + .push(exchange.id); self.exchanges.insert(exchange.id, exchange); } @@ -139,15 +183,39 @@ impl InMemoryDatabase { self.messages.insert(message.id, message); } + pub fn update_message(&mut self, updated_message: Message) { + if let Some(existing_message) = self.messages.get_mut(&updated_message.id) { + *existing_message = updated_message; + } + } + + pub fn update_message_by_id( + &mut self, + message_id: MessageId, + new_content: &str, + new_token_length: Option, + ) { + if let Some(message) = self.messages.get_mut(&message_id) { + message.content = new_content.to_string(); + if let Some(token_length) = new_token_length { + message.token_length = token_length; + } + } + } + pub fn add_attachment(&mut self, attachment: Attachment) { self.message_attachments .entry(attachment.message_id) .or_default() .push(attachment.attachment_id); - self.attachments.insert(attachment.attachment_id, attachment); + self.attachments + .insert(attachment.attachment_id, attachment); } - pub fn get_conversation_exchanges(&self, conversation_id: ConversationId) -> Vec<&Exchange> { + pub fn get_conversation_exchanges( + &self, + conversation_id: ConversationId, + ) -> Vec<&Exchange> { self.conversation_exchanges .get(&conversation_id) .map(|exchange_ids| { @@ -159,7 +227,22 @@ impl InMemoryDatabase { .unwrap_or_default() } - pub fn get_exchange_messages(&self, exchange_id: ExchangeId) -> Vec<&Message> { + pub fn get_last_exchange( + &self, + conversation_id: ConversationId, + ) -> Option { + self.conversation_exchanges + .get(&conversation_id) + .and_then(|exchanges| exchanges.last()) + .and_then(|last_exchange_id| { + self.exchanges.get(last_exchange_id).cloned() + }) + } + + pub fn get_exchange_messages( + &self, + exchange_id: ExchangeId, + ) -> Vec<&Message> { self.exchange_messages .get(&exchange_id) .map(|message_ids| { @@ -171,7 +254,20 @@ impl InMemoryDatabase { .unwrap_or_default() } - pub fn get_message_attachments(&self, message_id: MessageId) -> Vec<&Attachment> { + pub fn get_last_message_of_exchange( + &self, + exchange_id: ExchangeId, + ) -> Option<&Message> { + self.exchange_messages + .get(&exchange_id) + .and_then(|messages| messages.last()) + .and_then(|last_message_id| self.messages.get(last_message_id)) + } + + pub fn get_message_attachments( + &self, + message_id: MessageId, + ) -> Vec<&Attachment> { self.message_attachments .get(&message_id) .map(|attachment_ids| { @@ -201,7 +297,7 @@ impl Database { pub fn save_in_background(&self) { let in_memory = Arc::clone(&self.in_memory); let sqlite_conn = Arc::clone(&self.sqlite_conn); - + thread::spawn(move || { let data = in_memory.lock().unwrap(); let conn = sqlite_conn.lock().unwrap(); @@ -233,4 +329,4 @@ impl Database { let mut data = self.in_memory.lock().unwrap(); data.add_attachment(attachment); } -} \ No newline at end of file +} diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/schema.sql b/lumni/src/apps/builtin/llm/prompt/src/chat/schema.sql index 4e80a58f..6de4bfa9 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/schema.sql +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/schema.sql @@ -27,8 +27,10 @@ CREATE TABLE exchanges ( completion_tokens INTEGER, prompt_tokens INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + previous_exchange_id INTEGER, FOREIGN KEY (conversation_id) REFERENCES conversations(id), FOREIGN KEY (model_id) REFERENCES models(model_id) + FOREIGN KEY (previous_exchange_id) REFERENCES exchanges(id) ); CREATE TABLE messages ( diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs index 96fbf713..0f5ba308 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs @@ -4,9 +4,7 @@ use std::sync::Arc; use bytes::Bytes; use tokio::sync::{mpsc, oneshot, Mutex}; -use super::exchange::ChatExchange; -use super::history::ChatHistory; -use super::{LLMDefinition, PromptInstruction, ServerManager}; +use super::{LLMDefinition, PromptInstruction, PromptRole, ServerManager}; use crate::api::error::ApplicationError; pub struct ChatSession { @@ -70,46 +68,27 @@ impl ChatSession { } pub fn update_last_exchange(&mut self, answer: &str) { - self.prompt_instruction.update_last_exchange(answer); + self.prompt_instruction.append_last_response(answer); } pub async fn finalize_last_exchange( &mut self, _tokens_predicted: Option, ) -> Result<(), ApplicationError> { - // extract the last exchange, trim and tokenize it - let token_length = if let Some(last_exchange) = - self.prompt_instruction.get_last_exchange_mut() - { - // Strip off trailing whitespaces or newlines from the last exchange - let trimmed_answer = last_exchange.get_answer().trim().to_string(); - last_exchange.set_answer(trimmed_answer); - - let temp_vec = vec![&*last_exchange]; - let model = self.server.get_selected_model()?; - - let last_prompt_text = - ChatHistory::exchanges_to_string(model, temp_vec); + let last_answer = self.prompt_instruction.get_last_response(); - if let Some(response) = - self.server.tokenizer(&last_prompt_text).await? + if let Some(last_answer) = last_answer { + let trimmed_answer = last_answer.trim(); + let tokens_predicted = if let Some(response) = + self.server.tokenizer(trimmed_answer).await? { Some(response.get_tokens().len()) } else { None - } - } else { - None - }; - - if let Some(token_length) = token_length { - if let Some(last_exchange) = - self.prompt_instruction.get_last_exchange_mut() - { - last_exchange.set_token_length(token_length); - } + }; + self.prompt_instruction + .put_last_response(trimmed_answer, tokens_predicted); } - Ok(()) } @@ -122,12 +101,12 @@ impl ChatSession { .server .get_context_size(&mut self.prompt_instruction) .await?; - let new_exchange = self.initiate_new_exchange(question).await?; - let n_keep = self.prompt_instruction.get_n_keep(); - let exchanges = self.prompt_instruction.new_prompt( - new_exchange, + let (user_question, token_length) = + self.initiate_new_exchange(question).await?; + let messages = self.prompt_instruction.subsequent_exchange( + &user_question, + token_length, max_token_length, - n_keep, ); let (cancel_tx, cancel_rx) = oneshot::channel(); @@ -135,7 +114,7 @@ impl ChatSession { self.server .completion( - &exchanges, + &messages, &self.prompt_instruction, Some(tx), Some(cancel_rx), @@ -147,7 +126,7 @@ impl ChatSession { pub async fn initiate_new_exchange( &self, user_question: &str, - ) -> Result { + ) -> Result<(String, Option), ApplicationError> { let user_question = user_question.trim(); let user_question = if user_question.is_empty() { "continue".to_string() @@ -161,20 +140,18 @@ impl ChatSession { } }; - let mut new_exchange = ChatExchange::new(user_question, "".to_string()); - let temp_vec = vec![&new_exchange]; - let model = self.server.get_selected_model()?; - + let formatter = model.get_formatter(); let last_prompt_text = - ChatHistory::exchanges_to_string(model, temp_vec); - - if let Some(token_response) = + formatter.fmt_prompt_message(&PromptRole::User, &user_question); + let token_length = if let Some(token_response) = self.server.tokenizer(&last_prompt_text).await? { - new_exchange.set_token_length(token_response.get_tokens().len()); - } - Ok(new_exchange) + Some(token_response.get_tokens().len()) + } else { + None + }; + Ok((user_question, token_length)) } pub fn process_response( diff --git a/lumni/src/apps/builtin/llm/prompt/src/model/formatter.rs b/lumni/src/apps/builtin/llm/prompt/src/model/formatter.rs index 13f76731..79c9a31d 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/model/formatter.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/model/formatter.rs @@ -1,9 +1,11 @@ use async_trait::async_trait; use regex::Regex; +use serde::{Deserialize, Serialize}; use super::generic::Generic; use super::llama3::Llama3; +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub enum PromptRole { User, Assistant, @@ -59,7 +61,7 @@ impl ModelFormatterTrait for ModelFormatter { fn fmt_prompt_message( &self, - prompt_role: PromptRole, + prompt_role: &PromptRole, message: &str, ) -> String { match self.model { @@ -85,7 +87,7 @@ pub trait ModelFormatterTrait: Send + Sync { } } - fn get_role_prefix(&self, prompt_role: PromptRole) -> &str { + fn get_role_prefix(&self, prompt_role: &PromptRole) -> &str { match prompt_role { PromptRole::User => "### User: ", PromptRole::Assistant => "### Assistant: ", @@ -95,7 +97,7 @@ pub trait ModelFormatterTrait: Send + Sync { fn fmt_prompt_message( &self, - prompt_role: PromptRole, + prompt_role: &PromptRole, message: &str, ) -> String { let prompt_message = match prompt_role { diff --git a/lumni/src/apps/builtin/llm/prompt/src/model/llama3.rs b/lumni/src/apps/builtin/llm/prompt/src/model/llama3.rs index 42ac62a4..ba668601 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/model/llama3.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/model/llama3.rs @@ -32,7 +32,7 @@ impl ModelFormatterTrait for Llama3 { if let Some(instruction) = instruction { return format!( "<|begin_of_text|>{}", - self.fmt_prompt_message(PromptRole::System, instruction) + self.fmt_prompt_message(&PromptRole::System, instruction) ) .to_string(); } else { @@ -42,7 +42,7 @@ impl ModelFormatterTrait for Llama3 { fn fmt_prompt_message( &self, - prompt_role: PromptRole, + prompt_role: &PromptRole, message: &str, ) -> String { let role_handle = match prompt_role { diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs index c623e5c8..16d9bedb 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs @@ -18,8 +18,8 @@ use tokio::sync::{mpsc, oneshot}; use url::Url; use super::{ - http_post, ChatExchange, ChatHistory, ChatMessage, Endpoints, - LLMDefinition, PromptInstruction, ServerSpecTrait, ServerTrait, + http_post, ChatMessage, Endpoints, LLMDefinition, PromptInstruction, + ServerSpecTrait, ServerTrait, PromptRole, }; pub use crate::external as lumni; @@ -55,22 +55,32 @@ impl Bedrock { fn completion_api_payload( &self, _model: &LLMDefinition, - exchanges: &Vec, - system_prompt: Option<&str>, + chat_messages: &Vec, ) -> Result { - // Convert ChatExchange to list of ChatMessages - let chat_messages: Vec = - ChatHistory::exchanges_to_messages( - exchanges, - None, // dont add system prompt for Bedrock, this is added in the system field - &|role| self.get_role_name(role), - ); + // Check if the first message is a system prompt + let system_prompt = match chat_messages.first() { + Some(chat_message) => { + if chat_message.role == PromptRole::System { + Some(chat_message.content.clone()) + } else { + None + } + } + None => None, + }; + // skip system prompt if it exists + let skip = if system_prompt.is_some() { + 1 + } else { + 0 + }; // Convert ChatMessages to Messages for BedrockRequestPayload let messages: Vec = chat_messages .iter() + .skip(skip) .map(|chat_message| Message { - role: chat_message.role.clone(), + role: self.get_role_name(&chat_message.role).to_string(), content: vec![Content { text: Some(chat_message.content.clone()), image: None, @@ -82,7 +92,7 @@ impl Bedrock { }) .collect(); - // Cconvert system_prompt to a system message for BedrockRequestPayload + // Convert system_prompt to a system message for BedrockRequestPayload let system = if let Some(prompt) = system_prompt { Some(vec![SystemMessage { text: prompt.to_string(), @@ -152,13 +162,12 @@ impl ServerTrait for Bedrock { async fn completion( &self, - exchanges: &Vec, - prompt_instruction: &PromptInstruction, + messages: &Vec, + _prompt_instruction: &PromptInstruction, tx: Option>, cancel_rx: Option>, ) -> Result<(), ApplicationError> { let model = self.get_selected_model()?; - let system_prompt = prompt_instruction.get_instruction(); let resource = HttpClient::percent_encode_with_exclusion( &format!("/model/{}/converse-stream", model.get_name()), @@ -168,7 +177,7 @@ impl ServerTrait for Bedrock { let full_url = format!("{}{}", completion_endpoint, resource); let data_payload = self - .completion_api_payload(model, exchanges, Some(system_prompt)) + .completion_api_payload(model, messages) .map_err(|e| { ApplicationError::InvalidUserConfiguration(e.to_string()) })?; diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs index de151b21..c0202ab9 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs @@ -10,10 +10,9 @@ use tokio::sync::{mpsc, oneshot}; use url::Url; use super::{ - http_get_with_response, http_post, ChatCompletionOptions, ChatExchange, - ChatHistory, Endpoints, HttpClient, LLMDefinition, PromptInstruction, - PromptRole, ServerSpecTrait, ServerTrait, TokenResponse, - DEFAULT_CONTEXT_SIZE, + http_get_with_response, http_post, ChatCompletionOptions, ChatMessage, + Endpoints, HttpClient, LLMDefinition, PromptInstruction, PromptRole, + ServerSpecTrait, ServerTrait, TokenResponse, DEFAULT_CONTEXT_SIZE, }; use crate::external as lumni; @@ -76,7 +75,6 @@ impl Llama { fn completion_api_payload( &self, prompt: String, - _exchanges: &Vec, prompt_instruction: &PromptInstruction, ) -> Result { let payload = LlamaServerPayload { @@ -137,15 +135,21 @@ impl ServerTrait for Llama { async fn completion( &self, - exchanges: &Vec, + messages: &Vec, prompt_instruction: &PromptInstruction, tx: Option>, cancel_rx: Option>, ) -> Result<(), ApplicationError> { let model = self.get_selected_model()?; - let prompt = ChatHistory::exchanges_to_string(model, exchanges); + let formatter = model.get_formatter(); + let prompt = messages + .into_iter() + .map(|m| formatter.fmt_prompt_message(&m.role, &m.content)) + .collect::>() + .join("\n"); + let data_payload = - self.completion_api_payload(prompt, exchanges, prompt_instruction); + self.completion_api_payload(prompt, prompt_instruction); let completion_endpoint = self.endpoints.get_completion_endpoint()?; diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs index 41af7162..82578c4b 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs @@ -24,8 +24,7 @@ use tokio::sync::{mpsc, oneshot}; pub use super::chat::{ http_get_with_response, http_post, http_post_with_response, - ChatCompletionOptions, ChatExchange, ChatHistory, ChatMessage, - PromptInstruction, TokenResponse, + ChatCompletionOptions, ChatMessage, PromptInstruction, TokenResponse, }; pub use super::defaults::*; pub use super::model::{ModelFormatter, ModelFormatterTrait, PromptRole}; @@ -165,7 +164,7 @@ impl ServerTrait for ModelServer { async fn completion( &self, - exchanges: &Vec, + messages: &Vec, prompt_instruction: &PromptInstruction, tx: Option>, cancel_rx: Option>, @@ -173,22 +172,22 @@ impl ServerTrait for ModelServer { match self { ModelServer::Llama(llama) => { llama - .completion(exchanges, prompt_instruction, tx, cancel_rx) + .completion(messages, prompt_instruction, tx, cancel_rx) .await } ModelServer::Ollama(ollama) => { ollama - .completion(exchanges, prompt_instruction, tx, cancel_rx) + .completion(messages, prompt_instruction, tx, cancel_rx) .await } ModelServer::Bedrock(bedrock) => { bedrock - .completion(exchanges, prompt_instruction, tx, cancel_rx) + .completion(messages, prompt_instruction, tx, cancel_rx) .await } ModelServer::OpenAI(openai) => { openai - .completion(exchanges, prompt_instruction, tx, cancel_rx) + .completion(messages, prompt_instruction, tx, cancel_rx) .await } } @@ -227,7 +226,7 @@ pub trait ServerTrait: Send + Sync { async fn completion( &self, - exchanges: &Vec, + messages: &Vec, prompt_instruction: &PromptInstruction, tx: Option>, cancel_rx: Option>, @@ -296,7 +295,7 @@ pub trait ServerTrait: Send + Sync { Ok(DEFAULT_CONTEXT_SIZE) } - fn get_role_name(&self, prompt_role: PromptRole) -> &'static str { + fn get_role_name(&self, prompt_role: &PromptRole) -> &'static str { match prompt_role { PromptRole::User => "user", PromptRole::Assistant => "assistant", diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs index 9adbe21c..1b70835f 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs @@ -8,9 +8,9 @@ use tokio::sync::{mpsc, oneshot}; use url::Url; use super::{ - http_get_with_response, http_post, http_post_with_response, ChatExchange, - ChatHistory, ChatMessage, Endpoints, HttpClient, LLMDefinition, - PromptInstruction, ServerSpecTrait, ServerTrait, + http_get_with_response, http_post, http_post_with_response, ChatMessage, + Endpoints, HttpClient, LLMDefinition, PromptInstruction, ServerSpecTrait, + ServerTrait, }; use crate::external as lumni; @@ -47,14 +47,15 @@ impl Ollama { fn completion_api_payload( &self, model: &LLMDefinition, - exchanges: &Vec, - system_prompt: Option<&str>, + chat_messages: &Vec, ) -> Result { - let messages = ChatHistory::exchanges_to_messages( - exchanges, - system_prompt, - &|role| self.get_role_name(role), - ); + let messages: Vec = chat_messages + .iter() + .map(|m| OllamaChatMessage { + role: self.get_role_name(&m.role).to_string(), + content: m.content.to_string(), + }) + .collect(); let payload = ServerPayload { model: model.get_name(), @@ -130,16 +131,14 @@ impl ServerTrait for Ollama { async fn completion( &self, - exchanges: &Vec, - prompt_instruction: &PromptInstruction, + messages: &Vec, + _prompt_instruction: &PromptInstruction, tx: Option>, cancel_rx: Option>, ) -> Result<(), ApplicationError> { let model = self.get_selected_model()?; - let system_prompt = prompt_instruction.get_instruction(); - let data_payload = - self.completion_api_payload(model, exchanges, Some(system_prompt)); + self.completion_api_payload(model, messages); let completion_endpoint = self.endpoints.get_completion_endpoint()?; if let Ok(payload) = data_payload { @@ -199,10 +198,16 @@ impl ServerTrait for Ollama { } } +#[derive(Serialize)] +struct OllamaChatMessage { + role: String, + content: String, +} + #[derive(Serialize)] struct ServerPayload<'a> { model: &'a str, - messages: &'a Vec, + messages: &'a Vec, // TODO: reformat and pass options to ollama //#[serde(flatten)] // options: &'a ChatCompletionOptions, diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs index 8775d7d0..4cfec936 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs @@ -13,14 +13,14 @@ use credentials::OpenAICredentials; use error::OpenAIErrorHandler; use lumni::api::error::ApplicationError; use lumni::HttpClient; -use request::OpenAIRequestPayload; +use request::{OpenAIChatMessage, OpenAIRequestPayload}; use response::StreamParser; use tokio::sync::{mpsc, oneshot}; use url::Url; use super::{ - http_post, ChatExchange, ChatHistory, ChatMessage, Endpoints, - LLMDefinition, PromptInstruction, ServerSpecTrait, ServerTrait, + http_post, ChatMessage, Endpoints, LLMDefinition, PromptInstruction, + ServerSpecTrait, ServerTrait, }; pub use crate::external as lumni; @@ -57,14 +57,15 @@ impl OpenAI { fn completion_api_payload( &self, model: &LLMDefinition, - exchanges: &Vec, - system_prompt: Option<&str>, + chat_messages: Vec, ) -> Result { - let messages: Vec = ChatHistory::exchanges_to_messages( - exchanges, - system_prompt, - &|role| self.get_role_name(role), - ); + let messages: Vec = chat_messages + .iter() + .map(|m| OpenAIChatMessage { + role: self.get_role_name(&m.role).to_string(), + content: m.content.to_string(), + }) + .collect(); let openai_request_payload = OpenAIRequestPayload { model: model.get_name().to_string(), @@ -113,17 +114,19 @@ impl ServerTrait for OpenAI { async fn completion( &self, - exchanges: &Vec, - prompt_instruction: &PromptInstruction, + messages: &Vec, + _prompt_instruction: &PromptInstruction, tx: Option>, cancel_rx: Option>, ) -> Result<(), ApplicationError> { let model = self.get_selected_model()?; - let system_prompt = prompt_instruction.get_instruction(); let completion_endpoint = self.endpoints.get_completion_endpoint()?; let data_payload = self - .completion_api_payload(model, exchanges, Some(system_prompt)) + .completion_api_payload( + model, + messages.clone(), + ) .map_err(|e| { ApplicationError::InvalidUserConfiguration(e.to_string()) })?; diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/openai/request.rs b/lumni/src/apps/builtin/llm/prompt/src/server/openai/request.rs index feec0a99..4663c4ad 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/openai/request.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/openai/request.rs @@ -1,11 +1,9 @@ use serde::Serialize; -use super::ChatMessage; - #[derive(Debug, Serialize)] pub struct OpenAIRequestPayload { pub model: String, - pub messages: Vec, + pub messages: Vec, pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub frequency_penalty: Option, @@ -30,3 +28,9 @@ impl OpenAIRequestPayload { serde_json::to_string(&self) } } + +#[derive(Debug, Serialize)] +pub struct OpenAIChatMessage { + pub role: String, + pub content: String, +} diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/components/text_buffer.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/components/text_buffer.rs index 78c37632..1a97eae2 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/components/text_buffer.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/components/text_buffer.rs @@ -183,18 +183,16 @@ pub struct TextBuffer<'a> { display: TextDisplay<'a>, // text (e.g. wrapped, highlighted) for display cursor: Cursor, code_blocks: Vec, // code blocks - is_editable: bool, } impl TextBuffer<'_> { - pub fn new(is_editable: bool) -> Self { + pub fn new() -> Self { Self { text: PieceTable::new(), placeholder: String::new(), display: TextDisplay::new(0), cursor: Cursor::new(), code_blocks: Vec::new(), - is_editable, } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/components/text_window.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/components/text_window.rs index e7dd32a3..7ab30194 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/components/text_window.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/components/text_window.rs @@ -20,12 +20,11 @@ pub struct TextWindow<'a> { impl<'a> TextWindow<'a> { pub fn new(window_type: WindowConfig) -> Self { - let is_editable = window_type.is_editable(); Self { area: RectArea::default(), window_type, scroller: Scroller::new(), - text_buffer: TextBuffer::new(is_editable), + text_buffer: TextBuffer::new(), } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs index aca64481..9e5ad07f 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs @@ -3,10 +3,9 @@ use std::io; use ratatui::backend::Backend; use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect}; use ratatui::style::{Color, Style}; -use ratatui::text::Text; -use ratatui::widgets::block::{Padding, Position, Title}; +use ratatui::widgets::block::{Position, Title}; use ratatui::widgets::{ - Block, Borders, Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState, + Block, Borders, Scrollbar, ScrollbarOrientation, }; use ratatui::Terminal;