From e4a63f09a51339099fbbf549ce59a9b952f06ad5 Mon Sep 17 00:00:00 2001 From: Anthony Potappel Date: Sun, 18 Aug 2024 17:52:31 +0200 Subject: [PATCH] get server from profile at start --- lumni/src/apps/builtin/llm/prompt/src/app.rs | 74 +++++++++---------- .../llm/prompt/src/chat/conversation/mod.rs | 4 +- .../builtin/llm/prompt/src/chat/db/mod.rs | 2 +- .../prompt/src/chat/db/user_profile/mod.rs | 38 ++++++++++ .../apps/builtin/llm/prompt/src/chat/mod.rs | 4 +- .../apps/builtin/llm/prompt/src/cli/mod.rs | 6 -- .../llm/prompt/src/cli/subcommands/profile.rs | 4 +- .../builtin/llm/prompt/src/server/backend.rs | 13 ++++ .../apps/builtin/llm/prompt/src/server/mod.rs | 2 + .../src/tui/modals/profiles/profile_list.rs | 4 +- 10 files changed, 101 insertions(+), 50 deletions(-) create mode 100644 lumni/src/apps/builtin/llm/prompt/src/server/backend.rs diff --git a/lumni/src/apps/builtin/llm/prompt/src/app.rs b/lumni/src/apps/builtin/llm/prompt/src/app.rs index 32a0ece..65bf50b 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/app.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/app.rs @@ -1,5 +1,4 @@ use std::io::{self, Write}; -use std::path::PathBuf; use std::sync::Arc; use clap::{ArgMatches, Command}; @@ -40,20 +39,37 @@ async fn create_prompt_instruction( let assistant = matches.and_then(|m| m.get_one::("assistant").cloned()); let user_options = matches.and_then(|m| m.get_one::("options")); - let server_name = matches - .and_then(|m| m.get_one::("server")) - .map(|s| s.to_lowercase()) - .unwrap_or_else(|| "ollama".to_lowercase()); - - // create new (un-initialized) server from requested server name - let server = ModelServer::from_str(&server_name)?; - let default_model = match server.get_default_model().await { - Ok(model) => Some(model), - Err(e) => { - log::error!("Failed to get default model during startup: {}", e); - None + + let mut profile_handler = db_conn.get_profile_handler(None); + + // Handle --profile option + if let Some(profile_name) = + matches.and_then(|m| m.get_one::("profile")) + { + profile_handler.set_profile_name(profile_name.to_string()); + } else { + // Use default profile if set + if let Some(default_profile) = + profile_handler.get_default_profile().await? + { + profile_handler.set_profile_name(default_profile); } - }; + } + + // Check if a profile is set + if profile_handler.get_profile_name().is_none() { + return Err(ApplicationError::InvalidInput( + "No profile set".to_string(), + )); + } + + // Get model_backend + let model_backend = + profile_handler.model_backend().await?.ok_or_else(|| { + ApplicationError::InvalidInput( + "Failed to get model backend".to_string(), + ) + })?; let assistant_manager = AssistantManager::new(assistant, instruction.clone())?; @@ -62,8 +78,8 @@ async fn create_prompt_instruction( let mut completion_options = assistant_manager.get_completion_options().clone(); - let model_server = ModelServerName::from_str(&server_name); - completion_options.model_server = Some(model_server.clone()); + let model_server_name = model_backend.server_name(); + completion_options.model_server = Some(model_server_name.clone()); // overwrite default options with options set by the user if let Some(s) = user_options { @@ -72,8 +88,8 @@ async fn create_prompt_instruction( } let new_conversation = NewConversation { - server: model_server, - model: default_model, + server: model_server_name, + model: model_backend.model.clone(), options: Some(serde_json::to_value(completion_options)?), system_prompt: instruction, initial_messages: Some(initial_messages), @@ -170,32 +186,16 @@ pub async fn run_cli( let db_conn = Arc::new(ConversationDatabase::new(&sqlite_file, None)?); if let Some(ref matches) = matches { - let mut profile_handler = db_conn.get_profile_handler(None); if let Some(db_matches) = matches.subcommand_matches("db") { return handle_db_subcommand(db_matches, &db_conn).await; } if let Some(profile_matches) = matches.subcommand_matches("profile") { - return handle_profile_subcommand( - profile_matches, - &mut profile_handler, - ) - .await; - } - - // Handle --profile option - if let Some(profile_name) = matches.get_one::("profile") { - profile_handler.set_profile_name(profile_name.to_string()); - } else { - // Use default profile if set - if let Some(default_profile) = - profile_handler.get_default_profile().await? - { - profile_handler.set_profile_name(default_profile); - } + let profile_handler = db_conn.get_profile_handler(None); + return handle_profile_subcommand(profile_matches, profile_handler) + .await; } } - // TODO: Add support for --profile option in the prompt command let prompt_instruction = create_prompt_instruction(matches.as_ref(), &db_conn).await?; diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs index 51f063e..7b687fa 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs @@ -8,8 +8,8 @@ pub use prepare::NewConversation; pub use super::db; use super::{ - ChatCompletionOptions, ChatMessage, ColorScheme, PromptError, PromptRole, - SimpleString, TextLine, TextSegment, + ChatCompletionOptions, ChatMessage, ColorScheme, ModelBackend, PromptError, + PromptRole, SimpleString, TextLine, TextSegment, }; #[derive(Debug, Clone, PartialEq, Copy)] diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs index 2319540..6c9741c 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs @@ -18,7 +18,7 @@ pub use store::ConversationDatabase; pub use user_profile::{MaskMode, UserProfileDbHandler}; pub use super::ConversationCache; -use super::PromptRole; +use super::{ModelBackend, ModelServer, PromptRole}; use crate::external as lumni; #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profile/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profile/mod.rs index 65d7876..b70f58f 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profile/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profile/mod.rs @@ -11,6 +11,7 @@ use tokio::sync::Mutex as TokioMutex; use super::connector::{DatabaseConnector, DatabaseOperationError}; use super::encryption::EncryptionHandler; +use super::{ModelBackend, ModelServer, ModelSpec}; use crate::external as lumni; #[derive(Debug, Clone)] @@ -49,6 +50,43 @@ impl UserProfileDbHandler { self.profile_name = Some(profile_name); } + pub fn get_profile_name(&self) -> Option<&str> { + self.profile_name.as_deref() + } + + pub async fn model_backend( + &mut self, + ) -> Result, ApplicationError> { + let profile_name = self.profile_name.clone(); + + if let Some(profile_name) = profile_name { + let settings = self + .get_profile_settings(&profile_name, MaskMode::Unmask) + .await?; + + let model_server = settings + .get("__MODEL_SERVER") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + ApplicationError::InvalidInput( + "__MODEL_SERVER not found in profile".to_string(), + ) + })?; + + let server = ModelServer::from_str(model_server)?; + + let model = settings + .get("__MODEL_IDENTIFIER") + .and_then(|v| v.as_str()) + .map(|identifier| ModelSpec::new_with_validation(identifier)) + .transpose()?; + + Ok(Some(ModelBackend { server, model })) + } else { + Ok(None) + } + } + pub fn set_encryption_handler( &mut self, encryption_handler: Arc, diff --git a/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs index 3ab2043..8855c81 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs @@ -13,7 +13,9 @@ pub use session::{prompt_app, App, ChatEvent, ThreadedChatSession}; pub use super::defaults::*; pub use super::error::{PromptError, PromptNotReadyReason}; -use super::server::{CompletionResponse, ModelServer, ServerManager}; +use super::server::{ + CompletionResponse, ModelBackend, ModelServer, ServerManager, +}; use super::tui::{ draw_ui, AppUi, ColorScheme, ColorSchemeType, CommandLineAction, ConversationEvent, KeyEventHandler, ModalAction, ModalWindowType, diff --git a/lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs index 456c515..b8d992d 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs @@ -41,12 +41,6 @@ pub fn parse_cli_arguments(spec: ApplicationSpec) -> Command { .short('a') .help("Specify an assistant to use"), ) - .arg( - Arg::new("server") - .long("server") - .short('S') - .help("Server to use for processing the request"), - ) .arg(Arg::new("options").long("options").short('o').help( "Comma-separated list of model options e.g., \ temperature=1,max_tokens=100", diff --git a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs index c57554c..cb09f56 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile.rs @@ -191,7 +191,7 @@ fn create_truncate_subcommand() -> Command { pub async fn handle_profile_subcommand( profile_matches: &ArgMatches, - db_handler: &mut UserProfileDbHandler, + mut db_handler: UserProfileDbHandler, ) -> Result<(), ApplicationError> { match profile_matches.subcommand() { Some(("list", _)) => { @@ -365,7 +365,7 @@ pub async fn handle_profile_subcommand( let custom_ssh_key_path = edit_matches.get_one::("ssh-key-path").cloned(); interactive_profile_edit( - db_handler, + &mut db_handler, profile_name, custom_ssh_key_path, ) diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/backend.rs b/lumni/src/apps/builtin/llm/prompt/src/server/backend.rs new file mode 100644 index 0000000..4b8de2d --- /dev/null +++ b/lumni/src/apps/builtin/llm/prompt/src/server/backend.rs @@ -0,0 +1,13 @@ +use super::{ModelServer, ModelSpec, ServerManager}; +use crate::apps::builtin::llm::prompt::src::chat::db::ModelServerName; + +pub struct ModelBackend { + pub server: ModelServer, + pub model: Option, +} + +impl ModelBackend { + pub fn server_name(&self) -> ModelServerName { + self.server.server_name() + } +} diff --git a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs index 9b79267..b85d908 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/server/mod.rs @@ -2,6 +2,7 @@ #[macro_use] mod spec; +mod backend; mod bedrock; mod endpoints; mod llama; @@ -11,6 +12,7 @@ mod response; mod send; use async_trait::async_trait; +pub use backend::ModelBackend; pub use bedrock::Bedrock; use bytes::Bytes; pub use endpoints::Endpoints; diff --git a/lumni/src/apps/builtin/llm/prompt/src/tui/modals/profiles/profile_list.rs b/lumni/src/apps/builtin/llm/prompt/src/tui/modals/profiles/profile_list.rs index 290748c..a12d18f 100644 --- a/lumni/src/apps/builtin/llm/prompt/src/tui/modals/profiles/profile_list.rs +++ b/lumni/src/apps/builtin/llm/prompt/src/tui/modals/profiles/profile_list.rs @@ -95,7 +95,9 @@ impl ProfileList { } pub fn is_default_profile(&self, profile: &str) -> bool { - self.default_profile.as_ref().map_or(false, |default| default == profile) + self.default_profile + .as_ref() + .map_or(false, |default| default == profile) } pub fn get_profiles(&self) -> Vec {