diff --git a/lumni/src/apps/api/error.rs b/lumni/src/apps/api/error.rs index e63ef1c9..8f2d0bf8 100644 --- a/lumni/src/apps/api/error.rs +++ b/lumni/src/apps/api/error.rs @@ -133,7 +133,12 @@ impl std::error::Error for ApplicationError {} impl From for ApplicationError { fn from(error: HttpClientError) -> Self { - ApplicationError::HttpClientError(error) + match error { + HttpClientError::ConnectionError(e) => { + ApplicationError::NotReady(e.to_string()) + } + _ => ApplicationError::HttpClientError(error), + } } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/app.rs b/lumni/src/apps/builtin/llm/prompt/src/app.rs index 5b8d7534..7ae1e12d 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/app.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/app.rs @@ -24,13 +24,15 @@ use tokio::signal; use tokio::sync::{mpsc, Mutex}; use tokio::time::{interval, timeout, Duration}; -use super::chat::{ChatSession, ConversationDatabaseStore}; +use super::chat::{ChatSession, ConversationDatabaseStore, NewConversation}; use super::server::{ - ModelServer, PromptInstruction, ServerManager, ServerTrait, + ModelServer, ModelServerName, PromptInstruction, ServerManager, ServerTrait, }; use super::session::{AppSession, TabSession}; use super::tui::{ - ColorScheme, CommandLineAction, ConversationReader, KeyEventHandler, PromptAction, TabUi, TextWindowTrait, WindowEvent, WindowKind + ColorScheme, CommandLineAction, ConversationEvent, ConversationReader, + KeyEventHandler, PromptAction, TabUi, TextWindowTrait, WindowEvent, + WindowKind, }; pub use crate::external as lumni; @@ -56,11 +58,12 @@ async fn prompt_app( let mut redraw_ui = true; // TODO: reader should be updated when conversation_id changes - let mut reader: Option = if let Some(conversation_id) = tab.chat.get_conversation_id() { - Some(db_conn.get_conversation_reader(conversation_id)) - } else { - None - }; + let mut reader: Option = + if let Some(conversation_id) = tab.chat.get_conversation_id() { + Some(db_conn.get_conversation_reader(conversation_id)) + } else { + None + }; // Buffer to store the trimmed trailing newlines or empty spaces let mut trim_buffer: Option = None; @@ -135,11 +138,24 @@ async fn prompt_app( } Some(WindowEvent::PromptWindow(ref event)) => { match event { - None => {}, - Some(converation_event) => { - // TODO: if conversation_id changes, update reader - eprintln!("Conversation event: {:?}", converation_event); + Some(ConversationEvent::NewConversation(new_conversation)) => { + let prompt_instruction = PromptInstruction::new( + new_conversation.clone(), + &db_conn, + )?; + let chat_session = ChatSession::new( + &new_conversation.server.to_string(), + prompt_instruction, + &db_conn, + ).await?; + tab.new_conversation(chat_session); + reader = if let Some(conversation_id) = tab.chat.get_conversation_id() { + Some(db_conn.get_conversation_reader(conversation_id)) + } else { + None + }; } + _ => {}, } } _ => {} @@ -367,7 +383,13 @@ pub async fn run_cli( // optional arguments let instruction = matches.get_one::("system").cloned(); let assistant = matches.get_one::("assistant").cloned(); - let options = matches.get_one::("options"); + let options = match matches.get_one::("options") { + Some(s) => { + let value = serde_json::from_str::(s)?; + Some(value) + } + None => None, + }; let server_name = matches .get_one::("server") @@ -375,21 +397,29 @@ pub async fn run_cli( .unwrap_or_else(|| "ollama".to_lowercase()); // create new (un-initialized) server from requested server name - let mut server = ModelServer::from_str(&server_name)?; - let default_model = server.get_default_model().await; - - // setup prompt, server and chat session - let prompt_instruction = - PromptInstruction::new(default_model, instruction, assistant, options, &db_conn)?; - let conversation_id = prompt_instruction.get_conversation_id(); + let server = ModelServer::from_str(&server_name)?; + let default_model = match server.get_default_model().await { + Ok(model) => Some(model), + Err(e) => { + log::error!("Failed to get default model during startup: {}", e); + None + } + }; - if let Some(conversation_id) = conversation_id { - let reader = db_conn.get_conversation_reader(conversation_id); - server.setup_and_initialize(&reader).await?; - } + let prompt_instruction = PromptInstruction::new( + NewConversation { + server: ModelServerName::from_str(&server_name), + model: default_model, + options, + system_prompt: instruction, + assistant_name: assistant, + parent: None, + }, + &db_conn, + )?; let chat_session = - ChatSession::new(Box::new(server), prompt_instruction).await?; + ChatSession::new(&server_name, prompt_instruction, &db_conn).await?; match poll(Duration::from_millis(0)) { Ok(_) => { diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs index fec8139e..ea7edfb7 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; -use super::{ModelIdentifier, ModelSpec, PromptRole, - ConversationId, Message, - MessageId, AttachmentId, Attachment +use super::{ + Attachment, AttachmentId, ConversationId, Message, MessageId, + ModelIdentifier, ModelSpec, PromptRole, }; #[derive(Debug)] diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs index 830bb4b0..da1c3a02 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs @@ -1,17 +1,26 @@ use serde::{Deserialize, Serialize}; -mod model; mod cache; +mod model; -pub use model::{ModelIdentifier, ModelSpec}; pub use cache::ConversationCache; +pub use model::{ModelIdentifier, ModelSpec}; use super::PromptRole; - #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct ModelServerName(pub String); +impl ModelServerName { + pub fn from_str>(s: T) -> Self { + ModelServerName(s.as_ref().to_string()) + } + + pub fn to_string(&self) -> String { + self.0.clone() + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct ConversationId(pub i64); @@ -21,7 +30,6 @@ pub struct MessageId(pub i64); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct AttachmentId(pub i64); - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Conversation { pub id: ConversationId, @@ -67,4 +75,4 @@ pub struct Attachment { pub metadata: Option, pub created_at: i64, pub is_deleted: bool, -} \ No newline at end of file +} diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/model.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/model.rs index 65548d6e..4df63b2d 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/model.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/model.rs @@ -1,15 +1,14 @@ +use lazy_static::lazy_static; use lumni::api::error::ApplicationError; +use regex::Regex; use serde::{Deserialize, Serialize}; pub use crate::external as lumni; -use lazy_static::lazy_static; -use regex::Regex; - lazy_static! { - static ref IDENTIFIER_REGEX: Regex = Regex::new( - r"^[-a-z0-9_]+::[-a-z0-9_][-a-z0-9_:.]*[-a-z0-9_]+$" - ).unwrap(); + static ref IDENTIFIER_REGEX: Regex = + Regex::new(r"^[-a-z0-9_]+::[-a-z0-9_][-a-z0-9_:.]*[-a-z0-9_]+$") + .unwrap(); } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -21,7 +20,10 @@ impl ModelIdentifier { Ok(ModelIdentifier(identifier_str.to_string())) } else { Err(ApplicationError::InvalidInput(format!( - "Identifier must be in the format 'provider::model_name', where the provider contains only lowercase letters, numbers, hyphens, underscores, and the model name can include internal colons but not start or end with them. Got: '{}'", + "Identifier must be in the format 'provider::model_name', \ + where the provider contains only lowercase letters, numbers, \ + hyphens, underscores, and the model name can include \ + internal colons but not start or end with them. Got: '{}'", identifier_str ))) } @@ -48,12 +50,14 @@ pub struct ModelSpec { } impl ModelSpec { - pub fn new_with_validation(identifier_str: &str) -> Result { + pub fn new_with_validation( + identifier_str: &str, + ) -> Result { let identifier = ModelIdentifier::new(identifier_str)?; Ok(ModelSpec { identifier, info: None, - config: None, + config: None, context_window_size: None, input_token_limit: None, }) @@ -107,7 +111,11 @@ impl ModelSpec { self } - pub fn set_config_value(&mut self, key: &str, value: serde_json::Value) -> &mut Self { + pub fn set_config_value( + &mut self, + key: &str, + value: serde_json::Value, + ) -> &mut Self { if let Some(config) = self.config.as_mut() { if let serde_json::Value::Object(map) = config { map.insert(key.to_string(), value); @@ -142,20 +150,25 @@ impl ModelSpec { } pub fn set_family(&mut self, family: &str) -> &mut Self { - self.set_config_value("family", serde_json::Value::String(family.to_string())) + self.set_config_value( + "family", + serde_json::Value::String(family.to_string()), + ) } pub fn get_family(&self) -> Option<&str> { - self.get_config_value("family") - .and_then(|v| v.as_str()) + self.get_config_value("family").and_then(|v| v.as_str()) } pub fn set_description(&mut self, description: &str) -> &mut Self { - self.set_config_value("description", serde_json::Value::String(description.to_string())) + self.set_config_value( + "description", + serde_json::Value::String(description.to_string()), + ) } pub fn get_description(&self) -> Option<&str> { self.get_config_value("description") .and_then(|v| v.as_str()) } -} \ No newline at end of file +} diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs index 930f664d..d26611f5 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs @@ -4,5 +4,6 @@ mod reader; mod store; pub use reader::ConversationReader; -pub use super::conversation; pub use store::ConversationDatabaseStore; + +pub use super::conversation; diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs index 5bfd0147..ed5532ce 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex}; use rusqlite::{params, Error as SqliteError, OptionalExtension}; use super::connector::DatabaseConnector; -use super::conversation::{ConversationId, ModelIdentifier}; +use super::conversation::{ConversationId, MessageId, ModelIdentifier}; pub struct ConversationReader<'a> { conversation_id: ConversationId, @@ -23,6 +23,10 @@ impl<'a> ConversationReader<'a> { } impl<'a> ConversationReader<'a> { + pub fn get_conversation_id(&self) -> ConversationId { + self.conversation_id + } + pub fn get_model_identifier(&self) -> Result { let query = " SELECT m.identifier @@ -89,9 +93,25 @@ impl<'a> ConversationReader<'a> { }) } - pub fn get_conversation_stats( + pub fn get_last_message_id( &self, - ) -> Result<(i64, i64), SqliteError> { + ) -> Result, SqliteError> { + let query = " + SELECT MAX(id) as last_message_id + FROM messages + WHERE conversation_id = ? AND is_deleted = FALSE + "; + + let mut db = self.db.lock().unwrap(); + db.process_queue_with_result(|tx| { + tx.query_row(query, params![self.conversation_id.0], |row| { + row.get::<_, Option>(0) + .map(|opt_id| opt_id.map(MessageId)) + }) + }) + } + + pub fn get_conversation_stats(&self) -> Result<(i64, i64), SqliteError> { let query = "SELECT message_count, total_tokens FROM conversations \ WHERE id = ?"; let mut db = self.db.lock().unwrap(); diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs index 6281a70c..6d4f3734 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs @@ -4,11 +4,11 @@ use std::sync::{Arc, Mutex}; use rusqlite::{params, Error as SqliteError, OptionalExtension}; use super::connector::DatabaseConnector; -use super::reader::ConversationReader; use super::conversation::{ Attachment, AttachmentData, AttachmentId, Conversation, ConversationId, - Message, MessageId, ModelSpec, ModelIdentifier, ModelServerName, + Message, MessageId, ModelIdentifier, ModelServerName, ModelSpec, }; +use super::reader::ConversationReader; pub struct ConversationDatabaseStore { db: Arc>, @@ -40,20 +40,32 @@ impl ConversationDatabaseStore { let mut db = self.db.lock().unwrap(); db.process_queue_with_result(|tx| { // Ensure the model exists - let exists: bool = tx.query_row( - "SELECT 1 FROM models WHERE identifier = ?", - params![model.identifier.0], - |_| Ok(true), - ).optional()?.unwrap_or(false); + let exists: bool = tx + .query_row( + "SELECT 1 FROM models WHERE identifier = ?", + params![model.identifier.0], + |_| Ok(true), + ) + .optional()? + .unwrap_or(false); if !exists { tx.execute( - "INSERT INTO models (identifier, info, config, context_window_size, input_token_limit) + "INSERT INTO models (identifier, info, config, \ + context_window_size, input_token_limit) VALUES (?, ?, ?, ?, ?)", params![ model.identifier.0, - model.info.as_ref().map(|v| serde_json::to_string(v).unwrap_or_default()), - model.config.as_ref().map(|v| serde_json::to_string(v).unwrap_or_default()), + model + .info + .as_ref() + .map(|v| serde_json::to_string(v) + .unwrap_or_default()), + model + .config + .as_ref() + .map(|v| serde_json::to_string(v) + .unwrap_or_default()), model.context_window_size, model.input_token_limit, ], @@ -77,19 +89,25 @@ impl ConversationDatabaseStore { tx.execute( "INSERT INTO conversations ( - name, info, model_identifier, model_server, parent_conversation_id, - fork_message_id, completion_options, created_at, updated_at, + name, info, model_identifier, model_server, \ + parent_conversation_id, + fork_message_id, completion_options, created_at, \ + updated_at, message_count, total_tokens, is_deleted ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 0, 0, ?)", params![ conversation.name, - serde_json::to_string(&conversation.info).unwrap_or_default(), + serde_json::to_string(&conversation.info) + .unwrap_or_default(), conversation.model_identifier.0, conversation.model_server.0, conversation.parent_conversation_id.map(|id| id.0), conversation.fork_message_id.map(|id| id.0), - conversation.completion_options.as_ref().map(|v| serde_json::to_string(v).unwrap_or_default()), + conversation + .completion_options + .as_ref() + .map(|v| serde_json::to_string(v).unwrap_or_default()), conversation.created_at, conversation.updated_at, conversation.is_deleted, @@ -111,7 +129,8 @@ impl ConversationDatabaseStore { // Get the last message ID for this conversation let last_message_id: Option = tx .query_row( - "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id DESC LIMIT 1", + "SELECT id FROM messages WHERE conversation_id = ? ORDER \ + BY id DESC LIMIT 1", params![message.conversation_id.0], |row| row.get(0), ) @@ -173,7 +192,8 @@ impl ConversationDatabaseStore { let mut new_message_ids = Vec::with_capacity(messages.len()); let mut last_message_id: Option = tx .query_row( - "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id DESC LIMIT 1", + "SELECT id FROM messages WHERE conversation_id = ? ORDER \ + BY id DESC LIMIT 1", params![conversation_id], |row| row.get(0), ) @@ -260,12 +280,14 @@ impl ConversationDatabaseStore { let conversation = Conversation { id: ConversationId(row.get(0)?), name: row.get(1)?, - info: serde_json::from_str(&row.get::<_, String>(2)?).unwrap_or_default(), + info: serde_json::from_str(&row.get::<_, String>(2)?) + .unwrap_or_default(), model_identifier: ModelIdentifier(row.get(3)?), model_server: ModelServerName(row.get(4)?), parent_conversation_id: row.get(5).map(ConversationId).ok(), fork_message_id: row.get(6).map(MessageId).ok(), - completion_options: row.get::<_, Option>(7)? + completion_options: row + .get::<_, Option>(7)? .map(|s| serde_json::from_str(&s).unwrap_or_default()), created_at: row.get(8)?, updated_at: row.get(9)?, @@ -331,12 +353,18 @@ impl ConversationDatabaseStore { Ok(Conversation { id: ConversationId(row.get(0)?), name: row.get(1)?, - info: serde_json::from_str(&row.get::<_, String>(2)?).unwrap_or_default(), + info: serde_json::from_str(&row.get::<_, String>(2)?) + .unwrap_or_default(), model_identifier: ModelIdentifier(row.get(3)?), model_server: ModelServerName(row.get(4)?), - parent_conversation_id: row.get::<_, Option>(5)?.map(ConversationId), - fork_message_id: row.get::<_, Option>(6)?.map(MessageId), - completion_options: row.get::<_, Option>(7)? + parent_conversation_id: row + .get::<_, Option>(5)? + .map(ConversationId), + fork_message_id: row + .get::<_, Option>(6)? + .map(MessageId), + completion_options: row + .get::<_, Option>(7)? .map(|s| serde_json::from_str(&s).unwrap_or_default()), created_at: row.get(8)?, updated_at: row.get(9)?, diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs index 7bdc91c5..1dc794e6 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs @@ -1,14 +1,81 @@ use lumni::api::error::ApplicationError; -use super::db::ConversationDatabaseStore; use super::conversation::{ - ConversationCache, ConversationId, Message, MessageId, - ModelServerName, ModelSpec + ConversationCache, ConversationId, Message, MessageId, ModelServerName, + ModelSpec, }; +use super::db::{ConversationDatabaseStore, ConversationReader}; use super::prompt::Prompt; -use super::{ChatCompletionOptions, ChatMessage, PromptRole, PERSONAS}; +use super::{ + ChatCompletionOptions, ChatMessage, PromptRole, ServerManager, PERSONAS, +}; pub use crate::external as lumni; +#[derive(Debug, Clone)] +pub struct ParentConversation { + pub id: ConversationId, + pub fork_message_id: MessageId, +} + +#[derive(Debug, Clone)] +pub struct NewConversation { + pub server: ModelServerName, + pub model: Option, + pub options: Option, + pub system_prompt: Option, // system_prompt ignored if parent is provided + pub assistant_name: Option, // assistant_name ignored if parent is provided + pub parent: Option, // forked conversation +} + +impl NewConversation { + pub fn new( + new_server: Box, + new_model: ModelSpec, + conversation_reader: Option<&ConversationReader<'_>>, + ) -> Result { + if let Some(reader) = conversation_reader { + // fork from an existing conversation + let current_conversation_id = reader.get_conversation_id(); + let current_completion_options = reader.get_completion_options()?; + + if let Some(last_message_id) = reader.get_last_message_id()? { + Ok(NewConversation { + server: new_server.server_name(), + model: Some(new_model), + options: Some(current_completion_options), + system_prompt: None, // ignored when forking + assistant_name: None, // ignored when forking + parent: Some(ParentConversation { + id: current_conversation_id, + fork_message_id: last_message_id, + }), + }) + } else { + // start a new conversation, as there is no last message is there is nothing to fork from. + // Both system_prompt and assistant_name are set to None, because if no messages exist, these were also None in the (empty) parent conversation + Ok(NewConversation { + server: new_server.server_name(), + model: Some(new_model), + options: Some(current_completion_options), + system_prompt: None, + assistant_name: None, + parent: None, + }) + } + } else { + // start a new conversation + Ok(NewConversation { + server: new_server.server_name(), + model: Some(new_model), + options: None, + system_prompt: None, + assistant_name: None, + parent: None, + }) + } + } +} + pub struct PromptInstruction { cache: ConversationCache, model: Option, @@ -18,46 +85,34 @@ pub struct PromptInstruction { impl PromptInstruction { pub fn new( - model: Option, - instruction: Option, - assistant: Option, - options: Option<&String>, + new_conversation: NewConversation, db_conn: &ConversationDatabaseStore, ) -> Result { - // If both instruction and assistant are None, use the default assistant - let assistant = if instruction.is_none() && assistant.is_none() { - Some("Default".to_string()) - } else { - assistant - }; - let completion_options = match options { + let completion_options = match new_conversation.options { Some(opts) => { let mut options = ChatCompletionOptions::default(); - options.update_from_json(opts)?; + options.update(opts)?; serde_json::to_value(options)? } None => serde_json::to_value(ChatCompletionOptions::default())?, }; - let conversation_id = if let Some(ref model) = &model { - // Create a new Conversation in the database - Some({ - db_conn.new_conversation( - "New Conversation", - None, // parent_id, None for new conversation - None, // fork_message_id, None for new conversation - Some(completion_options), // completion_options - model, - ModelServerName("ollama".to_string()), - )? - }) + let conversation_id = if let Some(ref model) = new_conversation.model { + Some(db_conn.new_conversation( + "New Conversation", + new_conversation.parent.as_ref().map(|p| p.id), + new_conversation.parent.as_ref().map(|p| p.fork_message_id), + Some(completion_options), + model, + new_conversation.server, + )?) } else { None }; let mut prompt_instruction = PromptInstruction { cache: ConversationCache::new(), - model, + model: new_conversation.model, prompt_template: None, conversation_id, }; @@ -68,24 +123,60 @@ impl PromptInstruction { .set_conversation_id(conversation_id); } - if let Some(assistant) = assistant { - prompt_instruction.preload_from_assistant( - assistant, - instruction, // add user-instruction with assistant - db_conn, - )?; - } else if let Some(instruction) = instruction { - prompt_instruction.add_system_message(instruction, db_conn)?; + if new_conversation.parent.is_none() { + // evaluate system_prompt and assistant_name only if parent is not provided + match ( + new_conversation.system_prompt, + new_conversation.assistant_name, + ) { + (Some(system_prompt), Some(assistant_name)) => { + prompt_instruction.preload_from_assistant( + assistant_name, + Some(system_prompt), + db_conn, + )?; + } + (Some(system_prompt), None) => { + prompt_instruction + .add_system_message(system_prompt, db_conn)?; + } + (None, Some(assistant_name)) => { + prompt_instruction.preload_from_assistant( + assistant_name, + None, + db_conn, + )?; + } + (None, None) => { + // TODO: default should only apply to servers that do no handle this + // internally + // Use default assistant when both system_promt and assistant_name are None + prompt_instruction.preload_from_assistant( + "Default".to_string(), + None, + db_conn, + )?; + } + } } - Ok(prompt_instruction) } + + pub fn new_from_import( + &mut self, + reader: &ConversationReader<'_>, + ) -> Result<(), ApplicationError> { + Ok(()) + } pub fn get_model(&self) -> Option<&ModelSpec> { self.model.as_ref() } pub fn get_conversation_id(&self) -> Option { + // return the conversation_id from an active conversation + // use the ConversationId from this struct, and not the cache as + // the latter can be from a non-active conversation self.conversation_id } @@ -112,42 +203,6 @@ impl PromptInstruction { Ok(()) } - pub fn import_conversation( - &mut self, - id: &str, - db_conn: &ConversationDatabaseStore, - ) -> Result<(), ApplicationError> { - let conversation_id = ConversationId(id.parse().map_err(|_| { - ApplicationError::NotFound(format!( - "Conversation {id} not found in database" - )) - })?); - - let (conversation, messages) = db_conn - .fetch_conversation(Some(conversation_id), None)? - .ok_or_else(|| { - ApplicationError::NotFound("Conversation not found".to_string()) - })?; - - // Clear the existing ConversationCache - self.cache = ConversationCache::new(); - - // Set the conversation ID - self.cache.set_conversation_id(conversation.id); - - // Add messages to the cache - for message in messages { - // Fetch and add attachments for each message - let attachments = db_conn.fetch_message_attachments(message.id)?; - for attachment in attachments { - self.cache.add_attachment(attachment); - } - self.cache.add_message(message); - } - - Ok(()) - } - pub fn reset_history( &mut self, db_conn: &ConversationDatabaseStore, @@ -155,15 +210,14 @@ impl PromptInstruction { // reset by creating a new conversation // TODO: clone previous conversation settings if let Some(ref model) = &self.model { - let current_conversation_id = - db_conn.new_conversation( - "New Conversation", - None, - None, - None, - model, - ModelServerName("ollama".to_string()), - )?; + let current_conversation_id = db_conn.new_conversation( + "New Conversation", + None, + None, + None, + model, + ModelServerName("ollama".to_string()), + )?; self.cache.set_conversation_id(current_conversation_id); }; Ok(()) diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs index 64ef7a11..7df851b6 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs @@ -1,25 +1,24 @@ use std::error::Error; +pub mod conversation; mod db; mod instruction; mod options; mod prompt; -mod send; mod prompt_role; -pub mod conversation; +mod send; mod session; +pub use conversation::{ConversationId, ModelServerName, ModelSpec}; pub use db::{ConversationDatabaseStore, ConversationReader}; -pub use conversation::{ConversationId, ModelIdentifier, ModelSpec}; -pub use instruction::PromptInstruction; +pub use instruction::{NewConversation, PromptInstruction}; pub use options::ChatCompletionOptions; use prompt::Prompt; +pub use prompt_role::PromptRole; pub use send::{http_get_with_response, http_post, http_post_with_response}; pub use session::ChatSession; -pub use prompt_role::PromptRole; pub use super::defaults::*; -pub use super::server::{CompletionResponse, ServerManager}; -pub use super::tui::{WindowEvent, ConversationEvent}; +pub use super::server::{CompletionResponse, ModelServer, ServerManager}; // gets PERSONAS from the generated code include!(concat!(env!("OUT_DIR"), "/llm/prompt/templates.rs")); diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs index 40608193..8a8d294a 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/options.rs @@ -38,11 +38,12 @@ impl Default for ChatCompletionOptions { } impl ChatCompletionOptions { - pub fn update_from_json( + pub fn update( &mut self, - json: &str, + value: serde_json::Value, ) -> Result<(), serde_json::Error> { - let user_options = serde_json::from_str::(json)?; + let user_options = + serde_json::from_value::(value)?; self.temperature = user_options.temperature.or(self.temperature); self.top_k = user_options.top_k.or(self.top_k); self.top_p = user_options.top_p.or(self.top_p); @@ -54,12 +55,6 @@ impl ChatCompletionOptions { Ok(()) } -// pub fn update_from_model(&mut self, model: &LLMDefinition) { -// if self.stop.is_none() { -// self.stop = Some(model.get_stop_tokens().clone()); -// } -// } - pub fn set_temperature(mut self, temperature: f64) -> Self { self.temperature = Some(temperature); self diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/prompt_role.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/prompt_role.rs index 1be3ee8d..a6c9d40a 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/prompt_role.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/prompt_role.rs @@ -3,7 +3,6 @@ use std::fmt::Display; use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef}; use serde::{Deserialize, Serialize}; - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub enum PromptRole { User, @@ -37,4 +36,4 @@ impl FromSql for PromptRole { _ => Err(FromSqlError::InvalidType.into()), } } -} \ No newline at end of file +} diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs index 551fd811..d482cd3b 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/session.rs @@ -5,9 +5,8 @@ use bytes::Bytes; use tokio::sync::{mpsc, oneshot, Mutex}; use super::{ - CompletionResponse, ConversationDatabaseStore, ConversationReader, - PromptInstruction, ServerManager, WindowEvent, ConversationEvent, - ConversationId, + CompletionResponse, ConversationDatabaseStore, ConversationId, ModelServer, + PromptInstruction, ServerManager, }; use crate::api::error::ApplicationError; @@ -19,9 +18,18 @@ pub struct ChatSession { impl ChatSession { pub async fn new( - server: Box, + server_name: &str, prompt_instruction: PromptInstruction, + db_conn: &ConversationDatabaseStore, ) -> Result { + let mut server = Box::new(ModelServer::from_str(&server_name)?); + + if let Some(conversation_id) = prompt_instruction.get_conversation_id() + { + let reader = db_conn.get_conversation_reader(conversation_id); + server.setup_and_initialize(&reader).await?; + } + Ok(ChatSession { server, prompt_instruction, @@ -29,40 +37,21 @@ impl ChatSession { }) } - pub fn server_name(&self) -> &str { - self.server.server_name() + pub fn new_prompt_instruction( + &mut self, + prompt_instruction: PromptInstruction, + ) { + // stop any ongoing session + self.stop(); + self.prompt_instruction = prompt_instruction; } - pub fn get_conversation_id(&self) -> Option { - self.prompt_instruction.get_conversation_id() + pub fn server_name(&self) -> String { + self.server.server_name().to_string() } - pub async fn change_server( - &mut self, - mut server: Box, - reader: Option<&ConversationReader<'_>>, - ) -> Result, ApplicationError> { - log::debug!("switching server: {}", server.server_name()); - self.stop(); - - // TODO: update prompt instruction with new server / conversation - //let model = server.get_default_model().await; - if let Some(reader) = reader { - //self.prompt_instruction.set_conversation_id(reader.get_conversation_id()); - server.setup_and_initialize(reader).await?; - } - self.server = server; - // TODO: - // add new events to handle server / conversation change - // if conversation_id changes, return new conversation_id as - // well to create a new ConversationReader - if let Some(new_conversation_id) = self.get_conversation_id() { - return Ok(Some(WindowEvent::PromptWindow(Some(ConversationEvent::New( - new_conversation_id, - ))))); - } else { - return Ok(Some(WindowEvent::PromptWindow(None))); - } + pub fn get_conversation_id(&self) -> Option { + self.prompt_instruction.get_conversation_id() } pub fn stop(&mut self) { @@ -103,11 +92,14 @@ impl ChatSession { tx: mpsc::Sender, question: &str, ) -> Result<(), ApplicationError> { - let model = if let Some(model) = self.prompt_instruction.get_model().cloned() { - model - } else { - return Err(ApplicationError::NotReady("Model not available".to_string())); - }; + let model = + if let Some(model) = self.prompt_instruction.get_model().cloned() { + model + } else { + return Err(ApplicationError::NotReady( + "Model not available".to_string(), + )); + }; let max_token_length = self.server.get_max_context_size().await?; let user_question = self.initiate_new_exchange(question).await?; diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs index 1514f6f8..564c9932 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs @@ -19,8 +19,8 @@ use url::Url; use super::{ http_post, ChatMessage, CompletionResponse, CompletionStats, - ConversationReader, Endpoints, ModelSpec, - PromptRole, ServerSpecTrait, ServerTrait, + ConversationReader, Endpoints, ModelSpec, PromptRole, ServerSpecTrait, + ServerTrait, }; pub use crate::external as lumni; @@ -202,7 +202,11 @@ impl ServerTrait for Bedrock { cancel_rx: Option>, ) -> Result<(), ApplicationError> { let resource = HttpClient::percent_encode_with_exclusion( - &format!("/model/{}.{}/converse-stream", model.get_model_provider(), model.get_model_name()), + &format!( + "/model/{}.{}/converse-stream", + model.get_model_provider(), + model.get_model_name() + ), Some(&[b'/', b'.', b'-']), ); let completion_endpoint = self.endpoints.get_completion_endpoint()?; @@ -248,12 +252,10 @@ impl ServerTrait for Bedrock { Ok(()) } - async fn list_models( - &self, - ) -> Result, ApplicationError> { - Ok(vec![ - ModelSpec::new_with_validation("anthropic::claude-3-5-sonnet-20240620-v1:0")?, - ]) + async fn list_models(&self) -> Result, ApplicationError> { + Ok(vec![ModelSpec::new_with_validation( + "anthropic::claude-3-5-sonnet-20240620-v1:0", + )?]) } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/llama/formatters/llama3.rs b/lumni/src/apps/builtin/llm/prompt/src/server/llama/formatters/llama3.rs index 254096c2..03782cf5 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/llama/formatters/llama3.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/llama/formatters/llama3.rs @@ -53,8 +53,7 @@ impl ModelFormatterTrait for Llama3 { let mut prompt_message = String::new(); prompt_message.push_str(&format!( "<|start_header_id|>{}<|end_header_id|>\n{}", - role_handle, - message + role_handle, message )); if !message.is_empty() { prompt_message.push_str("<|eot_id|>\n"); diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/llama/formatters/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/llama/formatters/mod.rs index eb5234f9..85488f56 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/llama/formatters/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/llama/formatters/mod.rs @@ -2,10 +2,10 @@ mod generic; mod llama3; use async_trait::async_trait; -use regex::Regex; - use generic::Generic; use llama3::Llama3; +use regex::Regex; + pub use super::PromptRole; #[derive(Clone, Debug)] diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs index 04fe1214..73551750 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs @@ -4,18 +4,18 @@ use std::error::Error; use async_trait::async_trait; use bytes::Bytes; +use formatters::{ModelFormatter, ModelFormatterTrait}; use lumni::api::error::ApplicationError; use serde::{Deserialize, Serialize}; use tokio::sync::{mpsc, oneshot}; use url::Url; +pub use super::PromptRole; use super::{ http_get_with_response, http_post, ChatMessage, CompletionResponse, CompletionStats, ConversationReader, Endpoints, HttpClient, ModelSpec, ServerSpecTrait, ServerTrait, DEFAULT_CONTEXT_SIZE, }; -use formatters::{ModelFormatter, ModelFormatterTrait}; -pub use super::PromptRole; use crate::external as lumni; pub const DEFAULT_COMPLETION_ENDPOINT: &str = @@ -111,7 +111,7 @@ impl Llama { .map_err(|e| { ApplicationError::ServerConfigurationError(e.to_string()) })?), - Err(e) => Err(ApplicationError::NotReady(e.to_string())), + Err(e) => Err(e), // propagate the error } } } @@ -170,17 +170,16 @@ impl ServerTrait for Llama { Ok(()) } - async fn list_models( - &self, - ) -> Result, ApplicationError> { + async fn list_models(&self) -> Result, ApplicationError> { let settings = self.get_props().await?; let model_file = settings.default_generation_settings.model; let file_name = model_file.split('/').last().unwrap(); let model_name = file_name.split('.').next().unwrap().to_lowercase(); - Ok(vec![ - ModelSpec::new_with_validation(&format!("unknown::{}", model_name))?, - ]) + Ok(vec![ModelSpec::new_with_validation(&format!( + "unknown::{}", + model_name + ))?]) } async fn initialize_with_model( diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs index c7e17c08..c8d21169 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs @@ -24,10 +24,10 @@ use tokio::sync::{mpsc, oneshot}; pub use super::chat::{ http_get_with_response, http_post, http_post_with_response, ChatMessage, - ConversationReader, PromptInstruction, + ConversationReader, ModelServerName, ModelSpec, PromptInstruction, + PromptRole, }; pub use super::defaults::*; -pub use super::chat::{ModelIdentifier, ModelSpec, PromptRole}; use crate::external as lumni; pub const SUPPORTED_MODEL_ENDPOINTS: [&str; 4] = @@ -166,9 +166,7 @@ impl ServerTrait for ModelServer { } } - async fn list_models( - &self, - ) -> Result, ApplicationError> { + async fn list_models(&self) -> Result, ApplicationError> { match self { ModelServer::Llama(llama) => llama.list_models().await, ModelServer::Ollama(ollama) => ollama.list_models().await, @@ -195,24 +193,20 @@ pub trait ServerTrait: Send + Sync { cancel_rx: Option>, ) -> Result<(), ApplicationError>; - async fn list_models(&self) - -> Result, ApplicationError>; + async fn list_models(&self) -> Result, ApplicationError>; - async fn get_default_model(&self) -> Option { + async fn get_default_model(&self) -> Result { match self.list_models().await { Ok(models) => { if models.is_empty() { - log::warn!("Received empty model list"); - None + Err(ApplicationError::ServerConfigurationError( + "No models available".to_string(), + )) } else { - log::debug!("Available models: {:?}", models); - Some(models[0].to_owned()) + Ok(models[0].to_owned()) } } - Err(e) => { - log::error!("Failed to list models: {}", e); - None - } + Err(e) => Err(e), // propagate error } } @@ -241,13 +235,10 @@ pub trait ServerManager: ServerTrait { &mut self, reader: &ConversationReader, ) -> Result<(), ApplicationError> { - // update completion options from the model, i.e. stop tokens - // TODO: prompt_intruction should be re-initialized with the model - //prompt_instruction.set_model(&model); self.initialize_with_model(reader).await } - fn server_name(&self) -> &str { - self.get_spec().name() + fn server_name(&self) -> ModelServerName { + ModelServerName::from_str(self.get_spec().name().to_lowercase()) } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs index 93807e10..db2bc32f 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs @@ -9,8 +9,8 @@ use url::Url; use super::{ http_get_with_response, http_post, http_post_with_response, ApplicationError, ChatMessage, CompletionResponse, CompletionStats, - ConversationReader, Endpoints, HttpClient, ModelSpec, - ServerSpecTrait, ServerTrait, + ConversationReader, Endpoints, HttpClient, ModelSpec, ServerSpecTrait, + ServerTrait, }; pub const DEFAULT_COMPLETION_ENDPOINT: &str = "http://localhost:11434/api/chat"; @@ -75,15 +75,13 @@ impl ServerTrait for Ollama { ) -> Result<(), ApplicationError> { let identifier = reader.get_model_identifier()?; let model_name = identifier.get_model_name().to_string(); - let payload = OllamaShowPayload { - name: &model_name, - } - .serialize() - .ok_or_else(|| { - ApplicationError::ServerConfigurationError( - "Failed to serialize show payload".to_string(), - ) - })?; + let payload = OllamaShowPayload { name: &model_name } + .serialize() + .ok_or_else(|| { + ApplicationError::ServerConfigurationError( + "Failed to serialize show payload".to_string(), + ) + })?; let response = http_post_with_response( DEFAULT_SHOW_ENDPOINT.to_string(), @@ -145,9 +143,7 @@ impl ServerTrait for Ollama { Ok(()) } - async fn list_models( - &self, - ) -> Result, ApplicationError> { + async fn list_models(&self) -> Result, ApplicationError> { let list_models_endpoint = self.endpoints.get_list_models_endpoint()?; let response = http_get_with_response( list_models_endpoint.to_string(), @@ -173,8 +169,10 @@ impl ServerTrait for Ollama { .models .into_iter() .map(|model| { - let model_identifier = format!("{}::{}", "unknown", model.name.to_lowercase()); - let mut model_spec = ModelSpec::new_with_validation(&model_identifier)?; + let model_identifier = + format!("{}::{}", "unknown", model.name.to_lowercase()); + let mut model_spec = + ModelSpec::new_with_validation(&model_identifier)?; model_spec .set_size(model.size) @@ -183,7 +181,7 @@ impl ServerTrait for Ollama { "Parameter Size: {}", model.details.parameter_size )); - + Ok(model_spec) }) .collect(); diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs index e63e8db2..25d89c32 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs @@ -114,7 +114,6 @@ impl ServerTrait for OpenAI { tx: Option>, cancel_rx: Option>, ) -> Result<(), ApplicationError> { - let completion_endpoint = self.endpoints.get_completion_endpoint()?; let data_payload = self .completion_api_payload(model, messages.clone()) @@ -144,11 +143,9 @@ impl ServerTrait for OpenAI { Ok(()) } - async fn list_models( - &self, - ) -> Result, ApplicationError> { - Ok(vec![ - ModelSpec::new_with_validation("openai::gpt-3.5-turbo")?, - ]) + async fn list_models(&self) -> Result, ApplicationError> { + Ok(vec![ModelSpec::new_with_validation( + "openai::gpt-3.5-turbo", + )?]) } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/session.rs b/lumni/src/apps/builtin/llm/prompt/src/session.rs index 40331dd1..21585b97 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/session.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/session.rs @@ -67,6 +67,10 @@ impl TabSession<'_> { } } + pub fn new_conversation(&mut self, chat: ChatSession) { + self.chat = chat; + } + pub fn draw_ui( &mut self, terminal: &mut Terminal, diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs index ea3af4e9..d0fa6958 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs @@ -33,7 +33,7 @@ pub fn draw_ui( // add borders to main_window[0] frame.render_widget( - main_widget(tab.chat.server_name(), window_hint()), + main_widget(&tab.chat.server_name().to_string(), window_hint()), main_window[0], ); diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/events/key_event.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/events/key_event.rs index 18422b7f..73de96be 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/events/key_event.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/events/key_event.rs @@ -225,7 +225,9 @@ impl KeyEventHandler { ApplicationError::NotReady(message) => { // pass as warning to the user log::debug!("Not ready: {:?}", message); - tab_ui.command_line.set_alert(&message); + tab_ui.command_line.set_alert( + &format!("Not Ready: {}", message), + ); return Ok(Some(WindowEvent::Modal( window_type, ))); diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/events/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/events/mod.rs index 95811281..c7e238f5 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/events/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/events/mod.rs @@ -13,7 +13,7 @@ use super::components::{LineType, MoveCursor, TextWindowTrait, WindowKind}; use super::modal::ModalWindowType; use super::ui::TabUi; use super::windows::PromptWindow; -use super::{ChatSession, ConversationReader, ConversationId}; +use super::{ChatSession, ConversationReader, NewConversation}; pub use crate::external as lumni; #[derive(Debug)] @@ -40,5 +40,5 @@ pub enum CommandLineAction { #[derive(Debug)] pub enum ConversationEvent { - New(ConversationId), + NewConversation(NewConversation), } diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/events/text_window_event.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/events/text_window_event.rs index 68f93e72..a2bc5cbc 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/events/text_window_event.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/events/text_window_event.rs @@ -98,9 +98,7 @@ where match window.get_kind() { WindowKind::ResponseWindow => Some(WindowEvent::ResponseWindow), WindowKind::PromptWindow => Some(WindowEvent::PromptWindow(None)), - WindowKind::CommandLine => { - Some(WindowEvent::CommandLine(None)) - } + WindowKind::CommandLine => Some(WindowEvent::CommandLine(None)), } } @@ -197,9 +195,9 @@ where } ':' => { // Switch to command line mode on ":" key press - return Some(WindowEvent::CommandLine(Some(CommandLineAction::Write( - ":".to_string(), - )))); + return Some(WindowEvent::CommandLine(Some( + CommandLineAction::Write(":".to_string()), + ))); } // ignore other characters _ => {} @@ -207,9 +205,7 @@ where match window.get_kind() { WindowKind::ResponseWindow => Some(WindowEvent::ResponseWindow), WindowKind::PromptWindow => Some(WindowEvent::PromptWindow(None)), - WindowKind::CommandLine => { - Some(WindowEvent::CommandLine(None)) - } + WindowKind::CommandLine => Some(WindowEvent::CommandLine(None)), } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/mod.rs index 1c8d4d46..13a6c875 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/mod.rs @@ -12,15 +12,17 @@ pub use colorscheme::{ColorScheme, ColorSchemeType}; pub use components::{TextWindowTrait, WindowKind}; pub use draw::draw_ui; pub use events::{ - CommandLineAction, KeyEventHandler, PromptAction, WindowEvent, - ConversationEvent, + CommandLineAction, ConversationEvent, KeyEventHandler, PromptAction, + WindowEvent, }; use lumni::api::error::ApplicationError; pub use modal::{ModalConfigWindow, ModalWindowTrait, ModalWindowType}; pub use ui::TabUi; pub use windows::{CommandLine, PromptWindow, ResponseWindow}; -pub use super::chat::{ChatSession, ConversationReader, ConversationId}; +pub use super::chat::{ + ChatSession, ConversationId, ConversationReader, ModelSpec, NewConversation, +}; pub use super::server::{ModelServer, ServerTrait, SUPPORTED_MODEL_ENDPOINTS}; pub use super::session::TabSession; pub use crate::external as lumni; diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/modal.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/modal.rs index a44d6b93..7727114d 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/modal.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/modal.rs @@ -8,8 +8,8 @@ use super::components::Scroller; use super::events::KeyTrack; use super::widgets::SelectEndpoint; use super::{ - ApplicationError, ChatSession, ConversationReader, ModelServer, - ServerTrait, WindowEvent, + ApplicationError, ChatSession, ConversationEvent, ConversationReader, + ModelServer, ModelSpec, NewConversation, ServerTrait, WindowEvent, }; #[derive(Debug, Clone, Copy, PartialEq)] @@ -73,9 +73,37 @@ impl ModalWindowTrait for ModalConfigWindow { KeyCode::Down => self.widget.key_down(), KeyCode::Enter => { let selected_server = self.widget.current_endpoint(); + // TODO: allow model selection, + check if model changes if selected_server != tab_chat.server_name() { let server = ModelServer::from_str(selected_server)?; - return tab_chat.change_server(Box::new(server), reader).await; + + match server.get_default_model().await { + Ok(model) => { + let new_conversation = NewConversation::new( + Box::new(server), + model, + reader, + )?; + // Return the new conversation event + return Ok(Some(WindowEvent::PromptWindow(Some( + ConversationEvent::NewConversation( + new_conversation, + ), + )))); + } + Err(ApplicationError::NotReady(e)) => { + // already a NotReady error + return Err(ApplicationError::NotReady(e)); + } + Err(e) => { + // ensure each error is converted to NotReady, + // with additional logging as its unexpected + log::error!("Error: {}", e); + return Err(ApplicationError::NotReady( + e.to_string(), + )); + } + } } return Ok(Some(WindowEvent::PromptWindow(None))); } diff --git a/lumni/src/http/client.rs b/lumni/src/http/client.rs index 35a91d60..6009e60a 100644 --- a/lumni/src/http/client.rs +++ b/lumni/src/http/client.rs @@ -56,15 +56,15 @@ impl fmt::Display for HttpClientError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { HttpClientError::ConnectionError(e) => { - write!(f, "Connection error: {}", e) + write!(f, "ConnectionError: {}", e) } - HttpClientError::TimeoutError => write!(f, "Timeout error"), + HttpClientError::TimeoutError => write!(f, "TimeoutError"), HttpClientError::HttpError(code, message) => { - write!(f, "HTTP error {}: {}", code, message) + write!(f, "HTTPError: {} {}", code, message) } - HttpClientError::Utf8Error(e) => write!(f, "UTF-8 error: {}", e), - HttpClientError::Other(e) => write!(f, "Other error: {}", e), - HttpClientError::RequestCancelled => write!(f, "Request cancelled"), + HttpClientError::Utf8Error(e) => write!(f, "Utf8Error: {}", e), + HttpClientError::Other(e) => write!(f, "Other: {}", e), + HttpClientError::RequestCancelled => write!(f, "RequestCancelled"), } } } @@ -155,6 +155,7 @@ impl HttpClient { tx: Option>, mut cancel_rx: Option>, ) -> HttpClientResult { + log::debug!("{} {}", method, url); let uri = Uri::from_str(url) .map_err(|e| HttpClientError::Other(e.to_string()))?; @@ -175,11 +176,10 @@ impl HttpClient { .body(request_body) .expect("Failed to build the request"); // Send the request and await the response, handling timeout as needed - let mut response = self - .client - .request(request) - .await - .map_err(|e| HttpClientError::ConnectionError(e.to_string()))?; + let mut response = + self.client.request(request).await.map_err(|_| { + HttpClientError::ConnectionError(url.to_string()) + })?; if !response.status().is_success() { let canonical_reason = response