Skip to content

Commit

Permalink
remove legacy prompt_options
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 17, 2024
1 parent d1c03c0 commit 8f35b84
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 135 deletions.
1 change: 0 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ pub struct Conversation {
pub parent_conversation_id: Option<ConversationId>,
pub fork_exchange_id: Option<ExchangeId>,
pub completion_options: Option<serde_json::Value>,
pub prompt_options: Option<serde_json::Value>,
pub schema_version: i64,
pub created_at: i64,
pub updated_at: i64,
Expand Down
1 change: 0 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ CREATE TABLE conversations (
parent_conversation_id INTEGER,
fork_exchange_id INTEGER,
completion_options TEXT, -- JSON string
prompt_options TEXT, -- JSON string
schema_version INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
Expand Down
53 changes: 19 additions & 34 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ impl ConversationDatabaseStore {
name: &str,
parent_id: Option<ConversationId>,
completion_options: Option<serde_json::Value>,
prompt_options: Option<serde_json::Value>,
) -> Result<ConversationId, SqliteError> {
let conversation = Conversation {
id: ConversationId(-1), // Temporary ID
Expand All @@ -35,7 +34,6 @@ impl ConversationDatabaseStore {
parent_conversation_id: parent_id,
fork_exchange_id: None,
completion_options,
prompt_options,
schema_version: 1,
created_at: 0,
updated_at: 0,
Expand Down Expand Up @@ -74,10 +72,10 @@ impl ConversationDatabaseStore {
"INSERT INTO conversations (
name, metadata, model_id, parent_conversation_id, \
fork_exchange_id,
completion_options, prompt_options, schema_version,
completion_options, schema_version,
created_at, updated_at, is_deleted
)
VALUES ('{}', {}, {}, {}, {}, {}, {}, {}, {}, {}, {});",
VALUES ('{}', {}, {}, {}, {}, {}, {}, {}, {}, {});",
conversation.name.replace("'", "''"),
serde_json::to_string(&conversation.metadata)
.map(|s| format!("'{}'", s.replace("'", "''")))
Expand All @@ -96,13 +94,6 @@ impl ConversationDatabaseStore {
serde_json::to_string(v).unwrap().replace("'", "''")
)
),
conversation.prompt_options.as_ref().map_or(
"NULL".to_string(),
|v| format!(
"'{}'",
serde_json::to_string(v).unwrap().replace("'", "''")
)
),
conversation.schema_version,
conversation.created_at,
conversation.updated_at,
Expand Down Expand Up @@ -282,26 +273,23 @@ impl ConversationDatabaseStore {
completion_options: row
.get::<_, Option<String>>(6)?
.map(|s| serde_json::from_str(&s).unwrap_or_default()),
prompt_options: row
.get::<_, Option<String>>(7)?
.map(|s| serde_json::from_str(&s).unwrap_or_default()),
schema_version: row.get(8)?,
created_at: row.get(9)?,
updated_at: row.get(10)?,
is_deleted: row.get(11)?,
schema_version: row.get(7)?,
created_at: row.get(8)?,
updated_at: row.get(9)?,
is_deleted: row.get(10)?,
};

let message = if !row.get::<_, Option<i64>>(14)?.is_none() {
let message = if !row.get::<_, Option<i64>>(13)?.is_none() {
Some(Message {
id: MessageId(row.get(14)?),
id: MessageId(row.get(13)?),
conversation_id: conversation.id,
exchange_id: ExchangeId(row.get(0)?),
role: row.get(15)?,
message_type: row.get(16)?,
content: row.get(17)?,
has_attachments: row.get(18)?,
token_length: row.get(19)?,
created_at: row.get(20)?,
role: row.get(14)?,
message_type: row.get(15)?,
content: row.get(16)?,
has_attachments: row.get(17)?,
token_length: row.get(18)?,
created_at: row.get(19)?,
is_deleted: false,
})
} else {
Expand Down Expand Up @@ -335,7 +323,7 @@ impl ConversationDatabaseStore {
let query = format!(
"SELECT id, name, metadata, model_id, parent_conversation_id, \
fork_exchange_id,
completion_options, prompt_options, schema_version,
completion_options, schema_version,
created_at, updated_at, is_deleted
FROM conversations
WHERE is_deleted = FALSE
Expand Down Expand Up @@ -363,13 +351,10 @@ impl ConversationDatabaseStore {
completion_options: row
.get::<_, Option<String>>(6)?
.map(|s| serde_json::from_str(&s).unwrap_or_default()),
prompt_options: row
.get::<_, Option<String>>(7)?
.map(|s| serde_json::from_str(&s).unwrap_or_default()),
schema_version: row.get(8)?,
created_at: row.get(9)?,
updated_at: row.get(10)?,
is_deleted: row.get(11)?,
schema_version: row.get(7)?,
created_at: row.get(8)?,
updated_at: row.get(9)?,
is_deleted: row.get(10)?,
})
})?;

Expand Down
47 changes: 13 additions & 34 deletions lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@ use std::collections::HashMap;
use lumni::api::error::ApplicationError;

use super::db::{
self, ConversationCache, ConversationDatabaseStore, ConversationId,
Exchange, ExchangeId, Message,
ConversationCache, ConversationDatabaseStore, ConversationId, Exchange,
ExchangeId, Message,
};
use super::prompt::Prompt;
use super::{
ChatCompletionOptions, ChatMessage, LLMDefinition, PromptOptions,
PromptRole, DEFAULT_N_PREDICT, DEFAULT_TEMPERATURE, PERSONAS,
ChatCompletionOptions, ChatMessage, PromptRole,
DEFAULT_N_PREDICT, DEFAULT_TEMPERATURE, PERSONAS,
};
pub use crate::external as lumni;

pub struct PromptInstruction {
cache: ConversationCache,
completion_options: ChatCompletionOptions,
prompt_options: PromptOptions,
prompt_template: Option<String>,
}

Expand All @@ -31,7 +30,6 @@ impl Default for PromptInstruction {
PromptInstruction {
cache: ConversationCache::new(),
completion_options,
prompt_options: PromptOptions::default(),
prompt_template: None,
}
}
Expand All @@ -45,8 +43,9 @@ impl PromptInstruction {
db_conn: &ConversationDatabaseStore,
) -> Result<Self, ApplicationError> {
let mut prompt_instruction = PromptInstruction::default();
//let mut prompt_options = PromptOptions::default();
if let Some(json_str) = options {
prompt_instruction.prompt_options.update_from_json(json_str);
//prompt_options.update_from_json(json_str);
prompt_instruction
.completion_options
.update_from_json(json_str);
Expand All @@ -68,7 +67,7 @@ impl PromptInstruction {
None,
serde_json::to_value(&prompt_instruction.completion_options)
.ok(),
serde_json::to_value(&prompt_instruction.prompt_options).ok(),
//serde_json::to_value(&prompt_options).ok(),
)?
};
prompt_instruction
Expand Down Expand Up @@ -152,7 +151,7 @@ impl PromptInstruction {
// reset by creating a new conversation
// TODO: clone previous conversation settings
let current_conversation_id =
db_conn.new_conversation("New Conversation", None, None, None)?;
db_conn.new_conversation("New Conversation", None, None)?;
self.cache.set_conversation_id(current_conversation_id);
Ok(())
}
Expand Down Expand Up @@ -231,9 +230,6 @@ impl PromptInstruction {
question: &str,
max_token_length: usize,
) -> Vec<ChatMessage> {
// token budget for the system prompt
let system_prompt_token_length = self.get_n_keep().unwrap_or(0);

// add the partial exchange (question) to the conversation
let exchange = self.subsequent_exchange();

Expand All @@ -257,7 +253,7 @@ impl PromptInstruction {

// Collect messages while respecting token limits
let mut messages: Vec<ChatMessage> = Vec::new();
let mut total_tokens = system_prompt_token_length;
let mut total_tokens = 0;

let mut system_message: Option<ChatMessage> = None;

Expand All @@ -273,11 +269,14 @@ impl PromptInstruction {
msg.token_length.map(|len| len as usize).unwrap_or(0);

if msg.role == PromptRole::System {
// store system_prompt for later insertion at the beginning
system_message = Some(ChatMessage {
role: msg.role,
content: msg.content.clone(),
});
continue; // system prompt is included separately
// system prompt is always included
total_tokens += msg_token_length;
continue;
}
if total_tokens + msg_token_length <= max_token_length {
total_tokens += msg_token_length;
Expand Down Expand Up @@ -309,26 +308,6 @@ impl PromptInstruction {
&self.completion_options
}

pub fn set_model(&mut self, model: &LLMDefinition) {
self.completion_options.update_from_model(model);
}

pub fn get_role_prefix(&self, role: PromptRole) -> &str {
self.prompt_options.get_role_prefix(role)
}

pub fn get_context_size(&self) -> Option<usize> {
self.prompt_options.get_context_size()
}

pub fn set_context_size(&mut self, context_size: usize) {
self.prompt_options.set_context_size(context_size);
}

pub fn get_n_keep(&self) -> Option<usize> {
self.completion_options.get_n_keep()
}

pub fn get_prompt_template(&self) -> Option<&str> {
self.prompt_template.as_deref()
}
Expand Down
19 changes: 0 additions & 19 deletions lumni/src/apps/builtin/llm/prompt/src/chat/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,6 @@ impl Default for RolePrefix {
}
}

impl RolePrefix {
fn get_role_prefix(&self, prompt_role: PromptRole) -> &str {
match prompt_role {
PromptRole::User => &self.user,
PromptRole::Assistant => &self.assistant,
PromptRole::System => &self.system,
}
}
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PromptOptions {
n_ctx: Option<usize>,
Expand Down Expand Up @@ -154,13 +144,4 @@ impl PromptOptions {
pub fn get_context_size(&self) -> Option<usize> {
self.n_ctx
}

pub fn set_context_size(&mut self, context_size: usize) -> &mut Self {
self.n_ctx = Some(context_size);
self
}

pub fn get_role_prefix(&self, prompt_role: PromptRole) -> &str {
self.role_prefix.get_role_prefix(prompt_role)
}
}
5 changes: 1 addition & 4 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ impl ChatSession {
tx: mpsc::Sender<Bytes>,
question: &str,
) -> Result<(), ApplicationError> {
let max_token_length = self
.server
.get_context_size(&mut self.prompt_instruction)
.await?;
let max_token_length = self.server.get_max_context_size().await?;
let user_question = self.initiate_new_exchange(question).await?;
let messages = self
.prompt_instruction
Expand Down
31 changes: 8 additions & 23 deletions lumni/src/apps/builtin/llm/prompt/src/server/llama/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,8 @@ impl Llama {

let system_prompt = LlamaServerSystemPrompt::new(
instruction.to_string(),
prompt_instruction
.get_role_prefix(PromptRole::User)
.to_string(),
prompt_instruction
.get_role_prefix(PromptRole::Assistant)
.to_string(),
format!("### {}", PromptRole::User.to_string()),
format!("### {}", PromptRole::Assistant.to_string()),
);
let payload = LlamaServerPayload {
prompt: "",
Expand Down Expand Up @@ -206,23 +202,12 @@ impl ServerTrait for Llama {
self.model.as_ref()
}

async fn get_context_size(
&self,
prompt_instruction: &mut PromptInstruction,
) -> Result<usize, ApplicationError> {
let context_size = prompt_instruction.get_context_size();
match context_size {
Some(size) => Ok(size), // Return the context size if it's already set
None => {
// fetch the context size, and store it in the prompt options
let context_size = match self.get_props().await {
Ok(props) => props.get_n_ctx(),
Err(_) => DEFAULT_CONTEXT_SIZE,
};
prompt_instruction.set_context_size(context_size);
Ok(context_size)
}
}
async fn get_max_context_size(&self) -> Result<usize, ApplicationError> {
let context_size = match self.get_props().await {
Ok(props) => props.get_n_ctx(),
Err(_) => DEFAULT_CONTEXT_SIZE,
};
Ok(context_size)
}
}

Expand Down
27 changes: 8 additions & 19 deletions lumni/src/apps/builtin/llm/prompt/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,14 @@ impl ServerTrait for ModelServer {
}
}

async fn get_context_size(
&self,
prompt_instruction: &mut PromptInstruction,
) -> Result<usize, ApplicationError> {
async fn get_max_context_size(&self) -> Result<usize, ApplicationError> {
match self {
ModelServer::Llama(llama) => {
llama.get_context_size(prompt_instruction).await
}
ModelServer::Ollama(ollama) => {
ollama.get_context_size(prompt_instruction).await
}
ModelServer::Llama(llama) => llama.get_max_context_size().await,
ModelServer::Ollama(ollama) => ollama.get_max_context_size().await,
ModelServer::Bedrock(bedrock) => {
bedrock.get_context_size(prompt_instruction).await
}
ModelServer::OpenAI(openai) => {
openai.get_context_size(prompt_instruction).await
bedrock.get_max_context_size().await
}
ModelServer::OpenAI(openai) => openai.get_max_context_size().await,
}
}

Expand Down Expand Up @@ -260,10 +251,7 @@ pub trait ServerTrait: Send + Sync {
start_of_stream: bool,
) -> Option<CompletionResponse>;

async fn get_context_size(
&self,
_prompt_instruction: &mut PromptInstruction,
) -> Result<usize, ApplicationError> {
async fn get_max_context_size(&self) -> Result<usize, ApplicationError> {
Ok(DEFAULT_CONTEXT_SIZE)
}

Expand All @@ -285,7 +273,8 @@ pub trait ServerManager: ServerTrait {
) -> Result<(), ApplicationError> {
log::debug!("Initializing server with model: {:?}", model);
// update completion options from the model, i.e. stop tokens
prompt_instruction.set_model(&model);
// TODO: prompt_intruction should be re-initialized with the model
//prompt_instruction.set_model(&model);
self.initialize_with_model(model, prompt_instruction).await
}

Expand Down

0 comments on commit 8f35b84

Please sign in to comment.