Skip to content

Commit

Permalink
remove modelspec from server structs, add to promptinstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 20, 2024
1 parent e27da4c commit 8a2c000
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 135 deletions.
18 changes: 10 additions & 8 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ use super::server::{
};
use super::session::{AppSession, TabSession};
use super::tui::{
ColorScheme, CommandLineAction, KeyEventHandler, PromptAction, TabUi,
TextWindowTrait, WindowEvent, WindowKind,
ColorScheme, CommandLineAction, ConversationReader, KeyEventHandler, PromptAction, TabUi, TextWindowTrait, WindowEvent, WindowKind
};
pub use crate::external as lumni;

Expand All @@ -57,8 +56,11 @@ async fn prompt_app<B: Backend>(
let mut redraw_ui = true;

// TODO: reader should be updated when conversation_id changes
let conversation_id = tab.chat.get_conversation_id();
let mut reader = db_conn.get_conversation_reader(conversation_id);
let mut reader: Option<ConversationReader> = if let Some(conversation_id) = tab.chat.get_conversation_id() {
Some(db_conn.get_conversation_reader(conversation_id))
} else {
None
};

// Buffer to store the trimmed trailing newlines or empty spaces
let mut trim_buffer: Option<String> = None;
Expand All @@ -83,7 +85,7 @@ async fn prompt_app<B: Backend>(
&mut tab.chat,
mode,
keep_running.clone(),
&reader,
reader.as_ref(),
).await?
} else {
None
Expand Down Expand Up @@ -378,12 +380,12 @@ pub async fn run_cli(

// setup prompt, server and chat session
let prompt_instruction =
PromptInstruction::new(instruction, assistant, options, &db_conn)?;
PromptInstruction::new(default_model, instruction, assistant, options, &db_conn)?;
let conversation_id = prompt_instruction.get_conversation_id();

if let Some(model) = default_model {
if let Some(conversation_id) = conversation_id {
let reader = db_conn.get_conversation_reader(conversation_id);
server.setup_and_initialize(model, &reader).await?;
server.setup_and_initialize(&reader).await?;
}

let chat_session =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct ConversationCache {
impl ConversationCache {
pub fn new() -> Self {
ConversationCache {
conversation_id: ConversationId(0),
conversation_id: ConversationId(-1),
models: HashMap::new(),
messages: Vec::new(),
attachments: HashMap::new(),
Expand Down
25 changes: 24 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex};
use rusqlite::{params, Error as SqliteError, OptionalExtension};

use super::connector::DatabaseConnector;
use super::conversation::ConversationId;
use super::conversation::{ConversationId, ModelIdentifier};

pub struct ConversationReader<'a> {
conversation_id: ConversationId,
Expand All @@ -23,6 +23,29 @@ impl<'a> ConversationReader<'a> {
}

impl<'a> ConversationReader<'a> {
pub fn get_model_identifier(&self) -> Result<ModelIdentifier, SqliteError> {
let query = "
SELECT m.identifier
FROM conversations c
JOIN models m ON c.model_identifier = m.identifier
WHERE c.id = ?
";

let mut db = self.db.lock().unwrap();
db.process_queue_with_result(|tx| {
tx.query_row(query, params![self.conversation_id.0], |row| {
let identifier: String = row.get(0)?;
ModelIdentifier::new(&identifier).map_err(|e| {
SqliteError::FromSqlConversionFailure(
0,
rusqlite::types::Type::Text,
Box::new(e),
)
})
})
})
}

pub fn get_completion_options(
&self,
) -> Result<serde_json::Value, SqliteError> {
Expand Down
5 changes: 3 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl ConversationDatabaseStore {
parent_id: Option<ConversationId>,
fork_message_id: Option<MessageId>,
completion_options: Option<serde_json::Value>,
model: ModelSpec,
model: &ModelSpec,
model_server: ModelServerName,
) -> Result<ConversationId, SqliteError> {
let mut db = self.db.lock().unwrap();
Expand Down Expand Up @@ -65,7 +65,7 @@ impl ConversationDatabaseStore {
id: ConversationId(-1), // Temporary ID
name: name.to_string(),
info: serde_json::Value::Null,
model_identifier: model.identifier,
model_identifier: model.identifier.clone(),
model_server,
parent_conversation_id: parent_id,
fork_message_id,
Expand Down Expand Up @@ -100,6 +100,7 @@ impl ConversationDatabaseStore {
Ok(ConversationId(id))
})
}

pub fn put_new_message(
&self,
message: &Message,
Expand Down
73 changes: 44 additions & 29 deletions lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@ use lumni::api::error::ApplicationError;

use super::db::ConversationDatabaseStore;
use super::conversation::{
ConversationCache, ConversationId,
Message, MessageId, ModelSpec, ModelIdentifier, ModelServerName,
ConversationCache, ConversationId, Message, MessageId,
ModelServerName, ModelSpec
};
use super::prompt::Prompt;
use super::{ChatCompletionOptions, ChatMessage, PromptRole, PERSONAS};
pub use crate::external as lumni;

pub struct PromptInstruction {
cache: ConversationCache,
model: Option<ModelSpec>,
prompt_template: Option<String>,
conversation_id: Option<ConversationId>,
}

impl PromptInstruction {
pub fn new(
model: Option<ModelSpec>,
instruction: Option<String>,
assistant: Option<String>,
options: Option<&String>,
Expand All @@ -36,27 +39,34 @@ impl PromptInstruction {
None => serde_json::to_value(ChatCompletionOptions::default())?,
};

// Create a new Conversation in the database
let model = ModelSpec::new_with_validation("foo-provider::bar-model")?;
let conversation_id = {
db_conn.new_conversation(
"New Conversation",
None, // parent_id, None for new conversation
None, // fork_message_id, None for new conversation
Some(completion_options), // completion_options
model,
ModelServerName("ollama".to_string()),
)?
let conversation_id = if let Some(ref model) = &model {
// Create a new Conversation in the database
Some({
db_conn.new_conversation(
"New Conversation",
None, // parent_id, None for new conversation
None, // fork_message_id, None for new conversation
Some(completion_options), // completion_options
model,
ModelServerName("ollama".to_string()),
)?
})
} else {
None
};

let mut prompt_instruction = PromptInstruction {
cache: ConversationCache::new(),
model,
prompt_template: None,
conversation_id,
};

prompt_instruction
.cache
.set_conversation_id(conversation_id);
if let Some(conversation_id) = prompt_instruction.conversation_id {
prompt_instruction
.cache
.set_conversation_id(conversation_id);
}

if let Some(assistant) = assistant {
prompt_instruction.preload_from_assistant(
Expand All @@ -71,8 +81,12 @@ impl PromptInstruction {
Ok(prompt_instruction)
}

pub fn get_conversation_id(&self) -> ConversationId {
self.cache.get_conversation_id()
pub fn get_model(&self) -> Option<&ModelSpec> {
self.model.as_ref()
}

pub fn get_conversation_id(&self) -> Option<ConversationId> {
self.conversation_id
}

fn add_system_message(
Expand Down Expand Up @@ -140,17 +154,18 @@ impl PromptInstruction {
) -> Result<(), ApplicationError> {
// reset by creating a new conversation
// TODO: clone previous conversation settings
let model = ModelSpec::new_with_validation("foo-provider::bar-model")?;
let current_conversation_id =
db_conn.new_conversation(
"New Conversation",
None,
None,
None,
model,
ModelServerName("ollama".to_string()),
)?;
self.cache.set_conversation_id(current_conversation_id);
if let Some(ref model) = &self.model {
let current_conversation_id =
db_conn.new_conversation(
"New Conversation",
None,
None,
None,
model,
ModelServerName("ollama".to_string()),
)?;
self.cache.set_conversation_id(current_conversation_id);
};
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub mod conversation;
mod session;

pub use db::{ConversationDatabaseStore, ConversationReader};
pub use conversation::{ConversationId, ModelSpec};
pub use conversation::{ConversationId, ModelIdentifier, ModelSpec};
pub use instruction::PromptInstruction;
pub use options::ChatCompletionOptions;
use prompt::Prompt;
Expand Down
30 changes: 20 additions & 10 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,36 @@ impl ChatSession {
self.server.server_name()
}

pub fn get_conversation_id(&self) -> ConversationId {
pub fn get_conversation_id(&self) -> Option<ConversationId> {
self.prompt_instruction.get_conversation_id()
}

pub async fn change_server(
&mut self,
mut server: Box<dyn ServerManager>,
reader: &ConversationReader<'_>,
reader: Option<&ConversationReader<'_>>,
) -> Result<Option<WindowEvent>, ApplicationError> {
log::debug!("switching server: {}", server.server_name());
self.stop();

// TODO: update prompt instruction with new server / conversation
let model = server.get_default_model().await;
if let Some(model) = model {
server.setup_and_initialize(model, &reader).await?;
//let model = server.get_default_model().await;
if let Some(reader) = reader {
//self.prompt_instruction.set_conversation_id(reader.get_conversation_id());
server.setup_and_initialize(reader).await?;
}
self.server = server;
// TODO:
// add new events to handle server / conversation change
// if conversation_id changes, return new conversation_id as
// well to create a new ConversationReader
let new_conversation_id = self.get_conversation_id();
Ok(Some(WindowEvent::PromptWindow(Some(ConversationEvent::New(
new_conversation_id,
)))))
if let Some(new_conversation_id) = self.get_conversation_id() {
return Ok(Some(WindowEvent::PromptWindow(Some(ConversationEvent::New(
new_conversation_id,
)))));
} else {
return Ok(Some(WindowEvent::PromptWindow(None)));
}
}

pub fn stop(&mut self) {
Expand Down Expand Up @@ -99,6 +103,12 @@ impl ChatSession {
tx: mpsc::Sender<Bytes>,
question: &str,
) -> Result<(), ApplicationError> {
let model = if let Some(model) = self.prompt_instruction.get_model().cloned() {
model
} else {
return Err(ApplicationError::NotReady("Model not available".to_string()));
};

let max_token_length = self.server.get_max_context_size().await?;
let user_question = self.initiate_new_exchange(question).await?;
let messages = self
Expand All @@ -109,7 +119,7 @@ impl ChatSession {
self.cancel_tx = Some(cancel_tx); // channel to cancel

self.server
.completion(&messages, Some(tx), Some(cancel_rx))
.completion(&messages, &model, Some(tx), Some(cancel_rx))
.await?;
Ok(())
}
Expand Down
11 changes: 1 addition & 10 deletions lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub struct Bedrock {
spec: BedrockSpec,
http_client: HttpClient,
endpoints: Endpoints,
model: Option<ModelSpec>,
}

impl Bedrock {
Expand All @@ -49,7 +48,6 @@ impl Bedrock {
http_client: HttpClient::new()
.with_error_handler(Arc::new(AWSErrorHandler)),
endpoints,
model: None,
})
}

Expand Down Expand Up @@ -125,17 +123,11 @@ impl ServerTrait for Bedrock {

async fn initialize_with_model(
&mut self,
model: ModelSpec,
_reader: &ConversationReader,
) -> Result<(), ApplicationError> {
self.model = Some(model);
Ok(())
}

fn get_model(&self) -> Option<&ModelSpec> {
self.model.as_ref()
}

fn process_response(
&mut self,
mut response_bytes: Bytes,
Expand Down Expand Up @@ -205,11 +197,10 @@ impl ServerTrait for Bedrock {
async fn completion(
&self,
messages: &Vec<ChatMessage>,
model: &ModelSpec,
tx: Option<mpsc::Sender<Bytes>>,
cancel_rx: Option<oneshot::Receiver<()>>,
) -> Result<(), ApplicationError> {
let model = self.get_selected_model()?;

let resource = HttpClient::percent_encode_with_exclusion(
&format!("/model/{}.{}/converse-stream", model.get_model_provider(), model.get_model_name()),
Some(&[b'/', b'.', b'-']),
Expand Down
Loading

0 comments on commit 8a2c000

Please sign in to comment.