Skip to content

Commit

Permalink
simplify ConversationDatabaseStore, add additional display queries
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 17, 2024
1 parent c163737 commit 40b71f4
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 225 deletions.
6 changes: 6 additions & 0 deletions lumni/src/apps/api/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub enum ApplicationError {
Unexpected(String),
Runtime(String),
InvalidCredentials(String),
InvalidInput(String),
NotFound(String),
ServerConfigurationError(String),
HttpClientError(HttpClientError),
IoError(std::io::Error),
Expand Down Expand Up @@ -105,6 +107,10 @@ impl fmt::Display for ApplicationError {
ApplicationError::InvalidCredentials(s) => {
write!(f, "InvalidCredentials: {}", s)
}
ApplicationError::InvalidInput(s) => {
write!(f, "InvalidInput: {}", s)
}
ApplicationError::NotFound(s) => write!(f, "NotFound: {}", s),
ApplicationError::ServerConfigurationError(s) => {
write!(f, "ServerConfigurationError: {}", s)
}
Expand Down
79 changes: 28 additions & 51 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use tokio::signal;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, timeout, Duration};

use super::chat::{ChatSession, ConversationDatabase};
use super::chat::{ChatSession, ConversationDatabaseStore};
use super::server::{ModelServer, PromptInstruction, ServerTrait};
use super::session::{AppSession, TabSession};
use super::tui::{
Expand All @@ -40,7 +40,7 @@ const CHANNEL_QUEUE_SIZE: usize = 32;
async fn prompt_app<B: Backend>(
terminal: &mut Terminal<B>,
mut app_session: AppSession<'_>,
db_conn: ConversationDatabase,
db_conn: ConversationDatabaseStore,
) -> Result<(), ApplicationError> {
let defaults = app_session.get_defaults().clone();

Expand Down Expand Up @@ -252,7 +252,7 @@ async fn prompt_app<B: Backend>(
async fn finalize_response(
chat: &mut ChatSession,
tab_ui: &mut TabUi<'_>,
db_conn: &ConversationDatabase,
db_conn: &ConversationDatabaseStore,
tokens_predicted: Option<usize>,
color_scheme: &ColorScheme,
) -> Result<(), ApplicationError> {
Expand Down Expand Up @@ -285,12 +285,19 @@ fn parse_cli_arguments(spec: ApplicationSpec) -> Command {
Command::new("db")
.about("Query the conversation database")
.arg(
Arg::new("recent")
.long("recent")
.short('r')
.help("Get recent conversations")
Arg::new("list")
.long("list")
.short('l')
.help("List recent conversations")
.num_args(0..=1)
.default_value("20"),
.value_name("LIMIT"),
)
.arg(
Arg::new("id")
.long("id")
.short('i')
.help("Fetch a specific conversation by ID")
.num_args(1),
),
)
.arg(
Expand Down Expand Up @@ -330,19 +337,21 @@ pub async fn run_cli(
let config_dir =
env.get_config_dir().expect("Config directory not defined");
let sqlite_file = config_dir.join("chat.db");
let db_conn = ConversationDatabase::new(&sqlite_file)?;
let db_conn = ConversationDatabaseStore::new(&sqlite_file)?;

if let Some(db_matches) = matches.subcommand_matches("db") {
if db_matches.contains_id("recent") {
let limit: usize = db_matches
.get_one::<String>("recent")
.unwrap()
.parse()
.unwrap_or(20);
return query_recent_conversations(&db_conn, limit).await;
if db_matches.contains_id("list") {
let limit = match db_matches.get_one::<String>("list") {
Some(value) => value.parse().unwrap_or(20),
None => 20, // Default value when --list is used without a value
};
return db_conn.print_conversation_list(limit).await;
} else if let Some(id_value) = db_matches.get_one::<String>("id") {
return db_conn.print_conversation_by_id(id_value).await;
} else {
return db_conn.print_last_conversation().await;
}
}

// optional arguments
let instruction = matches.get_one::<String>("system").cloned();
let assistant = matches.get_one::<String>("assistant").cloned();
Expand Down Expand Up @@ -381,7 +390,7 @@ pub async fn run_cli(

async fn interactive_mode(
app_session: AppSession<'_>,
db_conn: ConversationDatabase,
db_conn: ConversationDatabaseStore,
) -> Result<(), ApplicationError> {
println!("Interactive mode detected. Starting interactive session:");
let mut stdout = io::stdout().lock();
Expand Down Expand Up @@ -427,7 +436,7 @@ async fn interactive_mode(

async fn process_non_interactive_input(
chat: ChatSession,
_db_conn: ConversationDatabase,
_db_conn: ConversationDatabaseStore,
) -> Result<(), ApplicationError> {
let chat = Arc::new(Mutex::new(chat));
let stdin = tokio::io::stdin();
Expand Down Expand Up @@ -572,35 +581,3 @@ async fn send_prompt<'a>(
}
Ok(())
}

async fn query_recent_conversations(
db_conn: &ConversationDatabase,
limit: usize,
) -> Result<(), ApplicationError> {
let recent_conversations = db_conn
.get_recent_conversations_with_last_exchange_and_messages(limit)?;

for (conversation, last_exchange_with_messages) in recent_conversations {
println!(
"Conversation: {} (ID: {})",
conversation.name, conversation.id.0
);
println!("Updated at: {}", conversation.updated_at);

if let Some((exchange, messages)) = last_exchange_with_messages {
println!("Last exchange: {}", exchange.created_at);
println!("Model: {:?}", exchange.model_id);
println!("Messages:");
for message in messages {
println!(" Role: {}", message.role);
println!(" Content: {}", message.content);
println!(" ---");
}
} else {
println!("No exchanges yet");
}
println!("===============================");
}

Ok(())
}
78 changes: 78 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/display.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use lumni::api::error::ApplicationError;

use super::schema::{
Conversation, ConversationId, Message,
};
use super::ConversationDatabaseStore;
pub use crate::external as lumni;

impl ConversationDatabaseStore {
pub async fn print_last_conversation(
&self,
) -> Result<(), ApplicationError> {
if let Some((conversation, messages)) =
self.fetch_conversation(None, None)?
{
display_conversation_with_messages(&conversation, &messages);
} else {
println!("No conversations found.");
}
Ok(())
}

pub async fn print_conversation_list(
&self,
limit: usize,
) -> Result<(), ApplicationError> {
let conversations = self.fetch_conversation_list(limit)?;
for conversation in conversations {
println!(
"ID: {}, Name: {}, Updated: {}",
conversation.id.0, conversation.name, conversation.updated_at
);
}
Ok(())
}

pub async fn print_conversation_by_id(
&self,
id: &str,
) -> Result<(), ApplicationError> {
let conversation_id = ConversationId(id.parse().map_err(|_| {
ApplicationError::NotFound(
format!("Conversation {id} not found in database"),
)})?);

if let Some((conversation, messages)) =
self.fetch_conversation(Some(conversation_id), None)?
{
display_conversation_with_messages(&conversation, &messages);
} else {
println!("Conversation not found.");
}
Ok(())
}
}

fn display_conversation_with_messages(
conversation: &Conversation,
messages: &[Message],
) {
println!(
"Conversation: {} (ID: {})",
conversation.name, conversation.id.0
);
println!("Updated at: {}", conversation.updated_at);

if !messages.is_empty() {
println!("Messages:");
for message in messages {
println!(" Role: {}", message.role);
println!(" Content: {}", message.content);
println!(" ---");
}
} else {
println!("No messages");
}
println!("===============================");
}
80 changes: 2 additions & 78 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,85 +1,9 @@
use std::path::PathBuf;
use std::sync::{Arc, Mutex};

use rusqlite::Error as SqliteError;

mod connector;
mod display;
mod schema;
mod store;

pub use schema::{
Attachment, Conversation, ConversationCache, ConversationId, Exchange,
ExchangeId, Message, ModelId,
};
pub use schema::{ConversationCache, Exchange, ExchangeId, ConversationId, Message, ModelId};
pub use store::ConversationDatabaseStore;

pub use super::PromptRole;

pub struct ConversationDatabase {
pub store: Arc<Mutex<ConversationDatabaseStore>>,
}

impl ConversationDatabase {
pub fn new(sqlite_file: &PathBuf) -> Result<Self, SqliteError> {
Ok(Self {
store: Arc::new(Mutex::new(ConversationDatabaseStore::new(
sqlite_file,
)?)),
})
}

pub fn new_conversation(
&self,
name: &str,
parent_id: Option<ConversationId>,
) -> Result<ConversationId, SqliteError> {
let mut store = self.store.lock().unwrap();
let conversation = Conversation {
id: ConversationId(-1), // Temporary ID
name: name.to_string(),
metadata: serde_json::Value::Null,
parent_conversation_id: parent_id,
fork_exchange_id: None,
schema_version: 1,
created_at: 0,
updated_at: 0,
is_deleted: false,
};
let conversation_id = store.put_new_conversation(&conversation)?;
Ok(conversation_id)
}

pub fn finalize_exchange(
&self,
exchange: &Exchange,
cache: &ConversationCache,
) -> Result<(), SqliteError> {
let messages = cache.get_exchange_messages(exchange.id);
let attachments = messages
.iter()
.flat_map(|message| cache.get_message_attachments(message.id))
.collect::<Vec<_>>();
let owned_messages: Vec<Message> =
messages.into_iter().cloned().collect();
let owned_attachments: Vec<Attachment> =
attachments.into_iter().cloned().collect();
let mut store = self.store.lock().unwrap();
store.put_finalized_exchange(
exchange,
&owned_messages,
&owned_attachments,
)?;
Ok(())
}

pub fn get_recent_conversations_with_last_exchange_and_messages(
&self,
limit: usize,
) -> Result<
Vec<(Conversation, Option<(Exchange, Vec<Message>)>)>,
SqliteError,
> {
let mut store = self.store.lock().unwrap();
store.get_recent_conversations_with_last_exchange_and_messages(limit)
}
}
Loading

0 comments on commit 40b71f4

Please sign in to comment.