From 114244a3029cebf21560e6fc755aa483ea05a4a2 Mon Sep 17 00:00:00 2001 From: Anthony Potappel Date: Wed, 14 Aug 2024 07:53:25 +0200 Subject: [PATCH] allow updating profile values through helper --- .../src/cli/subcommands/profile_helper.rs | 335 +++++++++++------- .../llm/prompt/src/server/bedrock/mod.rs | 2 +- .../llm/prompt/src/server/llama/mod.rs | 7 + .../apps/builtin/llm/prompt/src/server/mod.rs | 5 +- .../llm/prompt/src/server/ollama/mod.rs | 7 + .../llm/prompt/src/server/openai/mod.rs | 2 +- 6 files changed, 218 insertions(+), 140 deletions(-) diff --git a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs index 3c62c97e..383347e5 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs @@ -8,7 +8,7 @@ use lumni::api::error::ApplicationError; use serde_json::{json, Map, Value as JsonValue}; use super::{ - EncryptionHandler, ModelServer, ModelSpec, ServerTrait, + EncryptionHandler, MaskMode, ModelServer, ModelSpec, ServerTrait, UserProfileDbHandler, SUPPORTED_MODEL_ENDPOINTS, }; use crate::external as lumni; @@ -24,7 +24,7 @@ pub async fn interactive_profile_edit( db_handler: &mut UserProfileDbHandler, profile_name_to_update: Option, ) -> Result<(), ApplicationError> { - println!("Welcome to the profile creation wizard!"); + println!("Welcome to the profile creation/editing wizard!"); let profile_name = match &profile_name_to_update { Some(name) => name.clone(), @@ -32,14 +32,40 @@ pub async fn interactive_profile_edit( }; db_handler.set_profile_name(profile_name.clone()); - let profile_type = select_profile_type()?; - let mut settings = JsonValue::Object(Map::new()); + let (mut settings, is_updating) = match profile_name_to_update { + Some(name) => match db_handler + .get_profile_settings(&name, MaskMode::Unmask) + .await + { + Ok(existing_settings) => (existing_settings, true), + Err(ApplicationError::DatabaseError(_)) => { + println!( + "Profile '{}' does not exist. Creating a new profile.", + name + ); + (JsonValue::Object(Map::new()), false) + } + Err(e) => return Err(e), + }, + None => (JsonValue::Object(Map::new()), false), + }; + + let profile_type = if is_updating { + settings["__PROFILE_TYPE"] + .as_str() + .unwrap_or("Custom") + .to_string() + } else { + let selected_type = select_profile_type()?; + settings["__PROFILE_TYPE"] = JsonValue::String(selected_type.clone()); + selected_type + }; - if profile_type != "Custom" { + if !is_updating && profile_type != "Custom" { let model_server = ModelServer::from_str(&profile_type)?; if let Some(selected_model) = select_model(&model_server).await? { - settings["MODEL_IDENTIFIER"] = + settings["__MODEL_IDENTIFIER"] = JsonValue::String(selected_model.identifier.0); } else { println!("No model selected. Skipping model selection."); @@ -49,23 +75,31 @@ pub async fn interactive_profile_edit( let server_settings = model_server.get_profile_settings(); if let JsonValue::Object(map) = server_settings { for (key, value) in map { - settings[key] = value; + if settings.get(&key).is_none() { + settings[key] = value; + } } } } - if ask_yes_no("Do you want to set a project directory?", true)? { - let dir = get_project_directory()?; - settings["PROJECT_DIRECTORY"] = JsonValue::String(dir); + if !is_updating || settings.get("__PROJECT_DIRECTORY").is_none() { + if ask_yes_no("Do you want to set a project directory?", true)? { + let dir = get_project_directory()?; + settings["__PROJECT_DIRECTORY"] = JsonValue::String(dir); + } } - if profile_type == "Custom" { + collect_profile_settings(&mut settings, is_updating)?; + + // Allow adding custom keys, but default to No if updating or if a specific profile type is chosen + if ask_yes_no( + "Do you want to add custom keys?", + !is_updating && profile_type == "Custom", + )? { collect_custom_settings(&mut settings)?; - } else { - collect_profile_settings(&mut settings, &profile_type)?; } - if ask_for_custom_ssh_key()? { + if !is_updating && ask_for_custom_ssh_key()? { setup_custom_encryption(db_handler).await?; } @@ -76,16 +110,162 @@ pub async fn interactive_profile_edit( println!( "Profile '{}' {} successfully!", profile_name, - if profile_name_to_update.is_some() { - "updated" - } else { - "created" - } + if is_updating { "updated" } else { "created" } ); Ok(()) } +fn collect_profile_settings( + settings: &mut JsonValue, + is_updating: bool, +) -> Result<(), ApplicationError> { + if let JsonValue::Object(ref mut map) = settings { + for (key, value) in map.clone().iter() { + if key.starts_with("__") { + // Skip protected values when editing + if is_updating { + continue; + } + // For new profiles, just display the value of protected settings + println!("{}: {}", key, value); + continue; + } + + let current_value = parse_value(value); + + let prompt = if is_updating { + format!( + "Current value for '{}' is '{}'. Enter new value (or \ + press Enter to keep current): ", + key, current_value + ) + } else { + format!("Enter value for '{}': ", key) + }; + + print!("{}", prompt); + io::stdout().flush()?; + let mut new_value = String::new(); + io::stdin().read_line(&mut new_value)?; + let new_value = new_value.trim(); + + if !new_value.is_empty() { + match value { + JsonValue::Object(obj) + if obj.contains_key("content") + && obj.contains_key("encryption_key") => + { + // This is a predefined encrypted value, maintain its structure + map.insert( + key.to_string(), + json!({ + "content": new_value, + "encryption_key": "", + }), + ); + } + JsonValue::Null => { + // This is a predefined non-encrypted value + map.insert( + key.to_string(), + JsonValue::String(new_value.to_string()), + ); + } + _ => { + // For custom keys or updating existing values + if !is_updating && should_encrypt_value()? { + map.insert( + key.to_string(), + json!({ + "content": new_value, + "encryption_key": "", + }), + ); + } else { + map.insert( + key.to_string(), + JsonValue::String(new_value.to_string()), + ); + } + } + } + } + } + } + + Ok(()) +} + +fn collect_custom_settings( + settings: &mut JsonValue, +) -> Result<(), ApplicationError> { + loop { + print!("Enter a new custom key (or press Enter to finish): "); + io::stdout().flush()?; + let mut key = String::new(); + io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + break; + } + + if key.starts_with("__") { + println!( + "Keys starting with '__' are reserved. Please choose a \ + different key." + ); + continue; + } + + if let JsonValue::Object(ref map) = settings { + if map.get(key).is_some() { + println!( + "Key '{}' already exists. Please choose a different key.", + key + ); + continue; + } + } + + let value = get_value_for_key(key)?; + let encrypt = should_encrypt_value()?; + + if let JsonValue::Object(ref mut map) = settings { + if encrypt { + map.insert( + key.to_string(), + json!({ + "content": value, + "encryption_key": "", + }), + ); + } else { + map.insert(key.to_string(), JsonValue::String(value)); + } + } + } + + Ok(()) +} + +fn parse_value(value: &JsonValue) -> String { + match value { + JsonValue::Object(obj) + if obj.contains_key("content") + && obj.contains_key("encryption_key") => + { + obj.get("content") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string() + } + JsonValue::Null => "[Not set]".to_string(), + _ => value.as_str().unwrap_or_default().to_string(), + } +} + async fn select_model( model_server: &ModelServer, ) -> Result, ApplicationError> { @@ -174,7 +354,8 @@ fn select_model_from_list( fn handle_not_ready() -> Result { loop { print!( - "Press Enter to retry, 'q' to quit, or 's' to skip model selection: " + "Press Enter to retry, 'q' to quit, or 's' to skip model \ + selection: " ); io::stdout().flush()?; let mut input = String::new(); @@ -221,120 +402,6 @@ fn select_profile_type() -> Result { } } -fn collect_custom_settings( - settings: &mut JsonValue, -) -> Result<(), ApplicationError> { - loop { - print!("Enter a custom key (or press Enter to finish): "); - io::stdout().flush()?; - let mut key = String::new(); - io::stdin().read_line(&mut key)?; - let key = key.trim(); - - if key.is_empty() { - break; - } - - let value = get_value_for_key(key)?; - let encrypt = should_encrypt_value()?; - - if let JsonValue::Object(ref mut map) = settings { - if encrypt { - map.insert( - key.to_string(), - json!({ - "content": value, - "encryption_key": "", - }), - ); - } else { - map.insert(key.to_string(), JsonValue::String(value)); - } - } - } - - Ok(()) -} - -fn collect_profile_settings( - settings: &mut JsonValue, - profile_type: &str, -) -> Result<(), ApplicationError> { - if let JsonValue::Object(ref mut map) = settings { - let mut updates = Vec::new(); - let mut removals = Vec::new(); - - for (key, value) in map.iter() { - if *value == JsonValue::Null { - if let Some(new_value) = get_optional_value(key)? { - updates.push((key.clone(), JsonValue::String(new_value))); - } else { - removals.push(key.clone()); - } - } else if let JsonValue::Object(obj) = value { - if obj.contains_key("content") - && obj.contains_key("encryption_key") - { - if let Some(new_value) = get_secure_value(key)? { - updates.push(( - key.clone(), - json!({ - "content": new_value, - "encryption_key": "", - }), - )); - } else { - removals.push(key.clone()); - } - } - } - } - - // Apply updates - for (key, value) in updates { - map.insert(key, value); - } - - // Apply removals - for key in removals { - map.remove(&key); - } - } - - if profile_type == "Custom" { - loop { - print!("Enter a custom key (or press Enter to finish): "); - io::stdout().flush()?; - let mut key = String::new(); - io::stdin().read_line(&mut key)?; - let key = key.trim(); - - if key.is_empty() { - break; - } - - let value = get_value_for_key(key)?; - let encrypt = should_encrypt_value()?; - - if let JsonValue::Object(ref mut map) = settings { - if encrypt { - map.insert( - key.to_string(), - json!({ - "content": value, - "encryption_key": "", - }), - ); - } else { - map.insert(key.to_string(), JsonValue::String(value)); - } - } - } - } - - Ok(()) -} - fn ask_yes_no(question: &str, default: bool) -> Result { let default_option = if default { "Y/n" } else { "y/N" }; print!("{} [{}]: ", question, default_option); diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs index b7db2469..a31d8eb6 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs @@ -123,7 +123,7 @@ impl ServerTrait for Bedrock { fn get_profile_settings(&self) -> JsonValue { json!({ - "MODEL_SERVER": "bedrock", + "__MODEL_SERVER": "bedrock", "AWS_PROFILE": null, "AWS_REGION": null }) diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs index 5e23a85f..a856d5ed 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs @@ -7,6 +7,7 @@ use bytes::Bytes; use formatters::{ModelFormatter, ModelFormatterTrait}; use lumni::api::error::ApplicationError; use serde::{Deserialize, Serialize}; +use serde_json::{json, Value as JsonValue}; use tokio::sync::{mpsc, oneshot}; use url::Url; @@ -137,6 +138,12 @@ impl ServerTrait for Llama { &self.spec } + fn get_profile_settings(&self) -> JsonValue { + json!({ + "__MODEL_SERVER": "llama", + }) + } + fn process_response( &mut self, response: Bytes, diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs index 0a85566e..9b792675 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs @@ -179,10 +179,7 @@ impl ServerTrait for ModelServer { #[async_trait] pub trait ServerTrait: Send + Sync { fn get_spec(&self) -> &dyn ServerSpecTrait; - - fn get_profile_settings(&self) -> JsonValue { - JsonValue::Object(serde_json::Map::new()) - } + fn get_profile_settings(&self) -> JsonValue; async fn initialize_with_model( &mut self, diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs index fdd06b81..443314c5 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/ollama/mod.rs @@ -3,6 +3,7 @@ use std::error::Error; use async_trait::async_trait; use bytes::Bytes; use serde::{Deserialize, Serialize}; +use serde_json::{json, Value as JsonValue}; use tokio::sync::{mpsc, oneshot}; use url::Url; @@ -69,6 +70,12 @@ impl ServerTrait for Ollama { &self.spec } + fn get_profile_settings(&self) -> JsonValue { + json!({ + "__MODEL_SERVER": "ollama", + }) + } + async fn initialize_with_model( &mut self, handler: &ConversationDbHandler, diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs index 236e8aee..d4a93dcc 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs @@ -94,7 +94,7 @@ impl ServerTrait for OpenAI { fn get_profile_settings(&self) -> JsonValue { json!({ - "MODEL_SERVER": "openai", + "__MODEL_SERVER": "openai", "OPENAI_API_KEY": { "content": "", "encryption_key": "",