Skip to content

Commit

Permalink
db store fixes, prev exchange id, token count
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 16, 2024
1 parent d58b5ed commit e9c332e
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 55 deletions.
2 changes: 0 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ pub struct Exchange {
pub system_prompt: Option<String>,
pub completion_options: Option<serde_json::Value>,
pub prompt_options: Option<serde_json::Value>,
pub completion_tokens: Option<i64>,
pub prompt_tokens: Option<i64>,
pub created_at: i64,
pub previous_exchange_id: Option<ExchangeId>,
pub is_deleted: bool,
Expand Down
7 changes: 5 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ CREATE TABLE conversations (
schema_version INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
exchange_count INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
is_deleted BOOLEAN DEFAULT FALSE,
FOREIGN KEY (parent_conversation_id) REFERENCES conversations(id),
FOREIGN KEY (fork_exchange_id) REFERENCES exchanges(id)
Expand All @@ -33,11 +35,10 @@ CREATE TABLE exchanges (
system_prompt TEXT,
completion_options TEXT, -- JSON string
prompt_options TEXT, -- JSON string
completion_tokens INTEGER,
prompt_tokens INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
previous_exchange_id INTEGER,
is_deleted BOOLEAN DEFAULT FALSE,
is_latest BOOLEAN DEFAULT TRUE,
FOREIGN KEY (conversation_id) REFERENCES conversations(id),
FOREIGN KEY (model_id) REFERENCES models(model_id)
FOREIGN KEY (previous_exchange_id) REFERENCES exchanges(id)
Expand Down Expand Up @@ -79,6 +80,8 @@ CREATE INDEX idx_model_service ON models(model_service);
CREATE INDEX idx_conversation_id ON exchanges(conversation_id);
CREATE INDEX idx_exchange_id ON messages(exchange_id);
CREATE INDEX idx_parent_conversation ON conversations(parent_conversation_id);
CREATE INDEX idx_exchange_conversation_latest ON exchanges(conversation_id, is_latest);
CREATE INDEX idx_exchange_created_at ON exchanges(created_at);
CREATE INDEX idx_fork_exchange ON conversations(fork_exchange_id);
CREATE INDEX idx_model_id ON exchanges(model_id);
CREATE INDEX idx_conversation_created_at ON exchanges(conversation_id, created_at);
Expand Down
88 changes: 62 additions & 26 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::path::PathBuf;

use rusqlite::Error as SqliteError;
use rusqlite::{params, Error as SqliteError, OptionalExtension};

use super::connector::DatabaseConnector;
use super::schema::{
Attachment, AttachmentData, Conversation, ConversationId, Exchange, Message,
Attachment, AttachmentData, Conversation, ConversationId, Exchange,
ExchangeId, Message,
};

pub struct ConversationDatabaseStore {
Expand Down Expand Up @@ -57,12 +58,33 @@ impl ConversationDatabaseStore {
messages: &[Message],
attachments: &[Attachment],
) -> Result<(), SqliteError> {
// Insert the exchange
// Update the previous exchange to set is_latest to false
let last_exchange_id: Option<i64> =
self.db.process_queue_with_result(|tx| {
tx.query_row(
"SELECT id FROM exchanges WHERE conversation_id = ? AND \
is_latest = TRUE LIMIT 1",
params![exchange.conversation_id.0],
|row| row.get(0),
)
.optional()
})?;

// Update the previous exchange to set is_latest to false
if let Some(prev_id) = last_exchange_id {
let update_prev_sql = format!(
"UPDATE exchanges SET is_latest = FALSE WHERE id = {};",
prev_id
);
self.db.queue_operation(update_prev_sql);
}

// Insert the exchange (without token-related fields)
let exchange_sql = format!(
"INSERT INTO exchanges (conversation_id, model_id, system_prompt,
completion_options, prompt_options, completion_tokens,
prompt_tokens, created_at, previous_exchange_id, is_deleted)
VALUES ({}, {}, {}, {}, {}, {}, {}, {}, {}, {});",
completion_options, prompt_options, created_at, previous_exchange_id, \
is_deleted, is_latest)
VALUES ({}, {}, {}, {}, {}, {}, {}, {}, TRUE);",
exchange.conversation_id.0,
exchange.model_id.0,
exchange.system_prompt.as_ref().map_or(
Expand All @@ -77,29 +99,27 @@ impl ConversationDatabaseStore {
"NULL".to_string(),
|v| format!("'{}'", v.to_string().replace("'", "''"))
),
exchange
.completion_tokens
.map_or("NULL".to_string(), |t| t.to_string()),
exchange
.prompt_tokens
.map_or("NULL".to_string(), |t| t.to_string()),
exchange.created_at,
exchange
.previous_exchange_id
.map_or("NULL".to_string(), |id| id.0.to_string()),
last_exchange_id.map_or("NULL".to_string(), |id| id.to_string()),
exchange.is_deleted
);
self.db.queue_operation(exchange_sql);

// Insert messages
// Get the actual exchange_id from the database
let exchange_id = self.db.process_queue_with_result(|tx| {
Ok(ExchangeId(tx.last_insert_rowid()))
})?;

// Insert messages and calculate total token length
let mut total_tokens = 0;
for message in messages {
let message_sql = format!(
"INSERT INTO messages (conversation_id, exchange_id, role,
message_type, content, has_attachments, token_length,
created_at, is_deleted)
VALUES ({}, {}, '{}', '{}', '{}', {}, {}, {}, {});",
message.conversation_id.0,
message.exchange_id.0,
message_type, content, has_attachments, token_length,
created_at, is_deleted)
VALUES ({}, {}, '{}', '{}', '{}', {}, {}, {}, {});",
exchange.conversation_id.0,
exchange_id.0,
message.role.to_string(),
message.message_type,
message.content.replace("'", "''"),
Expand All @@ -111,18 +131,23 @@ impl ConversationDatabaseStore {
message.is_deleted
);
self.db.queue_operation(message_sql);

// Sum up token lengths
if let Some(token_length) = message.token_length {
total_tokens += token_length;
}
}

// Insert attachments
for attachment in attachments {
let attachment_sql = format!(
"INSERT INTO attachments (message_id, conversation_id,
exchange_id, file_uri, file_data, file_type, metadata,
created_at, is_deleted)
VALUES ({}, {}, {}, {}, {}, '{}', {}, {}, {});",
exchange_id, file_uri, file_data, file_type, metadata,
created_at, is_deleted)
VALUES ({}, {}, {}, {}, {}, '{}', {}, {}, {});",
attachment.message_id.0,
attachment.conversation_id.0,
attachment.exchange_id.0,
exchange.conversation_id.0,
exchange_id.0,
match &attachment.data {
AttachmentData::Uri(uri) =>
format!("'{}'", uri.replace("'", "''")),
Expand All @@ -144,6 +169,17 @@ impl ConversationDatabaseStore {
self.db.queue_operation(attachment_sql);
}

// Update conversation
let update_conversation_sql = format!(
"UPDATE conversations
SET updated_at = {},
exchange_count = exchange_count + 1,
total_tokens = total_tokens + {}
WHERE id = {};",
exchange.created_at, total_tokens, exchange.conversation_id.0
);
self.db.queue_operation(update_conversation_sql);

// Commit the transaction
self.commit_queued_operations()?;

Expand Down
41 changes: 22 additions & 19 deletions lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ impl PromptInstruction {
completion_options: serde_json::to_value(&self.completion_options)
.ok(),
prompt_options: serde_json::to_value(&self.prompt_options).ok(),
completion_tokens: None,
prompt_tokens: None,
created_at: 0,
previous_exchange_id: None,
is_deleted: false,
Expand All @@ -141,8 +139,6 @@ impl PromptInstruction {
system_prompt: last.system_prompt.clone(),
completion_options: last.completion_options.clone(),
prompt_options: last.prompt_options.clone(),
completion_tokens: None,
prompt_tokens: None,
created_at: 0,
previous_exchange_id: Some(last.id),
is_deleted: false,
Expand All @@ -159,7 +155,10 @@ impl PromptInstruction {
message_type: "text".to_string(),
content: self.system_prompt.clone(),
has_attachments: false,
token_length: Some(simple_token_estimator(&self.system_prompt, None)),
token_length: Some(simple_token_estimator(
&self.system_prompt,
None,
)),
created_at: 0,
is_deleted: false,
};
Expand All @@ -170,8 +169,6 @@ impl PromptInstruction {
// return subsequent exchange
Exchange {
id: self.cache.new_exchange_id(),
prompt_tokens: None,
completion_tokens: None,
previous_exchange_id: Some(exchange.id),
..exchange
}
Expand Down Expand Up @@ -456,31 +453,38 @@ impl ExchangeHandler {
let exchange_data = {
let last_exchange = cache.get_last_exchange()?;
let messages = cache.get_exchange_messages(last_exchange.id);
let user_message = messages.iter().find(|m| m.role == PromptRole::User)?;
let assistant_message = messages.iter().find(|m| m.role == PromptRole::Assistant)?;

let user_message =
messages.iter().find(|m| m.role == PromptRole::User)?;
let assistant_message =
messages.iter().find(|m| m.role == PromptRole::Assistant)?;

(
last_exchange.clone(),
assistant_message.id,
user_message.id,
user_message.content.clone(),
)
};

let (exchange, assistant_message_id, user_message_id, user_content) = exchange_data;


let (exchange, assistant_message_id, user_message_id, user_content) =
exchange_data;

// Calculate user token length
let user_token_length = tokens_predicted.map(|tokens| {
let chars_per_token = calculate_chars_per_token(answer, tokens);
simple_token_estimator(&user_content, Some(chars_per_token))
});

// Perform all updates in a single mutable borrow
{
if let Some(tokens) = tokens_predicted {
// Update assistant's message
cache.update_message_by_id(assistant_message_id, answer, Some(tokens as i64));

cache.update_message_by_id(
assistant_message_id,
answer,
Some(tokens as i64),
);

// Update user's message token length
if let Some(length) = user_token_length {
cache.update_message_token_length(user_message_id, length);
Expand All @@ -490,10 +494,9 @@ impl ExchangeHandler {
cache.update_message_by_id(assistant_message_id, answer, None);
}
}

Some(exchange)
}

}

fn simple_token_estimator(input: &str, chars_per_token: Option<f32>) -> i64 {
Expand All @@ -506,4 +509,4 @@ fn simple_token_estimator(input: &str, chars_per_token: Option<f32>) -> i64 {
fn calculate_chars_per_token(answer: &str, tokens_predicted: usize) -> f32 {
let char_count = answer.chars().count() as f32;
char_count / tokens_predicted as f32
}
}
10 changes: 4 additions & 6 deletions lumni/src/apps/builtin/llm/prompt/src/chat/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,10 @@ impl ChatSession {
.server
.get_context_size(&mut self.prompt_instruction)
.await?;
let user_question =
self.initiate_new_exchange(question).await?;
let messages = self.prompt_instruction.new_exchange(
&user_question,
max_token_length,
);
let user_question = self.initiate_new_exchange(question).await?;
let messages = self
.prompt_instruction
.new_exchange(&user_question, max_token_length);

let (cancel_tx, cancel_rx) = oneshot::channel();
self.cancel_tx = Some(cancel_tx); // channel to cancel
Expand Down

0 comments on commit e9c332e

Please sign in to comment.