Skip to content

Commit

Permalink
update OpenAI implementation to working version, preparations to allo…
Browse files Browse the repository at this point in the history
…w server changes from modal window
  • Loading branch information
aprxi committed Jul 2, 2024
1 parent a5b070e commit c9cbc10
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 129 deletions.
69 changes: 37 additions & 32 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,16 @@ async fn prompt_app<B: Backend>(
// TODO: add color scheme selection via modal
let color_scheme = ColorScheme::new(ColorSchemeType::Default);

// TODO: replace with loaded server and model
tab.ui.response.set_window_title("Response");

loop {
tokio::select! {
_ = tick.tick() => {
if redraw_ui {
tab.draw_ui(terminal)?;
redraw_ui = false;
}
let mut tab_ui = &mut tab.ui;
let mut chat = &mut tab.chat;

// set timeout to 1ms to allow for non-blocking polling
if poll(Duration::from_millis(1))? {
Expand All @@ -74,31 +75,31 @@ async fn prompt_app<B: Backend>(
// toggle beteen prompt and response windows
current_mode = match current_mode {
Some(WindowEvent::PromptWindow) => {
if tab_ui.prompt.is_status_insert() {
if tab.ui.prompt.is_status_insert() {
// tab is locked to prompt window when in insert mode
Some(WindowEvent::PromptWindow)
} else {
tab_ui.prompt.set_status_inactive();
tab_ui.response.set_status_normal();
tab.ui.prompt.set_status_inactive();
tab.ui.response.set_status_normal();
Some(WindowEvent::ResponseWindow)
}
}
Some(WindowEvent::ResponseWindow) => {
tab_ui.response.set_status_inactive();
tab_ui.prompt.set_status_normal();
tab.ui.response.set_status_inactive();
tab.ui.prompt.set_status_normal();
Some(WindowEvent::PromptWindow)
}
Some(WindowEvent::CommandLine(_)) => {
// exit command line mode
tab_ui.command_line.text_empty();
tab_ui.command_line.set_status_inactive();
tab.ui.command_line.text_empty();
tab.ui.command_line.set_status_inactive();

// switch to the active window,
if tab_ui.response.is_active() {
tab_ui.response.set_status_normal();
if tab.ui.response.is_active() {
tab.ui.response.set_status_normal();
Some(WindowEvent::ResponseWindow)
} else {
tab_ui.prompt.set_status_normal();
tab.ui.prompt.set_status_normal();
Some(WindowEvent::PromptWindow)
}
}
Expand All @@ -109,7 +110,8 @@ async fn prompt_app<B: Backend>(
current_mode = if let Some(mode) = current_mode {
key_event_handler.process_key(
key_event,
&mut tab_ui,
&mut tab.ui,
&mut tab.chat,
mode,
keep_running.clone(),
).await
Expand All @@ -127,57 +129,57 @@ async fn prompt_app<B: Backend>(
// prompt should end with single newline
let formatted_prompt = format!("{}\n", prompt.trim_end());

tab_ui.response.text_append_with_insert(
tab.ui.response.text_append_with_insert(
&formatted_prompt,
//Some(PromptStyle::user()),
Some(color_scheme.get_primary_style()),
);
tab_ui.response.text_append_with_insert(
tab.ui.response.text_append_with_insert(
"\n",
Some(Style::reset()),
);

chat.message(tx.clone(), formatted_prompt).await?;
tab.chat.message(tx.clone(), formatted_prompt).await?;
}
PromptAction::Clear => {
tab_ui.response.text_empty();
chat.reset();
tab.ui.response.text_empty();
tab.chat.reset();
trim_buffer = None;
}
PromptAction::Stop => {
chat.stop();
finalize_response(&mut chat, &mut tab_ui, None, &color_scheme).await?;
tab.chat.stop();
finalize_response(&mut tab.chat, &mut tab.ui, None, &color_scheme).await?;
trim_buffer = None;
}
}
current_mode = Some(WindowEvent::PromptWindow);
}
Some(WindowEvent::CommandLine(ref action)) => {
// enter command line mode
if tab_ui.prompt.is_active() {
tab_ui.prompt.set_status_background();
if tab.ui.prompt.is_active() {
tab.ui.prompt.set_status_background();
} else {
tab_ui.response.set_status_background();
tab.ui.response.set_status_background();
}
match action {
CommandLineAction::Write(prefix) => {
tab_ui.command_line.set_insert_mode();
tab_ui.command_line.text_set(prefix, None);
tab.ui.command_line.set_insert_mode();
tab.ui.command_line.text_set(prefix, None);
}
CommandLineAction::None => {}
}
}
Some(WindowEvent::Modal(modal_window_type)) => {
if tab_ui.needs_modal_update(modal_window_type) {
tab_ui.set_new_modal(modal_window_type);
if tab.ui.needs_modal_update(modal_window_type) {
tab.ui.set_new_modal(modal_window_type);
}
}
_ => {}
}
},
Event::Mouse(mouse_event) => {
// TODO: should track on which window the cursor actually is
let window = &mut tab_ui.response;
let window = &mut tab.ui.response;
match mouse_event.kind {
MouseEventKind::ScrollUp => {
window.scroll_up();
Expand Down Expand Up @@ -207,13 +209,16 @@ async fn prompt_app<B: Backend>(
let mut tab_ui = &mut tab.ui;
let mut chat = &mut tab.chat;

if trim_buffer.is_none() {
let start_of_stream = if trim_buffer.is_none() {
// new response stream started
log::debug!("New response stream started");
tab_ui.response.enable_auto_scroll();
}
true
} else {
false
};

let (response_content, is_final, tokens_predicted) = chat.process_response(response_bytes);
let (response_content, is_final, tokens_predicted) = chat.process_response(response_bytes, start_of_stream);

let trimmed_response = if let Some(text) = response_content.as_ref() {
text.trim_end().to_string()
Expand All @@ -237,7 +242,7 @@ async fn prompt_app<B: Backend>(
// e.g. for logging or metrics. These should be retrieved to ensure
// the stream is fully consumed and processed.
while let Ok(post_bytes) = rx.try_recv() {
chat.process_response(post_bytes);
chat.process_response(post_bytes, false);
}
finalize_response(&mut chat, &mut tab_ui, tokens_predicted, &color_scheme).await?;
trim_buffer = None;
Expand Down
16 changes: 12 additions & 4 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ impl ChatSession {
})
}

pub fn server_name(&self) -> &str {
// TODO: get server name from server
"foobar"
}

pub fn stop(&mut self) {
// Stop the chat session by sending a cancel signal
if let Some(cancel_tx) = self.cancel_tx.take() {
Expand Down Expand Up @@ -157,10 +162,11 @@ impl ChatSession {
}

pub fn process_response(
&self,
&mut self,
response: Bytes,
start_of_stream: bool,
) -> (Option<String>, bool, Option<usize>) {
self.server.process_response(response)
self.server.process_response(response, start_of_stream)
}

// used in non-interactive mode
Expand All @@ -177,11 +183,12 @@ impl ChatSession {
}

async fn handle_response(
&self,
&mut self,
mut rx: mpsc::Receiver<Bytes>,
stop_signal: Arc<Mutex<bool>>,
) -> Result<(), ApplicationError> {
let mut final_received = false;
let mut start_of_stream = true;
while let Some(response) = rx.recv().await {
// check if the session must be kept running
if !*stop_signal.lock().await {
Expand All @@ -196,12 +203,13 @@ impl ChatSession {
continue;
}
let (response_content, is_final, _) =
self.process_response(response);
self.process_response(response, start_of_stream);
if let Some(response_content) = response_content {
print!("{}", response_content);
}
io::stdout().flush().expect("Failed to flush stdout");

start_of_stream = false;
if is_final {
final_received = true;
}
Expand Down
3 changes: 2 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ impl ServerTrait for Bedrock {
}

fn process_response(
&self,
&mut self,
response_bytes: Bytes,
_start_of_stream: bool,
) -> (Option<String>, bool, Option<usize>) {
match EventStreamMessage::from_bytes(response_bytes) {
Ok(event) => {
Expand Down
3 changes: 2 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/server/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ impl Llama {
#[async_trait]
impl ServerTrait for Llama {
fn process_response(
&self,
&mut self,
response: Bytes,
_start_of_stream: bool,
) -> (Option<String>, bool, Option<usize>) {
match LlamaCompletionResponse::extract_content(response) {
Ok(chat) => (Some(chat.content), chat.stop, chat.tokens_predicted),
Expand Down
14 changes: 8 additions & 6 deletions lumni/src/apps/builtin/llm/prompt/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ impl ServerTrait for ModelServer {
}

fn process_response(
&self,
&mut self,
response: Bytes,
start_of_stream: bool,
) -> (Option<String>, bool, Option<usize>) {
match self {
ModelServer::Llama(llama) => llama.process_response(response),
ModelServer::Ollama(ollama) => ollama.process_response(response),
ModelServer::Bedrock(bedrock) => bedrock.process_response(response),
ModelServer::OpenAI(openai) => openai.process_response(response),
ModelServer::Llama(llama) => llama.process_response(response, start_of_stream),
ModelServer::Ollama(ollama) => ollama.process_response(response, start_of_stream),
ModelServer::Bedrock(bedrock) => bedrock.process_response(response, start_of_stream),
ModelServer::OpenAI(openai) => openai.process_response(response, start_of_stream),
}
}

Expand Down Expand Up @@ -222,8 +223,9 @@ pub trait ServerTrait: Send + Sync {
}

fn process_response(
&self,
&mut self,
response: Bytes,
start_of_stream: bool,
) -> (Option<String>, bool, Option<usize>);

async fn tokenizer(
Expand Down
3 changes: 2 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/server/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ impl ServerTrait for Ollama {
}

fn process_response(
&self,
&mut self,
response: Bytes,
_start_of_stream: bool,
) -> (Option<String>, bool, Option<usize>) {
match OllamaCompletionResponse::extract_content(response) {
Ok(chat) => {
Expand Down
29 changes: 6 additions & 23 deletions lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ use super::{
};
use credentials::OpenAICredentials;
use request::OpenAIRequestPayload;
use response::OpenAIResponsePayload;
use response::StreamParser;

pub use crate::external as lumni;

pub struct OpenAI {
http_client: HttpClient,
endpoints: Endpoints,
model: Option<LLMDefinition>,
stream_parser: StreamParser,
}

const OPENAI_COMPLETION_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
Expand All @@ -44,6 +45,7 @@ impl OpenAI {
.with_error_handler(Arc::new(OpenAIErrorHandler)),
endpoints,
model: None,
stream_parser: StreamParser::new(),
})
}

Expand Down Expand Up @@ -93,30 +95,11 @@ impl ServerTrait for OpenAI {
}

fn process_response(
&self,
&mut self,
response_bytes: Bytes,
start_of_stream: bool,
) -> (Option<String>, bool, Option<usize>) {
// TODO: OpenAI sents back split responses, which we need to concatenate first
match OpenAIResponsePayload::extract_content(response_bytes) {
Ok(chat) => {
let choices = chat.choices;
if choices.is_empty() {
return (None, false, None);
}
let chat_message = &choices[0];
let delta = &chat_message.delta;
let stop = if let Some(_) = chat_message.finish_reason {
true // if finish_reason is present, then always stop
} else {
false
};
let stop = true;
(delta.content.clone(), stop, None)
}
Err(e) => {
(Some(format!("Failed to parse JSON: {}", e)), true, None)
}
}
self.stream_parser.process_chunk(response_bytes, start_of_stream)
}

async fn completion(
Expand Down
Loading

0 comments on commit c9cbc10

Please sign in to comment.