Skip to content

Commit

Permalink
expand profile helper with model selection
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 13, 2024
1 parent 162cdec commit 58fc138
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 34 deletions.
4 changes: 4 additions & 0 deletions lumni/src/apps/api/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub enum ApplicationError {
DatabaseError(String),
NotImplemented(String),
NotReady(String),
UserCancelled(String),
EncryptionError(EncryptionError),
CustomError(Box<dyn Error + Send + Sync>),
}
Expand Down Expand Up @@ -138,6 +139,9 @@ impl fmt::Display for ApplicationError {
ApplicationError::EncryptionError(e) => {
write!(f, "EncryptionError: {}", e)
}
ApplicationError::UserCancelled(s) => {
write!(f, "UserCancelled: {}", s)
}
ApplicationError::CustomError(e) => write!(f, "{}", e),
}
}
Expand Down
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use subcommands::profile::create_profile_subcommand;
pub use subcommands::profile::handle_profile_subcommand;

use super::chat::db::{
ConversationDatabase, EncryptionHandler, EncryptionMode, MaskMode,
ConversationDatabase, EncryptionHandler, MaskMode, ModelSpec,
UserProfileDbHandler,
};
use super::server::{ModelServer, ServerTrait, SUPPORTED_MODEL_ENDPOINTS};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ pub mod profile;
pub mod profile_helper;

use super::{
ConversationDatabase, EncryptionHandler, MaskMode, ModelServer,
ConversationDatabase, EncryptionHandler, MaskMode, ModelServer, ModelSpec,
ServerTrait, UserProfileDbHandler, SUPPORTED_MODEL_ENDPOINTS,
};
167 changes: 135 additions & 32 deletions lumni/src/apps/builtin/llm/prompt/src/cli/subcommands/profile_helper.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::env;
use std::io::{self, Write};
use std::path::PathBuf;
Expand All @@ -9,11 +8,18 @@ use lumni::api::error::ApplicationError;
use serde_json::{json, Map, Value as JsonValue};

use super::{
EncryptionHandler, ModelServer, ServerTrait, UserProfileDbHandler,
SUPPORTED_MODEL_ENDPOINTS,
EncryptionHandler, ModelServer, ModelSpec, ServerTrait,
UserProfileDbHandler, SUPPORTED_MODEL_ENDPOINTS,
};
use crate::external as lumni;

enum ModelSelection {
Selected(ModelSpec),
Reload,
Quit,
Skip,
}

pub async fn interactive_profile_edit(
db_handler: &mut UserProfileDbHandler,
profile_name_to_update: Option<String>,
Expand All @@ -27,7 +33,31 @@ pub async fn interactive_profile_edit(
db_handler.set_profile_name(profile_name.clone());

let profile_type = select_profile_type()?;
let mut settings = get_predefined_settings(&profile_type)?;
let mut settings = JsonValue::Object(Map::new());

if profile_type != "Custom" {
let model_server = ModelServer::from_str(&profile_type)?;

if let Some(selected_model) = select_model(&model_server).await? {
settings["MODEL_IDENTIFIER"] =
JsonValue::String(selected_model.identifier.0);
} else {
println!("No model selected. Skipping model selection.");
}

// Get other predefined settings
let server_settings = model_server.get_profile_settings();
if let JsonValue::Object(map) = server_settings {
for (key, value) in map {
settings[key] = value;
}
}
}

if ask_yes_no("Do you want to set a project directory?", true)? {
let dir = get_project_directory()?;
settings["PROJECT_DIRECTORY"] = JsonValue::String(dir);
}

if profile_type == "Custom" {
collect_custom_settings(&mut settings)?;
Expand Down Expand Up @@ -56,6 +86,107 @@ pub async fn interactive_profile_edit(
Ok(())
}

async fn select_model(
model_server: &ModelServer,
) -> Result<Option<ModelSpec>, ApplicationError> {
loop {
match model_server.list_models().await {
Ok(models) => {
if models.is_empty() {
println!("No models available for this server.");
return Ok(None);
}

match select_model_from_list(&models)? {
ModelSelection::Selected(model) => return Ok(Some(model)),
ModelSelection::Reload => {
println!("Reloading model list...");
continue;
}
ModelSelection::Quit => {
return Err(ApplicationError::UserCancelled(
"Model selection cancelled by user".to_string(),
))
}
ModelSelection::Skip => return Ok(None),
}
}
Err(ApplicationError::NotReady(msg)) => {
println!("Error: {}", msg);
if let Err(e) = handle_not_ready() {
return Err(e);
}
}
Err(e) => return Err(e), // propagate other errors
}
}
}

fn select_model_from_list(
models: &[ModelSpec],
) -> Result<ModelSelection, ApplicationError> {
println!("Available models:");
for (index, model) in models.iter().enumerate() {
println!("{}. {}", index + 1, model.identifier.0);
}

print!(
"Select a model (1-{}), press Enter to reload, 'q' to quit, or 's' to \
skip: ",
models.len()
);
io::stdout().flush()?;
let mut choice = String::new();
io::stdin().read_line(&mut choice)?;
let choice = choice.trim().to_lowercase();

match choice.as_str() {
"" => Ok(ModelSelection::Reload),
"q" => {
println!("Quitting model selection.");
Ok(ModelSelection::Quit)
}
"s" => {
println!("Skipping model selection.");
Ok(ModelSelection::Skip)
}
_ => {
if let Ok(index) = choice.parse::<usize>() {
if index > 0 && index <= models.len() {
return Ok(ModelSelection::Selected(
models[index - 1].clone(),
));
}
}
println!("Invalid choice. Please try again.");
select_model_from_list(models) // Recursively ask for selection again
}
}
}

fn handle_not_ready() -> Result<(), ApplicationError> {
loop {
print!(
"Press Enter to retry, 'q' to quit, or 's' to skip model \
selection: "
);
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;

match input.trim().to_lowercase().as_str() {
"" => return Ok(()),
"q" => {
return Err(ApplicationError::UserCancelled(
"Model selection cancelled by user".to_string(),
))
}
"s" => return Ok(()),
_ => println!("Invalid input. Please try again."),
}
}
}

fn select_profile_type() -> Result<String, ApplicationError> {
println!("Select a profile type:");
println!("0. Custom (default)");
Expand Down Expand Up @@ -123,34 +254,6 @@ fn collect_custom_settings(
Ok(())
}

fn get_predefined_settings(
profile_type: &str,
) -> Result<JsonValue, ApplicationError> {
let mut settings = JsonValue::Object(Map::new());

if ask_yes_no("Do you want to set a project directory?", true)? {
let dir = get_project_directory()?;
settings["PROJECT_DIRECTORY"] = JsonValue::String(dir);
}

if profile_type != "Custom" {
let model_server = ModelServer::from_str(profile_type)?;
let server_settings = model_server.get_profile_settings();

if let JsonValue::Object(map) = server_settings {
for (key, value) in map {
settings[key] = value;
}
}
} else {
println!(
"Custom profile selected. You can add custom key-value pairs."
);
}

Ok(settings)
}

fn collect_profile_settings(
settings: &mut JsonValue,
profile_type: &str,
Expand Down

0 comments on commit 58fc138

Please sign in to comment.