Skip to content

Commit

Permalink
separate assistant logic out of promptinstruction, wip - load from re…
Browse files Browse the repository at this point in the history
…ader
  • Loading branch information
aprxi committed Jul 21, 2024
1 parent 8c1dfab commit eaecbde
Show file tree
Hide file tree
Showing 12 changed files with 389 additions and 219 deletions.
29 changes: 23 additions & 6 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ use tokio::signal;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, timeout, Duration};

use super::chat::{ChatSession, ConversationDatabaseStore, NewConversation};
use super::server::{
ModelServer, ModelServerName, PromptInstruction, ServerManager, ServerTrait,
use super::chat::{
AssistantManager, ChatSession, ConversationDatabaseStore, NewConversation,
PromptInstruction,
};
use super::server::{ModelServer, ModelServerName, ServerManager, ServerTrait};
use super::session::{AppSession, TabSession};
use super::tui::{
ColorScheme, CommandLineAction, ConversationEvent, ConversationReader,
Expand Down Expand Up @@ -143,11 +144,18 @@ async fn prompt_app<B: Backend>(
new_conversation.clone(),
&db_conn,
)?;
// assume prompt template does not change
// TODO: should be optional defined via modal
let prompt_template = tab.chat.get_prompt_template();
let chat_session = ChatSession::new(
&new_conversation.server.to_string(),
prompt_instruction,
prompt_template,
&db_conn,
).await?;
// stop current chat session
tab.chat.stop();
// update tab with new chat session
tab.new_conversation(chat_session);
reader = if let Some(conversation_id) = tab.chat.get_conversation_id() {
Some(db_conn.get_conversation_reader(conversation_id))
Expand Down Expand Up @@ -406,20 +414,29 @@ pub async fn run_cli(
}
};

let assistant_manager =
AssistantManager::new(assistant, instruction.clone())?;
let initial_messages = assistant_manager.get_initial_messages().to_vec();

let prompt_instruction = PromptInstruction::new(
NewConversation {
server: ModelServerName::from_str(&server_name),
model: default_model,
options,
system_prompt: instruction,
assistant_name: assistant,
initial_messages: Some(initial_messages),
parent: None,
},
&db_conn,
)?;

let chat_session =
ChatSession::new(&server_name, prompt_instruction, &db_conn).await?;
let chat_session = ChatSession::new(
&server_name,
prompt_instruction,
assistant_manager.get_prompt_template(),
&db_conn,
)
.await?;

match poll(Duration::from_millis(0)) {
Ok(_) => {
Expand Down
145 changes: 145 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/assistant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use lumni::api::error::ApplicationError;

use super::conversation::{ConversationId, Message, MessageId};
use super::prompt::Prompt;
use super::{PromptRole, PERSONAS};
pub use crate::external as lumni;

pub struct AssistantManager {
prompt_template: Option<String>,
initial_messages: Vec<Message>,
}

impl AssistantManager {
pub fn new(
assistant_name: Option<String>,
user_instruction: Option<String>,
) -> Result<Self, ApplicationError> {
let mut manager = AssistantManager {
prompt_template: None,
initial_messages: Vec::new(),
};

// TODO: default should only apply to servers that do no handle this internally
// Use default assistant when both system_promt and assistant_name are None
let assistant_name = assistant_name.or_else(|| {
if user_instruction.is_some() {
// assistant not needed
None
} else {
// no user instruction, use default assistant
Some("Default".to_string())
}
});

if let Some(assistant_name) = assistant_name {
manager.load_assistant(assistant_name, user_instruction)?;
}

Ok(manager)
}

fn load_assistant(
&mut self,
assistant: String,
user_instruction: Option<String>,
) -> Result<(), ApplicationError> {
let assistant_prompts: Vec<Prompt> = serde_yaml::from_str(PERSONAS)
.map_err(|e| {
ApplicationError::Unexpected(format!(
"Failed to parse persona data: {}",
e
))
})?;

let prompt = assistant_prompts
.into_iter()
.find(|p| p.name() == assistant)
.ok_or_else(|| {
ApplicationError::Unexpected(format!(
"Assistant '{}' not found in the dataset",
assistant
))
})?;

let system_prompt = build_system_prompt(&prompt, &user_instruction);

// Add system message
self.initial_messages.push(Message {
id: MessageId(0), // system message is always the first message
conversation_id: ConversationId(0), // temporary conversation id
role: PromptRole::System,
message_type: "text".to_string(),
content: system_prompt,
has_attachments: false,
token_length: None,
previous_message_id: None,
created_at: 0,
is_deleted: false,
});

// Add exchanges if any
if let Some(exchanges) = prompt.exchanges() {
for (index, exchange) in exchanges.iter().enumerate() {
// User message
self.initial_messages.push(Message {
id: MessageId((index * 2 + 1) as i64),
conversation_id: ConversationId(0), // temporary conversation id
role: PromptRole::User,
message_type: "text".to_string(),
content: exchange.question.clone(),
has_attachments: false,
token_length: None,
previous_message_id: Some(MessageId((index * 2) as i64)),
created_at: 0,
is_deleted: false,
});

// Assistant message
self.initial_messages.push(Message {
id: MessageId((index * 2 + 2) as i64),
conversation_id: ConversationId(0), // temporary conversation id
role: PromptRole::Assistant,
message_type: "text".to_string(),
content: exchange.answer.clone(),
has_attachments: false,
token_length: None,
previous_message_id: Some(MessageId(
(index * 2 + 1) as i64,
)),
created_at: 0,
is_deleted: false,
});
}
}

if let Some(prompt_template) = prompt.prompt_template() {
self.prompt_template = Some(prompt_template.to_string());
}
Ok(())
}

pub fn get_prompt_template(&self) -> Option<String> {
self.prompt_template.clone()
}

pub fn get_initial_messages(&self) -> &[Message] {
&self.initial_messages
}
}

fn build_system_prompt(
prompt: &Prompt,
user_instruction: &Option<String>,
) -> String {
match (prompt.system_prompt(), user_instruction) {
(Some(assistant_instruction), Some(user_instr)) => {
format!("{} {}", assistant_instruction.trim_end(), user_instr)
}
(Some(assistant_instruction), None) => {
assistant_instruction.to_string()
}
(None, Some(user_instr)) => user_instr.to_string(),
(None, None) => String::new(),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ pub struct Conversation {
pub name: String,
pub info: serde_json::Value,
pub model_identifier: ModelIdentifier,
pub model_server: ModelServerName,
pub parent_conversation_id: Option<ConversationId>,
pub fork_message_id: Option<MessageId>, // New field
pub completion_options: Option<serde_json::Value>,
Expand Down
98 changes: 97 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::sync::{Arc, Mutex};
use rusqlite::{params, Error as SqliteError, OptionalExtension};

use super::connector::DatabaseConnector;
use super::conversation::{ConversationId, MessageId, ModelIdentifier};
use super::conversation::{
Attachment, AttachmentData, AttachmentId, ConversationId, Message,
MessageId, ModelIdentifier, ModelSpec,
};

pub struct ConversationReader<'a> {
conversation_id: ConversationId,
Expand Down Expand Up @@ -73,6 +76,99 @@ impl<'a> ConversationReader<'a> {
})
}

pub fn get_model_spec(&self) -> Result<ModelSpec, SqliteError> {
let query = "
SELECT m.identifier, m.info, m.config, m.context_window_size, \
m.input_token_limit
FROM conversations c
JOIN models m ON c.model_identifier = m.identifier
WHERE c.id = ?
";
let mut db = self.db.lock().unwrap();
db.process_queue_with_result(|tx| {
tx.query_row(query, params![self.conversation_id.0], |row| {
Ok(ModelSpec {
identifier: ModelIdentifier::new(&row.get::<_, String>(0)?)
.unwrap(),
info: row
.get::<_, Option<String>>(1)?
.map(|s| serde_json::from_str(&s).unwrap()),
config: row
.get::<_, Option<String>>(2)?
.map(|s| serde_json::from_str(&s).unwrap()),
context_window_size: row.get(3)?,
input_token_limit: row.get(4)?,
})
})
})
}

pub fn get_all_messages(&self) -> Result<Vec<Message>, SqliteError> {
let query = "
SELECT id, role, message_type, content, has_attachments, \
token_length, previous_message_id, created_at, is_deleted
FROM messages
WHERE conversation_id = ? AND is_deleted = FALSE
ORDER BY created_at ASC
";
let mut db = self.db.lock().unwrap();
db.process_queue_with_result(|tx| {
tx.prepare(query)?
.query_map(params![self.conversation_id.0], |row| {
Ok(Message {
id: MessageId(row.get(0)?),
conversation_id: self.conversation_id,
role: row.get(1)?,
message_type: row.get(2)?,
content: row.get(3)?,
has_attachments: row.get::<_, i64>(4)? != 0,
token_length: row.get(5)?,
previous_message_id: row
.get::<_, Option<i64>>(6)?
.map(MessageId),
created_at: row.get(7)?,
is_deleted: row.get::<_, i64>(8)? != 0,
})
})?
.collect()
})
}

pub fn get_all_attachments(&self) -> Result<Vec<Attachment>, SqliteError> {
let query = "
SELECT attachment_id, message_id, file_uri, file_data, file_type, \
metadata, created_at, is_deleted
FROM attachments
WHERE conversation_id = ? AND is_deleted = FALSE
ORDER BY created_at ASC
";
let mut db = self.db.lock().unwrap();
db.process_queue_with_result(|tx| {
tx.prepare(query)?
.query_map(params![self.conversation_id.0], |row| {
Ok(Attachment {
attachment_id: AttachmentId(row.get(0)?),
message_id: MessageId(row.get(1)?),
conversation_id: self.conversation_id,
data: if let Some(uri) =
row.get::<_, Option<String>>(2)?
{
AttachmentData::Uri(uri)
} else {
AttachmentData::Data(row.get(3)?)
},
file_type: row.get(4)?,
metadata: row.get::<_, Option<String>>(5)?.map(|s| {
serde_json::from_str(&s).unwrap_or_default()
}),
created_at: row.get(6)?,
is_deleted: row.get::<_, i64>(7)? != 0,
})
})?
.collect()
})
}

pub fn get_system_prompt(&self) -> Result<Option<String>, SqliteError> {
let query = "
SELECT content
Expand Down
1 change: 0 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ CREATE TABLE conversations (
info TEXT, -- JSON string including description and other metadata
completion_options TEXT, -- JSON string
model_identifier TEXT NOT NULL,
model_server TEXT NOT NULL,
parent_conversation_id INTEGER,
fork_message_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
Expand Down
Loading

0 comments on commit eaecbde

Please sign in to comment.