Skip to content

Commit

Permalink
ensure conversation_id can be kept optional in chatsession-manager, u…
Browse files Browse the repository at this point in the history
…se uuid for session keys
  • Loading branch information
aprxi committed Aug 3, 2024
1 parent 1f0d89b commit 3918104
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 61 deletions.
1 change: 1 addition & 0 deletions lumni/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ lazy_static = { version = "1.5" }
rayon = { version = "1.10" }
crossbeam-channel = { version = "0.5" }
globset = { version = "0.4" }
uuid = { version = "1.10.0", features = ["v4"] }

# CLI
env_logger = { version = "0.9", optional = true }
Expand Down
13 changes: 6 additions & 7 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ pub async fn run_cli(
// continue the conversation, otherwise start a new conversation

let mut db_handler = db_conn.get_conversation_handler(None);

let conversation_id = db_conn.fetch_last_conversation_id().await?;

let prompt_instruction = if let Some(conversation_id) = conversation_id {
Expand Down Expand Up @@ -182,24 +181,24 @@ pub async fn run_cli(
match poll(Duration::from_millis(0)) {
Ok(_) => {
// Starting interactive session
let app =
App::new(prompt_instruction, Arc::clone(&db_conn)).await?;
interactive_mode(app, db_conn).await
log::debug!("Starting interactive session");
interactive_mode(prompt_instruction, db_conn).await
}
Err(_) => {
// potential non-interactive input detected due to poll error.
// attempt to use in non interactive mode
//let chat_session = ChatSession::new(prompt_instruction);
log::debug!("Starting non-interactive session");
process_non_interactive_input(prompt_instruction, db_conn).await
}
}
}

async fn interactive_mode(
app: App<'_>,
prompt_instruction: PromptInstruction,
db_conn: Arc<ConversationDatabase>,
) -> Result<(), ApplicationError> {
println!("Interactive mode detected. Starting interactive session:");
let app =
App::new(prompt_instruction, Arc::clone(&db_conn)).await?;
let mut stdout = io::stdout().lock();

// Enable raw mode and setup the screen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl PromptInstruction {
.await?,
)
} else {
// Model is required to add conversation to the database
None
};

Expand All @@ -64,7 +65,11 @@ impl PromptInstruction {
.set_conversation_id(conversation_id);
}

if new_conversation.parent.is_none() {
if new_conversation.parent.is_some() || conversation_id.is_none() {
// if parent is provided, do not evaluate system_prompt and initial_messages
// as they are already evaluated in the parent
// If conversation_id is none, cant create system prompt or initial messages yet
} else {
// evaluate system_prompt and initial_messages only if parent is not provided
if let Some(messages) = new_conversation.initial_messages {
let mut messages_to_insert = Vec::new();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;

use lumni::api::error::ApplicationError;

Expand All @@ -17,13 +18,14 @@ pub enum ChatEvent {
}

pub struct SessionInfo {
pub id: ConversationId,
pub id: Uuid,
pub conversation_id: Option<ConversationId>,
pub server_name: Option<String>,
}

pub struct ChatSessionManager {
sessions: HashMap<ConversationId, ThreadedChatSession>,
pub active_session_info: SessionInfo, // cache frequently accessed session info
sessions: HashMap<Uuid, ThreadedChatSession>,
pub active_session_info: SessionInfo,
}

#[allow(dead_code)]
Expand All @@ -32,46 +34,55 @@ impl ChatSessionManager {
initial_prompt_instruction: PromptInstruction,
db_conn: Arc<ConversationDatabase>,
) -> Self {
let id = initial_prompt_instruction.get_conversation_id().unwrap();

let session_id = Uuid::new_v4();
let conversation_id = initial_prompt_instruction.get_conversation_id();
let server_name = initial_prompt_instruction
.get_completion_options()
.model_server
.as_ref()
.map(|s| s.to_string());

let initial_session = ThreadedChatSession::new(
initial_prompt_instruction,
db_conn.clone(),
);

let mut sessions = HashMap::new();
sessions.insert(id.clone(), initial_session);
sessions.insert(session_id, initial_session);

Self {
sessions,
active_session_info: SessionInfo { id, server_name },
active_session_info: SessionInfo {
id: session_id,
conversation_id,
server_name
},
}
}

pub fn get_active_session(&mut self) -> &mut ThreadedChatSession {
self.sessions.get_mut(&self.active_session_info.id).unwrap()
pub fn get_active_session(&mut self) -> Result<&mut ThreadedChatSession, ApplicationError> {
self.sessions.get_mut(&self.active_session_info.id).ok_or_else(||
ApplicationError::Runtime("Active session not found".to_string())
)
}

pub fn get_active_session_id(&self) -> &ConversationId {
&self.active_session_info.id
pub fn get_conversation_id_for_active_session(&self) -> Option<ConversationId> {
self.active_session_info.conversation_id
}

pub fn get_active_session_id(&self) -> Uuid {
self.active_session_info.id
}

pub async fn stop_session(
&mut self,
id: &ConversationId,
id: &Uuid,
) -> Result<(), ApplicationError> {
if let Some(session) = self.sessions.remove(id) {
session.stop();
Ok(())
} else {
Err(ApplicationError::InvalidInput(
"Session not found".to_string(),
))
Err(ApplicationError::InvalidInput("Session not found".to_string()))
}
}

Expand All @@ -85,46 +96,43 @@ impl ChatSessionManager {
&mut self,
prompt_instruction: PromptInstruction,
db_conn: Arc<ConversationDatabase>,
) -> Result<ConversationId, ApplicationError> {
let id = prompt_instruction.get_conversation_id().ok_or_else(|| {
ApplicationError::Runtime(
"Failed to get conversation ID".to_string(),
)
})?;
) -> Uuid {
let session_id = Uuid::new_v4();
let new_session = ThreadedChatSession::new(prompt_instruction, db_conn);
self.sessions.insert(id.clone(), new_session);
Ok(id)
self.sessions.insert(session_id, new_session);
session_id
}

pub async fn set_active_session(
&mut self,
id: ConversationId,
id: Uuid,
) -> Result<(), ApplicationError> {
if self.sessions.contains_key(&id) {
self.active_session_info.id = id;
self.active_session_info.server_name = self
.sessions
.get(&id)
.unwrap()
.get_instruction()
.await?
if let Some(session) = self.sessions.get(&id) {
let instruction = session.get_instruction().await?;
let conversation_id = instruction.get_conversation_id();
let server_name = instruction
.get_completion_options()
.model_server
.as_ref()
.map(|s| s.to_string());

self.active_session_info = SessionInfo {
id,
conversation_id,
server_name,
};
Ok(())
} else {
Err(ApplicationError::InvalidInput(
"Session not found".to_string(),
))
Err(ApplicationError::InvalidInput("Session not found".to_string()))
}
}

pub fn stop_active_chat_session(&mut self) {
if let Some(session) =
self.sessions.get_mut(&self.active_session_info.id)
{
pub fn stop_active_chat_session(&mut self) -> Result<(), ApplicationError> {
if let Some(session) = self.sessions.get_mut(&self.active_session_info.id) {
session.stop();
Ok(())
} else {
Err(ApplicationError::Runtime("Active session not found".to_string()))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ pub async fn prompt_app<B: Backend>(
let mut current_mode = Some(WindowEvent::ResponseWindow);
let mut key_event_handler = KeyEventHandler::new();
let mut redraw_ui = true;
let conversation_id = app.get_conversation_id_for_active_session().clone();
let conversation_id = app.get_conversation_id_for_active_session();
let mut db_handler =
db_conn.get_conversation_handler(Some(conversation_id));
db_conn.get_conversation_handler(conversation_id);

loop {
tokio::select! {
Expand All @@ -42,7 +42,7 @@ pub async fn prompt_app<B: Backend>(
}
_ = async {
// Process chat events
let events = app.chat_manager.get_active_session().subscribe().recv().await;
let events = app.chat_manager.get_active_session()?.subscribe().recv().await;
if let Ok(event) = events {
match event {
ChatEvent::ResponseUpdate(content) => {
Expand Down Expand Up @@ -121,7 +121,7 @@ async fn handle_key_event(
.process_key(
key_event,
&mut app.ui,
app.chat_manager.get_active_session(),
app.chat_manager.get_active_session()?,
mode,
keep_running.clone(),
db_handler,
Expand Down Expand Up @@ -170,7 +170,7 @@ async fn handle_prompt_action(
send_prompt(app, &prompt, color_scheme).await?;
}
PromptAction::Stop => {
app.stop_active_chat_session().await;
app.stop_active_chat_session().await?;
app.ui
.response
.text_append("\n", Some(color_scheme.get_secondary_style()))?;
Expand Down Expand Up @@ -261,7 +261,7 @@ async fn send_prompt<'a>(
let formatted_prompt = format!("{}\n", prompt.trim_end());
let result = app
.chat_manager
.get_active_session()
.get_active_session()?
.message(&formatted_prompt)
.await;

Expand Down
12 changes: 6 additions & 6 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl App<'_> {
) -> Result<(), ApplicationError> {
let prompt_instruction = self
.chat_manager
.get_active_session()
.get_active_session()?
.get_instruction()
.await?;

Expand Down Expand Up @@ -93,20 +93,20 @@ impl App<'_> {
Ok(())
}

pub fn get_conversation_id_for_active_session(&self) -> &ConversationId {
self.chat_manager.get_active_session_id()
pub fn get_conversation_id_for_active_session(&self) -> Option<ConversationId> {
self.chat_manager.get_conversation_id_for_active_session()
}

pub async fn stop_active_chat_session(&mut self) {
self.chat_manager.stop_active_chat_session();
pub async fn stop_active_chat_session(&mut self) -> Result<(), ApplicationError> {
self.chat_manager.stop_active_chat_session()
}

pub async fn load_instruction_for_active_session(
&mut self,
prompt_instruction: PromptInstruction,
) -> Result<(), ApplicationError> {
self.chat_manager
.get_active_session()
.get_active_session()?
.load_instruction(prompt_instruction)
.await
}
Expand Down
4 changes: 4 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/tui/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ impl AppUi<'_> {
self.command_line.init(); // initialize with defaults
}

pub fn set_alert(&mut self, message: &str) -> Result<(), ApplicationError> {
self.command_line.set_alert(message)
}

pub async fn set_new_modal(
&mut self,
modal_type: ModalWindowType,
Expand Down

0 comments on commit 3918104

Please sign in to comment.