Skip to content

Commit

Permalink
add model selection to profile modal
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 17, 2024
1 parent 8730d4f commit d0ba08e
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 104 deletions.
6 changes: 2 additions & 4 deletions lumni/src/apps/builtin/llm/prompt/src/tui/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ pub use window::{

use super::chat::db::{
Conversation, ConversationDatabase, ConversationDbHandler, ConversationId,
ConversationStatus, MaskMode, UserProfileDbHandler,
ConversationStatus, MaskMode, ModelSpec, UserProfileDbHandler,
};
use super::chat::{
App, NewConversation, PromptInstruction, ThreadedChatSession,
};
use super::server::{
ModelServer, ServerManager, ServerTrait, SUPPORTED_MODEL_ENDPOINTS,
};
use super::server::{ModelServer, ServerTrait, SUPPORTED_MODEL_ENDPOINTS};
use crate::external as lumni;
5 changes: 3 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/tui/modals/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ use ratatui::Frame;

use super::{
ApplicationError, CommandLine, Conversation, ConversationDbHandler,
ConversationStatus, KeyTrack, MaskMode, PromptInstruction, TextWindowTrait,
ThreadedChatSession, UserEvent, UserProfileDbHandler, WindowEvent,
ConversationStatus, KeyTrack, MaskMode, ModelServer, ModelSpec,
PromptInstruction, ServerTrait, TextWindowTrait, ThreadedChatSession,
UserEvent, UserProfileDbHandler, WindowEvent, SUPPORTED_MODEL_ENDPOINTS,
};

#[derive(Debug, Clone, PartialEq)]
Expand Down
138 changes: 125 additions & 13 deletions lumni/src/apps/builtin/llm/prompt/src/tui/modals/profiles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ use ratatui::widgets::{
Block, Borders, Clear, List, ListItem, ListState, Paragraph,
};
use ratatui::Frame;
use serde_json::{json, Map, Value};
use serde_json::{json, Map, Value as JsonValue};
use settings_editor::SettingsEditor;
use tokio::sync::mpsc;
use ui_state::{EditMode, Focus, UIState};

use super::{
ApplicationError, ConversationDbHandler, KeyTrack, MaskMode, ModalAction,
ModalWindowTrait, ModalWindowType, ThreadedChatSession,
UserProfileDbHandler, WindowEvent,
ModalWindowTrait, ModalWindowType, ModelServer, ModelSpec, ServerTrait,
ThreadedChatSession, UserProfileDbHandler, WindowEvent,
SUPPORTED_MODEL_ENDPOINTS,
};

pub struct ProfileEditModal {
Expand All @@ -50,10 +51,10 @@ impl ProfileEditModal {
.get_profile_settings(profile, MaskMode::Mask)
.await?
} else {
Value::Object(serde_json::Map::new())
JsonValue::Object(serde_json::Map::new())
};
let settings_editor = SettingsEditor::new(settings);
let new_profile_creator = NewProfileCreator::new();
let new_profile_creator = NewProfileCreator::new(db_handler.clone());

Ok(Self {
profile_list,
Expand Down Expand Up @@ -312,23 +313,130 @@ impl ProfileEditModal {
self.new_profile_creator.selected_type += 1;
}
}
KeyCode::Enter => {
if self
.new_profile_creator
.prepare_for_model_selection()
.await?
{
self.ui_state.set_focus(Focus::ModelSelection);
} else {
// If no model selection is needed, create the profile without a model
let profile_count = self.profile_list.total_items();
self.new_profile_creator
.create_new_profile(&self.db_handler, profile_count)
.await?;
self.ui_state.set_focus(Focus::ProfileList);
}
return Ok(WindowEvent::Modal(ModalAction::Refresh));
}
KeyCode::Esc => {
self.ui_state.set_focus(Focus::ProfileList);
}
_ => {}
}
Ok(WindowEvent::Modal(ModalAction::WaitForKeyEvent))
}

async fn handle_model_selection_input(
&mut self,
key_code: KeyCode,
) -> Result<WindowEvent, ApplicationError> {
match key_code {
KeyCode::Up => {
self.new_profile_creator.move_model_selection_up();
}
KeyCode::Down => {
self.new_profile_creator.move_model_selection_down();
}
KeyCode::Enter => {
let profile_count = self.profile_list.total_items();
self.new_profile_creator
.create_new_profile(&self.db_handler, profile_count)
.await?;
self.ui_state.set_focus(Focus::ProfileList);
return Ok(WindowEvent::Modal(ModalAction::Refresh));
}
KeyCode::Esc => {
self.ui_state.set_edit_mode(EditMode::NotEditing);
// Cancel model selection, create profile without a model
self.new_profile_creator.model_selection_pending = false;
let profile_count = self.profile_list.total_items();
self.new_profile_creator
.create_new_profile(&self.db_handler, profile_count)
.await?;
self.ui_state.set_focus(Focus::ProfileList);
return Ok(WindowEvent::Modal(ModalAction::Refresh));
}
_ => {}
}

Ok(WindowEvent::Modal(ModalAction::WaitForKeyEvent))
}

fn render_model_selection(&self, f: &mut Frame, area: Rect) {
let models = &self.new_profile_creator.available_models;
let items: Vec<ListItem> = models
.iter()
.enumerate()
.map(|(i, model)| {
let style = if i
== self.new_profile_creator.selected_model_index
{
Style::default().bg(Color::Rgb(40, 40, 40)).fg(Color::White)
} else {
Style::default().bg(Color::Black).fg(Color::Cyan)
};
ListItem::new(Line::from(vec![Span::styled(
&model.identifier.0,
style,
)]))
})
.collect();

let list = List::new(items)
.block(Block::default().borders(Borders::ALL).title("Select Model"))
.highlight_style(Style::default().add_modifier(Modifier::BOLD))
.highlight_symbol(">> ");

let mut state = ListState::default();
state.select(Some(self.new_profile_creator.selected_model_index));

f.render_stateful_widget(list, area, &mut state);
}

fn render_new_profile_type(&self, f: &mut Frame, area: Rect) {
let items: Vec<ListItem> = self
.new_profile_creator
.predefined_types
.iter()
.enumerate()
.map(|(i, profile_type)| {
let style = if i == self.new_profile_creator.selected_type {
Style::default().bg(Color::Rgb(40, 40, 40)).fg(Color::White)
} else {
Style::default().bg(Color::Black).fg(Color::Cyan)
};
ListItem::new(Line::from(vec![Span::styled(
profile_type,
style,
)]))
})
.collect();

let list = List::new(items)
.block(
Block::default()
.borders(Borders::ALL)
.title("Select Profile Type"),
)
.highlight_style(Style::default().add_modifier(Modifier::BOLD))
.highlight_symbol(">> ");

let mut state = ListState::default();
state.select(Some(self.new_profile_creator.selected_type));

f.render_stateful_widget(list, area, &mut state);
}

fn cancel_edit(&mut self) {
self.settings_editor.cancel_edit();
self.ui_state.set_edit_mode(EditMode::NotEditing);
Expand Down Expand Up @@ -366,10 +474,13 @@ impl ModalWindowTrait for ProfileEditModal {
self.renderer
.render_profile_list(frame, content_chunks[0], self);

match self.ui_state.edit_mode {
EditMode::CreatingNewProfile => self
.renderer
.render_new_profile_type(frame, content_chunks[1], self),
match self.ui_state.focus {
Focus::NewProfileType => {
self.render_new_profile_type(frame, content_chunks[1])
}
Focus::ModelSelection => {
self.render_model_selection(frame, content_chunks[1])
}
_ => self.renderer.render_settings_list(
frame,
content_chunks[1],
Expand Down Expand Up @@ -438,7 +549,6 @@ impl ModalWindowTrait for ProfileEditModal {
_handler: &mut ConversationDbHandler,
) -> Result<WindowEvent, ApplicationError> {
let key_code = key_event.current_key().code;

let result = match self.ui_state.focus {
Focus::ProfileList => match key_code {
KeyCode::Right | KeyCode::Tab => {
Expand Down Expand Up @@ -470,8 +580,10 @@ impl ModalWindowTrait for ProfileEditModal {
Focus::RenamingProfile => {
Ok(self.handle_profile_list_input(key_code).await?)
}
Focus::ModelSelection => {
Ok(self.handle_model_selection_input(key_code).await?)
}
};

result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,65 @@ use super::*;
pub enum BackgroundTaskResult {
ProfileCreated(Result<(), ApplicationError>),
}

pub struct NewProfileCreator {
pub predefined_types: Vec<String>,
pub selected_type: usize,
pub selected_model_index: usize, // New field for model selection
pub background_task: Option<mpsc::Receiver<BackgroundTaskResult>>,
pub task_start_time: Option<Instant>,
pub spinner_state: usize,
pub new_profile_name: Option<String>,
pub available_models: Vec<ModelSpec>,
pub model_selection_pending: bool,
pub db_handler: UserProfileDbHandler,
}

impl NewProfileCreator {
pub fn new() -> Self {
pub fn new(db_handler: UserProfileDbHandler) -> Self {
Self {
predefined_types: vec![
"Custom".to_string(),
"OpenAI".to_string(),
"Anthropic".to_string(),
],
predefined_types: SUPPORTED_MODEL_ENDPOINTS
.iter()
.map(|s| s.to_string())
.collect(),
selected_type: 0,
selected_model_index: 0, // Initialize the new field
background_task: None,
task_start_time: None,
spinner_state: 0,
new_profile_name: None,
available_models: Vec::new(),
model_selection_pending: false,
db_handler,
}
}

pub async fn prepare_for_model_selection(
&mut self,
) -> Result<bool, ApplicationError> {
let profile_type = &self.predefined_types[self.selected_type];
let model_server = ModelServer::from_str(profile_type)?;

match model_server.list_models().await {
Ok(models) if !models.is_empty() => {
self.available_models = models;
self.model_selection_pending = true;
self.selected_model_index = 0; // Reset model selection index
Ok(true)
}
Ok(_) => {
println!("No models available for this server.");
self.model_selection_pending = false;
Ok(false)
}
Err(ApplicationError::NotReady(msg)) => {
println!(
"Server not ready: {}. Model selection will be skipped.",
msg
);
self.model_selection_pending = false;
Ok(false)
}
Err(e) => Err(e),
}
}

Expand All @@ -36,52 +72,58 @@ impl NewProfileCreator {
) -> Result<(), ApplicationError> {
let new_profile_name = format!("New_Profile_{}", profile_count + 1);
let profile_type = &self.predefined_types[self.selected_type];

let mut settings = Map::new();
settings.insert("__PROFILE_TYPE".to_string(), json!(profile_type));

// Add default settings based on the profile type
match profile_type.as_str() {
"OpenAI" => {
settings.insert("api_key".to_string(), json!(""));
settings.insert("model".to_string(), json!("gpt-3.5-turbo"));
}
"Anthropic" => {
settings.insert("api_key".to_string(), json!(""));
settings.insert("model".to_string(), json!("claude-2"));
let model_server = ModelServer::from_str(profile_type)?;
let server_settings = model_server.get_profile_settings();
if let JsonValue::Object(map) = server_settings {
for (key, value) in map {
settings.insert(key, value);
}
"Custom" => {}
_ => {
return Err(ApplicationError::InvalidInput(
"Unknown profile type".to_string(),
))
}

if self.model_selection_pending {
if let Some(selected_model) =
self.available_models.get(self.selected_model_index)
{
settings.insert(
"__MODEL_IDENTIFIER".to_string(),
json!(selected_model.identifier.0),
);
}
}

let mut db_handler = db_handler.clone();
let (tx, rx) = mpsc::channel(1);

let new_profile_name_clone = new_profile_name.clone();
let settings_clone = settings.clone();
tokio::spawn(async move {
let result = db_handler
.create_or_update(&new_profile_name_clone, &json!(settings))
.create_or_update(
&new_profile_name_clone,
&json!(settings_clone),
)
.await;
let _ = tx.send(BackgroundTaskResult::ProfileCreated(result)).await;
});

self.background_task = Some(rx);
self.task_start_time = Some(Instant::now());
self.spinner_state = 0;
self.new_profile_name = Some(new_profile_name);

Ok(())
}

pub fn get_predefined_types(&self) -> &[String] {
&self.predefined_types
// Add methods to manipulate selected_model_index
pub fn move_model_selection_up(&mut self) {
if self.selected_model_index > 0 {
self.selected_model_index -= 1;
}
}

pub fn get_selected_type(&self) -> usize {
self.selected_type
pub fn move_model_selection_down(&mut self) {
if self.selected_model_index < self.available_models.len() - 1 {
self.selected_model_index += 1;
}
}
}
Loading

0 comments on commit d0ba08e

Please sign in to comment.