From 8156fcc82cd97353d07cc44b8eada53135ff494a Mon Sep 17 00:00:00 2001 From: Anthony Potappel Date: Sat, 10 Aug 2024 05:54:42 +0200 Subject: [PATCH] ensure that if values are secure, not to allow overwrite with plaintext value --- .../prompt/src/chat/db/user_profiles/mod.rs | 179 ++++++++++-------- .../llm/prompt/src/cli/subcommands/profile.rs | 18 +- 2 files changed, 111 insertions(+), 86 deletions(-) diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profiles/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profiles/mod.rs index bab7504..39a9a2e 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profiles/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profiles/mod.rs @@ -91,52 +91,70 @@ impl UserProfileDbHandler { ) -> Result<(), ApplicationError> { let mut db = self.db.lock().await; db.process_queue_with_result(|tx| { - // First, check if the profile exists and get its current data and is_default status - let current_data: Option<(String, Option, bool)> = tx + let current_data: Option = tx .query_row( - "SELECT options, ssh_key_hash, is_default FROM \ - user_profiles WHERE name = ?", + "SELECT options FROM user_profiles WHERE name = ?", params![profile_name], - |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)), + |row| row.get(0), ) .optional() .map_err(|e| ApplicationError::DatabaseError(e.to_string()))?; - let (merged_settings, ssh_key_hash, is_default) = - if let Some((current_json, existing_hash, is_default)) = - current_data + let merged_settings = if let Some(current_json) = current_data { + let current: JsonValue = serde_json::from_str(¤t_json) + .map_err(|e| { + ApplicationError::InvalidInput(format!( + "Invalid JSON: {}", + e + )) + })?; + + let mut merged = current.clone(); + if let (Some(merged_obj), Some(new_obj)) = + (merged.as_object_mut(), new_settings.as_object()) { - let mut current: JsonValue = - serde_json::from_str(¤t_json).map_err(|e| { - ApplicationError::InvalidInput(format!( - "Invalid JSON: {}", - e - )) - })?; - - // Merge settings, handling deletions - if let Some(current_obj) = current.as_object_mut() { - if let Some(new_obj) = new_settings.as_object() { - for (key, value) in new_obj { - if value.is_null() { - current_obj.remove(key); // Remove the key if the new value is null - } else { - current_obj - .insert(key.clone(), value.clone()); // Otherwise, update or add the key-value pair + for (key, new_value) in new_obj { + if new_value.is_null() { + merged_obj.remove(key); + } else { + let current_value = current.get(key); + let is_currently_encrypted = current_value + .map(Self::is_encrypted_value) + .unwrap_or(false); + let is_new_value_marked_for_encryption = + Self::is_marked_for_encryption(new_value); + + if is_currently_encrypted { + if is_new_value_marked_for_encryption { + // Update with new secure value + merged_obj + .insert(key.clone(), new_value.clone()); + } else if let Some(content) = new_value.as_str() + { + // Re-encrypt the new content + let encrypted = + self.encrypt_value(content)?; + merged_obj.insert(key.clone(), encrypted); } + // If new_value is not a string, keep the current encrypted value + } else if is_new_value_marked_for_encryption { + // New secure value + merged_obj + .insert(key.clone(), new_value.clone()); + } else { + // Non-secure value, update normally + merged_obj + .insert(key.clone(), new_value.clone()); } } } + } + merged + } else { + new_settings.clone() + }; - (current, existing_hash, is_default) - } else { - let ssh_key_hash = self - .encryption_handler - .as_ref() - .and_then(|_| self.calculate_ssh_key_hash().ok()); - (new_settings.clone(), ssh_key_hash, false) // New profiles are not default by default - }; - + // We use encrypt = true and mask_encrypted = false when storing data let processed_settings = self.process_settings(&merged_settings, true, false)?; let json_string = serde_json::to_string(&processed_settings) @@ -147,11 +165,10 @@ impl UserProfileDbHandler { )) })?; - // Use INSERT OR REPLACE, but explicitly set the is_default status tx.execute( - "INSERT OR REPLACE INTO user_profiles (name, options, \ - ssh_key_hash, is_default) VALUES (?, ?, ?, ?)", - params![profile_name, json_string, ssh_key_hash, is_default], + "INSERT OR REPLACE INTO user_profiles (name, options) VALUES \ + (?, ?)", + params![profile_name, json_string], ) .map_err(|e| ApplicationError::DatabaseError(e.to_string()))?; @@ -171,31 +188,28 @@ impl UserProfileDbHandler { let mut new_obj = Map::new(); for (k, v) in obj { if encrypt { - if let JsonValue::Object(inner_obj) = v { - if inner_obj.get("secure") - == Some(&JsonValue::Bool(true)) + if Self::is_marked_for_encryption(v) { + if let Some(JsonValue::String(content)) = + v.get("content") { - // Encrypt secure strings - if let Some(JsonValue::String(content)) = - inner_obj.get("value") - { - new_obj.insert( - k.clone(), - self.encrypt_value( - &JsonValue::String(content.clone()), - )?, - ); - } + new_obj.insert( + k.clone(), + self.encrypt_value(content)?, + ); } else { - // Don't encrypt regular strings - new_obj.insert(k.clone(), v.clone()); + return Err(ApplicationError::InvalidInput( + "Invalid secure string format".to_string(), + )); } + } else if !Self::is_encrypted_value(v) { + // This is a non-secure value, keep it as is + new_obj.insert(k.clone(), v.clone()); } else { - // Don't encrypt non-string values + // This is already an encrypted value, keep it as is new_obj.insert(k.clone(), v.clone()); } } else { - // During get operation, handle decryption and masking + // When not encrypting (i.e., retrieving), decrypt if necessary if Self::is_encrypted_value(v) { if mask_encrypted { new_obj.insert( @@ -226,33 +240,25 @@ impl UserProfileDbHandler { fn encrypt_value( &self, - value: &JsonValue, + content: &str, ) -> Result { if let Some(ref encryption_handler) = self.encryption_handler { - if let JsonValue::String(content) = value { - let (encrypted_content, encryption_key) = encryption_handler - .encrypt_string(content) - .map_err(|e| { - ApplicationError::EncryptionError( - EncryptionError::Other(Box::new(e)), - ) - })?; + let (encrypted_content, encryption_key) = + encryption_handler.encrypt_string(content).map_err(|e| { + ApplicationError::EncryptionError(EncryptionError::Other( + Box::new(e), + )) + })?; - Ok(JsonValue::Object(Map::from_iter(vec![ - ( - "content".to_string(), - JsonValue::String(encrypted_content), - ), - ( - "encryption_key".to_string(), - JsonValue::String(encryption_key), - ), - ]))) - } else { - Ok(value.clone()) - } + Ok(JsonValue::Object(Map::from_iter(vec![ + ("content".to_string(), JsonValue::String(encrypted_content)), + ( + "encryption_key".to_string(), + JsonValue::String(encryption_key), + ), + ]))) } else { - Ok(value.clone()) + Ok(JsonValue::String(content.to_string())) } } @@ -261,7 +267,7 @@ impl UserProfileDbHandler { value: &JsonValue, ) -> Result { if let Some(ref encryption_handler) = self.encryption_handler { - if let JsonValue::Object(obj) = value { + if let Some(obj) = value.as_object() { if let ( Some(JsonValue::String(content)), Some(JsonValue::String(encrypted_key)), @@ -282,12 +288,23 @@ impl UserProfileDbHandler { } fn is_encrypted_value(value: &JsonValue) -> bool { - if let JsonValue::Object(obj) = value { + if let Some(obj) = value.as_object() { obj.contains_key("content") && obj.contains_key("encryption_key") } else { false } } + + fn is_marked_for_encryption(value: &JsonValue) -> bool { + if let Some(obj) = value.as_object() { + obj.contains_key("content") + && obj.get("encryption_key") + == Some(&JsonValue::String("".to_string())) + } else { + false + } + } + pub async fn delete_profile( &self, profile_name: &str, diff --git a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs index 21bc3f9..4017c5a 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs @@ -145,10 +145,13 @@ pub async fn handle_profile_subcommand( settings[key.to_string()] = JsonValue::Object(Map::from_iter(vec![ ( - "value".to_string(), + "content".to_string(), JsonValue::String(value.to_string()), ), - ("secure".to_string(), JsonValue::Bool(true)), + ( + "encryption_key".to_string(), + JsonValue::String("".to_string()), + ), ])); } else { settings[key.to_string()] = @@ -231,10 +234,15 @@ pub async fn handle_profile_subcommand( } Some(("show-default", show_default_matches)) => { - if let Some(default_profile) = db_handler.get_default_profile().await? { + if let Some(default_profile) = + db_handler.get_default_profile().await? + { println!("Default profile: {}", default_profile); - let show_decrypted = show_default_matches.get_flag("show-decrypted"); - let settings = db_handler.get_profile_settings(&default_profile, !show_decrypted).await?; + let show_decrypted = + show_default_matches.get_flag("show-decrypted"); + let settings = db_handler + .get_profile_settings(&default_profile, !show_decrypted) + .await?; println!("Settings:"); for (key, value) in settings.as_object().unwrap() { println!(" {}: {}", key, value);