Skip to content

Commit

Permalink
allow to change server/ model during chat
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 21, 2024
1 parent 8a2c000 commit 8c1dfab
Show file tree
Hide file tree
Showing 28 changed files with 474 additions and 312 deletions.
7 changes: 6 additions & 1 deletion lumni/src/apps/api/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ impl std::error::Error for ApplicationError {}

impl From<HttpClientError> for ApplicationError {
fn from(error: HttpClientError) -> Self {
ApplicationError::HttpClientError(error)
match error {
HttpClientError::ConnectionError(e) => {
ApplicationError::NotReady(e.to_string())
}
_ => ApplicationError::HttpClientError(error),
}
}
}

Expand Down
80 changes: 55 additions & 25 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ use tokio::signal;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, timeout, Duration};

use super::chat::{ChatSession, ConversationDatabaseStore};
use super::chat::{ChatSession, ConversationDatabaseStore, NewConversation};
use super::server::{
ModelServer, PromptInstruction, ServerManager, ServerTrait,
ModelServer, ModelServerName, PromptInstruction, ServerManager, ServerTrait,
};
use super::session::{AppSession, TabSession};
use super::tui::{
ColorScheme, CommandLineAction, ConversationReader, KeyEventHandler, PromptAction, TabUi, TextWindowTrait, WindowEvent, WindowKind
ColorScheme, CommandLineAction, ConversationEvent, ConversationReader,
KeyEventHandler, PromptAction, TabUi, TextWindowTrait, WindowEvent,
WindowKind,
};
pub use crate::external as lumni;

Expand All @@ -56,11 +58,12 @@ async fn prompt_app<B: Backend>(
let mut redraw_ui = true;

// TODO: reader should be updated when conversation_id changes
let mut reader: Option<ConversationReader> = if let Some(conversation_id) = tab.chat.get_conversation_id() {
Some(db_conn.get_conversation_reader(conversation_id))
} else {
None
};
let mut reader: Option<ConversationReader> =
if let Some(conversation_id) = tab.chat.get_conversation_id() {
Some(db_conn.get_conversation_reader(conversation_id))
} else {
None
};

// Buffer to store the trimmed trailing newlines or empty spaces
let mut trim_buffer: Option<String> = None;
Expand Down Expand Up @@ -135,11 +138,24 @@ async fn prompt_app<B: Backend>(
}
Some(WindowEvent::PromptWindow(ref event)) => {
match event {
None => {},
Some(converation_event) => {
// TODO: if conversation_id changes, update reader
eprintln!("Conversation event: {:?}", converation_event);
Some(ConversationEvent::NewConversation(new_conversation)) => {
let prompt_instruction = PromptInstruction::new(
new_conversation.clone(),
&db_conn,
)?;
let chat_session = ChatSession::new(
&new_conversation.server.to_string(),
prompt_instruction,
&db_conn,
).await?;
tab.new_conversation(chat_session);
reader = if let Some(conversation_id) = tab.chat.get_conversation_id() {
Some(db_conn.get_conversation_reader(conversation_id))
} else {
None
};
}
_ => {},
}
}
_ => {}
Expand Down Expand Up @@ -367,29 +383,43 @@ pub async fn run_cli(
// optional arguments
let instruction = matches.get_one::<String>("system").cloned();
let assistant = matches.get_one::<String>("assistant").cloned();
let options = matches.get_one::<String>("options");
let options = match matches.get_one::<String>("options") {
Some(s) => {
let value = serde_json::from_str::<serde_json::Value>(s)?;
Some(value)
}
None => None,
};

let server_name = matches
.get_one::<String>("server")
.map(|s| s.to_lowercase())
.unwrap_or_else(|| "ollama".to_lowercase());

// create new (un-initialized) server from requested server name
let mut server = ModelServer::from_str(&server_name)?;
let default_model = server.get_default_model().await;

// setup prompt, server and chat session
let prompt_instruction =
PromptInstruction::new(default_model, instruction, assistant, options, &db_conn)?;
let conversation_id = prompt_instruction.get_conversation_id();
let server = ModelServer::from_str(&server_name)?;
let default_model = match server.get_default_model().await {
Ok(model) => Some(model),
Err(e) => {
log::error!("Failed to get default model during startup: {}", e);
None
}
};

if let Some(conversation_id) = conversation_id {
let reader = db_conn.get_conversation_reader(conversation_id);
server.setup_and_initialize(&reader).await?;
}
let prompt_instruction = PromptInstruction::new(
NewConversation {
server: ModelServerName::from_str(&server_name),
model: default_model,
options,
system_prompt: instruction,
assistant_name: assistant,
parent: None,
},
&db_conn,
)?;

let chat_session =
ChatSession::new(Box::new(server), prompt_instruction).await?;
ChatSession::new(&server_name, prompt_instruction, &db_conn).await?;

match poll(Duration::from_millis(0)) {
Ok(_) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::collections::HashMap;

use super::{ModelIdentifier, ModelSpec, PromptRole,
ConversationId, Message,
MessageId, AttachmentId, Attachment
use super::{
Attachment, AttachmentId, ConversationId, Message, MessageId,
ModelIdentifier, ModelSpec, PromptRole,
};

#[derive(Debug)]
Expand Down
18 changes: 13 additions & 5 deletions lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
use serde::{Deserialize, Serialize};

mod model;
mod cache;
mod model;

pub use model::{ModelIdentifier, ModelSpec};
pub use cache::ConversationCache;
pub use model::{ModelIdentifier, ModelSpec};

use super::PromptRole;


#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelServerName(pub String);

impl ModelServerName {
pub fn from_str<T: AsRef<str>>(s: T) -> Self {
ModelServerName(s.as_ref().to_string())
}

pub fn to_string(&self) -> String {
self.0.clone()
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ConversationId(pub i64);

Expand All @@ -21,7 +30,6 @@ pub struct MessageId(pub i64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AttachmentId(pub i64);


#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Conversation {
pub id: ConversationId,
Expand Down Expand Up @@ -67,4 +75,4 @@ pub struct Attachment {
pub metadata: Option<serde_json::Value>,
pub created_at: i64,
pub is_deleted: bool,
}
}
43 changes: 28 additions & 15 deletions lumni/src/apps/builtin/llm/prompt/src/chat/conversation/model.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use lazy_static::lazy_static;
use lumni::api::error::ApplicationError;
use regex::Regex;
use serde::{Deserialize, Serialize};

pub use crate::external as lumni;

use lazy_static::lazy_static;
use regex::Regex;

lazy_static! {
static ref IDENTIFIER_REGEX: Regex = Regex::new(
r"^[-a-z0-9_]+::[-a-z0-9_][-a-z0-9_:.]*[-a-z0-9_]+$"
).unwrap();
static ref IDENTIFIER_REGEX: Regex =
Regex::new(r"^[-a-z0-9_]+::[-a-z0-9_][-a-z0-9_:.]*[-a-z0-9_]+$")
.unwrap();
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand All @@ -21,7 +20,10 @@ impl ModelIdentifier {
Ok(ModelIdentifier(identifier_str.to_string()))
} else {
Err(ApplicationError::InvalidInput(format!(
"Identifier must be in the format 'provider::model_name', where the provider contains only lowercase letters, numbers, hyphens, underscores, and the model name can include internal colons but not start or end with them. Got: '{}'",
"Identifier must be in the format 'provider::model_name', \
where the provider contains only lowercase letters, numbers, \
hyphens, underscores, and the model name can include \
internal colons but not start or end with them. Got: '{}'",
identifier_str
)))
}
Expand All @@ -48,12 +50,14 @@ pub struct ModelSpec {
}

impl ModelSpec {
pub fn new_with_validation(identifier_str: &str) -> Result<Self, ApplicationError> {
pub fn new_with_validation(
identifier_str: &str,
) -> Result<Self, ApplicationError> {
let identifier = ModelIdentifier::new(identifier_str)?;
Ok(ModelSpec {
identifier,
info: None,
config: None,
config: None,
context_window_size: None,
input_token_limit: None,
})
Expand Down Expand Up @@ -107,7 +111,11 @@ impl ModelSpec {
self
}

pub fn set_config_value(&mut self, key: &str, value: serde_json::Value) -> &mut Self {
pub fn set_config_value(
&mut self,
key: &str,
value: serde_json::Value,
) -> &mut Self {
if let Some(config) = self.config.as_mut() {
if let serde_json::Value::Object(map) = config {
map.insert(key.to_string(), value);
Expand Down Expand Up @@ -142,20 +150,25 @@ impl ModelSpec {
}

pub fn set_family(&mut self, family: &str) -> &mut Self {
self.set_config_value("family", serde_json::Value::String(family.to_string()))
self.set_config_value(
"family",
serde_json::Value::String(family.to_string()),
)
}

pub fn get_family(&self) -> Option<&str> {
self.get_config_value("family")
.and_then(|v| v.as_str())
self.get_config_value("family").and_then(|v| v.as_str())
}

pub fn set_description(&mut self, description: &str) -> &mut Self {
self.set_config_value("description", serde_json::Value::String(description.to_string()))
self.set_config_value(
"description",
serde_json::Value::String(description.to_string()),
)
}

pub fn get_description(&self) -> Option<&str> {
self.get_config_value("description")
.and_then(|v| v.as_str())
}
}
}
3 changes: 2 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ mod reader;
mod store;

pub use reader::ConversationReader;
pub use super::conversation;
pub use store::ConversationDatabaseStore;

pub use super::conversation;
26 changes: 23 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex};
use rusqlite::{params, Error as SqliteError, OptionalExtension};

use super::connector::DatabaseConnector;
use super::conversation::{ConversationId, ModelIdentifier};
use super::conversation::{ConversationId, MessageId, ModelIdentifier};

pub struct ConversationReader<'a> {
conversation_id: ConversationId,
Expand All @@ -23,6 +23,10 @@ impl<'a> ConversationReader<'a> {
}

impl<'a> ConversationReader<'a> {
pub fn get_conversation_id(&self) -> ConversationId {
self.conversation_id
}

pub fn get_model_identifier(&self) -> Result<ModelIdentifier, SqliteError> {
let query = "
SELECT m.identifier
Expand Down Expand Up @@ -89,9 +93,25 @@ impl<'a> ConversationReader<'a> {
})
}

pub fn get_conversation_stats(
pub fn get_last_message_id(
&self,
) -> Result<(i64, i64), SqliteError> {
) -> Result<Option<MessageId>, SqliteError> {
let query = "
SELECT MAX(id) as last_message_id
FROM messages
WHERE conversation_id = ? AND is_deleted = FALSE
";

let mut db = self.db.lock().unwrap();
db.process_queue_with_result(|tx| {
tx.query_row(query, params![self.conversation_id.0], |row| {
row.get::<_, Option<i64>>(0)
.map(|opt_id| opt_id.map(MessageId))
})
})
}

pub fn get_conversation_stats(&self) -> Result<(i64, i64), SqliteError> {
let query = "SELECT message_count, total_tokens FROM conversations \
WHERE id = ?";
let mut db = self.db.lock().unwrap();
Expand Down
Loading

0 comments on commit 8c1dfab

Please sign in to comment.