Skip to content

Commit

Permalink
wip - initial/ experimental insert to database
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 11, 2024
1 parent 892ca49 commit f1d204d
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 64 deletions.
41 changes: 27 additions & 14 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use tokio::signal;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, timeout, Duration};

use super::chat::ChatSession;
use super::chat::{ChatSession, ConversationDatabase};
use super::server::{ModelServer, PromptInstruction, ServerTrait};
use super::session::{AppSession, TabSession};
use super::tui::{
Expand All @@ -40,6 +40,7 @@ const CHANNEL_QUEUE_SIZE: usize = 32;
async fn prompt_app<B: Backend>(
terminal: &mut Terminal<B>,
mut app_session: AppSession<'_>,
db_conn: ConversationDatabase,
) -> Result<(), ApplicationError> {
let defaults = app_session.get_defaults().clone();

Expand Down Expand Up @@ -94,16 +95,16 @@ async fn prompt_app<B: Backend>(
Some(WindowEvent::Prompt(prompt_action)) => {
match prompt_action {
PromptAction::Write(prompt) => {
send_prompt(tab, &prompt, &color_scheme, tx.clone()).await?;
send_prompt(tab, &db_conn, &prompt, &color_scheme, tx.clone()).await?;
}
PromptAction::Clear => {
tab.ui.response.text_empty();
tab.chat.reset();
tab.chat.reset(&db_conn);
trim_buffer = None;
}
PromptAction::Stop => {
tab.chat.stop();
finalize_response(&mut tab.chat, &mut tab.ui, None, &color_scheme).await?;
finalize_response(&mut tab.chat, &mut tab.ui, &db_conn, None, &color_scheme).await?;
trim_buffer = None;
}
}
Expand Down Expand Up @@ -187,7 +188,7 @@ async fn prompt_app<B: Backend>(
let display_content = format!("{}{}", trim_buffer.unwrap_or("".to_string()), trimmed_response);

if !display_content.is_empty() {
chat.update_last_exchange(&display_content);
chat.update_last_exchange(&db_conn, &display_content);
tab_ui.response.text_append_with_insert(&display_content, Some(color_scheme.get_secondary_style()));
}

Expand All @@ -199,7 +200,7 @@ async fn prompt_app<B: Backend>(
while let Ok(post_bytes) = rx.try_recv() {
chat.process_response(post_bytes, false);
}
finalize_response(&mut chat, &mut tab_ui, tokens_predicted, &color_scheme).await?;
finalize_response(&mut chat, &mut tab_ui, &db_conn, tokens_predicted, &color_scheme).await?;
trim_buffer = None;
} else {
// Capture trailing whitespaces or newlines to the trim_buffer
Expand All @@ -221,6 +222,7 @@ async fn prompt_app<B: Backend>(
async fn finalize_response(
chat: &mut ChatSession,
tab_ui: &mut TabUi<'_>,
db_conn: &ConversationDatabase,
tokens_predicted: Option<usize>,
color_scheme: &ColorScheme,
) -> Result<(), ApplicationError> {
Expand All @@ -236,7 +238,8 @@ async fn finalize_response(
.response
.text_append_with_insert("\n", Some(Style::reset()));
// trim exchange + update token length
chat.finalize_last_exchange(tokens_predicted).await?;
chat.finalize_last_exchange(db_conn, tokens_predicted)
.await?;
Ok(())
}

Expand Down Expand Up @@ -298,30 +301,35 @@ pub async fn run_cli(
// create new (un-initialized) server from requested server name
let server = ModelServer::from_str(&server_name)?;
let default_model = server.get_default_model().await;

let sqlite_file = config_dir.join("chat.db");
let db_conn = ConversationDatabase::new(&sqlite_file)?;

// setup prompt, server and chat session
let prompt_instruction =
PromptInstruction::new(instruction, assistant, options)?;
PromptInstruction::new(instruction, assistant, options, &db_conn)?;
let chat_session =
ChatSession::new(Box::new(server), prompt_instruction, default_model)
.await?;

match poll(Duration::from_millis(0)) {
Ok(_) => {
// Starting interactive session
let mut app_session = AppSession::new(config_dir)?;
let mut app_session = AppSession::new()?;
app_session.add_tab(chat_session);
interactive_mode(app_session).await
interactive_mode(app_session, db_conn).await
}
Err(_) => {
// potential non-interactive input detected due to poll error.
// attempt to use in non interactive mode
process_non_interactive_input(chat_session).await
process_non_interactive_input(chat_session, db_conn).await
}
}
}

async fn interactive_mode(
app_session: AppSession<'_>,
db_conn: ConversationDatabase,
) -> Result<(), ApplicationError> {
println!("Interactive mode detected. Starting interactive session:");
let mut stdout = io::stdout().lock();
Expand Down Expand Up @@ -352,7 +360,7 @@ async fn interactive_mode(
};

// Run the application logic and capture the result
let result = prompt_app(&mut terminal, app_session).await;
let result = prompt_app(&mut terminal, app_session, db_conn).await;

// Regardless of the result, perform cleanup
let _ = disable_raw_mode();
Expand All @@ -367,6 +375,7 @@ async fn interactive_mode(

async fn process_non_interactive_input(
chat: ChatSession,
db_conn: ConversationDatabase,
) -> Result<(), ApplicationError> {
let chat = Arc::new(Mutex::new(chat));
let stdin = tokio::io::stdin();
Expand Down Expand Up @@ -400,7 +409,7 @@ async fn process_non_interactive_input(
// Process the prompt
let process_handle = tokio::spawn(async move {
let mut chat = chat_clone.lock().await;
chat.process_prompt(input, running.clone()).await
chat.process_prompt(&db_conn, input, running.clone()).await
});

// Wait for the process to complete or for a shutdown signal
Expand Down Expand Up @@ -482,13 +491,17 @@ async fn handle_ctrl_c(r: Arc<Mutex<bool>>, s: Arc<Mutex<bool>>) {

async fn send_prompt<'a>(
tab: &mut TabSession<'a>,
db_conn: &ConversationDatabase,
prompt: &str,
color_scheme: &ColorScheme,
tx: mpsc::Sender<Bytes>,
) -> Result<(), ApplicationError> {
// prompt should end with single newline
let formatted_prompt = format!("{}\n", prompt.trim_end());
let result = tab.chat.message(tx.clone(), &formatted_prompt).await;
let result = tab
.chat
.message(tx.clone(), db_conn, &formatted_prompt)
.await;

match result {
Ok(_) => {
Expand Down
16 changes: 4 additions & 12 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@ use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::thread;

use rusqlite::{
params, Connection, Error as SqliteError, Result as SqliteResult,
};
use serde::{Deserialize, Serialize};

use super::schema::{
Attachment, AttachmentData, Conversation, Exchange, Message,
};
use rusqlite::{params, Error as SqliteError, Result as SqliteResult};

pub struct DatabaseConnector {
connection: rusqlite::Connection,
Expand Down Expand Up @@ -46,7 +39,7 @@ impl DatabaseConnector {

match stmt {
Ok(ref mut stmt) => {
let result: Result<Vec<(String, String)>, rusqlite::Error> =
let result: Result<Vec<(String, String)>, SqliteError> =
stmt.query_map([], |row| {
Ok((row.get(0)?, row.get(1)?))
})?
Expand Down Expand Up @@ -136,14 +129,13 @@ impl DatabaseConnector {
queue.push_back(sql);
}

pub fn process_queue(&mut self) -> Result<(), rusqlite::Error> {
pub fn process_queue(&mut self) -> Result<(), SqliteError> {
// Lock the queue and start a transaction for items in the queue
let mut queue = self.operation_queue.lock().unwrap();
let tx = self.connection.transaction()?;

while let Some(sql) = queue.pop_front() {
tx.execute(&sql, [])?;
}

tx.commit()?;
Ok(())
}
Expand Down
24 changes: 23 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
use std::path::PathBuf;
use std::sync::{Arc, Mutex};

use rusqlite::Error as SqliteError;

mod connector;
mod schema;
mod store;

pub use connector::DatabaseConnector;
pub use schema::{
ConversationId, Exchange, InMemoryDatabase, Message, ModelId,
};
pub use store::ConversationDatabaseStore;

pub use super::PromptRole;

pub struct ConversationDatabase {
pub store: Arc<Mutex<ConversationDatabaseStore>>,
pub cache: Arc<Mutex<InMemoryDatabase>>,
}

impl ConversationDatabase {
pub fn new(sqlite_file: &PathBuf) -> Result<Self, SqliteError> {
Ok(Self {
store: Arc::new(Mutex::new(ConversationDatabaseStore::new(
sqlite_file,
)?)),
cache: Arc::new(Mutex::new(InMemoryDatabase::new())),
})
}
}
17 changes: 13 additions & 4 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;

use serde::{Deserialize, Serialize};

use super::store::ConversationDatabaseStore;
use super::PromptRole;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down Expand Up @@ -135,6 +135,7 @@ impl InMemoryDatabase {
pub fn new_conversation(
&mut self,
name: &str,
db_store: &Arc<Mutex<ConversationDatabaseStore>>,
parent_id: Option<ConversationId>,
) -> ConversationId {
let new_id = self.new_conversation_id();
Expand All @@ -149,17 +150,25 @@ impl InMemoryDatabase {
updated_at: 0, // not using timestamps for now, stick with 0 for now
is_deleted: false,
};
self.add_conversation(conversation);
self.add_conversation(db_store, conversation);
new_id
}

pub fn add_model(&mut self, model: Model) {
self.models.insert(model.model_id, model);
}

pub fn add_conversation(&mut self, conversation: Conversation) {
pub fn add_conversation(
&mut self,
db_store: &Arc<Mutex<ConversationDatabaseStore>>,
conversation: Conversation,
) {
let mut store_lock = db_store.lock().unwrap();
store_lock.store_new_conversation(&conversation);
let result = store_lock.commit_queued_operations();
eprintln!("Commit result: {:?}", result);

self.conversations.insert(conversation.id, conversation);
// TODO: push to DB
}

pub fn add_exchange(&mut self, exchange: Exchange) {
Expand Down
Loading

0 comments on commit f1d204d

Please sign in to comment.