From 39181040170a52b78ac924323541be2fd05ddcb0 Mon Sep 17 00:00:00 2001 From: Anthony Potappel Date: Sun, 4 Aug 2024 00:08:31 +0200 Subject: [PATCH] ensure conversation_id can be kept optional in chatsession-manager, use uuid for session keys --- lumni/Cargo.toml | 1 + lumni/src/apps/builtin/llm/prompt/src/app.rs | 13 ++- .../src/chat/conversation/instruction.rs | 7 +- .../src/chat/session/chat_session_manager.rs | 90 ++++++++++--------- .../src/chat/session/conversation_loop.rs | 12 +-- .../llm/prompt/src/chat/session/mod.rs | 12 +-- .../src/apps/builtin/llm/prompt/src/tui/ui.rs | 4 + 7 files changed, 78 insertions(+), 61 deletions(-) diff --git a/lumni/Cargo.toml b/lumni/Cargo.toml index fcd69e67..984bc497 100644 --- a/lumni/Cargo.toml +++ b/lumni/Cargo.toml @@ -52,6 +52,7 @@ lazy_static = { version = "1.5" } rayon = { version = "1.10" } crossbeam-channel = { version = "0.5" } globset = { version = "0.4" } +uuid = { version = "1.10.0", features = ["v4"] } # CLI env_logger = { version = "0.9", optional = true } diff --git a/lumni/src/apps/builtin/llm/prompt/src/app.rs b/lumni/src/apps/builtin/llm/prompt/src/app.rs index d1f608ec..abb518ed 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/app.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/app.rs @@ -153,7 +153,6 @@ pub async fn run_cli( // continue the conversation, otherwise start a new conversation let mut db_handler = db_conn.get_conversation_handler(None); - let conversation_id = db_conn.fetch_last_conversation_id().await?; let prompt_instruction = if let Some(conversation_id) = conversation_id { @@ -182,24 +181,24 @@ pub async fn run_cli( match poll(Duration::from_millis(0)) { Ok(_) => { // Starting interactive session - let app = - App::new(prompt_instruction, Arc::clone(&db_conn)).await?; - interactive_mode(app, db_conn).await + log::debug!("Starting interactive session"); + interactive_mode(prompt_instruction, db_conn).await } Err(_) => { // potential non-interactive input detected due to poll error. // attempt to use in non interactive mode - //let chat_session = ChatSession::new(prompt_instruction); + log::debug!("Starting non-interactive session"); process_non_interactive_input(prompt_instruction, db_conn).await } } } async fn interactive_mode( - app: App<'_>, + prompt_instruction: PromptInstruction, db_conn: Arc, ) -> Result<(), ApplicationError> { - println!("Interactive mode detected. Starting interactive session:"); + let app = + App::new(prompt_instruction, Arc::clone(&db_conn)).await?; let mut stdout = io::stdout().lock(); // Enable raw mode and setup the screen diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/instruction.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/instruction.rs index 778b699d..25663426 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/instruction.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/instruction.rs @@ -48,6 +48,7 @@ impl PromptInstruction { .await?, ) } else { + // Model is required to add conversation to the database None }; @@ -64,7 +65,11 @@ impl PromptInstruction { .set_conversation_id(conversation_id); } - if new_conversation.parent.is_none() { + if new_conversation.parent.is_some() || conversation_id.is_none() { + // if parent is provided, do not evaluate system_prompt and initial_messages + // as they are already evaluated in the parent + // If conversation_id is none, cant create system prompt or initial messages yet + } else { // evaluate system_prompt and initial_messages only if parent is not provided if let Some(messages) = new_conversation.initial_messages { let mut messages_to_insert = Vec::new(); diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/session/chat_session_manager.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/session/chat_session_manager.rs index e31afc69..8de035cd 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/session/chat_session_manager.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/session/chat_session_manager.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use uuid::Uuid; use lumni::api::error::ApplicationError; @@ -17,13 +18,14 @@ pub enum ChatEvent { } pub struct SessionInfo { - pub id: ConversationId, + pub id: Uuid, + pub conversation_id: Option, pub server_name: Option, } pub struct ChatSessionManager { - sessions: HashMap, - pub active_session_info: SessionInfo, // cache frequently accessed session info + sessions: HashMap, + pub active_session_info: SessionInfo, } #[allow(dead_code)] @@ -32,46 +34,55 @@ impl ChatSessionManager { initial_prompt_instruction: PromptInstruction, db_conn: Arc, ) -> Self { - let id = initial_prompt_instruction.get_conversation_id().unwrap(); - + let session_id = Uuid::new_v4(); + let conversation_id = initial_prompt_instruction.get_conversation_id(); let server_name = initial_prompt_instruction .get_completion_options() .model_server .as_ref() .map(|s| s.to_string()); - + let initial_session = ThreadedChatSession::new( initial_prompt_instruction, db_conn.clone(), ); - + let mut sessions = HashMap::new(); - sessions.insert(id.clone(), initial_session); + sessions.insert(session_id, initial_session); + Self { sessions, - active_session_info: SessionInfo { id, server_name }, + active_session_info: SessionInfo { + id: session_id, + conversation_id, + server_name + }, } } - pub fn get_active_session(&mut self) -> &mut ThreadedChatSession { - self.sessions.get_mut(&self.active_session_info.id).unwrap() + pub fn get_active_session(&mut self) -> Result<&mut ThreadedChatSession, ApplicationError> { + self.sessions.get_mut(&self.active_session_info.id).ok_or_else(|| + ApplicationError::Runtime("Active session not found".to_string()) + ) } - pub fn get_active_session_id(&self) -> &ConversationId { - &self.active_session_info.id + pub fn get_conversation_id_for_active_session(&self) -> Option { + self.active_session_info.conversation_id + } + + pub fn get_active_session_id(&self) -> Uuid { + self.active_session_info.id } pub async fn stop_session( &mut self, - id: &ConversationId, + id: &Uuid, ) -> Result<(), ApplicationError> { if let Some(session) = self.sessions.remove(id) { session.stop(); Ok(()) } else { - Err(ApplicationError::InvalidInput( - "Session not found".to_string(), - )) + Err(ApplicationError::InvalidInput("Session not found".to_string())) } } @@ -85,46 +96,43 @@ impl ChatSessionManager { &mut self, prompt_instruction: PromptInstruction, db_conn: Arc, - ) -> Result { - let id = prompt_instruction.get_conversation_id().ok_or_else(|| { - ApplicationError::Runtime( - "Failed to get conversation ID".to_string(), - ) - })?; + ) -> Uuid { + let session_id = Uuid::new_v4(); let new_session = ThreadedChatSession::new(prompt_instruction, db_conn); - self.sessions.insert(id.clone(), new_session); - Ok(id) + self.sessions.insert(session_id, new_session); + session_id } pub async fn set_active_session( &mut self, - id: ConversationId, + id: Uuid, ) -> Result<(), ApplicationError> { - if self.sessions.contains_key(&id) { - self.active_session_info.id = id; - self.active_session_info.server_name = self - .sessions - .get(&id) - .unwrap() - .get_instruction() - .await? + if let Some(session) = self.sessions.get(&id) { + let instruction = session.get_instruction().await?; + let conversation_id = instruction.get_conversation_id(); + let server_name = instruction .get_completion_options() .model_server .as_ref() .map(|s| s.to_string()); + + self.active_session_info = SessionInfo { + id, + conversation_id, + server_name, + }; Ok(()) } else { - Err(ApplicationError::InvalidInput( - "Session not found".to_string(), - )) + Err(ApplicationError::InvalidInput("Session not found".to_string())) } } - pub fn stop_active_chat_session(&mut self) { - if let Some(session) = - self.sessions.get_mut(&self.active_session_info.id) - { + pub fn stop_active_chat_session(&mut self) -> Result<(), ApplicationError> { + if let Some(session) = self.sessions.get_mut(&self.active_session_info.id) { session.stop(); + Ok(()) + } else { + Err(ApplicationError::Runtime("Active session not found".to_string())) } } } diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/session/conversation_loop.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/session/conversation_loop.rs index b56622d3..0f756b16 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/session/conversation_loop.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/session/conversation_loop.rs @@ -31,9 +31,9 @@ pub async fn prompt_app( let mut current_mode = Some(WindowEvent::ResponseWindow); let mut key_event_handler = KeyEventHandler::new(); let mut redraw_ui = true; - let conversation_id = app.get_conversation_id_for_active_session().clone(); + let conversation_id = app.get_conversation_id_for_active_session(); let mut db_handler = - db_conn.get_conversation_handler(Some(conversation_id)); + db_conn.get_conversation_handler(conversation_id); loop { tokio::select! { @@ -42,7 +42,7 @@ pub async fn prompt_app( } _ = async { // Process chat events - let events = app.chat_manager.get_active_session().subscribe().recv().await; + let events = app.chat_manager.get_active_session()?.subscribe().recv().await; if let Ok(event) = events { match event { ChatEvent::ResponseUpdate(content) => { @@ -121,7 +121,7 @@ async fn handle_key_event( .process_key( key_event, &mut app.ui, - app.chat_manager.get_active_session(), + app.chat_manager.get_active_session()?, mode, keep_running.clone(), db_handler, @@ -170,7 +170,7 @@ async fn handle_prompt_action( send_prompt(app, &prompt, color_scheme).await?; } PromptAction::Stop => { - app.stop_active_chat_session().await; + app.stop_active_chat_session().await?; app.ui .response .text_append("\n", Some(color_scheme.get_secondary_style()))?; @@ -261,7 +261,7 @@ async fn send_prompt<'a>( let formatted_prompt = format!("{}\n", prompt.trim_end()); let result = app .chat_manager - .get_active_session() + .get_active_session()? .message(&formatted_prompt) .await; diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/session/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/session/mod.rs index 07e8d363..ef4779ab 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/session/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/session/mod.rs @@ -64,7 +64,7 @@ impl App<'_> { ) -> Result<(), ApplicationError> { let prompt_instruction = self .chat_manager - .get_active_session() + .get_active_session()? .get_instruction() .await?; @@ -93,12 +93,12 @@ impl App<'_> { Ok(()) } - pub fn get_conversation_id_for_active_session(&self) -> &ConversationId { - self.chat_manager.get_active_session_id() + pub fn get_conversation_id_for_active_session(&self) -> Option { + self.chat_manager.get_conversation_id_for_active_session() } - pub async fn stop_active_chat_session(&mut self) { - self.chat_manager.stop_active_chat_session(); + pub async fn stop_active_chat_session(&mut self) -> Result<(), ApplicationError> { + self.chat_manager.stop_active_chat_session() } pub async fn load_instruction_for_active_session( @@ -106,7 +106,7 @@ impl App<'_> { prompt_instruction: PromptInstruction, ) -> Result<(), ApplicationError> { self.chat_manager - .get_active_session() + .get_active_session()? .load_instruction(prompt_instruction) .await } diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/ui.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/ui.rs index 6b9be0e8..6e9618a0 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/ui.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/ui.rs @@ -40,6 +40,10 @@ impl AppUi<'_> { self.command_line.init(); // initialize with defaults } + pub fn set_alert(&mut self, message: &str) -> Result<(), ApplicationError> { + self.command_line.set_alert(message) + } + pub async fn set_new_modal( &mut self, modal_type: ModalWindowType,