Skip to content

Commit

Permalink
wip - ui updates are now passed over mpsc channel to prep for future …
Browse files Browse the repository at this point in the history
…threading
  • Loading branch information
aprxi committed Jul 31, 2024
1 parent 0b79b78 commit 10e840c
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 32 deletions.
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ pub async fn run_cli(
match poll(Duration::from_millis(0)) {
Ok(_) => {
// Starting interactive session
let app = App::new(chat_session)?;
let app = App::new(chat_session).await?;
interactive_mode(app, db_conn).await
}
Err(_) => {
Expand Down
69 changes: 44 additions & 25 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session/chat_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use bytes::Bytes;
use tokio::sync::{mpsc, oneshot, Mutex};

use super::db::{ConversationDbHandler, ConversationId};
use super::chat_session_manager::UiUpdate;
use super::{
AppUi, TextWindowTrait,
ColorScheme, CompletionResponse, ModelServer, PromptInstruction,
ServerManager, TextLine,
};
Expand Down Expand Up @@ -67,6 +67,7 @@ pub struct ChatSession {
model_server_session: ModelServerSession,
response_sender: mpsc::Sender<Bytes>,
response_receiver: mpsc::Receiver<Bytes>,
ui_sender: Arc<Mutex<Option<mpsc::Sender<UiUpdate>>>>,
}

impl ChatSession {
Expand All @@ -78,9 +79,15 @@ impl ChatSession {
model_server_session: ModelServerSession::new(),
response_sender,
response_receiver,
ui_sender: Arc::new(Mutex::new(None)),
}
}

pub async fn set_ui_sender(&self, sender: Option<mpsc::Sender<UiUpdate>>) {
let mut ui_sender = self.ui_sender.lock().await;
*ui_sender = sender;
}

pub async fn load_instruction(
&mut self,
prompt_instruction: PromptInstruction,
Expand Down Expand Up @@ -301,11 +308,10 @@ impl ChatSession {
&mut self,
db_handler: &mut ConversationDbHandler<'a>,
color_scheme: Option<&ColorScheme>,
ui: Option<&mut AppUi<'a>>,
) -> Result<bool, ApplicationError> {
match self.receive_response().await? {
Some(response) => {
self.process_chat_response(response, db_handler, color_scheme, ui).await?;
self.process_chat_response(response, db_handler, color_scheme).await?;
Ok(true) // Indicates that a response was processed
}
None => {
Expand All @@ -330,32 +336,37 @@ impl ChatSession {
response: CompletionResponse,
db_handler: &mut ConversationDbHandler<'a>,
color_scheme: Option<&ColorScheme>,
mut ui: Option<&mut AppUi<'a>>,
) -> Result<(), ApplicationError> {
log::debug!(
"Received response with length {:?}",
response.get_content().len()
);

let style = if let Some(color_scheme) = color_scheme {
color_scheme.get_secondary_style()
} else {
Style::default()
};

let trimmed_response = response.get_content().trim_end().to_string();
log::debug!("Trimmed response: {:?}", trimmed_response);

if !trimmed_response.is_empty() {
self.update_last_exchange(&trimmed_response);

// Update UI if provided
if let (Some(ui), Some(color_scheme)) = (ui.as_mut(), color_scheme) {
ui.response
.text_append(
&trimmed_response,
Some(color_scheme.get_secondary_style()),
)
.map_err(|e| ApplicationError::Runtime(e.to_string()))?;
// Send UI update if ui_sender is set
let ui_sender = self.ui_sender.lock().await;
if let Some(sender) = &*ui_sender {
let update = UiUpdate {
content: trimmed_response.clone(),
style: Some(style),
};
sender.send(update).await.map_err(|e| ApplicationError::Runtime(e.to_string()))?;
}
}

if response.is_final {
self.finalize_chat_response(response, db_handler, color_scheme, ui).await?;
self.finalize_chat_response(response, db_handler, color_scheme).await?;
}

Ok(())
Expand All @@ -366,22 +377,30 @@ impl ChatSession {
response: CompletionResponse,
db_handler: &mut ConversationDbHandler<'a>,
color_scheme: Option<&ColorScheme>,
mut ui: Option<&mut AppUi<'a>>,
) -> Result<(), ApplicationError> {
let tokens_predicted = response.stats.as_ref().and_then(|s| s.tokens_predicted);

self.stop_chat_session();

if let (Some(ui), Some(color_scheme)) = (ui.as_mut(), color_scheme) {
ui.response
.text_append("\n", Some(color_scheme.get_secondary_style()))
.map_err(|e| ApplicationError::Runtime(e.to_string()))?;

ui.response
.text_append("\n", Some(Style::reset()))
.map_err(|e| ApplicationError::Runtime(e.to_string()))?;
}
// Send UI updates for newlines
{
let ui_sender = self.ui_sender.lock().await;
if let (Some(sender), Some(color_scheme)) = (&*ui_sender, color_scheme) {
// First newline with secondary style
let update1 = UiUpdate {
content: "\n".to_string(),
style: Some(color_scheme.get_secondary_style()),
};
sender.send(update1).await.map_err(|e| ApplicationError::Runtime(e.to_string()))?;

// Second newline with reset style
let update2 = UiUpdate {
content: "\n".to_string(),
style: Some(Style::reset()),
};
sender.send(update2).await.map_err(|e| ApplicationError::Runtime(e.to_string()))?;
}
} // The MutexGuard is dropped here, releasing the immutable borrow

self.stop_chat_session();
self.finalize_last_exchange(db_handler, tokens_predicted).await?;

Ok(())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,55 @@
use std::collections::HashMap;

use lumni::api::error::ApplicationError;
use ratatui::style::Style;

use tokio::sync::mpsc;

use super::db::ConversationId;
use super::ChatSession;
use super::{App, ChatSession, TextWindowTrait};
pub use crate::external as lumni;

pub struct ChatSessionManager {
sessions: HashMap<ConversationId, ChatSession>,
active_session_id: ConversationId,
ui_sender: mpsc::Sender<UiUpdate>,
ui_receiver: mpsc::Receiver<UiUpdate>,
}

pub struct UiUpdate {
pub content: String,
pub style: Option<Style>,
}

impl ChatSessionManager {
pub fn new(initial_session: ChatSession) -> Self {
pub async fn new(initial_session: ChatSession) -> Self {
let id = initial_session.get_conversation_id().unwrap();
let mut sessions = HashMap::new();
let (ui_sender, ui_receiver) = mpsc::channel(100);

initial_session.set_ui_sender(Some(ui_sender.clone())).await;
sessions.insert(id.clone(), initial_session);

Self {
sessions,
active_session_id: id,
ui_sender,
ui_receiver,
}
}

pub fn switch_active_session(&mut self, id: ConversationId) -> Result<(), ApplicationError> {
pub async fn switch_active_session(&mut self, id: ConversationId) -> Result<(), ApplicationError> {
if self.sessions.contains_key(&id) {
// Remove UI sender from the previous active session
if let Some(prev_session) = self.sessions.get_mut(&self.active_session_id) {
prev_session.set_ui_sender(None).await;
}

// Set UI sender for the new active session
if let Some(new_session) = self.sessions.get_mut(&id) {
new_session.set_ui_sender(Some(self.ui_sender.clone())).await;
}

self.active_session_id = id;
Ok(())
} else {
Expand All @@ -34,4 +60,12 @@ impl ChatSessionManager {
pub fn get_active_session(&mut self) -> &mut ChatSession {
self.sessions.get_mut(&self.active_session_id).unwrap()
}

pub fn process_ui_updates(&mut self) -> Vec<UiUpdate> {
let mut updates = Vec::new();
while let Ok(update) = self.ui_receiver.try_recv() {
updates.push(update);
}
updates
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ pub async fn prompt_app<B: Backend>(
result = app.chat_manager.get_active_session().receive_and_process_response(
&mut db_handler,
Some(&app.color_scheme),
Some(&mut app.ui)
) => {
match result {
Ok(true) => {
let updates = app.chat_manager.process_ui_updates();
for update in updates {
app.ui.response.text_append(&update.content, update.style)?;
}
redraw_ui = true;
},
Ok(false) => {
Expand Down
4 changes: 2 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub struct App<'a> {
}

impl App<'_> {
pub fn new(
pub async fn new(
initial_chat_session: ChatSession,
) -> Result<Self, ApplicationError> {
let color_scheme = ColorScheme::new(ColorSchemeType::Default);
Expand All @@ -42,7 +42,7 @@ impl App<'_> {

Ok(App {
ui,
chat_manager: ChatSessionManager::new(initial_chat_session),
chat_manager: ChatSessionManager::new(initial_chat_session).await,
color_scheme,
})
}
Expand Down

0 comments on commit 10e840c

Please sign in to comment.