Skip to content

Commit

Permalink
add keyhash check
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 10, 2024
1 parent 8d533a7 commit 2ae3ae2
Showing 1 changed file with 149 additions and 87 deletions.
236 changes: 149 additions & 87 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profiles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,45 +178,36 @@ impl UserProfileDbHandler {
.map_err(ApplicationError::from)
}

fn calculate_ssh_key_hash(&self) -> Result<String, ApplicationError> {
if let Some(ref encryption_handler) = self.encryption_handler {
let ssh_private_key = encryption_handler
.get_ssh_private_key()
.map_err(|e| ApplicationError::EncryptionError(e.into()))?;
let mut hasher = Sha256::new();
hasher.update(ssh_private_key);
let result = hasher.finalize();
Ok(general_purpose::STANDARD.encode(result))
} else {
Err(ApplicationError::NotReady(
"No encryption handler available".to_string(),
))
}
}

fn verify_ssh_key_hash(
&self,
stored_hash: &str,
) -> Result<(), ApplicationError> {
let current_hash = self.calculate_ssh_key_hash()?;
if current_hash != stored_hash {
return Err(ApplicationError::InvalidInput(
"SSH key hash mismatch".to_string(),
));
}
Ok(())
}

pub async fn get_profile_settings(
&self,
profile_name: &str,
mask_encrypted: bool,
) -> Result<JsonValue, ApplicationError> {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
self.fetch_and_process_settings(tx, profile_name, mask_encrypted)
})
.map_err(|e| e.into())
let (json_string, ssh_key_hash): (String, Option<String>) = tx
.query_row(
"SELECT options, ssh_key_hash FROM user_profiles WHERE \
name = ?",
params![profile_name],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.map_err(DatabaseOperationError::SqliteError)?;

self.verify_ssh_key_hash(ssh_key_hash.as_deref())?;

let settings: JsonValue = serde_json::from_str(&json_string)
.map_err(|e| {
DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(format!(
"Invalid JSON: {}",
e
)),
)
})?;

Ok(self.process_settings(&settings, false, mask_encrypted))
})?
}

fn fetch_and_process_settings(
Expand All @@ -230,7 +221,7 @@ impl UserProfileDbHandler {
let settings: JsonValue = self.parse_json(&json_string)?;

if let Some(hash) = ssh_key_hash {
self.verify_ssh_key_hash(&hash)
self.verify_ssh_key_hash(Some(&hash))
.map_err(DatabaseOperationError::ApplicationError)?;
}

Expand Down Expand Up @@ -270,15 +261,39 @@ impl UserProfileDbHandler {
) -> Result<JsonValue, ApplicationError> {
match value {
JsonValue::Object(obj) => {
self.process_object(obj, encrypt, mask_encrypted)
let mut new_obj = Map::new();
for (k, v) in obj {
new_obj.insert(
k.clone(),
self.process_value(v, encrypt, mask_encrypted)?,
);
}
Ok(JsonValue::Object(new_obj))
}
JsonValue::Array(arr) => {
self.process_array(arr, encrypt, mask_encrypted)
let new_arr: Result<Vec<JsonValue>, _> = arr
.iter()
.map(|v| self.process_settings(v, encrypt, mask_encrypted))
.collect();
Ok(JsonValue::Array(new_arr?))
}
_ => Ok(value.clone()),
}
}

fn process_value(
&self,
value: &JsonValue,
encrypt: bool,
mask_encrypted: bool,
) -> Result<JsonValue, ApplicationError> {
if encrypt {
self.handle_encryption(value)
} else {
self.handle_decryption(value, mask_encrypted)
}
}

fn process_object(
&self,
obj: &Map<String, JsonValue>,
Expand Down Expand Up @@ -308,33 +323,22 @@ impl UserProfileDbHandler {
Ok(JsonValue::Array(new_arr?))
}

fn process_value(
&self,
value: &JsonValue,
encrypt: bool,
mask_encrypted: bool,
) -> Result<JsonValue, ApplicationError> {
if encrypt {
self.handle_encryption(value)
} else {
self.handle_decryption(value, mask_encrypted)
}
}

fn handle_encryption(
&self,
value: &JsonValue,
) -> Result<JsonValue, ApplicationError> {
if Self::is_marked_for_encryption(value) {
if let Some(JsonValue::String(content)) = value.get("content") {
self.encrypt_value(content)
if self.encryption_handler.is_some() {
self.encrypt_value(content)
} else {
Ok(value.clone()) // Keep as is if no encryption handler
}
} else {
Err(ApplicationError::InvalidInput(
"Invalid secure string format".to_string(),
))
}
} else if !Self::is_encrypted_value(value) {
Ok(value.clone())
} else {
Ok(value.clone())
}
Expand All @@ -346,23 +350,72 @@ impl UserProfileDbHandler {
mask_encrypted: bool,
) -> Result<JsonValue, ApplicationError> {
if Self::is_encrypted_value(value) {
if mask_encrypted {
Ok(JsonValue::String("*****".to_string()))
if self.encryption_handler.is_some() {
if mask_encrypted {
Ok(JsonValue::String("*****".to_string()))
} else {
self.decrypt_value(value)
}
} else {
self.decrypt_value(value)
Ok(JsonValue::String("*****".to_string())) // Always mask if no encryption handler
}
} else {
Ok(value.clone())
}
}

pub async fn create_or_update(
&self,
profile_name: &str,
new_settings: &JsonValue,
) -> Result<(), ApplicationError> {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
self.perform_create_or_update(tx, profile_name, new_settings)
let (existing_options, existing_hash): (
Option<String>,
Option<String>,
) = tx
.query_row(
"SELECT options, ssh_key_hash FROM user_profiles WHERE \
name = ?",
params![profile_name],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.optional()
.map_err(DatabaseOperationError::SqliteError)?
.unwrap_or((None, None));

let current_hash = self.calculate_ssh_key_hash()?;

// Only verify the SSH key hash if both existing and current hashes are present
if let (Some(existing), Some(current)) =
(existing_hash.as_deref(), current_hash.as_deref())
{
if existing != current {
return Err(DatabaseOperationError::ApplicationError(
ApplicationError::InvalidInput(
"SSH key mismatch".to_string(),
),
));
}
}

let merged_settings =
if let Some(existing_options) = existing_options {
self.merge_settings(Some(existing_options), new_settings)?
} else {
new_settings.clone()
};

let processed_settings =
self.process_settings(&merged_settings, true, false)?;
self.save_profile_settings(
tx,
profile_name,
&processed_settings,
current_hash.as_deref(),
)?;
Ok(())
})
.map_err(|e| match e {
DatabaseOperationError::SqliteError(sqlite_err) => {
Expand All @@ -372,35 +425,6 @@ impl UserProfileDbHandler {
})
}

fn perform_create_or_update(
&self,
tx: &Transaction,
profile_name: &str,
new_settings: &JsonValue,
) -> Result<(), DatabaseOperationError> {
let current_data = self.fetch_current_profile_data(tx, profile_name)?;
let merged_settings =
self.merge_settings(current_data, new_settings)?;
let processed_settings =
self.process_settings(&merged_settings, true, false)?;
self.save_profile_settings(tx, profile_name, &processed_settings)?;
Ok(())
}

fn fetch_current_profile_data(
&self,
tx: &Transaction,
profile_name: &str,
) -> Result<Option<String>, DatabaseOperationError> {
tx.query_row(
"SELECT options FROM user_profiles WHERE name = ?",
params![profile_name],
|row| row.get(0),
)
.optional()
.map_err(|e| DatabaseOperationError::SqliteError(e))
}

fn merge_settings(
&self,
current_data: Option<String>,
Expand Down Expand Up @@ -486,6 +510,7 @@ impl UserProfileDbHandler {
tx: &Transaction,
profile_name: &str,
settings: &JsonValue,
ssh_key_hash: Option<&str>,
) -> Result<(), DatabaseOperationError> {
let json_string = serde_json::to_string(settings).map_err(|e| {
DatabaseOperationError::ApplicationError(
Expand All @@ -497,12 +522,49 @@ impl UserProfileDbHandler {
})?;

tx.execute(
"INSERT OR REPLACE INTO user_profiles (name, options) VALUES (?, \
?)",
params![profile_name, json_string],
"INSERT OR REPLACE INTO user_profiles (name, options, \
ssh_key_hash) VALUES (?, ?, ?)",
params![profile_name, json_string, ssh_key_hash],
)
.map_err(DatabaseOperationError::SqliteError)?;

Ok(())
}

fn calculate_ssh_key_hash(
&self,
) -> Result<Option<String>, ApplicationError> {
if let Some(ref encryption_handler) = self.encryption_handler {
let ssh_private_key = encryption_handler
.get_ssh_private_key()
.map_err(|e| ApplicationError::EncryptionError(e.into()))?;
let mut hasher = Sha256::new();
hasher.update(ssh_private_key);
let result = hasher.finalize();
Ok(Some(general_purpose::STANDARD.encode(result)))
} else {
Ok(None)
}
}

fn verify_ssh_key_hash(
&self,
stored_hash: Option<&str>,
) -> Result<(), ApplicationError> {
match (self.calculate_ssh_key_hash()?, stored_hash) {
(Some(current_hash), Some(stored_hash))
if current_hash != stored_hash =>
{
Err(ApplicationError::InvalidInput(
"SSH key hash mismatch".to_string(),
))
}
(Some(_), None) | (None, Some(_)) => {
Err(ApplicationError::InvalidInput(
"Encryption status mismatch".to_string(),
))
}
_ => Ok(()),
}
}
}

0 comments on commit 2ae3ae2

Please sign in to comment.