Skip to content

Commit

Permalink
use threaded chatsession also for non-interactive input
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 2, 2024
1 parent aac25f7 commit a60adee
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 28 deletions.
45 changes: 33 additions & 12 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io;
use std::io::{self, Write};
use std::sync::Arc;

use clap::{Arg, Command};
Expand All @@ -21,8 +21,8 @@ use tokio::time::{timeout, Duration};

use super::chat::db::ConversationDatabase;
use super::chat::{
prompt_app, App, AssistantManager, ChatSession, NewConversation,
PromptInstruction,
prompt_app, App, AssistantManager, ChatEvent, NewConversation,
PromptInstruction, ThreadedChatSession,
};
use super::server::{ModelServer, ModelServerName, ServerTrait};
pub use crate::external as lumni;
Expand Down Expand Up @@ -189,8 +189,8 @@ pub async fn run_cli(
Err(_) => {
// potential non-interactive input detected due to poll error.
// attempt to use in non interactive mode
let chat_session = ChatSession::new(prompt_instruction);
process_non_interactive_input(chat_session, db_conn).await
//let chat_session = ChatSession::new(prompt_instruction);
process_non_interactive_input(prompt_instruction, db_conn).await
}
}
}
Expand Down Expand Up @@ -242,10 +242,13 @@ async fn interactive_mode(
}

async fn process_non_interactive_input(
chat: ChatSession,
prompt_instruction: PromptInstruction,
db_conn: Arc<ConversationDatabase>,
) -> Result<(), ApplicationError> {
let chat = Arc::new(Mutex::new(chat));
let chat = Arc::new(Mutex::new(ThreadedChatSession::new(
prompt_instruction,
db_conn.clone(),
)));
let stdin = tokio::io::stdin();
let mut reader = BufReader::new(stdin);
let mut stdin_input = String::new();
Expand All @@ -272,14 +275,28 @@ async fn process_non_interactive_input(
stdin_input.push('\n'); // Maintain line breaks
}

let chat_clone = chat.clone();
let input = stdin_input.trim_end().to_string();
let chat_clone = chat.clone();

// Process the prompt
let process_handle = tokio::spawn(async move {
let db_handler = db_conn.get_conversation_handler(None);
let mut chat = chat_clone.lock().await;
chat.process_prompt(input, running.clone(), &db_handler)
.await
chat_clone.lock().await.message(&input).await?;

let mut receiver = chat_clone.lock().await.subscribe();
while let Ok(event) = receiver.recv().await {
match event {
ChatEvent::ResponseUpdate(content) => {
print!("{}", content);
std::io::stdout().flush().unwrap();
}
ChatEvent::FinalResponse => break,
ChatEvent::Error(e) => {
return Err(ApplicationError::Unexpected(e));
}
}
}

Ok(())
});

// Wait for the process to complete or for a shutdown signal
Expand All @@ -294,10 +311,12 @@ async fn process_non_interactive_input(
"Processing completed successfully during \
shutdown."
);
chat.lock().await.stop();
return Ok(());
}
Ok(Err(e)) => {
eprintln!("Process error during shutdown: {}", e);
chat.lock().await.stop();
return Err(ApplicationError::Unexpected(format!(
"Process error: {}",
e
Expand All @@ -307,6 +326,7 @@ async fn process_non_interactive_input(
eprintln!(
"Graceful shutdown timed out. Forcing exit..."
);
chat.lock().await.stop();
return Ok(());
}
}
Expand All @@ -328,6 +348,7 @@ async fn process_non_interactive_input(
e
))
})?;
chat.lock().await.stop();
return Ok(());
}

Expand Down
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use completion_options::ChatCompletionOptions;
pub use conversation::{ConversationCache, NewConversation, PromptInstruction};
use prompt::Prompt;
pub use prompt::{AssistantManager, PromptRole};
pub use session::{prompt_app, App, ChatSession, ThreadedChatSession};
pub use session::{prompt_app, App, ChatEvent, ThreadedChatSession};

pub use super::defaults::*;
pub use super::server::{CompletionResponse, ModelServer, ServerManager};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ impl ThreadedChatSession {
}
}

pub struct ChatSession {
struct ChatSession {
prompt_instruction: PromptInstruction,
model_server_session: ModelServerSession,
response_sender: mpsc::Sender<Bytes>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use super::db::{ConversationDatabase, ConversationId};
use super::PromptInstruction;
pub use crate::external as lumni;

// add clone
#[derive(Clone)]
pub enum ChatEvent {
ResponseUpdate(String),
Expand Down
4 changes: 2 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ mod conversation_loop;
use std::io;
use std::sync::Arc;

pub use chat_session::{ChatSession, ThreadedChatSession};
pub use chat_session_manager::ChatSessionManager;
pub use chat_session::ThreadedChatSession;
pub use chat_session_manager::{ChatEvent, ChatSessionManager};
pub use conversation_loop::prompt_app;
use lumni::api::error::ApplicationError;
use ratatui::backend::Backend;
Expand Down
5 changes: 2 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/tui/events/key_event.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::thread::Thread;

use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};

use super::handle_command_line::handle_command_line_event;
use super::handle_prompt_window::handle_prompt_window_event;
use super::handle_response_window::handle_response_window_event;
use super::{
AppUi, ApplicationError, ChatSession, ConversationDbHandler,
ThreadedChatSession, WindowEvent,
AppUi, ApplicationError, ConversationDbHandler, ThreadedChatSession,
WindowEvent,
};

#[derive(Debug, Clone)]
Expand Down
4 changes: 1 addition & 3 deletions lumni/src/apps/builtin/llm/prompt/src/tui/events/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ use super::window::{
LineType, MoveCursor, PromptWindow, TextDocumentTrait, TextWindowTrait,
WindowKind,
};
use super::{
ChatSession, ConversationDbHandler, NewConversation, ThreadedChatSession,
};
use super::{ConversationDbHandler, NewConversation, ThreadedChatSession};
pub use crate::external as lumni;

#[derive(Debug)]
Expand Down
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/tui/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use super::chat::db::{
Conversation, ConversationDbHandler, ConversationStatus,
};
pub use super::chat::{
App, ChatSession, NewConversation, PromptInstruction, ThreadedChatSession,
App, NewConversation, PromptInstruction, ThreadedChatSession,
};
pub use super::server::{
ModelServer, ServerManager, ServerTrait, SUPPORTED_MODEL_ENDPOINTS,
Expand Down
8 changes: 4 additions & 4 deletions lumni/src/apps/builtin/llm/prompt/src/tui/modals/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use ratatui::layout::Rect;
use ratatui::Frame;

pub use super::{
ApplicationError, ChatSession, CommandLine, Conversation,
ConversationDbHandler, ConversationEvent, ConversationStatus, KeyTrack,
ModelServer, NewConversation, PromptInstruction, Scroller, ServerManager,
ServerTrait, TextWindowTrait, ThreadedChatSession, WindowEvent,
ApplicationError, CommandLine, Conversation, ConversationDbHandler,
ConversationEvent, ConversationStatus, KeyTrack, ModelServer,
NewConversation, PromptInstruction, Scroller, ServerManager, ServerTrait,
TextWindowTrait, ThreadedChatSession, WindowEvent,
SUPPORTED_MODEL_ENDPOINTS,
};

Expand Down

0 comments on commit a60adee

Please sign in to comment.