diff --git a/lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs index f152ccb8..7295f5f5 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs @@ -11,6 +11,7 @@ use super::chat::db::{ ConversationDatabase, EncryptionHandler, EncryptionMode, MaskMode, UserProfileDbHandler, }; +use super::server::{ModelServer, ServerTrait, SUPPORTED_MODEL_ENDPOINTS}; use crate::external as lumni; pub fn parse_cli_arguments(spec: ApplicationSpec) -> Command { diff --git a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/mod.rs index 87bd9fc2..5b42cda8 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/mod.rs @@ -3,6 +3,6 @@ pub mod profile; pub mod profile_helper; use super::{ - ConversationDatabase, EncryptionHandler, EncryptionMode, MaskMode, - UserProfileDbHandler, + ConversationDatabase, EncryptionHandler, MaskMode, ModelServer, + ServerTrait, UserProfileDbHandler, SUPPORTED_MODEL_ENDPOINTS, }; diff --git a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs index 997c5ad5..51107ed9 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs @@ -4,8 +4,8 @@ use clap::{Arg, ArgAction, ArgMatches, Command}; use lumni::api::error::ApplicationError; use serde_json::{json, Map, Value as JsonValue}; -use super::profile_helper::interactive_profile_creation; -use super::{EncryptionMode, MaskMode, UserProfileDbHandler}; +use super::profile_helper::interactive_profile_edit; +use super::{MaskMode, UserProfileDbHandler}; use crate::external as lumni; pub fn create_profile_subcommand() -> Command { @@ -19,7 +19,7 @@ pub fn create_profile_subcommand() -> Command { .subcommand(create_rm_subcommand()) .subcommand(create_set_default_subcommand()) .subcommand(create_show_default_subcommand()) - .subcommand(create_add_profile_subcommand()) + .subcommand(create_edit_subcommand()) .subcommand(create_key_subcommand()) .subcommand(create_export_subcommand()) } @@ -98,8 +98,10 @@ fn create_show_default_subcommand() -> Command { ) } -fn create_add_profile_subcommand() -> Command { - Command::new("add").about("Add a new profile with guided setup") +fn create_edit_subcommand() -> Command { + Command::new("edit") + .about("Add a new profile or edit an existing one with guided setup") + .arg(Arg::new("name").help("Name of the profile to edit (optional)")) } fn create_key_subcommand() -> Command { @@ -339,8 +341,9 @@ pub async fn handle_profile_subcommand( } } - Some(("add", _)) => { - interactive_profile_creation(db_handler).await?; + Some(("edit", edit_matches)) => { + let profile_name = edit_matches.get_one::("name").cloned(); + interactive_profile_edit(db_handler, profile_name).await?; } Some(("key", key_matches)) => match key_matches.subcommand() { diff --git a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs index 51f62b79..3fe50290 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; +use std::env; use std::io::{self, Write}; use std::path::PathBuf; use std::sync::Arc; @@ -6,43 +8,91 @@ use dirs::home_dir; use lumni::api::error::ApplicationError; use serde_json::{json, Map, Value as JsonValue}; -use super::{EncryptionHandler, UserProfileDbHandler}; +use super::{ + EncryptionHandler, ModelServer, ServerTrait, UserProfileDbHandler, + SUPPORTED_MODEL_ENDPOINTS, +}; use crate::external as lumni; -pub async fn interactive_profile_creation( +pub async fn interactive_profile_edit( db_handler: &mut UserProfileDbHandler, + profile_name_to_update: Option, ) -> Result<(), ApplicationError> { println!("Welcome to the profile creation wizard!"); - let profile_name = get_profile_name()?; + let profile_name = match &profile_name_to_update { + Some(name) => name.clone(), + None => get_profile_name()?, + }; db_handler.set_profile_name(profile_name.clone()); - let settings = collect_profile_settings()?; + let profile_type = select_profile_type()?; + let mut settings = get_predefined_settings(&profile_type)?; + + if profile_type == "Custom" { + collect_custom_settings(&mut settings)?; + } else { + collect_profile_settings(&mut settings, &profile_type)?; + } if ask_for_custom_ssh_key()? { setup_custom_encryption(db_handler).await?; } + db_handler .create_or_update(&profile_name, &settings) .await?; - println!("Profile '{}' created successfully!", profile_name); + + println!( + "Profile '{}' {} successfully!", + profile_name, + if profile_name_to_update.is_some() { + "updated" + } else { + "created" + } + ); Ok(()) } -fn get_profile_name() -> Result { - print!("Enter a name for the new profile: "); - io::stdout().flush()?; - let mut profile_name = String::new(); - io::stdin().read_line(&mut profile_name)?; - Ok(profile_name.trim().to_string()) -} +fn select_profile_type() -> Result { + println!("Select a profile type:"); + println!("0. Custom (default)"); + for (index, server_type) in SUPPORTED_MODEL_ENDPOINTS.iter().enumerate() { + println!("{}. {}", index + 1, server_type); + } -fn collect_profile_settings() -> Result { - let mut settings = JsonValue::Object(Map::new()); + loop { + print!( + "Enter your choice (0-{}, or press Enter for Custom): ", + SUPPORTED_MODEL_ENDPOINTS.len() + ); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + if choice.is_empty() { + return Ok("Custom".to_string()); + } + if let Ok(index) = choice.parse::() { + if index == 0 { + return Ok("Custom".to_string()); + } else if index <= SUPPORTED_MODEL_ENDPOINTS.len() { + return Ok(SUPPORTED_MODEL_ENDPOINTS[index - 1].to_string()); + } + } + println!("Invalid choice. Please try again."); + } +} + +fn collect_custom_settings( + settings: &mut JsonValue, +) -> Result<(), ApplicationError> { loop { - print!("Enter a key (or press Enter to finish): "); + print!("Enter a custom key (or press Enter to finish): "); io::stdout().flush()?; let mut key = String::new(); io::stdin().read_line(&mut key)?; @@ -55,19 +105,217 @@ fn collect_profile_settings() -> Result { let value = get_value_for_key(key)?; let encrypt = should_encrypt_value()?; - if encrypt { - settings[key] = json!({ - "content": value, - "encryption_key": "", - }); - } else { - settings[key] = JsonValue::String(value); + if let JsonValue::Object(ref mut map) = settings { + if encrypt { + map.insert( + key.to_string(), + json!({ + "content": value, + "encryption_key": "", + }), + ); + } else { + map.insert(key.to_string(), JsonValue::String(value)); + } + } + } + + Ok(()) +} + +fn get_predefined_settings( + profile_type: &str, +) -> Result { + let mut settings = JsonValue::Object(Map::new()); + + if ask_yes_no("Do you want to set a project directory?", true)? { + let dir = get_project_directory()?; + settings["PROJECT_DIRECTORY"] = JsonValue::String(dir); + } + + if profile_type != "Custom" { + let model_server = ModelServer::from_str(profile_type)?; + let server_settings = model_server.get_profile_settings(); + + if let JsonValue::Object(map) = server_settings { + for (key, value) in map { + settings[key] = value; + } } + } else { + println!( + "Custom profile selected. You can add custom key-value pairs." + ); } Ok(settings) } +fn collect_profile_settings( + settings: &mut JsonValue, + profile_type: &str, +) -> Result<(), ApplicationError> { + if let JsonValue::Object(ref mut map) = settings { + let mut updates = Vec::new(); + let mut removals = Vec::new(); + + for (key, value) in map.iter() { + if *value == JsonValue::Null { + if let Some(new_value) = get_optional_value(key)? { + updates.push((key.clone(), JsonValue::String(new_value))); + } else { + removals.push(key.clone()); + } + } else if let JsonValue::Object(obj) = value { + if obj.contains_key("content") + && obj.contains_key("encryption_key") + { + if let Some(new_value) = get_secure_value(key)? { + updates.push(( + key.clone(), + json!({ + "content": new_value, + "encryption_key": "", + }), + )); + } else { + removals.push(key.clone()); + } + } + } + } + + // Apply updates + for (key, value) in updates { + map.insert(key, value); + } + + // Apply removals + for key in removals { + map.remove(&key); + } + } + + if profile_type == "Custom" { + loop { + print!("Enter a custom key (or press Enter to finish): "); + io::stdout().flush()?; + let mut key = String::new(); + io::stdin().read_line(&mut key)?; + let key = key.trim(); + + if key.is_empty() { + break; + } + + let value = get_value_for_key(key)?; + let encrypt = should_encrypt_value()?; + + if let JsonValue::Object(ref mut map) = settings { + if encrypt { + map.insert( + key.to_string(), + json!({ + "content": value, + "encryption_key": "", + }), + ); + } else { + map.insert(key.to_string(), JsonValue::String(value)); + } + } + } + } + + Ok(()) +} + +fn ask_yes_no(question: &str, default: bool) -> Result { + let default_option = if default { "Y/n" } else { "y/N" }; + print!("{} [{}]: ", question, default_option); + io::stdout().flush()?; + let mut response = String::new(); + io::stdin().read_line(&mut response)?; + let response = response.trim().to_lowercase(); + Ok(match response.as_str() { + "y" | "yes" => true, + "n" | "no" => false, + "" => default, + _ => ask_yes_no(question, default)?, + }) +} + +fn path_to_tilde_string(path: &PathBuf) -> String { + if let Ok(home_dir) = env::var("HOME") { + let home_path = PathBuf::from(home_dir); + if let Ok(relative_path) = path.strip_prefix(&home_path) { + return format!("~/{}", relative_path.display()); + } + } + path.to_string_lossy().to_string() +} + +fn get_project_directory() -> Result { + let current_dir = env::current_dir()?; + let tilde_current_dir = path_to_tilde_string(¤t_dir); + + println!("Current directory:"); + println!(" {}", tilde_current_dir); + + print!("Enter project directory (or press Enter for current directory): "); + io::stdout().flush()?; + let mut dir = String::new(); + io::stdin().read_line(&mut dir)?; + let dir = dir.trim(); + + let path = if dir.is_empty() { + current_dir.clone() + } else { + PathBuf::from(dir) + }; + + // Convert to absolute path + let absolute_path = if path.is_absolute() { + path + } else { + current_dir.join(path) + }; + + Ok(path_to_tilde_string(&absolute_path)) +} + +fn get_optional_value(key: &str) -> Result, ApplicationError> { + print!( + "Enter the value for '{}' (optional, press Enter to skip): ", + key + ); + io::stdout().flush()?; + let mut value = String::new(); + io::stdin().read_line(&mut value)?; + let value = value.trim().to_string(); + Ok(if value.is_empty() { None } else { Some(value) }) +} + +fn get_secure_value(key: &str) -> Result, ApplicationError> { + print!( + "Enter the secure value for '{}' (optional, press Enter to skip): ", + key + ); + io::stdout().flush()?; + let mut value = String::new(); + io::stdin().read_line(&mut value)?; + let value = value.trim().to_string(); + Ok(if value.is_empty() { None } else { Some(value) }) +} + +fn get_profile_name() -> Result { + print!("Enter a name for the new profile: "); + io::stdout().flush()?; + let mut profile_name = String::new(); + io::stdin().read_line(&mut profile_name)?; + Ok(profile_name.trim().to_string()) +} + fn get_value_for_key(key: &str) -> Result { print!("Enter the value for '{}': ", key); io::stdout().flush()?; diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs index 2faa7572..b7db2469 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs @@ -12,7 +12,7 @@ use eventstream::EventStreamMessage; use lumni::api::error::ApplicationError; use lumni::{AWSCredentials, AWSRequestBuilder, HttpClient}; use request::*; -use serde_json::Value; +use serde_json::{json, Value as JsonValue}; use sha2::{Digest, Sha256}; use tokio::sync::{mpsc, oneshot}; use url::Url; @@ -121,6 +121,14 @@ impl ServerTrait for Bedrock { &self.spec } + fn get_profile_settings(&self) -> JsonValue { + json!({ + "MODEL_SERVER": "bedrock", + "AWS_PROFILE": null, + "AWS_REGION": null + }) + } + async fn initialize_with_model( &mut self, _reader: &ConversationDbHandler, @@ -307,7 +315,7 @@ fn process_event_payload( (None, true, tokens_predicted, tokens_in_prompt) } -fn parse_payload(payload: Option) -> Option { +fn parse_payload(payload: Option) -> Option { payload.and_then(|p| match serde_json::from_slice(&p) { Ok(json) => Some(json), Err(_) => { diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs index c645d6c3..0a85566e 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs @@ -21,6 +21,7 @@ pub use ollama::Ollama; pub use openai::OpenAI; pub use response::{CompletionResponse, CompletionStats}; use send::{http_get_with_response, http_post, http_post_with_response}; +use serde_json::Value as JsonValue; pub use spec::ServerSpecTrait; use tokio::sync::{mpsc, oneshot}; @@ -81,6 +82,15 @@ impl ServerTrait for ModelServer { } } + fn get_profile_settings(&self) -> JsonValue { + match self { + ModelServer::Llama(llama) => llama.get_profile_settings(), + ModelServer::Ollama(ollama) => ollama.get_profile_settings(), + ModelServer::Bedrock(bedrock) => bedrock.get_profile_settings(), + ModelServer::OpenAI(openai) => openai.get_profile_settings(), + } + } + async fn initialize_with_model( &mut self, reader: &ConversationDbHandler, @@ -170,6 +180,10 @@ impl ServerTrait for ModelServer { pub trait ServerTrait: Send + Sync { fn get_spec(&self) -> &dyn ServerSpecTrait; + fn get_profile_settings(&self) -> JsonValue { + JsonValue::Object(serde_json::Map::new()) + } + async fn initialize_with_model( &mut self, reader: &ConversationDbHandler, diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs index f46dfe7c..236e8aee 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs @@ -15,6 +15,7 @@ use lumni::api::error::ApplicationError; use lumni::HttpClient; use request::{OpenAIChatMessage, OpenAIRequestPayload, StreamOptions}; use response::StreamParser; +use serde_json::{json, Value as JsonValue}; use tokio::sync::{mpsc, oneshot}; use url::Url; @@ -91,6 +92,16 @@ impl ServerTrait for OpenAI { &self.spec } + fn get_profile_settings(&self) -> JsonValue { + json!({ + "MODEL_SERVER": "openai", + "OPENAI_API_KEY": { + "content": "", + "encryption_key": "", + } + }) + } + async fn initialize_with_model( &mut self, _reader: &ConversationDbHandler,