Skip to content

Commit

Permalink
wip - refactor profile/ provider edit
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 29, 2024
1 parent 02143f1 commit 0ddd980
Show file tree
Hide file tree
Showing 16 changed files with 1,721 additions and 1,804 deletions.
4 changes: 3 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ pub use store::ConversationDatabase;
pub use user_profile::{MaskMode, UserProfile, UserProfileDbHandler};

pub use super::ConversationCache;
use super::{ModelBackend, ModelServer, PromptRole};
use super::{
AdditionalSetting, ModelBackend, ModelServer, PromptRole, ProviderConfig,
};
use crate::external as lumni;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down
8 changes: 8 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/model.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::{self, Display, Formatter};

use lazy_static::lazy_static;
use lumni::api::error::ApplicationError;
use regex::Regex;
Expand All @@ -14,6 +16,12 @@ lazy_static! {
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelIdentifier(pub String);

impl Display for ModelIdentifier {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}

impl ModelIdentifier {
pub fn new(identifier_str: &str) -> Result<Self, ApplicationError> {
if IDENTIFIER_REGEX.is_match(identifier_str) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ impl UserProfileDbHandler {
}
}

fn process_value(
pub fn process_value(
&self,
value: &JsonValue,
encryption_mode: EncryptionMode,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;

use dirs::home_dir;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ mod content_operations;
mod database_operations;
mod encryption_operations;
mod profile_operations;
mod provider_config;
use std::sync::Arc;

use lumni::api::error::ApplicationError;
use lumni::api::error::{ApplicationError, EncryptionError};
use rusqlite::{params, OptionalExtension};
use serde_json::{json, Value as JsonValue};
use tokio::sync::Mutex as TokioMutex;

use super::connector::{DatabaseConnector, DatabaseOperationError};
use super::encryption::EncryptionHandler;
use super::{ModelBackend, ModelServer, ModelSpec};
use super::{
AdditionalSetting, ModelBackend, ModelServer, ModelSpec, ProviderConfig,
};
use crate::external as lumni;

#[derive(Debug, Clone, PartialEq)]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
use std::collections::HashMap;
use std::path::PathBuf;

use super::*;

impl UserProfileDbHandler {
pub async fn save_provider_config(
&mut self,
config: &ProviderConfig,
) -> Result<(), ApplicationError> {
let encryption_key_id = self.get_or_create_encryption_key().await?;

let mut db = self.db.lock().await;

db.process_queue_with_result(|tx| {
let additional_settings: HashMap<String, String> = config
.additional_settings
.iter()
.map(|(k, v)| (k.clone(), v.value.clone()))
.collect();

let additional_settings_json =
serde_json::to_string(&additional_settings).map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(format!(
"Failed to serialize additional settings: {}",
e
)),
)
})?;

// Use encrypt_value method to encrypt the settings
let encrypted_value = self
.encrypt_value(&JsonValue::String(additional_settings_json))
.map_err(DatabaseOperationError::ApplicationError)?;

// Extract the encrypted content
let encrypted_settings =
encrypted_value["content"].as_str().ok_or_else(|| {
DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(
"Failed to extract encrypted content".to_string(),
),
)
})?;

if let Some(id) = config.id {
// Update existing config
tx.execute(
"UPDATE provider_configs SET
name = ?, provider_type = ?, model_identifier = ?,
additional_settings = ?
WHERE id = ?",
params![
config.name,
config.provider_type,
config.model_identifier,
encrypted_settings,
id as i64
],
)?;
} else {
// Insert new config
tx.execute(
"INSERT INTO provider_configs
(name, provider_type, model_identifier,
additional_settings, encryption_key_id)
VALUES (?, ?, ?, ?, ?)",
params![
config.name,
config.provider_type,
config.model_identifier,
encrypted_settings,
encryption_key_id
],
)?;
}

Ok(())
})
.map_err(ApplicationError::from)
}
}

impl UserProfileDbHandler {
pub async fn load_provider_configs(
&self,
) -> Result<Vec<ProviderConfig>, ApplicationError> {
let mut db = self.db.lock().await;

db.process_queue_with_result(|tx| {
let mut stmt = tx.prepare(
"SELECT pc.id, pc.name, pc.provider_type, \
pc.model_identifier, pc.additional_settings,
ek.file_path as encryption_key_path
FROM provider_configs pc
JOIN encryption_keys ek ON pc.encryption_key_id = ek.id",
)?;

let configs = stmt.query_map([], |row| {
let id: i64 = row.get(0)?;
let name: String = row.get(1)?;
let provider_type: String = row.get(2)?;
let model_identifier: Option<String> = row.get(3)?;
let additional_settings_encrypted: String = row.get(4)?;
let encryption_key_path: String = row.get(5)?;

Ok((
id,
name,
provider_type,
model_identifier,
additional_settings_encrypted,
encryption_key_path,
))
})?;

let mut result = Vec::new();

for config in configs {
let (
id,
name,
provider_type,
model_identifier,
additional_settings_encrypted,
encryption_key_path,
) = config?;

// Load the specific encryption handler for this config
let encryption_handler = EncryptionHandler::new_from_path(
&PathBuf::from(encryption_key_path),
)
.map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::EncryptionError(
EncryptionError::KeyGenerationFailed(e.to_string()),
),
)
})?
.ok_or_else(|| {
DatabaseOperationError::ApplicationError(
ApplicationError::EncryptionError(
EncryptionError::InvalidKey(
"Failed to create encryption handler"
.to_string(),
),
),
)
})?;

// Decrypt the additional settings using the specific encryption handler
let decrypted_value = encryption_handler
.decrypt_string(
&additional_settings_encrypted,
"", // The actual key should be retrieved from the encryption handler
)
.map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::EncryptionError(
EncryptionError::DecryptionFailed(
e.to_string(),
),
),
)
})?;

let additional_settings: HashMap<String, AdditionalSetting> =
serde_json::from_str(&decrypted_value).map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(format!(
"Failed to deserialize decrypted settings: {}",
e
)),
)
})?;

result.push(ProviderConfig {
id: Some(id as usize),
name,
provider_type,
model_identifier,
additional_settings,
});
}

Ok(result)
})
.map_err(ApplicationError::from)
}
}
8 changes: 4 additions & 4 deletions lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use super::server::{
CompletionResponse, ModelBackend, ModelServer, ServerManager,
};
use super::tui::{
draw_ui, AppUi, ColorScheme, ColorSchemeType, CommandLineAction,
ConversationEvent, KeyEventHandler, ModalAction, ModalWindowType,
PromptAction, SimpleString, TextLine, TextSegment, TextWindowTrait,
UserEvent, WindowEvent, WindowKind,
draw_ui, AdditionalSetting, AppUi, ColorScheme, ColorSchemeType,
CommandLineAction, ConversationEvent, KeyEventHandler, ModalAction,
ModalWindowType, PromptAction, ProviderConfig, SimpleString, TextLine,
TextSegment, TextWindowTrait, UserEvent, WindowEvent, WindowKind,
};

// gets PERSONAS from the generated code
Expand Down
8 changes: 6 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/tui/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ pub use events::{
PromptAction, UserEvent, WindowEvent,
};
use lumni::api::error::ApplicationError;
pub use modals::{ModalAction, ModalWindowTrait, ModalWindowType};
pub use modals::{
AdditionalSetting, ModalAction, ModalWindowTrait, ModalWindowType,
ProviderConfig,
};
pub use ui::AppUi;
pub use window::{
CommandLine, ResponseWindow, SimpleString, TextArea, TextLine, TextSegment,
Expand All @@ -23,7 +26,8 @@ pub use window::{

use super::chat::db::{
Conversation, ConversationDatabase, ConversationDbHandler, ConversationId,
ConversationStatus, MaskMode, UserProfile, UserProfileDbHandler,
ConversationStatus, MaskMode, ModelIdentifier, ModelSpec, UserProfile,
UserProfileDbHandler,
};
use super::chat::{
App, NewConversation, PromptInstruction, ThreadedChatSession,
Expand Down
9 changes: 5 additions & 4 deletions lumni/src/apps/builtin/llm/prompt/src/tui/modals/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ mod profiles;
use async_trait::async_trait;
pub use conversations::ConversationListModal;
pub use filebrowser::FileBrowserModal;
pub use profiles::ProfileEditModal;
pub use profiles::{AdditionalSetting, ProfileEditModal, ProviderConfig};
use ratatui::layout::Rect;
use ratatui::Frame;
use widgets::FileBrowserWidget;

pub use super::widgets;
use super::{
ApplicationError, Conversation, ConversationDbHandler, ConversationStatus,
KeyTrack, MaskMode, ModelServer, PromptInstruction, ServerTrait,
SimpleString, TextArea, TextWindowTrait, ThreadedChatSession, UserEvent,
UserProfile, UserProfileDbHandler, WindowEvent, SUPPORTED_MODEL_ENDPOINTS,
KeyTrack, MaskMode, ModelIdentifier, ModelServer, ModelSpec,
PromptInstruction, ServerTrait, SimpleString, TextArea, TextWindowTrait,
ThreadedChatSession, UserEvent, UserProfile, UserProfileDbHandler,
WindowEvent, SUPPORTED_MODEL_ENDPOINTS,
};

#[derive(Debug, Clone, PartialEq)]
Expand Down
Loading

0 comments on commit 0ddd980

Please sign in to comment.