Skip to content

Commit

Permalink
cli profile usability improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 9, 2024
1 parent 1cc7adc commit 16b9d2d
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fs;
use std::path::PathBuf;

use base64::engine::general_purpose;
use base64::{Engine, Engine as _};
use base64::Engine;
use lumni::api::error::{ApplicationError, EncryptionError};
use ring::aead;
use ring::rand::{SecureRandom, SystemRandom};
Expand Down
67 changes: 30 additions & 37 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profiles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,40 +90,55 @@ impl UserProfileDbHandler {
new_settings: &JsonValue,
) -> Result<(), ApplicationError> {
let mut db = self.db.lock().await;

db.process_queue_with_result(|tx| {
let current_data: Option<(String, Option<String>)> = tx
// First, check if the profile exists and get its current data and is_default status
let current_data: Option<(String, Option<String>, bool)> = tx
.query_row(
"SELECT options, ssh_key_hash FROM user_profiles WHERE \
name = ?",
"SELECT options, ssh_key_hash, is_default FROM \
user_profiles WHERE name = ?",
params![profile_name],
|row| Ok((row.get(0)?, row.get(1)?)),
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
)
.optional()
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

let (merged_settings, ssh_key_hash) =
if let Some((current_json, existing_hash)) = current_data {
let current: JsonValue =
let (merged_settings, ssh_key_hash, is_default) =
if let Some((current_json, existing_hash, is_default)) =
current_data
{
let mut current: JsonValue =
serde_json::from_str(&current_json).map_err(|e| {
ApplicationError::InvalidInput(format!(
"Invalid JSON: {}",
e
))
})?;
let merged = self.merge_settings(&current, new_settings)?;
(merged, existing_hash)

// Merge settings, handling deletions
if let Some(current_obj) = current.as_object_mut() {
if let Some(new_obj) = new_settings.as_object() {
for (key, value) in new_obj {
if value.is_null() {
current_obj.remove(key); // Remove the key if the new value is null
} else {
current_obj
.insert(key.clone(), value.clone()); // Otherwise, update or add the key-value pair
}
}
}
}

(current, existing_hash, is_default)
} else {
let ssh_key_hash = self
.encryption_handler
.as_ref()
.and_then(|_| self.calculate_ssh_key_hash().ok());
(new_settings.clone(), ssh_key_hash)
(new_settings.clone(), ssh_key_hash, false) // New profiles are not default by default
};

let processed_settings =
self.process_settings(&merged_settings, true, false)?;

let json_string = serde_json::to_string(&processed_settings)
.map_err(|e| {
ApplicationError::InvalidInput(format!(
Expand All @@ -132,10 +147,11 @@ impl UserProfileDbHandler {
))
})?;

// Use INSERT OR REPLACE, but explicitly set the is_default status
tx.execute(
"INSERT OR REPLACE INTO user_profiles (name, options, \
ssh_key_hash) VALUES (?, ?, ?)",
params![profile_name, json_string, ssh_key_hash],
ssh_key_hash, is_default) VALUES (?, ?, ?, ?)",
params![profile_name, json_string, ssh_key_hash, is_default],
)
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

Expand All @@ -144,29 +160,6 @@ impl UserProfileDbHandler {
.map_err(ApplicationError::from)
}

fn merge_settings(
&self,
current: &JsonValue,
new: &JsonValue,
) -> Result<JsonValue, ApplicationError> {
match (current, new) {
(JsonValue::Object(current_obj), JsonValue::Object(new_obj)) => {
let mut merged = current_obj.clone();
for (key, value) in new_obj {
merged.insert(
key.clone(),
self.merge_settings(
current_obj.get(key).unwrap_or(&JsonValue::Null),
value,
)?,
);
}
Ok(JsonValue::Object(merged))
}
(_, new) => Ok(new.clone()),
}
}

fn process_settings(
&self,
value: &JsonValue,
Expand Down
Loading

0 comments on commit 16b9d2d

Please sign in to comment.