Skip to content

Commit

Permalink
minor improvement on getting server_name in loop
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 2, 2024
1 parent 97c1345 commit aac25f7
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 47 deletions.
12 changes: 0 additions & 12 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session/chat_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,6 @@ impl ThreadedChatSession {
self.event_receiver.resubscribe()
}

pub async fn server_name(&self) -> Result<String, ApplicationError> {
let instruction = self.get_instruction().await?;
instruction
.get_completion_options()
.model_server
.as_ref()
.map(|s| s.to_string())
.ok_or_else(|| {
ApplicationError::NotReady("Server not initialized".to_string())
})
}

pub async fn load_instruction(
&self,
prompt_instruction: PromptInstruction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ pub enum ChatEvent {
Error(String),
}

pub struct SessionInfo {
pub id: ConversationId,
pub server_name: Option<String>,
}

pub struct ChatSessionManager {
sessions: HashMap<ConversationId, ThreadedChatSession>,
active_session_id: ConversationId,
db_conn: Arc<ConversationDatabase>,
pub active_session_info: SessionInfo, // cache frequently accessed session info
}

impl ChatSessionManager {
Expand All @@ -28,6 +33,13 @@ impl ChatSessionManager {
db_conn: Arc<ConversationDatabase>,
) -> Self {
let id = initial_prompt_instruction.get_conversation_id().unwrap();

let server_name = initial_prompt_instruction
.get_completion_options()
.model_server
.as_ref()
.map(|s| s.to_string());

let initial_session = ThreadedChatSession::new(
initial_prompt_instruction,
db_conn.clone(),
Expand All @@ -37,28 +49,17 @@ impl ChatSessionManager {
sessions.insert(id.clone(), initial_session);
Self {
sessions,
active_session_id: id,
db_conn,
active_session_info: SessionInfo { id, server_name },
}
}

pub fn get_active_session(&mut self) -> &mut ThreadedChatSession {
self.sessions.get_mut(&self.active_session_id).unwrap()
self.sessions.get_mut(&self.active_session_info.id).unwrap()
}

pub fn get_active_session_id(&self) -> &ConversationId {
&self.active_session_id
}

pub async fn process_events(&mut self) -> Vec<ChatEvent> {
let mut events = Vec::new();
if let Some(session) = self.sessions.get_mut(&self.active_session_id) {
let mut receiver = session.subscribe();
while let Ok(event) = receiver.try_recv() {
events.push(event);
}
}
events
&self.active_session_info.id
}

pub async fn stop_session(
Expand Down Expand Up @@ -96,12 +97,22 @@ impl ChatSessionManager {
Ok(id)
}

pub fn set_active_session(
pub async fn set_active_session(
&mut self,
id: ConversationId,
) -> Result<(), ApplicationError> {
if self.sessions.contains_key(&id) {
self.active_session_id = id;
self.active_session_info.id = id;
self.active_session_info.server_name = self
.sessions
.get(&id)
.unwrap()
.get_instruction()
.await?
.get_completion_options()
.model_server
.as_ref()
.map(|s| s.to_string());
Ok(())
} else {
Err(ApplicationError::InvalidInput(
Expand All @@ -111,7 +122,9 @@ impl ChatSessionManager {
}

pub fn stop_active_chat_session(&mut self) {
if let Some(session) = self.sessions.get_mut(&self.active_session_id) {
if let Some(session) =
self.sessions.get_mut(&self.active_session_info.id)
{
session.stop();
}
}
Expand Down
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use lumni::api::error::ApplicationError;
use ratatui::backend::Backend;
use ratatui::Terminal;

use super::db::{ConversationDatabase, ConversationDbHandler, ConversationId};
use super::db::{ConversationDatabase, ConversationId};
use super::{
db, draw_ui, AppUi, ColorScheme, ColorSchemeType, CommandLineAction,
CompletionResponse, ConversationEvent, KeyEventHandler, ModalWindowType,
Expand Down
11 changes: 5 additions & 6 deletions lumni/src/apps/builtin/llm/prompt/src/tui/draw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ pub async fn draw_ui<B: Backend>(
terminal: &mut Terminal<B>,
app: &mut App<'_>,
) -> Result<(), io::Error> {
let server_name = &app
let server_name = app
.chat_manager
.get_active_session()
.server_name()
.await
.active_session_info
.server_name
.as_deref()
.unwrap_or_default();
//let server_name = "TODO: get server name";

terminal.draw(|frame| {
let terminal_area = frame.size();
Expand All @@ -41,7 +40,7 @@ pub async fn draw_ui<B: Backend>(

// add borders to main_window[0]
frame.render_widget(
main_widget(&server_name, window_hint()),
main_widget(server_name, window_hint()),
main_window[0],
);

Expand Down
3 changes: 1 addition & 2 deletions lumni/src/apps/builtin/llm/prompt/src/tui/events/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ use super::window::{
WindowKind,
};
use super::{
ChatSession, ConversationDbHandler, NewConversation, PromptInstruction,
ThreadedChatSession,
ChatSession, ConversationDbHandler, NewConversation, ThreadedChatSession,
};
pub use crate::external as lumni;

Expand Down
22 changes: 14 additions & 8 deletions lumni/src/apps/builtin/llm/prompt/src/tui/modals/endpoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use super::{
ServerManager, ServerTrait, ThreadedChatSession, WindowEvent,
SUPPORTED_MODEL_ENDPOINTS,
};
use crate::apps::builtin::llm::prompt::src::server;

pub struct SelectEndpointModal {
widget: SelectEndpoint,
Expand Down Expand Up @@ -60,14 +61,19 @@ impl ModalWindowTrait for SelectEndpointModal {
let selected_server = self.widget.current_endpoint();
// TODO: allow model selection, + check if model changes
// TODO: catch ApplicationError::NotReady, if it is assume selected_server != tab_chat.server_name()
let should_create_new_conversation =
match tab_chat.server_name().await {
Ok(current_server_name) => {
selected_server != current_server_name
}
Err(ApplicationError::NotReady(_)) => true, // Assume new server if NotReady
Err(e) => return Err(e), // Propagate other errors
};
let instruction = tab_chat.get_instruction().await?;
let server_name = instruction
.get_completion_options()
.model_server
.as_ref()
.map(|s| s.to_string());

let should_create_new_conversation = match server_name {
Some(current_server_name) => {
selected_server != current_server_name
}
None => true, // Assume new server if no current server
};

let event = if should_create_new_conversation {
let server = ModelServer::from_str(selected_server)?;
Expand Down

0 comments on commit aac25f7

Please sign in to comment.