Skip to content

Commit

Permalink
separate conversation cache from db, each conversation can have its own
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 13, 2024
1 parent c9ff573 commit 61774af
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 104 deletions.
14 changes: 5 additions & 9 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async fn prompt_app<B: Backend>(
Some(WindowEvent::Prompt(prompt_action)) => {
match prompt_action {
PromptAction::Write(prompt) => {
send_prompt(tab, &db_conn, &prompt, &color_scheme, tx.clone()).await?;
send_prompt(tab, &prompt, &color_scheme, tx.clone()).await?;
}
PromptAction::Clear => {
tab.ui.response.text_empty();
Expand Down Expand Up @@ -188,7 +188,7 @@ async fn prompt_app<B: Backend>(
let display_content = format!("{}{}", trim_buffer.unwrap_or("".to_string()), trimmed_response);

if !display_content.is_empty() {
chat.update_last_exchange(&db_conn, &display_content);
chat.update_last_exchange(&display_content);
tab_ui.response.text_append_with_insert(&display_content, Some(color_scheme.get_secondary_style()));
}

Expand Down Expand Up @@ -375,7 +375,7 @@ async fn interactive_mode(

async fn process_non_interactive_input(
chat: ChatSession,
db_conn: ConversationDatabase,
_db_conn: ConversationDatabase,
) -> Result<(), ApplicationError> {
let chat = Arc::new(Mutex::new(chat));
let stdin = tokio::io::stdin();
Expand Down Expand Up @@ -409,7 +409,7 @@ async fn process_non_interactive_input(
// Process the prompt
let process_handle = tokio::spawn(async move {
let mut chat = chat_clone.lock().await;
chat.process_prompt(&db_conn, input, running.clone()).await
chat.process_prompt(input, running.clone()).await
});

// Wait for the process to complete or for a shutdown signal
Expand Down Expand Up @@ -491,17 +491,13 @@ async fn handle_ctrl_c(r: Arc<Mutex<bool>>, s: Arc<Mutex<bool>>) {

async fn send_prompt<'a>(
tab: &mut TabSession<'a>,
db_conn: &ConversationDatabase,
prompt: &str,
color_scheme: &ColorScheme,
tx: mpsc::Sender<Bytes>,
) -> Result<(), ApplicationError> {
// prompt should end with single newline
let formatted_prompt = format!("{}\n", prompt.trim_end());
let result = tab
.chat
.message(tx.clone(), db_conn, &formatted_prompt)
.await;
let result = tab.chat.message(tx.clone(), &formatted_prompt).await;

match result {
Ok(_) => {
Expand Down
8 changes: 3 additions & 5 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@ mod schema;
mod store;

pub use schema::{
Attachment, Conversation, ConversationId, Exchange,
ExchangeId, ConversationCache, Message, ModelId,
Attachment, Conversation, ConversationCache, ConversationId, Exchange,
ExchangeId, Message, ModelId,
};
pub use store::ConversationDatabaseStore;

pub use super::PromptRole;

pub struct ConversationDatabase {
pub store: Arc<Mutex<ConversationDatabaseStore>>,
pub cache: Arc<Mutex<ConversationCache>>,
}

impl ConversationDatabase {
Expand All @@ -26,7 +25,6 @@ impl ConversationDatabase {
store: Arc::new(Mutex::new(ConversationDatabaseStore::new(
sqlite_file,
)?)),
cache: Arc::new(Mutex::new(ConversationCache::new())),
})
}

Expand Down Expand Up @@ -54,8 +52,8 @@ impl ConversationDatabase {
pub fn finalize_exchange(
&self,
exchange: &Exchange,
cache: &ConversationCache,
) -> Result<(), SqliteError> {
let cache = self.cache.lock().unwrap();
let messages = cache.get_exchange_messages(exchange.id);
let attachments = messages
.iter()
Expand Down
10 changes: 10 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub struct Attachment {

#[derive(Debug)]
pub struct ConversationCache {
conversation_id: ConversationId,
models: HashMap<ModelId, Model>,
exchanges: Vec<Exchange>,
messages: HashMap<MessageId, Message>,
Expand All @@ -100,6 +101,7 @@ pub struct ConversationCache {
impl ConversationCache {
pub fn new() -> Self {
ConversationCache {
conversation_id: ConversationId(0),
models: HashMap::new(),
exchanges: Vec::new(),
messages: HashMap::new(),
Expand All @@ -109,6 +111,14 @@ impl ConversationCache {
}
}

pub fn get_conversation_id(&self) -> ConversationId {
self.conversation_id
}

pub fn set_conversation_id(&mut self, conversation_id: ConversationId) {
self.conversation_id = conversation_id;
}

pub fn new_exchange_id(&self) -> ExchangeId {
ExchangeId(self.exchanges.len() as i64)
}
Expand Down
3 changes: 1 addition & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use rusqlite::Error as SqliteError;

use super::connector::DatabaseConnector;
use super::schema::{
Attachment, AttachmentData, Conversation, ConversationId,
Exchange, Message,
Attachment, AttachmentData, Conversation, ConversationId, Exchange, Message,
};

pub struct ConversationDatabaseStore {
Expand Down
Loading

0 comments on commit 61774af

Please sign in to comment.