Skip to content

Commit

Permalink
wip - update profiles to primary use id, name is now optional and can…
Browse files Browse the repository at this point in the history
… be non-unique, disable profile helper as its redundant
  • Loading branch information
aprxi committed Aug 20, 2024
1 parent 4f3717d commit e042763
Show file tree
Hide file tree
Showing 18 changed files with 739 additions and 523 deletions.
42 changes: 22 additions & 20 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use super::chat::{
PromptInstruction, ThreadedChatSession,
};
use super::cli::{
handle_db_subcommand, handle_profile_subcommand, parse_cli_arguments,
handle_db_subcommand,
handle_profile_subcommand,
parse_cli_arguments,
};
use super::server::{ModelServer, ServerTrait};
use crate::external as lumni;
Expand All @@ -43,25 +45,25 @@ async fn create_prompt_instruction(
let mut profile_handler = db_conn.get_profile_handler(None);

// Handle --profile option
if let Some(profile_name) =
matches.and_then(|m| m.get_one::<String>("profile"))
{
profile_handler.set_profile_name(profile_name.to_string());
} else {
// Use default profile if set
if let Some(default_profile) =
profile_handler.get_default_profile().await?
{
profile_handler.set_profile_name(default_profile);
}
}

// Check if a profile is set
if profile_handler.get_profile_name().is_none() {
return Err(ApplicationError::InvalidInput(
"No profile set".to_string(),
));
}
// if let Some(profile_name) =
// matches.and_then(|m| m.get_one::<String>("profile"))
// {
// profile_handler.set_profile_name(profile_name.to_string());
// } else {
// // Use default profile if set
// if let Some(default_profile) =
// profile_handler.get_default_profile().await?
// {
// profile_handler.set_profile_name(default_profile);
// }
// }
//
// // Check if a profile is set
// if profile_handler.get_profile_name().is_none() {
// return Err(ApplicationError::InvalidInput(
// "No profile set".to_string(),
// ));
// }

// Get model_backend
let model_backend =
Expand Down
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub use lumni::Timestamp;
pub use model::{ModelIdentifier, ModelSpec};
use serde::{Deserialize, Serialize};
pub use store::ConversationDatabase;
pub use user_profile::{MaskMode, UserProfileDbHandler};
pub use user_profile::{MaskMode, UserProfile, UserProfileDbHandler};

pub use super::ConversationCache;
use super::{ModelBackend, ModelServer, PromptRole};
Expand Down
3 changes: 2 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ CREATE TABLE metadata (

CREATE TABLE user_profiles (
id INTEGER PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
options TEXT NOT NULL, -- JSON string
is_default INTEGER DEFAULT 0,
encryption_key_id INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (encryption_key_id) REFERENCES encryption_keys(id)
);

Expand Down
6 changes: 3 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tokio::sync::Mutex as TokioMutex;
use super::connector::{DatabaseConnector, DatabaseOperationError};
use super::conversations::ConversationDbHandler;
use super::encryption::EncryptionHandler;
use super::user_profile::UserProfileDbHandler;
use super::user_profile::{UserProfile, UserProfileDbHandler};
use super::{
Conversation, ConversationId, ConversationStatus, Message, MessageId,
ModelIdentifier,
Expand Down Expand Up @@ -59,10 +59,10 @@ impl ConversationDatabase {

pub fn get_profile_handler(
&self,
profile_name: Option<String>,
profile: Option<UserProfile>,
) -> UserProfileDbHandler {
UserProfileDbHandler::new(
profile_name,
profile,
self.db.clone(),
self.encryption_handler.clone(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,24 @@ use std::path::PathBuf;
use lumni::api::error::ApplicationError;
use rusqlite::{params, OptionalExtension};

use super::{DatabaseOperationError, EncryptionHandler, UserProfileDbHandler};
use super::{
DatabaseOperationError, EncryptionHandler, UserProfile,
UserProfileDbHandler,
};
use crate::external as lumni;

impl UserProfileDbHandler {
pub async fn profile_exists(
&self,
profile_name: &str,
profile: &UserProfile,
) -> Result<bool, ApplicationError> {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
let count: i64 = tx
.query_row(
"SELECT COUNT(*) FROM user_profiles WHERE name = ?",
params![profile_name],
"SELECT COUNT(*) FROM user_profiles WHERE id = ? AND name \
= ?",
params![profile.id, profile.name],
|row| row.get(0),
)
.map_err(DatabaseOperationError::SqliteError)?;
Expand All @@ -25,34 +29,86 @@ impl UserProfileDbHandler {
.map_err(ApplicationError::from)
}

pub async fn get_profile_by_id(
&self,
id: i64,
) -> Result<Option<UserProfile>, ApplicationError> {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
tx.query_row(
"SELECT id, name FROM user_profiles WHERE id = ?",
params![id],
|row| {
Ok(UserProfile {
id: row.get(0)?,
name: row.get(1)?,
})
},
)
.optional()
.map_err(|e| DatabaseOperationError::SqliteError(e))
})
.map_err(ApplicationError::from)
}

pub async fn get_profiles_by_name(
&self,
name: &str,
) -> Result<Vec<UserProfile>, ApplicationError> {
let mut db = self.db.lock().await;

db.process_queue_with_result(|tx| {
let mut stmt = tx
.prepare("SELECT id, name FROM user_profiles WHERE name = ?")
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
let profiles = stmt
.query_map(params![name], |row| {
Ok(UserProfile {
id: row.get(0)?,
name: row.get(1)?,
})
})
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?
.collect::<Result<Vec<UserProfile>, _>>()
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
Ok(profiles)
})
.map_err(ApplicationError::from)
}

pub async fn delete_profile(
&self,
profile_name: &str,
profile: &UserProfile,
) -> Result<(), ApplicationError> {
let mut db = self.db.lock().await;

db.process_queue_with_result(|tx| {
tx.execute(
"DELETE FROM user_profiles WHERE name = ?",
params![profile_name],
"DELETE FROM user_profiles WHERE id = ? AND name = ?",
params![profile.id, profile.name],
)
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
Ok(())
})
.map_err(ApplicationError::from)
}

pub async fn list_profiles(&self) -> Result<Vec<String>, ApplicationError> {
pub async fn list_profiles(&self) -> Result<Vec<UserProfile>, ApplicationError> {
let mut db = self.db.lock().await;

db.process_queue_with_result(|tx| {
let mut stmt = tx
.prepare("SELECT name FROM user_profiles")
.prepare("SELECT id, name FROM user_profiles ORDER BY created_at DESC")
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
let profiles = stmt
.query_map([], |row| row.get(0))
.query_map([], |row| {
Ok(UserProfile {
id: row.get(0)?,
name: row.get(1)?,
})
})
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?
.collect::<Result<Vec<String>, _>>()
.collect::<Result<Vec<UserProfile>, _>>()
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
Ok(profiles)
})
Expand All @@ -61,13 +117,18 @@ impl UserProfileDbHandler {

pub async fn get_default_profile(
&self,
) -> Result<Option<String>, ApplicationError> {
) -> Result<Option<UserProfile>, ApplicationError> {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
tx.query_row(
"SELECT name FROM user_profiles WHERE is_default = 1",
"SELECT id, name FROM user_profiles WHERE is_default = 1",
[],
|row| row.get(0),
|row| {
Ok(UserProfile {
id: row.get(0)?,
name: row.get(1)?,
})
},
)
.optional()
.map_err(|e| DatabaseOperationError::SqliteError(e))
Expand All @@ -77,46 +138,51 @@ impl UserProfileDbHandler {

pub async fn set_default_profile(
&self,
profile_name: &str,
profile: &UserProfile,
) -> Result<(), ApplicationError> {
let mut db = self.db.lock().await;
eprintln!("Setting default profile to {}", profile_name);
eprintln!(
"Setting default profile to {} (ID: {})",
profile.name, profile.id
);
db.process_queue_with_result(|tx| {
tx.execute(
"UPDATE user_profiles SET is_default = 0 WHERE is_default = 1",
[],
)
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
tx.execute(
"UPDATE user_profiles SET is_default = 1 WHERE name = ?",
params![profile_name],
"UPDATE user_profiles SET is_default = 1 WHERE id = ? AND \
name = ?",
params![profile.id, profile.name],
)
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
Ok(())
})
.map_err(ApplicationError::from)
}

pub async fn get_profile_list(
&self,
) -> Result<Vec<String>, ApplicationError> {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
let mut stmt =
tx.prepare("SELECT name FROM user_profiles ORDER BY name ASC")?;
let profiles = stmt
.query_map([], |row| row.get(0))?
.collect::<Result<Vec<String>, _>>()
.map_err(|e| DatabaseOperationError::SqliteError(e))?;
Ok(profiles)
})
.map_err(|e| match e {
DatabaseOperationError::SqliteError(sqlite_err) => {
ApplicationError::DatabaseError(sqlite_err.to_string())
}
DatabaseOperationError::ApplicationError(app_err) => app_err,
})
}
// TODO: should return Vec<UserProfile>
//pub async fn get_profile_list(
// &self,
//) -> Result<Vec<String>, ApplicationError> {
// let mut db = self.db.lock().await;
// db.process_queue_with_result(|tx| {
// let mut stmt =
// tx.prepare("SELECT name FROM user_profiles ORDER BY name ASC")?;
// let profiles = stmt
// .query_map([], |row| row.get(0))?
// .collect::<Result<Vec<String>, _>>()
// .map_err(|e| DatabaseOperationError::SqliteError(e))?;
// Ok(profiles)
// })
// .map_err(|e| match e {
// DatabaseOperationError::SqliteError(sqlite_err) => {
// ApplicationError::DatabaseError(sqlite_err.to_string())
// }
// DatabaseOperationError::ApplicationError(app_err) => app_err,
// })
//}

pub async fn register_encryption_key(
&self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;

use dirs::home_dir;
use lumni::api::error::{ApplicationError, EncryptionError};
use serde_json::{json, Value as JsonValue};
use sha2::{Digest, Sha256};

use super::{EncryptionHandler, UserProfileDbHandler};
use super::{DatabaseOperationError, EncryptionHandler, UserProfileDbHandler};
use crate::external as lumni;

impl UserProfileDbHandler {
Expand Down Expand Up @@ -117,6 +118,49 @@ impl UserProfileDbHandler {
}
}

pub async fn get_or_create_encryption_key(
&mut self,
profile_name: &str,
) -> Result<i64, ApplicationError> {
let has_encryption_handler = self.encryption_handler.is_some();
let mut created_encryption_handler: Option<EncryptionHandler> = None;

let mut db = self.db.lock().await;
let key_id = db
.process_queue_with_result(|tx| {
if has_encryption_handler {
let key_path = self
.encryption_handler
.as_ref()
.unwrap()
.get_key_path();
let key_hash =
EncryptionHandler::get_private_key_hash(&key_path)?;
self.get_or_insert_encryption_key(tx, &key_path, &key_hash)
} else {
let (new_encryption_handler, key_path, key_hash) =
Self::generate_encryption_key(profile_name)?;
let key_id = self.get_or_insert_encryption_key(
tx, &key_path, &key_hash,
)?;
created_encryption_handler = Some(new_encryption_handler);
Ok(key_id)
}
})
.map_err(|e| match e {
DatabaseOperationError::SqliteError(sqlite_err) => {
ApplicationError::DatabaseError(sqlite_err.to_string())
}
DatabaseOperationError::ApplicationError(app_err) => app_err,
})?;

if let Some(new_handler) = created_encryption_handler {
self.encryption_handler = Some(Arc::new(new_handler));
}

Ok(key_id)
}

pub fn generate_encryption_key(
profile_name: &str,
) -> Result<(EncryptionHandler, PathBuf, String), ApplicationError> {
Expand Down
Loading

0 comments on commit e042763

Please sign in to comment.