Skip to content

Commit

Permalink
load last conversation text into response window at start-up, fix bug…
Browse files Browse the repository at this point in the history
… where system message was not properly stored in db
  • Loading branch information
aprxi committed Jul 23, 2024
1 parent c0d02bb commit b312430
Show file tree
Hide file tree
Showing 19 changed files with 251 additions and 117 deletions.
21 changes: 11 additions & 10 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,21 +428,22 @@ pub async fn run_cli(

// check if the last conversation is the same as the new conversation, if so,
// continue the conversation, otherwise start a new conversation
let prompt_instruction =
if let Some(conversation_id) = db_conn.fetch_last_conversation_id()? {
let prompt_instruction = db_conn
.fetch_last_conversation_id()?
.and_then(|conversation_id| {
let reader = db_conn.get_conversation_reader(conversation_id);
let is_equal = new_conversation.is_equal(&reader)?;
if is_equal {
// Convert Result to Option using .ok()
if new_conversation.is_equal(&reader).ok()? {
log::debug!("Continuing last conversation");
PromptInstruction::from_reader(&reader)?
Some(PromptInstruction::from_reader(&reader))
} else {
log::debug!("Starting new conversation");
PromptInstruction::new(new_conversation, &db_conn)?
None
}
} else {
})
.unwrap_or_else(|| {
log::debug!("Starting new conversation");
PromptInstruction::new(new_conversation, &db_conn)?
};
PromptInstruction::new(new_conversation, &db_conn)
})?;

let chat_session =
ChatSession::new(&server_name, prompt_instruction, &db_conn).await?;
Expand Down
47 changes: 27 additions & 20 deletions lumni/src/apps/builtin/llm/prompt/src/chat/assistant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ impl AssistantManager {
completion_options: ChatCompletionOptions::default(),
};

// TODO: default should only apply to servers that do no handle this internally
// Use default assistant when both system_promt and assistant_name are None
let assistant_name = assistant_name.or_else(|| {
if user_instruction.is_some() {
// assistant not needed
None
} else {
// no user instruction, use default assistant
Some("Default".to_string())
}
user_instruction
.as_ref()
.map(|instruction| {
manager.add_system_message(instruction.to_string());
None // No assistant required
})
.unwrap_or_else(|| {
// TODO: default should only apply to servers that do no handle this internally
Some("Default".to_string()) // Use default assistant
})
});

if let Some(assistant_name) = assistant_name {
Expand All @@ -40,6 +42,21 @@ impl AssistantManager {
Ok(manager)
}

fn add_system_message(&mut self, system_prompt: String) {
self.initial_messages.push(Message {
id: MessageId(0), // system message is always the first message
conversation_id: ConversationId(0), // temporary conversation id
role: PromptRole::System,
message_type: "text".to_string(),
content: system_prompt,
has_attachments: false,
token_length: None,
previous_message_id: None,
created_at: 0,
is_deleted: false,
});
}

fn load_assistant(
&mut self,
assistant_name: String,
Expand All @@ -66,18 +83,7 @@ impl AssistantManager {
let system_prompt = build_system_prompt(&prompt, &user_instruction);

// Add system message
self.initial_messages.push(Message {
id: MessageId(0), // system message is always the first message
conversation_id: ConversationId(0), // temporary conversation id
role: PromptRole::System,
message_type: "text".to_string(),
content: system_prompt,
has_attachments: false,
token_length: None,
previous_message_id: None,
created_at: 0,
is_deleted: false,
});
self.add_system_message(system_prompt);

// Add exchanges if any
if let Some(exchanges) = prompt.exchanges() {
Expand Down Expand Up @@ -116,6 +122,7 @@ impl AssistantManager {

let assistant_options = AssistantOptions {
name: assistant_name,
preloaded_messages: self.initial_messages.len() - 1, // exclude the first system message
prompt_template: prompt.prompt_template().map(|s| s.to_string()),
};
self.completion_options
Expand Down
87 changes: 58 additions & 29 deletions lumni/src/apps/builtin/llm/prompt/src/chat/conversation/cache.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
use std::collections::HashMap;

use super::db::{
Attachment, AttachmentId, ConversationId, Message, MessageId,
ModelIdentifier, ModelSpec,
};
use super::PromptRole;
use ratatui::style::Style;

use super::db::{Attachment, AttachmentId, ConversationId, Message, MessageId};
use super::{ColorScheme, PromptRole, TextSegment};

#[derive(Debug)]
pub struct ConversationCache {
conversation_id: ConversationId,
models: HashMap<ModelIdentifier, ModelSpec>,
messages: Vec<Message>, // messages have to be ordered
attachments: HashMap<AttachmentId, Attachment>,
message_attachments: HashMap<MessageId, Vec<AttachmentId>>,
preloaded_messages: usize,
}

impl ConversationCache {
pub fn new() -> Self {
ConversationCache {
conversation_id: ConversationId(-1),
models: HashMap::new(),
messages: Vec::new(),
attachments: HashMap::new(),
message_attachments: HashMap::new(),
preloaded_messages: 0,
}
}

Expand All @@ -34,6 +33,10 @@ impl ConversationCache {
self.conversation_id = conversation_id;
}

pub fn set_preloaded_messages(&mut self, preloaded_messages: usize) {
self.preloaded_messages = preloaded_messages;
}

pub fn new_message_id(&self) -> MessageId {
MessageId(self.messages.len() as i64)
}
Expand All @@ -42,10 +45,6 @@ impl ConversationCache {
AttachmentId(self.attachments.len() as i64)
}

pub fn add_model(&mut self, model: ModelSpec) {
self.models.insert(model.identifier.clone(), model);
}

pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
Expand Down Expand Up @@ -100,27 +99,57 @@ impl ConversationCache {
self.attachments
.insert(attachment.attachment_id, attachment);
}
}

pub fn get_message_attachments(
impl ConversationCache {
pub fn export_conversation(
&self,
message_id: MessageId,
) -> Vec<&Attachment> {
self.message_attachments
.get(&message_id)
.map(|attachment_ids| {
attachment_ids
.iter()
.filter_map(|id| self.attachments.get(id))
.collect()
})
.unwrap_or_default()
}
color_scheme: &ColorScheme,
) -> Vec<TextSegment> {
let mut skip_count = self.preloaded_messages;

// Check if the first message is a system message, and skip it
if !self.messages.is_empty()
&& self.messages[0].role == PromptRole::System
{
skip_count += 1;
}

pub fn get_system_prompt(&self) -> Option<String> {
// system prompt is the first message in the conversation
self.messages
.first()
.filter(|m| m.role == PromptRole::System)
.map(|m| m.content.clone())
.iter()
.skip(skip_count) // Skip preloaded messages
.filter(|m| !m.is_deleted && m.role != PromptRole::System)
.flat_map(|message| {
let style = match message.role {
PromptRole::User => Some(color_scheme.get_primary_style()),
PromptRole::Assistant => {
Some(color_scheme.get_secondary_style())
}
_ => None,
};

let mut segments = Vec::new();

// Add the message content
let text = message.content.clone();
segments.push(TextSegment { text, style });

// For assistant messages, add an extra newline with the same style
if message.role == PromptRole::Assistant {
segments.push(TextSegment {
text: "\n".to_string(),
style,
});
}

// Add an empty unstyled line after each message
segments.push(TextSegment {
text: "\n".to_string(),
style: Some(Style::reset()),
});

segments
})
.collect()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use lumni::api::error::ApplicationError;

use super::db::{
system_time_in_milliseconds, ConversationCache, ConversationDatabaseStore,
ConversationId, ConversationReader, Message, MessageId, ModelServerName,
ModelSpec,
ConversationId, ConversationReader, Message, MessageId, ModelSpec,
};
use super::prepare::NewConversation;
use super::{ChatCompletionOptions, ChatMessage, PromptRole};
use super::{
ChatCompletionOptions, ChatMessage, ColorScheme, PromptRole, TextSegment,
};
use crate::apps::builtin::llm::prompt::src::chat::assistant;
pub use crate::external as lumni;

#[derive(Debug)]
Expand All @@ -31,6 +33,7 @@ impl PromptInstruction {
None => ChatCompletionOptions::default(),
};

eprintln!("completion_options: {:?}", completion_options);
let conversation_id = if let Some(ref model) = new_conversation.model {
Some(db_conn.new_conversation(
"New Conversation",
Expand Down Expand Up @@ -58,8 +61,6 @@ impl PromptInstruction {

if new_conversation.parent.is_none() {
// evaluate system_prompt and initial_messages only if parent is not provided
// TODO: check if first initial message is a System message,
// if not, and prompt is provided, add it as the first message
if let Some(messages) = new_conversation.initial_messages {
let mut messages_to_insert = Vec::new();

Expand All @@ -81,6 +82,9 @@ impl PromptInstruction {
prompt_instruction.cache.add_message(message.clone());
messages_to_insert.push(message);
}
prompt_instruction
.cache
.set_preloaded_messages(messages_to_insert.len());

// Insert messages into the database
db_conn.put_new_messages(&messages_to_insert)?;
Expand Down Expand Up @@ -108,6 +112,11 @@ impl PromptInstruction {
let completion_options: ChatCompletionOptions =
serde_json::from_value(completion_options)?;

let preloaded_messages = completion_options
.assistant_options
.as_ref()
.map_or(0, |options| options.preloaded_messages);

let mut prompt_instruction = PromptInstruction {
cache: ConversationCache::new(),
model: Some(model_spec),
Expand All @@ -118,6 +127,9 @@ impl PromptInstruction {
prompt_instruction
.cache
.set_conversation_id(conversation_id);
prompt_instruction
.cache
.set_preloaded_messages(preloaded_messages);

// Load messages
let messages = reader
Expand Down Expand Up @@ -380,6 +392,13 @@ impl PromptInstruction {
messages.reverse();
messages
}

pub fn export_conversation(
&self,
color_scheme: &ColorScheme,
) -> Vec<TextSegment> {
self.cache.export_conversation(color_scheme)
}
}

fn simple_token_estimator(input: &str, chars_per_token: Option<f32>) -> i64 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ pub use instruction::PromptInstruction;
pub use prepare::NewConversation;

pub use super::db;
use super::{ChatCompletionOptions, ChatMessage, PromptRole};
use super::{
ChatCompletionOptions, ChatMessage, ColorScheme, PromptRole, TextSegment,
};

#[derive(Debug, Clone)]
pub struct ParentConversation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ impl NewConversation {
if last_options != new_options {
return Ok(false);
}

// Compare system prompt. If the system prompt is not set in the new conversation, we check by first system prompt in the initial messages
let last_system_prompt = reader.get_system_prompt()?;
let new_system_prompt = match &self.system_prompt {
Expand All @@ -105,6 +104,7 @@ impl NewConversation {
})
}),
};

if last_system_prompt.as_deref() != new_system_prompt {
return Ok(false);
}
Expand Down
1 change: 1 addition & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub use session::ChatSession;

pub use super::defaults::*;
pub use super::server::{CompletionResponse, ModelServer, ServerManager};
pub use super::tui::{ColorScheme, TextSegment};

// gets PERSONAS from the generated code
include!(concat!(env!("OUT_DIR"), "/llm/prompt/templates.rs"));
Expand Down
3 changes: 2 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ impl ChatCompletionOptions {

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AssistantOptions {
pub name: String, // name of assistant used
pub name: String, // name of assistant used
pub preloaded_messages: usize, // number of messages loaded by the assistant, does not include the first system message
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_template: Option<String>,
}
10 changes: 9 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use tokio::sync::{mpsc, oneshot, Mutex};

use super::db::{ConversationDatabaseStore, ConversationId};
use super::{
CompletionResponse, ModelServer, PromptInstruction, ServerManager,
ColorScheme, CompletionResponse, ModelServer, PromptInstruction,
ServerManager, TextSegment,
};
use crate::api::error::ApplicationError;

Expand Down Expand Up @@ -190,4 +191,11 @@ impl ChatSession {
}
Ok(())
}

pub fn export_conversation(
&self,
color_scheme: &ColorScheme,
) -> Vec<TextSegment> {
self.prompt_instruction.export_conversation(color_scheme)
}
}
Loading

0 comments on commit b312430

Please sign in to comment.