Skip to content

Commit

Permalink
ensure that if values are secure, not to allow overwrite with plainte…
Browse files Browse the repository at this point in the history
…xt value
  • Loading branch information
aprxi committed Aug 10, 2024
1 parent 16b9d2d commit 8156fcc
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 86 deletions.
179 changes: 98 additions & 81 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profiles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,52 +91,70 @@ impl UserProfileDbHandler {
) -> Result<(), ApplicationError> {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
// First, check if the profile exists and get its current data and is_default status
let current_data: Option<(String, Option<String>, bool)> = tx
let current_data: Option<String> = tx
.query_row(
"SELECT options, ssh_key_hash, is_default FROM \
user_profiles WHERE name = ?",
"SELECT options FROM user_profiles WHERE name = ?",
params![profile_name],
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
|row| row.get(0),
)
.optional()
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

let (merged_settings, ssh_key_hash, is_default) =
if let Some((current_json, existing_hash, is_default)) =
current_data
let merged_settings = if let Some(current_json) = current_data {
let current: JsonValue = serde_json::from_str(&current_json)
.map_err(|e| {
ApplicationError::InvalidInput(format!(
"Invalid JSON: {}",
e
))
})?;

let mut merged = current.clone();
if let (Some(merged_obj), Some(new_obj)) =
(merged.as_object_mut(), new_settings.as_object())
{
let mut current: JsonValue =
serde_json::from_str(&current_json).map_err(|e| {
ApplicationError::InvalidInput(format!(
"Invalid JSON: {}",
e
))
})?;

// Merge settings, handling deletions
if let Some(current_obj) = current.as_object_mut() {
if let Some(new_obj) = new_settings.as_object() {
for (key, value) in new_obj {
if value.is_null() {
current_obj.remove(key); // Remove the key if the new value is null
} else {
current_obj
.insert(key.clone(), value.clone()); // Otherwise, update or add the key-value pair
for (key, new_value) in new_obj {
if new_value.is_null() {
merged_obj.remove(key);
} else {
let current_value = current.get(key);
let is_currently_encrypted = current_value
.map(Self::is_encrypted_value)
.unwrap_or(false);
let is_new_value_marked_for_encryption =
Self::is_marked_for_encryption(new_value);

if is_currently_encrypted {
if is_new_value_marked_for_encryption {
// Update with new secure value
merged_obj
.insert(key.clone(), new_value.clone());
} else if let Some(content) = new_value.as_str()
{
// Re-encrypt the new content
let encrypted =
self.encrypt_value(content)?;
merged_obj.insert(key.clone(), encrypted);
}
// If new_value is not a string, keep the current encrypted value
} else if is_new_value_marked_for_encryption {
// New secure value
merged_obj
.insert(key.clone(), new_value.clone());
} else {
// Non-secure value, update normally
merged_obj
.insert(key.clone(), new_value.clone());
}
}
}
}
merged
} else {
new_settings.clone()
};

(current, existing_hash, is_default)
} else {
let ssh_key_hash = self
.encryption_handler
.as_ref()
.and_then(|_| self.calculate_ssh_key_hash().ok());
(new_settings.clone(), ssh_key_hash, false) // New profiles are not default by default
};

// We use encrypt = true and mask_encrypted = false when storing data
let processed_settings =
self.process_settings(&merged_settings, true, false)?;
let json_string = serde_json::to_string(&processed_settings)
Expand All @@ -147,11 +165,10 @@ impl UserProfileDbHandler {
))
})?;

// Use INSERT OR REPLACE, but explicitly set the is_default status
tx.execute(
"INSERT OR REPLACE INTO user_profiles (name, options, \
ssh_key_hash, is_default) VALUES (?, ?, ?, ?)",
params![profile_name, json_string, ssh_key_hash, is_default],
"INSERT OR REPLACE INTO user_profiles (name, options) VALUES \
(?, ?)",
params![profile_name, json_string],
)
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

Expand All @@ -171,31 +188,28 @@ impl UserProfileDbHandler {
let mut new_obj = Map::new();
for (k, v) in obj {
if encrypt {
if let JsonValue::Object(inner_obj) = v {
if inner_obj.get("secure")
== Some(&JsonValue::Bool(true))
if Self::is_marked_for_encryption(v) {
if let Some(JsonValue::String(content)) =
v.get("content")
{
// Encrypt secure strings
if let Some(JsonValue::String(content)) =
inner_obj.get("value")
{
new_obj.insert(
k.clone(),
self.encrypt_value(
&JsonValue::String(content.clone()),
)?,
);
}
new_obj.insert(
k.clone(),
self.encrypt_value(content)?,
);
} else {
// Don't encrypt regular strings
new_obj.insert(k.clone(), v.clone());
return Err(ApplicationError::InvalidInput(
"Invalid secure string format".to_string(),
));
}
} else if !Self::is_encrypted_value(v) {
// This is a non-secure value, keep it as is
new_obj.insert(k.clone(), v.clone());
} else {
// Don't encrypt non-string values
// This is already an encrypted value, keep it as is
new_obj.insert(k.clone(), v.clone());
}
} else {
// During get operation, handle decryption and masking
// When not encrypting (i.e., retrieving), decrypt if necessary
if Self::is_encrypted_value(v) {
if mask_encrypted {
new_obj.insert(
Expand Down Expand Up @@ -226,33 +240,25 @@ impl UserProfileDbHandler {

fn encrypt_value(
&self,
value: &JsonValue,
content: &str,
) -> Result<JsonValue, ApplicationError> {
if let Some(ref encryption_handler) = self.encryption_handler {
if let JsonValue::String(content) = value {
let (encrypted_content, encryption_key) = encryption_handler
.encrypt_string(content)
.map_err(|e| {
ApplicationError::EncryptionError(
EncryptionError::Other(Box::new(e)),
)
})?;
let (encrypted_content, encryption_key) =
encryption_handler.encrypt_string(content).map_err(|e| {
ApplicationError::EncryptionError(EncryptionError::Other(
Box::new(e),
))
})?;

Ok(JsonValue::Object(Map::from_iter(vec![
(
"content".to_string(),
JsonValue::String(encrypted_content),
),
(
"encryption_key".to_string(),
JsonValue::String(encryption_key),
),
])))
} else {
Ok(value.clone())
}
Ok(JsonValue::Object(Map::from_iter(vec![
("content".to_string(), JsonValue::String(encrypted_content)),
(
"encryption_key".to_string(),
JsonValue::String(encryption_key),
),
])))
} else {
Ok(value.clone())
Ok(JsonValue::String(content.to_string()))
}
}

Expand All @@ -261,7 +267,7 @@ impl UserProfileDbHandler {
value: &JsonValue,
) -> Result<JsonValue, ApplicationError> {
if let Some(ref encryption_handler) = self.encryption_handler {
if let JsonValue::Object(obj) = value {
if let Some(obj) = value.as_object() {
if let (
Some(JsonValue::String(content)),
Some(JsonValue::String(encrypted_key)),
Expand All @@ -282,12 +288,23 @@ impl UserProfileDbHandler {
}

fn is_encrypted_value(value: &JsonValue) -> bool {
if let JsonValue::Object(obj) = value {
if let Some(obj) = value.as_object() {
obj.contains_key("content") && obj.contains_key("encryption_key")
} else {
false
}
}

fn is_marked_for_encryption(value: &JsonValue) -> bool {
if let Some(obj) = value.as_object() {
obj.contains_key("content")
&& obj.get("encryption_key")
== Some(&JsonValue::String("".to_string()))
} else {
false
}
}

pub async fn delete_profile(
&self,
profile_name: &str,
Expand Down
18 changes: 13 additions & 5 deletions lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,13 @@ pub async fn handle_profile_subcommand(
settings[key.to_string()] =
JsonValue::Object(Map::from_iter(vec![
(
"value".to_string(),
"content".to_string(),
JsonValue::String(value.to_string()),
),
("secure".to_string(), JsonValue::Bool(true)),
(
"encryption_key".to_string(),
JsonValue::String("".to_string()),
),
]));
} else {
settings[key.to_string()] =
Expand Down Expand Up @@ -231,10 +234,15 @@ pub async fn handle_profile_subcommand(
}

Some(("show-default", show_default_matches)) => {
if let Some(default_profile) = db_handler.get_default_profile().await? {
if let Some(default_profile) =
db_handler.get_default_profile().await?
{
println!("Default profile: {}", default_profile);
let show_decrypted = show_default_matches.get_flag("show-decrypted");
let settings = db_handler.get_profile_settings(&default_profile, !show_decrypted).await?;
let show_decrypted =
show_default_matches.get_flag("show-decrypted");
let settings = db_handler
.get_profile_settings(&default_profile, !show_decrypted)
.await?;
println!("Settings:");
for (key, value) in settings.as_object().unwrap() {
println!(" {}: {}", key, value);
Expand Down

0 comments on commit 8156fcc

Please sign in to comment.