Skip to content

Commit

Permalink
replace legacy LLMDefinition with ModelSpec, move formatters to llama…
Browse files Browse the repository at this point in the history
… specific as its the only endpoint that only accepts raw inputs
  • Loading branch information
aprxi committed Jul 20, 2024
1 parent bf4ae7c commit e27da4c
Show file tree
Hide file tree
Showing 20 changed files with 195 additions and 186 deletions.
1 change: 1 addition & 0 deletions lumni/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ libc = "0.2"
syntect = { version = "5.2.0", default-features = false, features = ["parsing", "default-fancy"] }
crc32fast = { version = "1.4" }
rusqlite = { version = "0.31" }
lazy_static = { version = "1.5" }

# CLI
env_logger = { version = "0.9", optional = true }
Expand Down
1 change: 0 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ pub mod src {
mod chat;
mod defaults;
mod handler;
mod model;
mod server;
mod session;
mod tui;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::collections::HashMap;

use super::{ModelIdentifier, LLMModel, PromptRole,
use super::{ModelIdentifier, ModelSpec, PromptRole,
ConversationId, Message,
MessageId, AttachmentId, Attachment
};

#[derive(Debug)]
pub struct ConversationCache {
conversation_id: ConversationId,
models: HashMap<ModelIdentifier, LLMModel>,
models: HashMap<ModelIdentifier, ModelSpec>,
messages: Vec<Message>, // messages have to be ordered
attachments: HashMap<AttachmentId, Attachment>,
message_attachments: HashMap<MessageId, Vec<AttachmentId>>,
Expand Down Expand Up @@ -41,7 +41,7 @@ impl ConversationCache {
AttachmentId(self.attachments.len() as i64)
}

pub fn add_model(&mut self, model: LLMModel) {
pub fn add_model(&mut self, model: ModelSpec) {
self.models.insert(model.identifier.clone(), model);
}

Expand Down
10 changes: 1 addition & 9 deletions lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,11 @@ use serde::{Deserialize, Serialize};
mod model;
mod cache;

pub use model::LLMModel;
pub use model::{ModelIdentifier, ModelSpec};
pub use cache::ConversationCache;

use super::PromptRole;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelIdentifier(pub String);

impl ModelIdentifier {
pub fn new(provider: &str, name: &str) -> Self {
ModelIdentifier(format!("{}::{}", provider, name))
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelServerName(pub String);
Expand Down
58 changes: 52 additions & 6 deletions lumni/src/apps/builtin/llm/prompt/src/chat/conversation/model.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,70 @@
use lumni::api::error::ApplicationError;
use serde::{Deserialize, Serialize};
use super::ModelIdentifier;

pub use crate::external as lumni;

use lazy_static::lazy_static;
use regex::Regex;

lazy_static! {
static ref IDENTIFIER_REGEX: Regex = Regex::new(
r"^[-a-z0-9_]+::[-a-z0-9_][-a-z0-9_:.]*[-a-z0-9_]+$"
).unwrap();
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelIdentifier(pub String);

impl ModelIdentifier {
pub fn new(identifier_str: &str) -> Result<Self, ApplicationError> {
if IDENTIFIER_REGEX.is_match(identifier_str) {
Ok(ModelIdentifier(identifier_str.to_string()))
} else {
Err(ApplicationError::InvalidInput(format!(
"Identifier must be in the format 'provider::model_name', where the provider contains only lowercase letters, numbers, hyphens, underscores, and the model name can include internal colons but not start or end with them. Got: '{}'",
identifier_str
)))
}
}

pub fn get_model_provider(&self) -> &str {
// model provider is the first part of the identifier
self.0.split("::").next().unwrap()
}

pub fn get_model_name(&self) -> &str {
// model name is the second part of the identifier
self.0.split("::").nth(1).unwrap()
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMModel {
pub struct ModelSpec {
pub identifier: ModelIdentifier,
pub info: Option<serde_json::Value>,
pub config: Option<serde_json::Value>,
pub context_window_size: Option<i64>,
pub input_token_limit: Option<i64>,
}

impl LLMModel {
pub fn new(identifier: ModelIdentifier) -> Self {
LLMModel {
impl ModelSpec {
pub fn new_with_validation(identifier_str: &str) -> Result<Self, ApplicationError> {
let identifier = ModelIdentifier::new(identifier_str)?;
Ok(ModelSpec {
identifier,
info: None,
config: None,
context_window_size: None,
input_token_limit: None,
}
})
}

pub fn get_model_provider(&self) -> &str {
self.identifier.get_model_provider()
}

pub fn get_model_name(&self) -> &str {
self.identifier.get_model_name()
}

pub fn identifier(&self) -> &ModelIdentifier {
Expand Down
4 changes: 2 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::connector::DatabaseConnector;
use super::reader::ConversationReader;
use super::conversation::{
Attachment, AttachmentData, AttachmentId, Conversation, ConversationId,
Message, MessageId, LLMModel, ModelIdentifier, ModelServerName,
Message, MessageId, ModelSpec, ModelIdentifier, ModelServerName,
};

pub struct ConversationDatabaseStore {
Expand All @@ -34,7 +34,7 @@ impl ConversationDatabaseStore {
parent_id: Option<ConversationId>,
fork_message_id: Option<MessageId>,
completion_options: Option<serde_json::Value>,
model: LLMModel,
model: ModelSpec,
model_server: ModelServerName,
) -> Result<ConversationId, SqliteError> {
let mut db = self.db.lock().unwrap();
Expand Down
13 changes: 4 additions & 9 deletions lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use lumni::api::error::ApplicationError;
use super::db::ConversationDatabaseStore;
use super::conversation::{
ConversationCache, ConversationId,
Message, MessageId, LLMModel, ModelIdentifier, ModelServerName,
Message, MessageId, ModelSpec, ModelIdentifier, ModelServerName,
};
use super::prompt::Prompt;
use super::{ChatCompletionOptions, ChatMessage, PromptRole, PERSONAS};
Expand Down Expand Up @@ -35,11 +35,9 @@ impl PromptInstruction {
}
None => serde_json::to_value(ChatCompletionOptions::default())?,
};
// Create a new Conversation in the database
let model = LLMModel::new(
ModelIdentifier::new("foo-provider", "bar-model"),
);

// Create a new Conversation in the database
let model = ModelSpec::new_with_validation("foo-provider::bar-model")?;
let conversation_id = {
db_conn.new_conversation(
"New Conversation",
Expand Down Expand Up @@ -142,10 +140,7 @@ impl PromptInstruction {
) -> Result<(), ApplicationError> {
// reset by creating a new conversation
// TODO: clone previous conversation settings
let model = LLMModel::new(
ModelIdentifier::new("foo-provider", "bar-model"),
);

let model = ModelSpec::new_with_validation("foo-provider::bar-model")?;
let current_conversation_id =
db_conn.new_conversation(
"New Conversation",
Expand Down
7 changes: 4 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@ mod instruction;
mod options;
mod prompt;
mod send;
mod prompt_role;
pub mod conversation;
mod session;

pub use db::{ConversationDatabaseStore, ConversationReader};
pub use conversation::{ConversationId, LLMModel};
pub use conversation::{ConversationId, ModelSpec};
pub use instruction::PromptInstruction;
pub use options::ChatCompletionOptions;
use prompt::Prompt;
pub use send::{http_get_with_response, http_post, http_post_with_response};
pub use session::ChatSession;
pub use prompt_role::PromptRole;

pub use super::defaults::*;
pub use super::model::PromptRole;
pub use super::server::{CompletionResponse, LLMDefinition, ServerManager};
pub use super::server::{CompletionResponse, ServerManager};
pub use super::tui::{WindowEvent, ConversationEvent};

// gets PERSONAS from the generated code
Expand Down
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/options.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use super::{LLMDefinition, DEFAULT_N_PREDICT, DEFAULT_TEMPERATURE};
use super::{DEFAULT_N_PREDICT, DEFAULT_TEMPERATURE};

#[derive(Debug, Deserialize, Serialize)]
pub struct ChatCompletionOptions {
Expand Down
40 changes: 40 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/prompt_role.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::fmt::Display;

use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef};
use serde::{Deserialize, Serialize};


#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum PromptRole {
User,
Assistant,
System,
}

impl PromptRole {
pub fn to_string(&self) -> String {
match self {
PromptRole::User => "user",
PromptRole::Assistant => "assistant",
PromptRole::System => "system",
}
.to_string()
}
}

impl Display for PromptRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_string())
}
}

impl FromSql for PromptRole {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
match value.as_str()? {
"user" => Ok(PromptRole::User),
"assistant" => Ok(PromptRole::Assistant),
"system" => Ok(PromptRole::System),
_ => Err(FromSqlError::InvalidType.into()),
}
}
}
5 changes: 0 additions & 5 deletions lumni/src/apps/builtin/llm/prompt/src/model/mod.rs

This file was deleted.

21 changes: 10 additions & 11 deletions lumni/src/apps/builtin/llm/prompt/src/server/bedrock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use url::Url;

use super::{
http_post, ChatMessage, CompletionResponse, CompletionStats,
ConversationReader, Endpoints, LLMDefinition,
ConversationReader, Endpoints, ModelSpec,
PromptRole, ServerSpecTrait, ServerTrait,
};
pub use crate::external as lumni;
Expand All @@ -30,7 +30,7 @@ pub struct Bedrock {
spec: BedrockSpec,
http_client: HttpClient,
endpoints: Endpoints,
model: Option<LLMDefinition>,
model: Option<ModelSpec>,
}

impl Bedrock {
Expand All @@ -55,7 +55,7 @@ impl Bedrock {

fn completion_api_payload(
&self,
_model: &LLMDefinition,
_model: &ModelSpec,
chat_messages: &Vec<ChatMessage>,
) -> Result<String, serde_json::Error> {
// Check if the first message is a system prompt
Expand Down Expand Up @@ -125,14 +125,14 @@ impl ServerTrait for Bedrock {

async fn initialize_with_model(
&mut self,
model: LLMDefinition,
model: ModelSpec,
_reader: &ConversationReader,
) -> Result<(), ApplicationError> {
self.model = Some(model);
Ok(())
}

fn get_model(&self) -> Option<&LLMDefinition> {
fn get_model(&self) -> Option<&ModelSpec> {
self.model.as_ref()
}

Expand Down Expand Up @@ -211,7 +211,7 @@ impl ServerTrait for Bedrock {
let model = self.get_selected_model()?;

let resource = HttpClient::percent_encode_with_exclusion(
&format!("/model/{}/converse-stream", model.get_name()),
&format!("/model/{}.{}/converse-stream", model.get_model_provider(), model.get_model_name()),
Some(&[b'/', b'.', b'-']),
);
let completion_endpoint = self.endpoints.get_completion_endpoint()?;
Expand Down Expand Up @@ -259,11 +259,10 @@ impl ServerTrait for Bedrock {

async fn list_models(
&self,
) -> Result<Vec<LLMDefinition>, ApplicationError> {
let model = LLMDefinition::new(
"anthropic.claude-3-5-sonnet-20240620-v1:0".to_string(),
);
Ok(vec![model])
) -> Result<Vec<ModelSpec>, ApplicationError> {
Ok(vec![
ModelSpec::new_with_validation("anthropic::claude-3-5-sonnet-20240620-v1:0")?,
])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ impl ModelFormatterTrait for Llama3 {
};
let mut prompt_message = String::new();
prompt_message.push_str(&format!(
"<|start_header_id|>{}<|end_header_id|>\n{}{}",
"<|start_header_id|>{}<|end_header_id|>\n{}",
role_handle,
self.get_role_prefix(prompt_role),
message
));
if !message.is_empty() {
Expand Down
Loading

0 comments on commit e27da4c

Please sign in to comment.