Skip to content

Commit

Permalink
automatically create database file on startup
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 10, 2024
1 parent d915ae0 commit 892ca49
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 122 deletions.
15 changes: 15 additions & 0 deletions lumni/src/apps/api/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt;

use rusqlite::Error as SqliteError;

// export the http client error via api::error
pub use crate::http::client::HttpClientError;

Expand Down Expand Up @@ -28,6 +30,7 @@ pub enum ApplicationError {
ServerConfigurationError(String),
HttpClientError(HttpClientError),
IoError(std::io::Error),
DatabaseError(String),
NotImplemented(String),
NotReady(String),
}
Expand Down Expand Up @@ -109,6 +112,9 @@ impl fmt::Display for ApplicationError {
write!(f, "HttpClientError: {}", e)
}
ApplicationError::IoError(e) => write!(f, "IoError: {}", e),
ApplicationError::DatabaseError(s) => {
write!(f, "DatabaseError: {}", s)
}
ApplicationError::NotImplemented(s) => {
write!(f, "NotImplemented: {}", s)
}
Expand All @@ -130,3 +136,12 @@ impl From<std::io::Error> for ApplicationError {
ApplicationError::IoError(error)
}
}

impl From<SqliteError> for ApplicationError {
fn from(error: SqliteError) -> Self {
ApplicationError::DatabaseError(format!(
"Database operation failed: {}",
error
))
}
}
7 changes: 5 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,17 @@ fn parse_cli_arguments(spec: ApplicationSpec) -> Command {

pub async fn run_cli(
spec: ApplicationSpec,
_env: ApplicationEnv,
env: ApplicationEnv,
args: Vec<String>,
) -> Result<(), ApplicationError> {
let app = parse_cli_arguments(spec);
let matches = app.try_get_matches_from(args).unwrap_or_else(|e| {
e.exit();
});

let config_dir =
env.get_config_dir().expect("Config directory not defined");

// optional arguments
let instruction = matches.get_one::<String>("system").cloned();
let assistant = matches.get_one::<String>("assistant").cloned();
Expand All @@ -305,7 +308,7 @@ pub async fn run_cli(
match poll(Duration::from_millis(0)) {
Ok(_) => {
// Starting interactive session
let mut app_session = AppSession::new();
let mut app_session = AppSession::new(config_dir)?;
app_session.add_tab(chat_session);
interactive_mode(app_session).await
}
Expand Down
150 changes: 150 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/connector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::thread;

use rusqlite::{
params, Connection, Error as SqliteError, Result as SqliteResult,
};
use serde::{Deserialize, Serialize};

use super::schema::{
Attachment, AttachmentData, Conversation, Exchange, Message,
};

pub struct DatabaseConnector {
connection: rusqlite::Connection,
operation_queue: Arc<Mutex<VecDeque<String>>>,
}

impl DatabaseConnector {
const SCHEMA_SQL: &'static str = include_str!("schema.sql");
const EXPECTED_VERSION: &'static str = "1";
const EXPECTED_IDENTIFIER: &'static str = "prompt.chat";

pub fn new(sqlite_file: &PathBuf) -> Result<Self, SqliteError> {
let connection = rusqlite::Connection::open(sqlite_file)?;
let operation_queue = Arc::new(Mutex::new(VecDeque::new()));

let mut conn = DatabaseConnector {
connection,
operation_queue,
};
conn.initialize_schema()?;
Ok(conn)
}

fn initialize_schema(&mut self) -> Result<(), SqliteError> {
let transaction = self.connection.transaction()?;

// Check if the metadata table exists and has the correct version and identifier
let (version, identifier, need_initialization) = {
let mut stmt = transaction.prepare(
"SELECT key, value FROM metadata WHERE key IN \
('schema_version', 'schema_identifier')",
);

match stmt {
Ok(ref mut stmt) => {
let result: Result<Vec<(String, String)>, rusqlite::Error> =
stmt.query_map([], |row| {
Ok((row.get(0)?, row.get(1)?))
})?
.collect();

let mut version = None;
let mut identifier = None;

match result {
Ok(rows) if !rows.is_empty() => {
for (key, value) in rows {
match key.as_str() {
"schema_version" => version = Some(value),
"schema_identifier" => {
identifier = Some(value)
}
_ => {}
}
}
eprintln!(
"Version: {:?}, Identifier: {:?}",
version, identifier
);
(version, identifier, false)
}
Ok(_) | Err(SqliteError::QueryReturnedNoRows) => {
eprintln!(
"No schema version or identifier found. Need \
initialization."
);
(None, None, true)
}
Err(e) => return Err(e),
}
}
Err(e) => match e {
SqliteError::SqliteFailure(_, Some(ref error_string))
if error_string.contains("no such table") =>
{
eprintln!(
"No metadata table found. Need to create the \
schema."
);
(None, None, true)
}
_ => return Err(e),
},
}
};

if need_initialization {
eprintln!("Initializing database schema...");
transaction.execute_batch(Self::SCHEMA_SQL)?;
transaction.execute(
"INSERT INTO metadata (key, value) VALUES ('schema_version', \
?1), ('schema_identifier', ?2)",
params![Self::EXPECTED_VERSION, Self::EXPECTED_IDENTIFIER],
)?;
eprintln!(
"Schema version and identifier metadata initialized \
successfully."
);
} else if let (Some(v), Some(i)) = (version, identifier) {
if v == Self::EXPECTED_VERSION && i == Self::EXPECTED_IDENTIFIER {
eprintln!("Database schema is up to date (version {}).", v);
} else {
eprintln!(
"Found existing schema version {} for app {}. Expected \
version {} for {}.",
v,
i,
Self::EXPECTED_VERSION,
Self::EXPECTED_IDENTIFIER
);
return Err(SqliteError::SqliteFailure(
rusqlite::ffi::Error::new(1), // 1 is SQLITE_ERROR
Some("Schema version mismatch".to_string()),
));
}
}
transaction.commit()?;
Ok(())
}

pub fn queue_operation(&self, sql: String) {
let mut queue = self.operation_queue.lock().unwrap();
queue.push_back(sql);
}

pub fn process_queue(&mut self) -> Result<(), rusqlite::Error> {
let mut queue = self.operation_queue.lock().unwrap();
let tx = self.connection.transaction()?;

while let Some(sql) = queue.pop_front() {
tx.execute(&sql, [])?;
}

tx.commit()?;
Ok(())
}
}
9 changes: 9 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
mod connector;
mod schema;

pub use connector::DatabaseConnector;
pub use schema::{
ConversationId, Exchange, InMemoryDatabase, Message, ModelId,
};

pub use super::PromptRole;
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;

use rusqlite;
use serde::{Deserialize, Serialize};

use super::PromptRole;
Expand Down Expand Up @@ -83,7 +82,7 @@ pub struct Attachment {
pub message_id: MessageId,
pub conversation_id: ConversationId,
pub exchange_id: ExchangeId,
pub data: AttachmentData, // file_uri or file_data
pub data: AttachmentData, // file_uri or file_data
pub file_type: String,
pub metadata: Option<serde_json::Value>,
pub created_at: i64,
Expand Down Expand Up @@ -143,7 +142,7 @@ impl InMemoryDatabase {
id: new_id,
name: name.to_string(),
metadata: serde_json::Value::Null,
parent_conversation_id: parent_id,
parent_conversation_id: parent_id,
fork_exchange_id: None,
schema_version: 1,
created_at: 0, // not using timestamps for now, stick with 0 for now
Expand Down Expand Up @@ -180,7 +179,9 @@ impl InMemoryDatabase {
}

pub fn update_message(&mut self, updated_message: Message) {
if let Some(existing_message) = self.messages.get_mut(&updated_message.id) {
if let Some(existing_message) =
self.messages.get_mut(&updated_message.id)
{
*existing_message = updated_message;
}
}
Expand Down Expand Up @@ -273,54 +274,3 @@ impl InMemoryDatabase {
.unwrap_or_default()
}
}

pub struct Database {
in_memory: Arc<Mutex<InMemoryDatabase>>,
sqlite_conn: Arc<Mutex<rusqlite::Connection>>,
}

impl Database {
pub fn new(sqlite_path: &str) -> rusqlite::Result<Self> {
let sqlite_conn = rusqlite::Connection::open(sqlite_path)?;
Ok(Database {
in_memory: Arc::new(Mutex::new(InMemoryDatabase::new())),
sqlite_conn: Arc::new(Mutex::new(sqlite_conn)),
})
}

pub fn save_in_background(&self) {
let in_memory = Arc::clone(&self.in_memory);
let sqlite_conn = Arc::clone(&self.sqlite_conn);

thread::spawn(move || {
let data = in_memory.lock().unwrap();
let conn = sqlite_conn.lock().unwrap();
// saving to SQLite here
});
}

pub fn add_model(&self, model: Model) {
let mut data = self.in_memory.lock().unwrap();
data.add_model(model);
}

pub fn add_conversation(&self, conversation: Conversation) {
let mut data = self.in_memory.lock().unwrap();
data.add_conversation(conversation);
}

pub fn add_exchange(&self, exchange: Exchange) {
let mut data = self.in_memory.lock().unwrap();
data.add_exchange(exchange);
}

pub fn add_message(&self, message: Message) {
let mut data = self.in_memory.lock().unwrap();
data.add_message(message);
}

pub fn add_attachment(&self, attachment: Attachment) {
let mut data = self.in_memory.lock().unwrap();
data.add_attachment(attachment);
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@

PRAGMA foreign_keys = ON;

CREATE TABLE metadata (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);

CREATE TABLE models (
model_id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
Expand Down Expand Up @@ -76,4 +84,3 @@ CREATE INDEX idx_model_id ON exchanges(model_id);
CREATE INDEX idx_conversation_created_at ON exchanges(conversation_id, created_at);
CREATE INDEX idx_attachment_message ON attachments(message_id);

PRAGMA foreign_keys = ON;
Loading

0 comments on commit 892ca49

Please sign in to comment.