Skip to content

Commit

Permalink
wip - initial/ experimental db, move out conversations to db, ensure …
Browse files Browse the repository at this point in the history
…ids are unique when committed
  • Loading branch information
aprxi committed Jul 12, 2024
1 parent b1bbcb8 commit fa8450c
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 287 deletions.
29 changes: 26 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/connector.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::collections::{HashMap, VecDeque};
use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::thread;

use rusqlite::{params, Error as SqliteError, Result as SqliteResult};
use rusqlite::{params, Error as SqliteError, Transaction};

pub struct DatabaseConnector {
connection: rusqlite::Connection,
Expand Down Expand Up @@ -139,4 +138,28 @@ impl DatabaseConnector {
tx.commit()?;
Ok(())
}

pub fn process_queue_with_result<T>(
&mut self,
result_handler: impl FnOnce(&Transaction) -> Result<T, SqliteError>,
) -> Result<T, SqliteError> {
let mut queue = self.operation_queue.lock().unwrap();
let tx = self.connection.transaction()?;

while let Some(sql) = queue.pop_front() {
eprintln!("Executing SQL {}", sql);
if sql.trim().to_lowercase().starts_with("select") {
// For SELECT statements, use query
tx.query_row(&sql, [], |_| Ok(()))?;
} else {
// For other statements (INSERT, UPDATE, DELETE), use execute
tx.execute(&sql, [])?;
}
}
eprintln!("Committing transaction");
let result = result_handler(&tx)?;

tx.commit()?;
Ok(result)
}
}
51 changes: 48 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ mod schema;
mod store;

pub use schema::{
ConversationId, Exchange, InMemoryDatabase, Message, ModelId,
Attachment, Conversation, ConversationId, Exchange,
ExchangeId, ConversationCache, Message, ModelId,
};
pub use store::ConversationDatabaseStore;

pub use super::PromptRole;

pub struct ConversationDatabase {
pub store: Arc<Mutex<ConversationDatabaseStore>>,
pub cache: Arc<Mutex<InMemoryDatabase>>,
pub cache: Arc<Mutex<ConversationCache>>,
}

impl ConversationDatabase {
Expand All @@ -25,7 +26,51 @@ impl ConversationDatabase {
store: Arc::new(Mutex::new(ConversationDatabaseStore::new(
sqlite_file,
)?)),
cache: Arc::new(Mutex::new(InMemoryDatabase::new())),
cache: Arc::new(Mutex::new(ConversationCache::new())),
})
}

pub fn new_conversation(
&self,
name: &str,
parent_id: Option<ConversationId>,
) -> Result<ConversationId, SqliteError> {
let mut store = self.store.lock().unwrap();
let conversation = Conversation {
id: ConversationId(-1), // Temporary ID
name: name.to_string(),
metadata: serde_json::Value::Null,
parent_conversation_id: parent_id,
fork_exchange_id: None,
schema_version: 1,
created_at: 0,
updated_at: 0,
is_deleted: false,
};
let conversation_id = store.store_new_conversation(&conversation)?;
Ok(conversation_id)
}

pub fn finalize_exchange(
&self,
exchange: &Exchange,
) -> Result<(), SqliteError> {
let cache = self.cache.lock().unwrap();
let messages = cache.get_exchange_messages(exchange.id);
let attachments = messages
.iter()
.flat_map(|message| cache.get_message_attachments(message.id))
.collect::<Vec<_>>();
let owned_messages: Vec<Message> =
messages.into_iter().cloned().collect();
let owned_attachments: Vec<Attachment> =
attachments.into_iter().cloned().collect();
let mut store = self.store.lock().unwrap();
store.store_finalized_exchange(
exchange,
&owned_messages,
&owned_attachments,
)?;
Ok(())
}
}
128 changes: 11 additions & 117 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

use serde::{Deserialize, Serialize};

use super::store::ConversationDatabaseStore;
use super::PromptRole;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down Expand Up @@ -90,36 +88,27 @@ pub struct Attachment {
}

#[derive(Debug)]
pub struct InMemoryDatabase {
pub struct ConversationCache {
models: HashMap<ModelId, Model>,
conversations: HashMap<ConversationId, Conversation>,
exchanges: HashMap<ExchangeId, Exchange>,
exchanges: Vec<Exchange>,
messages: HashMap<MessageId, Message>,
attachments: HashMap<AttachmentId, Attachment>,

conversation_exchanges: HashMap<ConversationId, Vec<ExchangeId>>,
exchange_messages: HashMap<ExchangeId, Vec<MessageId>>,
message_attachments: HashMap<MessageId, Vec<AttachmentId>>,
}

impl InMemoryDatabase {
impl ConversationCache {
pub fn new() -> Self {
InMemoryDatabase {
ConversationCache {
models: HashMap::new(),
conversations: HashMap::new(),
exchanges: HashMap::new(),
exchanges: Vec::new(),
messages: HashMap::new(),
attachments: HashMap::new(),
conversation_exchanges: HashMap::new(),
exchange_messages: HashMap::new(),
message_attachments: HashMap::new(),
}
}

pub fn new_conversation_id(&self) -> ConversationId {
ConversationId(self.conversations.len() as i64)
}

pub fn new_exchange_id(&self) -> ExchangeId {
ExchangeId(self.exchanges.len() as i64)
}
Expand All @@ -132,51 +121,16 @@ impl InMemoryDatabase {
AttachmentId(self.attachments.len() as i64)
}

pub fn new_conversation(
&mut self,
name: &str,
db_store: &Arc<Mutex<ConversationDatabaseStore>>,
parent_id: Option<ConversationId>,
) -> ConversationId {
let new_id = self.new_conversation_id();
let conversation = Conversation {
id: new_id,
name: name.to_string(),
metadata: serde_json::Value::Null,
parent_conversation_id: parent_id,
fork_exchange_id: None,
schema_version: 1,
created_at: 0, // not using timestamps for now, stick with 0 for now
updated_at: 0, // not using timestamps for now, stick with 0 for now
is_deleted: false,
};
self.add_conversation(db_store, conversation);
new_id
}

pub fn add_model(&mut self, model: Model) {
self.models.insert(model.model_id, model);
}

pub fn add_conversation(
&mut self,
db_store: &Arc<Mutex<ConversationDatabaseStore>>,
conversation: Conversation,
) {
let mut store_lock = db_store.lock().unwrap();
store_lock.store_new_conversation(&conversation);
let result = store_lock.commit_queued_operations();
eprintln!("Commit result: {:?}", result);

self.conversations.insert(conversation.id, conversation);
pub fn add_exchange(&mut self, exchange: Exchange) {
self.exchanges.push(exchange);
}

pub fn add_exchange(&mut self, exchange: Exchange) {
self.conversation_exchanges
.entry(exchange.conversation_id)
.or_default()
.push(exchange.id);
self.exchanges.insert(exchange.id, exchange);
pub fn get_exchanges(&self) -> Vec<&Exchange> {
self.exchanges.iter().collect()
}

pub fn add_message(&mut self, message: Message) {
Expand All @@ -187,14 +141,6 @@ impl InMemoryDatabase {
self.messages.insert(message.id, message);
}

pub fn update_message(&mut self, updated_message: Message) {
if let Some(existing_message) =
self.messages.get_mut(&updated_message.id)
{
*existing_message = updated_message;
}
}

pub fn update_message_by_id(
&mut self,
message_id: MessageId,
Expand All @@ -216,31 +162,8 @@ impl InMemoryDatabase {
.insert(attachment.attachment_id, attachment);
}

pub fn get_conversation_exchanges(
&self,
conversation_id: ConversationId,
) -> Vec<&Exchange> {
self.conversation_exchanges
.get(&conversation_id)
.map(|exchange_ids| {
exchange_ids
.iter()
.filter_map(|id| self.exchanges.get(id))
.collect()
})
.unwrap_or_default()
}

pub fn get_last_exchange(
&self,
conversation_id: ConversationId,
) -> Option<Exchange> {
self.conversation_exchanges
.get(&conversation_id)
.and_then(|exchanges| exchanges.last())
.and_then(|last_exchange_id| {
self.exchanges.get(last_exchange_id).cloned()
})
pub fn get_last_exchange(&self) -> Option<&Exchange> {
self.exchanges.last()
}

pub fn get_exchange_messages(
Expand All @@ -258,35 +181,6 @@ impl InMemoryDatabase {
.unwrap_or_default()
}

pub fn finalize_last_exchange(
&mut self,
db_store: &Arc<Mutex<ConversationDatabaseStore>>,
conversation_id: ConversationId,
) {
let exchange = self.get_last_exchange(conversation_id);
if let Some(exchange) = exchange {
let messages = self.get_exchange_messages(exchange.id);
let attachments = messages
.iter()
.flat_map(|message| {
self.get_message_attachments(message.id)
})
.collect::<Vec<_>>();

// Convert Vec<&Message> to Vec<Message> and Vec<&Attachment> to Vec<Attachment>
let owned_messages: Vec<Message> = messages.into_iter().cloned().collect();
let owned_attachments: Vec<Attachment> = attachments.into_iter().cloned().collect();

eprintln!("Owned messages: {:?}", owned_messages);

let mut db_lock_store = db_store.lock().unwrap();
db_lock_store.store_finalized_exchange(&exchange, &owned_messages, &owned_attachments);
let result = db_lock_store.commit_queued_operations();
eprintln!("Commit result: {:?}", result);
}

}

pub fn get_last_message_of_exchange(
&self,
exchange_id: ExchangeId,
Expand Down
Loading

0 comments on commit fa8450c

Please sign in to comment.