From b3124303a1775e729c09608a0183c0b1f197f84e Mon Sep 17 00:00:00 2001 From: Anthony Potappel Date: Tue, 23 Jul 2024 23:19:43 +0200 Subject: [PATCH] load last conversation text into response window at start-up, fix bug where system message was not properly stored in db --- lumni/src/apps/builtin/llm/prompt/src/app.rs | 21 ++--- .../builtin/llm/prompt/src/chat/assistant.rs | 47 +++++----- .../llm/prompt/src/chat/conversation/cache.rs | 87 ++++++++++++------- .../src/chat/conversation/instruction.rs | 29 +++++-- .../llm/prompt/src/chat/conversation/mod.rs | 4 +- .../prompt/src/chat/conversation/prepare.rs | 2 +- .../apps/builtin/llm/prompt/src/chat/mod.rs | 1 + .../builtin/llm/prompt/src/chat/options.rs | 3 +- .../builtin/llm/prompt/src/chat/session.rs | 10 ++- .../apps/builtin/llm/prompt/src/session.rs | 10 ++- .../llm/prompt/src/tui/components/mod.rs | 1 + .../prompt/src/tui/components/piece_table.rs | 79 ++++++++++++++--- .../prompt/src/tui/components/text_buffer.rs | 17 ++-- .../prompt/src/tui/components/text_window.rs | 8 +- .../prompt/src/tui/components/text_wrapper.rs | 19 ++-- .../src/tui/events/handle_response_window.rs | 10 +-- .../apps/builtin/llm/prompt/src/tui/mod.rs | 2 +- .../src/apps/builtin/llm/prompt/src/tui/ui.rs | 7 +- .../builtin/llm/prompt/src/tui/windows.rs | 11 +-- 19 files changed, 251 insertions(+), 117 deletions(-) diff --git a/lumni/src/apps/builtin/llm/prompt/src/app.rs b/lumni/src/apps/builtin/llm/prompt/src/app.rs index ac7d8601..bd815bd0 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/app.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/app.rs @@ -428,21 +428,22 @@ pub async fn run_cli( // check if the last conversation is the same as the new conversation, if so, // continue the conversation, otherwise start a new conversation - let prompt_instruction = - if let Some(conversation_id) = db_conn.fetch_last_conversation_id()? { + let prompt_instruction = db_conn + .fetch_last_conversation_id()? + .and_then(|conversation_id| { let reader = db_conn.get_conversation_reader(conversation_id); - let is_equal = new_conversation.is_equal(&reader)?; - if is_equal { + // Convert Result to Option using .ok() + if new_conversation.is_equal(&reader).ok()? { log::debug!("Continuing last conversation"); - PromptInstruction::from_reader(&reader)? + Some(PromptInstruction::from_reader(&reader)) } else { - log::debug!("Starting new conversation"); - PromptInstruction::new(new_conversation, &db_conn)? + None } - } else { + }) + .unwrap_or_else(|| { log::debug!("Starting new conversation"); - PromptInstruction::new(new_conversation, &db_conn)? - }; + PromptInstruction::new(new_conversation, &db_conn) + })?; let chat_session = ChatSession::new(&server_name, prompt_instruction, &db_conn).await?; diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/assistant.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/assistant.rs index ffaed7fb..05c26f72 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/assistant.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/assistant.rs @@ -21,16 +21,18 @@ impl AssistantManager { completion_options: ChatCompletionOptions::default(), }; - // TODO: default should only apply to servers that do no handle this internally // Use default assistant when both system_promt and assistant_name are None let assistant_name = assistant_name.or_else(|| { - if user_instruction.is_some() { - // assistant not needed - None - } else { - // no user instruction, use default assistant - Some("Default".to_string()) - } + user_instruction + .as_ref() + .map(|instruction| { + manager.add_system_message(instruction.to_string()); + None // No assistant required + }) + .unwrap_or_else(|| { + // TODO: default should only apply to servers that do no handle this internally + Some("Default".to_string()) // Use default assistant + }) }); if let Some(assistant_name) = assistant_name { @@ -40,6 +42,21 @@ impl AssistantManager { Ok(manager) } + fn add_system_message(&mut self, system_prompt: String) { + self.initial_messages.push(Message { + id: MessageId(0), // system message is always the first message + conversation_id: ConversationId(0), // temporary conversation id + role: PromptRole::System, + message_type: "text".to_string(), + content: system_prompt, + has_attachments: false, + token_length: None, + previous_message_id: None, + created_at: 0, + is_deleted: false, + }); + } + fn load_assistant( &mut self, assistant_name: String, @@ -66,18 +83,7 @@ impl AssistantManager { let system_prompt = build_system_prompt(&prompt, &user_instruction); // Add system message - self.initial_messages.push(Message { - id: MessageId(0), // system message is always the first message - conversation_id: ConversationId(0), // temporary conversation id - role: PromptRole::System, - message_type: "text".to_string(), - content: system_prompt, - has_attachments: false, - token_length: None, - previous_message_id: None, - created_at: 0, - is_deleted: false, - }); + self.add_system_message(system_prompt); // Add exchanges if any if let Some(exchanges) = prompt.exchanges() { @@ -116,6 +122,7 @@ impl AssistantManager { let assistant_options = AssistantOptions { name: assistant_name, + preloaded_messages: self.initial_messages.len() - 1, // exclude the first system message prompt_template: prompt.prompt_template().map(|s| s.to_string()), }; self.completion_options diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs index 03379b70..b1fcb1bf 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs @@ -1,28 +1,27 @@ use std::collections::HashMap; -use super::db::{ - Attachment, AttachmentId, ConversationId, Message, MessageId, - ModelIdentifier, ModelSpec, -}; -use super::PromptRole; +use ratatui::style::Style; + +use super::db::{Attachment, AttachmentId, ConversationId, Message, MessageId}; +use super::{ColorScheme, PromptRole, TextSegment}; #[derive(Debug)] pub struct ConversationCache { conversation_id: ConversationId, - models: HashMap, messages: Vec, // messages have to be ordered attachments: HashMap, message_attachments: HashMap>, + preloaded_messages: usize, } impl ConversationCache { pub fn new() -> Self { ConversationCache { conversation_id: ConversationId(-1), - models: HashMap::new(), messages: Vec::new(), attachments: HashMap::new(), message_attachments: HashMap::new(), + preloaded_messages: 0, } } @@ -34,6 +33,10 @@ impl ConversationCache { self.conversation_id = conversation_id; } + pub fn set_preloaded_messages(&mut self, preloaded_messages: usize) { + self.preloaded_messages = preloaded_messages; + } + pub fn new_message_id(&self) -> MessageId { MessageId(self.messages.len() as i64) } @@ -42,10 +45,6 @@ impl ConversationCache { AttachmentId(self.attachments.len() as i64) } - pub fn add_model(&mut self, model: ModelSpec) { - self.models.insert(model.identifier.clone(), model); - } - pub fn add_message(&mut self, message: Message) { self.messages.push(message); } @@ -100,27 +99,57 @@ impl ConversationCache { self.attachments .insert(attachment.attachment_id, attachment); } +} - pub fn get_message_attachments( +impl ConversationCache { + pub fn export_conversation( &self, - message_id: MessageId, - ) -> Vec<&Attachment> { - self.message_attachments - .get(&message_id) - .map(|attachment_ids| { - attachment_ids - .iter() - .filter_map(|id| self.attachments.get(id)) - .collect() - }) - .unwrap_or_default() - } + color_scheme: &ColorScheme, + ) -> Vec { + let mut skip_count = self.preloaded_messages; + + // Check if the first message is a system message, and skip it + if !self.messages.is_empty() + && self.messages[0].role == PromptRole::System + { + skip_count += 1; + } - pub fn get_system_prompt(&self) -> Option { - // system prompt is the first message in the conversation self.messages - .first() - .filter(|m| m.role == PromptRole::System) - .map(|m| m.content.clone()) + .iter() + .skip(skip_count) // Skip preloaded messages + .filter(|m| !m.is_deleted && m.role != PromptRole::System) + .flat_map(|message| { + let style = match message.role { + PromptRole::User => Some(color_scheme.get_primary_style()), + PromptRole::Assistant => { + Some(color_scheme.get_secondary_style()) + } + _ => None, + }; + + let mut segments = Vec::new(); + + // Add the message content + let text = message.content.clone(); + segments.push(TextSegment { text, style }); + + // For assistant messages, add an extra newline with the same style + if message.role == PromptRole::Assistant { + segments.push(TextSegment { + text: "\n".to_string(), + style, + }); + } + + // Add an empty unstyled line after each message + segments.push(TextSegment { + text: "\n".to_string(), + style: Some(Style::reset()), + }); + + segments + }) + .collect() } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/instruction.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/instruction.rs index d0edea30..e93c0490 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/instruction.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/instruction.rs @@ -2,11 +2,13 @@ use lumni::api::error::ApplicationError; use super::db::{ system_time_in_milliseconds, ConversationCache, ConversationDatabaseStore, - ConversationId, ConversationReader, Message, MessageId, ModelServerName, - ModelSpec, + ConversationId, ConversationReader, Message, MessageId, ModelSpec, }; use super::prepare::NewConversation; -use super::{ChatCompletionOptions, ChatMessage, PromptRole}; +use super::{ + ChatCompletionOptions, ChatMessage, ColorScheme, PromptRole, TextSegment, +}; +use crate::apps::builtin::llm::prompt::src::chat::assistant; pub use crate::external as lumni; #[derive(Debug)] @@ -31,6 +33,7 @@ impl PromptInstruction { None => ChatCompletionOptions::default(), }; + eprintln!("completion_options: {:?}", completion_options); let conversation_id = if let Some(ref model) = new_conversation.model { Some(db_conn.new_conversation( "New Conversation", @@ -58,8 +61,6 @@ impl PromptInstruction { if new_conversation.parent.is_none() { // evaluate system_prompt and initial_messages only if parent is not provided - // TODO: check if first initial message is a System message, - // if not, and prompt is provided, add it as the first message if let Some(messages) = new_conversation.initial_messages { let mut messages_to_insert = Vec::new(); @@ -81,6 +82,9 @@ impl PromptInstruction { prompt_instruction.cache.add_message(message.clone()); messages_to_insert.push(message); } + prompt_instruction + .cache + .set_preloaded_messages(messages_to_insert.len()); // Insert messages into the database db_conn.put_new_messages(&messages_to_insert)?; @@ -108,6 +112,11 @@ impl PromptInstruction { let completion_options: ChatCompletionOptions = serde_json::from_value(completion_options)?; + let preloaded_messages = completion_options + .assistant_options + .as_ref() + .map_or(0, |options| options.preloaded_messages); + let mut prompt_instruction = PromptInstruction { cache: ConversationCache::new(), model: Some(model_spec), @@ -118,6 +127,9 @@ impl PromptInstruction { prompt_instruction .cache .set_conversation_id(conversation_id); + prompt_instruction + .cache + .set_preloaded_messages(preloaded_messages); // Load messages let messages = reader @@ -380,6 +392,13 @@ impl PromptInstruction { messages.reverse(); messages } + + pub fn export_conversation( + &self, + color_scheme: &ColorScheme, + ) -> Vec { + self.cache.export_conversation(color_scheme) + } } fn simple_token_estimator(input: &str, chars_per_token: Option) -> i64 { diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs index 6a9821df..6f827b44 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs @@ -7,7 +7,9 @@ pub use instruction::PromptInstruction; pub use prepare::NewConversation; pub use super::db; -use super::{ChatCompletionOptions, ChatMessage, PromptRole}; +use super::{ + ChatCompletionOptions, ChatMessage, ColorScheme, PromptRole, TextSegment, +}; #[derive(Debug, Clone)] pub struct ParentConversation { diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/prepare.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/prepare.rs index df894858..872b2c91 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/prepare.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/prepare.rs @@ -90,7 +90,6 @@ impl NewConversation { if last_options != new_options { return Ok(false); } - // Compare system prompt. If the system prompt is not set in the new conversation, we check by first system prompt in the initial messages let last_system_prompt = reader.get_system_prompt()?; let new_system_prompt = match &self.system_prompt { @@ -105,6 +104,7 @@ impl NewConversation { }) }), }; + if last_system_prompt.as_deref() != new_system_prompt { return Ok(false); } diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs index 1f8cddd2..35d94193 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs @@ -18,6 +18,7 @@ pub use session::ChatSession; pub use super::defaults::*; pub use super::server::{CompletionResponse, ModelServer, ServerManager}; +pub use super::tui::{ColorScheme, TextSegment}; // gets PERSONAS from the generated code include!(concat!(env!("OUT_DIR"), "/llm/prompt/templates.rs")); diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs index 51e67104..ccfe040c 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs @@ -107,7 +107,8 @@ impl ChatCompletionOptions { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AssistantOptions { - pub name: String, // name of assistant used + pub name: String, // name of assistant used + pub preloaded_messages: usize, // number of messages loaded by the assistant, does not include the first system message #[serde(skip_serializing_if = "Option::is_none")] pub prompt_template: Option, } diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs index e35900b9..476dc0a6 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs @@ -6,7 +6,8 @@ use tokio::sync::{mpsc, oneshot, Mutex}; use super::db::{ConversationDatabaseStore, ConversationId}; use super::{ - CompletionResponse, ModelServer, PromptInstruction, ServerManager, + ColorScheme, CompletionResponse, ModelServer, PromptInstruction, + ServerManager, TextSegment, }; use crate::api::error::ApplicationError; @@ -190,4 +191,11 @@ impl ChatSession { } Ok(()) } + + pub fn export_conversation( + &self, + color_scheme: &ColorScheme, + ) -> Vec { + self.prompt_instruction.export_conversation(color_scheme) + } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/session.rs b/lumni/src/apps/builtin/llm/prompt/src/session.rs index 21585b97..e83eb8de 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/session.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/session.rs @@ -58,12 +58,18 @@ pub struct TabSession<'a> { impl TabSession<'_> { fn new(chat: ChatSession) -> Self { - let mut tab_ui = TabUi::new(); + let color_scheme = ColorScheme::new(ColorSchemeType::Default); + let conversation_text = { + let export = chat.export_conversation(&color_scheme); + (!export.is_empty()).then(|| export) + }; + + let mut tab_ui = TabUi::new(conversation_text); tab_ui.init(); TabSession { ui: tab_ui, chat, - color_scheme: ColorScheme::new(ColorSchemeType::Default), + color_scheme, } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/components/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/components/mod.rs index 4100c9ab..9582ce6c 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/components/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/components/mod.rs @@ -8,6 +8,7 @@ mod text_wrapper; mod window_config; pub use cursor::MoveCursor; +pub use piece_table::TextSegment; pub use scroller::Scroller; pub use text_buffer::{LineType, TextBuffer}; pub use text_window::{TextWindow, TextWindowTrait}; diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/components/piece_table.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/components/piece_table.rs index bb90d6dc..7b351ec8 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/components/piece_table.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/components/piece_table.rs @@ -1,5 +1,11 @@ use ratatui::style::{Color, Style}; +#[derive(Clone, Debug, PartialEq)] +pub struct StyledText { + pub content: String, + pub style: Option