Skip to content

Commit

Permalink
wip - refactor profile/ provider II
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 30, 2024
1 parent 0ddd980 commit 861e8b8
Show file tree
Hide file tree
Showing 24 changed files with 2,515 additions and 1,744 deletions.
9 changes: 5 additions & 4 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ pub use lumni::Timestamp;
pub use model::{ModelIdentifier, ModelSpec};
use serde::{Deserialize, Serialize};
pub use store::ConversationDatabase;
pub use user_profile::{MaskMode, UserProfile, UserProfileDbHandler};
pub use user_profile::{
MaskMode, ProviderConfig, ProviderConfigOptions, UserProfile,
UserProfileDbHandler,
};

pub use super::ConversationCache;
use super::{
AdditionalSetting, ModelBackend, ModelServer, PromptRole, ProviderConfig,
};
use super::{ModelBackend, ModelServer, PromptRole};
use crate::external as lumni;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down
2 changes: 0 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ CREATE TABLE user_profiles (
options TEXT NOT NULL, -- JSON string
is_default INTEGER DEFAULT 0,
encryption_key_id INTEGER NOT NULL,
provider_config_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (encryption_key_id) REFERENCES encryption_keys(id)
FOREIGN KEY (provider_config_id) REFERENCES provider_configs(id)
);

CREATE TABLE provider_configs (
Expand Down
27 changes: 23 additions & 4 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ mod database_operations;
mod encryption_operations;
mod profile_operations;
mod provider_config;

use std::collections::HashMap;
use std::sync::Arc;

use lumni::api::error::{ApplicationError, EncryptionError};
use rusqlite::{params, OptionalExtension};
use serde_json::{json, Value as JsonValue};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use tokio::sync::Mutex as TokioMutex;

use super::connector::{DatabaseConnector, DatabaseOperationError};
use super::encryption::EncryptionHandler;
use super::{
AdditionalSetting, ModelBackend, ModelServer, ModelSpec, ProviderConfig,
};
use super::{ModelBackend, ModelServer, ModelSpec};
use crate::external as lumni;

#[derive(Debug, Clone, PartialEq)]
Expand All @@ -30,6 +31,24 @@ pub struct UserProfileDbHandler {
encryption_handler: Option<Arc<EncryptionHandler>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub id: Option<usize>,
pub name: String,
pub provider_type: String,
pub model_identifier: Option<String>,
pub additional_settings: HashMap<String, ProviderConfigOptions>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfigOptions {
pub name: String,
pub display_name: String,
pub value: String,
pub is_secure: bool,
pub placeholder: String,
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum EncryptionMode {
Encrypt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,37 @@ impl UserProfileDbHandler {
pub async fn save_provider_config(
&mut self,
config: &ProviderConfig,
) -> Result<(), ApplicationError> {
) -> Result<ProviderConfig, ApplicationError> {
let encryption_key_id = self.get_or_create_encryption_key().await?;

let mut db = self.db.lock().await;

db.process_queue_with_result(|tx| {
let additional_settings: HashMap<String, String> = config
.additional_settings
.iter()
.map(|(k, v)| (k.clone(), v.value.clone()))
.collect();

let additional_settings_json =
serde_json::to_string(&additional_settings).map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(format!(
"Failed to serialize additional settings: {}",
e
)),
)
})?;
let additional_settings_json = serde_json::to_string(
&config.additional_settings,
)
.map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(format!(
"Failed to serialize additional settings: {}",
e
)),
)
})?;

// Use encrypt_value method to encrypt the settings
let encrypted_value = self
.encrypt_value(&JsonValue::String(additional_settings_json))
.map_err(DatabaseOperationError::ApplicationError)?;

// Extract the encrypted content
let encrypted_settings =
encrypted_value["content"].as_str().ok_or_else(|| {
let encrypted_value_json = serde_json::to_string(&encrypted_value)
.map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(
"Failed to extract encrypted content".to_string(),
),
ApplicationError::InvalidInput(format!(
"Failed to serialize encrypted value: {}",
e
)),
)
})?;

if let Some(id) = config.id {
let config_id = if let Some(id) = config.id {
// Update existing config
tx.execute(
"UPDATE provider_configs SET
Expand All @@ -55,10 +48,11 @@ impl UserProfileDbHandler {
config.name,
config.provider_type,
config.model_identifier,
encrypted_settings,
encrypted_value_json,
id as i64
],
)?;
id
} else {
// Insert new config
tx.execute(
Expand All @@ -70,110 +64,104 @@ impl UserProfileDbHandler {
config.name,
config.provider_type,
config.model_identifier,
encrypted_settings,
encrypted_value_json,
encryption_key_id
],
)?;
}

Ok(())
tx.last_insert_rowid() as usize
};

Ok(ProviderConfig {
id: Some(config_id),
name: config.name.clone(),
provider_type: config.provider_type.clone(),
model_identifier: config.model_identifier.clone(),
additional_settings: config.additional_settings.clone(),
})
})
.map_err(ApplicationError::from)
}
}

impl UserProfileDbHandler {
pub async fn load_provider_configs(
&self,
) -> Result<Vec<ProviderConfig>, ApplicationError> {
let mut db = self.db.lock().await;

db.process_queue_with_result(|tx| {
let mut stmt = tx.prepare(
"SELECT pc.id, pc.name, pc.provider_type, \
pc.model_identifier, pc.additional_settings,
ek.file_path as encryption_key_path
"SELECT pc.id, pc.name, pc.provider_type,
pc.model_identifier, pc.additional_settings,
ek.file_path as encryption_key_path, ek.sha256_hash
FROM provider_configs pc
JOIN encryption_keys ek ON pc.encryption_key_id = ek.id",
)?;

let configs = stmt.query_map([], |row| {
let id: i64 = row.get(0)?;
let name: String = row.get(1)?;
let provider_type: String = row.get(2)?;
let model_identifier: Option<String> = row.get(3)?;
let additional_settings_encrypted: String = row.get(4)?;
let encrypted_value_json: String = row.get(4)?;
let encryption_key_path: String = row.get(5)?;

let sha256_hash: String = row.get(6)?;
Ok((
id,
name,
provider_type,
model_identifier,
additional_settings_encrypted,
encrypted_value_json,
encryption_key_path,
sha256_hash,
))
})?;

let mut result = Vec::new();

for config in configs {
let (
id,
name,
provider_type,
model_identifier,
additional_settings_encrypted,
encrypted_value_json,
encryption_key_path,
sha256_hash,
) = config?;

// Load the specific encryption handler for this config
let encryption_handler = EncryptionHandler::new_from_path(
&PathBuf::from(encryption_key_path),
)
.map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::EncryptionError(
EncryptionError::KeyGenerationFailed(e.to_string()),
),
)
})?
.ok_or_else(|| {
DatabaseOperationError::ApplicationError(
ApplicationError::EncryptionError(
EncryptionError::InvalidKey(
"Failed to create encryption handler"
.to_string(),
),
),
)
})?;

// Decrypt the additional settings using the specific encryption handler
let decrypted_value = encryption_handler
.decrypt_string(
&additional_settings_encrypted,
"", // The actual key should be retrieved from the encryption handler
)
.map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::EncryptionError(
EncryptionError::DecryptionFailed(
e.to_string(),
),
),
)
})?;

let additional_settings: HashMap<String, AdditionalSetting> =
serde_json::from_str(&decrypted_value).map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(format!(
"Failed to deserialize decrypted settings: {}",
e
)),
)
})?;
// Create a new EncryptionHandler for this config
let encryption_handler = EncryptionHandler::new_from_path(&PathBuf::from(&encryption_key_path))
.map_err(|e| DatabaseOperationError::ApplicationError(e))?
.ok_or_else(|| DatabaseOperationError::ApplicationError(ApplicationError::EncryptionError(
EncryptionError::InvalidKey("Failed to create encryption handler".to_string())
)))?;

// Verify the encryption key hash
if encryption_handler.get_sha256_hash() != sha256_hash {
return Err(DatabaseOperationError::ApplicationError(ApplicationError::EncryptionError(
EncryptionError::InvalidKey("Encryption key hash mismatch".to_string())
)));
}

let encrypted_value: JsonValue = serde_json::from_str(&encrypted_value_json)
.map_err(|e| DatabaseOperationError::ApplicationError(ApplicationError::InvalidInput(
format!("Failed to deserialize encrypted value: {}", e)
)))?;

// Use the encryption handler to decrypt the value
let decrypted_value = if let (Some(content), Some(encryption_key)) = (
encrypted_value["content"].as_str(),
encrypted_value["encryption_key"].as_str()
) {
encryption_handler.decrypt_string(content, encryption_key)
.map_err(|e| DatabaseOperationError::ApplicationError(e))?
} else {
return Err(DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput("Invalid encrypted value format".to_string())
));
};

let additional_settings: HashMap<String, ProviderConfigOptions> =
serde_json::from_str(&decrypted_value)
.map_err(|e| DatabaseOperationError::ApplicationError(ApplicationError::InvalidInput(
format!("Failed to deserialize decrypted settings: {}", e)
)))?;

result.push(ProviderConfig {
id: Some(id as usize),
Expand All @@ -183,9 +171,33 @@ impl UserProfileDbHandler {
additional_settings,
});
}

Ok(result)
})
.map_err(ApplicationError::from)
}

pub async fn delete_provider_config(
&self,
config_id: usize,
) -> Result<(), ApplicationError> {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
let deleted_rows = tx.execute(
"DELETE FROM provider_configs WHERE id = ?",
params![config_id as i64],
)?;

if deleted_rows == 0 {
Err(DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(format!(
"No provider config found with ID {}",
config_id
)),
))
} else {
Ok(())
}
})
.map_err(ApplicationError::from)
}
}
10 changes: 5 additions & 5 deletions lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod session;
pub use completion_options::ChatCompletionOptions;
pub use conversation::{ConversationCache, NewConversation, PromptInstruction};
use prompt::Prompt;
pub use prompt::{AssistantManager, PromptInstructionBuilder, PromptRole};
pub use prompt::{PromptInstructionBuilder, PromptRole};
pub use session::{prompt_app, App, ChatEvent, ThreadedChatSession};

pub use super::defaults::*;
Expand All @@ -17,10 +17,10 @@ use super::server::{
CompletionResponse, ModelBackend, ModelServer, ServerManager,
};
use super::tui::{
draw_ui, AdditionalSetting, AppUi, ColorScheme, ColorSchemeType,
CommandLineAction, ConversationEvent, KeyEventHandler, ModalAction,
ModalWindowType, PromptAction, ProviderConfig, SimpleString, TextLine,
TextSegment, TextWindowTrait, UserEvent, WindowEvent, WindowKind,
draw_ui, AppUi, ColorScheme, ColorSchemeType, CommandLineAction,
ConversationEvent, KeyEventHandler, ModalAction, ModalWindowType,
PromptAction, SimpleString, TextLine, TextSegment, TextWindowTrait,
UserEvent, WindowEvent, WindowKind,
};

// gets PERSONAS from the generated code
Expand Down
Loading

0 comments on commit 861e8b8

Please sign in to comment.