Skip to content

Commit

Permalink
get server from profile at start
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 18, 2024
1 parent e4009ad commit e4a63f0
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 50 deletions.
74 changes: 37 additions & 37 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::io::{self, Write};
use std::path::PathBuf;
use std::sync::Arc;

use clap::{ArgMatches, Command};
Expand Down Expand Up @@ -40,20 +39,37 @@ async fn create_prompt_instruction(
let assistant =
matches.and_then(|m| m.get_one::<String>("assistant").cloned());
let user_options = matches.and_then(|m| m.get_one::<String>("options"));
let server_name = matches
.and_then(|m| m.get_one::<String>("server"))
.map(|s| s.to_lowercase())
.unwrap_or_else(|| "ollama".to_lowercase());

// create new (un-initialized) server from requested server name
let server = ModelServer::from_str(&server_name)?;
let default_model = match server.get_default_model().await {
Ok(model) => Some(model),
Err(e) => {
log::error!("Failed to get default model during startup: {}", e);
None

let mut profile_handler = db_conn.get_profile_handler(None);

// Handle --profile option
if let Some(profile_name) =
matches.and_then(|m| m.get_one::<String>("profile"))
{
profile_handler.set_profile_name(profile_name.to_string());
} else {
// Use default profile if set
if let Some(default_profile) =
profile_handler.get_default_profile().await?
{
profile_handler.set_profile_name(default_profile);
}
};
}

// Check if a profile is set
if profile_handler.get_profile_name().is_none() {
return Err(ApplicationError::InvalidInput(
"No profile set".to_string(),
));
}

// Get model_backend
let model_backend =
profile_handler.model_backend().await?.ok_or_else(|| {
ApplicationError::InvalidInput(
"Failed to get model backend".to_string(),
)
})?;

let assistant_manager =
AssistantManager::new(assistant, instruction.clone())?;
Expand All @@ -62,8 +78,8 @@ async fn create_prompt_instruction(
let mut completion_options =
assistant_manager.get_completion_options().clone();

let model_server = ModelServerName::from_str(&server_name);
completion_options.model_server = Some(model_server.clone());
let model_server_name = model_backend.server_name();
completion_options.model_server = Some(model_server_name.clone());

// overwrite default options with options set by the user
if let Some(s) = user_options {
Expand All @@ -72,8 +88,8 @@ async fn create_prompt_instruction(
}

let new_conversation = NewConversation {
server: model_server,
model: default_model,
server: model_server_name,
model: model_backend.model.clone(),
options: Some(serde_json::to_value(completion_options)?),
system_prompt: instruction,
initial_messages: Some(initial_messages),
Expand Down Expand Up @@ -170,32 +186,16 @@ pub async fn run_cli(
let db_conn = Arc::new(ConversationDatabase::new(&sqlite_file, None)?);

if let Some(ref matches) = matches {
let mut profile_handler = db_conn.get_profile_handler(None);
if let Some(db_matches) = matches.subcommand_matches("db") {
return handle_db_subcommand(db_matches, &db_conn).await;
}
if let Some(profile_matches) = matches.subcommand_matches("profile") {
return handle_profile_subcommand(
profile_matches,
&mut profile_handler,
)
.await;
}

// Handle --profile option
if let Some(profile_name) = matches.get_one::<String>("profile") {
profile_handler.set_profile_name(profile_name.to_string());
} else {
// Use default profile if set
if let Some(default_profile) =
profile_handler.get_default_profile().await?
{
profile_handler.set_profile_name(default_profile);
}
let profile_handler = db_conn.get_profile_handler(None);
return handle_profile_subcommand(profile_matches, profile_handler)
.await;
}
}

// TODO: Add support for --profile option in the prompt command
let prompt_instruction =
create_prompt_instruction(matches.as_ref(), &db_conn).await?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ pub use prepare::NewConversation;

pub use super::db;
use super::{
ChatCompletionOptions, ChatMessage, ColorScheme, PromptError, PromptRole,
SimpleString, TextLine, TextSegment,
ChatCompletionOptions, ChatMessage, ColorScheme, ModelBackend, PromptError,
PromptRole, SimpleString, TextLine, TextSegment,
};

#[derive(Debug, Clone, PartialEq, Copy)]
Expand Down
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub use store::ConversationDatabase;
pub use user_profile::{MaskMode, UserProfileDbHandler};

pub use super::ConversationCache;
use super::PromptRole;
use super::{ModelBackend, ModelServer, PromptRole};
use crate::external as lumni;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down
38 changes: 38 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/user_profile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tokio::sync::Mutex as TokioMutex;

use super::connector::{DatabaseConnector, DatabaseOperationError};
use super::encryption::EncryptionHandler;
use super::{ModelBackend, ModelServer, ModelSpec};
use crate::external as lumni;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -49,6 +50,43 @@ impl UserProfileDbHandler {
self.profile_name = Some(profile_name);
}

pub fn get_profile_name(&self) -> Option<&str> {
self.profile_name.as_deref()
}

pub async fn model_backend(
&mut self,
) -> Result<Option<ModelBackend>, ApplicationError> {
let profile_name = self.profile_name.clone();

if let Some(profile_name) = profile_name {
let settings = self
.get_profile_settings(&profile_name, MaskMode::Unmask)
.await?;

let model_server = settings
.get("__MODEL_SERVER")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ApplicationError::InvalidInput(
"__MODEL_SERVER not found in profile".to_string(),
)
})?;

let server = ModelServer::from_str(model_server)?;

let model = settings
.get("__MODEL_IDENTIFIER")
.and_then(|v| v.as_str())
.map(|identifier| ModelSpec::new_with_validation(identifier))
.transpose()?;

Ok(Some(ModelBackend { server, model }))
} else {
Ok(None)
}
}

pub fn set_encryption_handler(
&mut self,
encryption_handler: Arc<EncryptionHandler>,
Expand Down
4 changes: 3 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ pub use session::{prompt_app, App, ChatEvent, ThreadedChatSession};

pub use super::defaults::*;
pub use super::error::{PromptError, PromptNotReadyReason};
use super::server::{CompletionResponse, ModelServer, ServerManager};
use super::server::{
CompletionResponse, ModelBackend, ModelServer, ServerManager,
};
use super::tui::{
draw_ui, AppUi, ColorScheme, ColorSchemeType, CommandLineAction,
ConversationEvent, KeyEventHandler, ModalAction, ModalWindowType,
Expand Down
6 changes: 0 additions & 6 deletions lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ pub fn parse_cli_arguments(spec: ApplicationSpec) -> Command {
.short('a')
.help("Specify an assistant to use"),
)
.arg(
Arg::new("server")
.long("server")
.short('S')
.help("Server to use for processing the request"),
)
.arg(Arg::new("options").long("options").short('o').help(
"Comma-separated list of model options e.g., \
temperature=1,max_tokens=100",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ fn create_truncate_subcommand() -> Command {

pub async fn handle_profile_subcommand(
profile_matches: &ArgMatches,
db_handler: &mut UserProfileDbHandler,
mut db_handler: UserProfileDbHandler,
) -> Result<(), ApplicationError> {
match profile_matches.subcommand() {
Some(("list", _)) => {
Expand Down Expand Up @@ -365,7 +365,7 @@ pub async fn handle_profile_subcommand(
let custom_ssh_key_path =
edit_matches.get_one::<String>("ssh-key-path").cloned();
interactive_profile_edit(
db_handler,
&mut db_handler,
profile_name,
custom_ssh_key_path,
)
Expand Down
13 changes: 13 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/server/backend.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use super::{ModelServer, ModelSpec, ServerManager};
use crate::apps::builtin::llm::prompt::src::chat::db::ModelServerName;

pub struct ModelBackend {
pub server: ModelServer,
pub model: Option<ModelSpec>,
}

impl ModelBackend {
pub fn server_name(&self) -> ModelServerName {
self.server.server_name()
}
}
2 changes: 2 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#[macro_use]
mod spec;

mod backend;
mod bedrock;
mod endpoints;
mod llama;
Expand All @@ -11,6 +12,7 @@ mod response;
mod send;

use async_trait::async_trait;
pub use backend::ModelBackend;
pub use bedrock::Bedrock;
use bytes::Bytes;
pub use endpoints::Endpoints;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ impl ProfileList {
}

pub fn is_default_profile(&self, profile: &str) -> bool {
self.default_profile.as_ref().map_or(false, |default| default == profile)
self.default_profile
.as_ref()
.map_or(false, |default| default == profile)
}

pub fn get_profiles(&self) -> Vec<String> {
Expand Down

0 comments on commit e4a63f0

Please sign in to comment.