Skip to content

Commit

Permalink
make chatsession run in a separate thread
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Aug 2, 2024
1 parent 10e840c commit 97c1345
Show file tree
Hide file tree
Showing 27 changed files with 678 additions and 536 deletions.
46 changes: 27 additions & 19 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub async fn run_cli(
let config_dir =
env.get_config_dir().expect("Config directory not defined");
let sqlite_file = config_dir.join("chat.db");
let db_conn = ConversationDatabase::new(&sqlite_file)?;
let db_conn = Arc::new(ConversationDatabase::new(&sqlite_file)?);

if let Some(db_matches) = matches.subcommand_matches("db") {
if db_matches.contains_id("list") {
Expand Down Expand Up @@ -154,42 +154,50 @@ pub async fn run_cli(

let mut db_handler = db_conn.get_conversation_handler(None);

let prompt_instruction = db_conn
.fetch_last_conversation_id()?
.and_then(|conversation_id| {
db_handler.set_conversation_id(conversation_id);
// Convert Result to Option using .ok()
if new_conversation.is_equal(&db_handler).ok()? {
let conversation_id = db_conn.fetch_last_conversation_id().await?;

let prompt_instruction = if let Some(conversation_id) = conversation_id {
db_handler.set_conversation_id(conversation_id);

match new_conversation.is_equal(&db_handler).await {
Ok(true) => {
log::debug!("Continuing last conversation");
Some(PromptInstruction::from_reader(&db_handler))
} else {
None
Some(PromptInstruction::from_reader(&db_handler).await?)
}
})
.unwrap_or_else(|| {
log::debug!("Starting new conversation");
PromptInstruction::new(new_conversation, &mut db_handler)
})?;
Ok(_) => None,
Err(e) => return Err(e.into()),
}
} else {
None
};

let chat_session = ChatSession::new(prompt_instruction);
let prompt_instruction = match prompt_instruction {
Some(instruction) => instruction,
None => {
log::debug!("Starting new conversation");
PromptInstruction::new(new_conversation, &mut db_handler).await?
}
};

match poll(Duration::from_millis(0)) {
Ok(_) => {
// Starting interactive session
let app = App::new(chat_session).await?;
let app =
App::new(prompt_instruction, Arc::clone(&db_conn)).await?;
interactive_mode(app, db_conn).await
}
Err(_) => {
// potential non-interactive input detected due to poll error.
// attempt to use in non interactive mode
let chat_session = ChatSession::new(prompt_instruction);
process_non_interactive_input(chat_session, db_conn).await
}
}
}

async fn interactive_mode(
app: App<'_>,
db_conn: ConversationDatabase,
db_conn: Arc<ConversationDatabase>,
) -> Result<(), ApplicationError> {
println!("Interactive mode detected. Starting interactive session:");
let mut stdout = io::stdout().lock();
Expand Down Expand Up @@ -235,7 +243,7 @@ async fn interactive_mode(

async fn process_non_interactive_input(
chat: ChatSession,
db_conn: ConversationDatabase,
db_conn: Arc<ConversationDatabase>,
) -> Result<(), ApplicationError> {
let chat = Arc::new(Mutex::new(chat));
let stdin = tokio::io::stdin();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ pub struct PromptInstruction {
}

impl PromptInstruction {
pub fn new(
pub async fn new(
new_conversation: NewConversation,
db_handler: &mut ConversationDbHandler<'_>,
db_handler: &mut ConversationDbHandler,
) -> Result<Self, ApplicationError> {
let completion_options = match new_conversation.options {
Some(opts) => {
Expand All @@ -33,13 +33,20 @@ impl PromptInstruction {
};

let conversation_id = if let Some(ref model) = new_conversation.model {
Some(db_handler.new_conversation(
"New Conversation",
new_conversation.parent.as_ref().map(|p| p.id),
new_conversation.parent.as_ref().map(|p| p.fork_message_id),
Some(serde_json::to_value(&completion_options)?),
model,
)?)
Some(
db_handler
.new_conversation(
"New Conversation",
new_conversation.parent.as_ref().map(|p| p.id),
new_conversation
.parent
.as_ref()
.map(|p| p.fork_message_id),
Some(serde_json::to_value(&completion_options)?),
model,
)
.await?,
)
} else {
None
};
Expand Down Expand Up @@ -85,18 +92,19 @@ impl PromptInstruction {
.set_preloaded_messages(messages_to_insert.len());

// Insert messages into the database
db_handler.put_new_messages(&messages_to_insert)?;
db_handler.put_new_messages(&messages_to_insert).await?;
} else if let Some(system_prompt) = new_conversation.system_prompt {
// add system_prompt as the first message
prompt_instruction
.add_system_message(system_prompt, db_handler)?;
.add_system_message(system_prompt, db_handler)
.await?;
}
}
Ok(prompt_instruction)
}

pub fn from_reader(
reader: &ConversationDbHandler<'_>,
pub async fn from_reader(
reader: &ConversationDbHandler,
) -> Result<Self, ApplicationError> {
// if conversation_id is none, it should err
let conversation_id =
Expand All @@ -107,10 +115,12 @@ impl PromptInstruction {
})?;
let model_spec = reader
.fetch_model_spec()
.await
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

let completion_options = reader
.fetch_completion_options()
.await
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

let completion_options: ChatCompletionOptions =
Expand Down Expand Up @@ -138,6 +148,7 @@ impl PromptInstruction {
// Load messages
let messages = reader
.fetch_messages()
.await
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
for message in messages {
prompt_instruction.cache.add_message(message);
Expand All @@ -146,6 +157,7 @@ impl PromptInstruction {
// Load attachments
let attachments = reader
.fetch_attachments()
.await
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;
for attachment in attachments {
prompt_instruction.cache.add_attachment(attachment);
Expand All @@ -169,10 +181,10 @@ impl PromptInstruction {
&self.completion_options
}

fn add_system_message(
async fn add_system_message(
&mut self,
content: String,
db_handler: &ConversationDbHandler<'_>,
db_handler: &ConversationDbHandler,
) -> Result<(), ApplicationError> {
let timestamp = Timestamp::from_system_time()?.as_millis();
let message = Message {
Expand All @@ -191,30 +203,11 @@ impl PromptInstruction {
is_deleted: false,
};
// put system message directly into the database
db_handler.put_new_message(&message)?;
db_handler.put_new_message(&message).await?;
self.cache.add_message(message);
Ok(())
}

pub fn reset_history(
&mut self,
db_handler: &mut ConversationDbHandler<'_>,
) -> Result<(), ApplicationError> {
// reset by creating a new conversation
// TODO: clone previous conversation settings
if let Some(ref model) = &self.model {
let current_conversation_id = db_handler.new_conversation(
"New Conversation",
None,
None,
None,
model,
)?;
self.cache.set_conversation_id(current_conversation_id);
};
Ok(())
}

pub fn append_last_response(&mut self, answer: &str) {
if let Some(last_message) = self.cache.get_last_message() {
if last_message.role == PromptRole::Assistant {
Expand All @@ -239,11 +232,11 @@ impl PromptInstruction {
.map(|msg| msg.content.clone())
}

pub fn put_last_response(
pub async fn put_last_response(
&mut self,
answer: &str,
tokens_predicted: Option<usize>,
db_handler: &ConversationDbHandler<'_>,
db_handler: &ConversationDbHandler,
) -> Result<(), ApplicationError> {
let (user_message, assistant_message) =
self.finalize_last_messages(answer, tokens_predicted)?;
Expand All @@ -262,6 +255,7 @@ impl PromptInstruction {
// Insert messages into the database
db_handler
.put_new_messages(&messages_to_insert)
.await
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

Ok(())
Expand Down
23 changes: 12 additions & 11 deletions lumni/src/apps/builtin/llm/prompt/src/chat/conversation/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,26 @@ pub struct NewConversation {
}

impl NewConversation {
pub fn new(
pub async fn new(
new_server: ModelServerName,
new_model: ModelSpec,
conversation_handler: &ConversationDbHandler<'_>,
conversation_handler: &ConversationDbHandler,
) -> Result<NewConversation, ApplicationError> {
match conversation_handler.get_conversation_id() {
Some(current_conversation_id) => {
// Fork from an existing conversation
let mut current_completion_options =
conversation_handler.fetch_completion_options()?;
conversation_handler.fetch_completion_options().await?;
current_completion_options["model_server"] =
serde_json::to_value(new_server.clone())?;

let parent = conversation_handler.fetch_last_message_id()?.map(
|last_message_id| ParentConversation {
let parent = conversation_handler
.fetch_last_message_id()
.await?
.map(|last_message_id| ParentConversation {
id: current_conversation_id,
fork_message_id: last_message_id,
},
);
});

create_new_conversation(
new_server,
Expand All @@ -57,20 +58,20 @@ impl NewConversation {
}
}

pub fn is_equal(
pub async fn is_equal(
&self,
handler: &ConversationDbHandler,
) -> Result<bool, ApplicationError> {
// check if conversation settings are equal to the conversation stored in the database

// Compare model
let last_model = handler.fetch_model_spec()?;
let last_model = handler.fetch_model_spec().await?;
if self.model.as_ref() != Some(&last_model) {
return Ok(false);
}

// Compare completion options (which includes server name and assistant)
let last_options = handler.fetch_completion_options()?;
let last_options = handler.fetch_completion_options().await?;
let new_options = match &self.options {
Some(opts) => opts.clone(),
None => serde_json::json!({}),
Expand All @@ -79,7 +80,7 @@ impl NewConversation {
return Ok(false);
}
// Compare system prompt. If the system prompt is not set in the new conversation, we check by first system prompt in the initial messages
let last_system_prompt = handler.fetch_system_prompt()?;
let last_system_prompt = handler.fetch_system_prompt().await?;
let new_system_prompt = match &self.system_prompt {
Some(prompt) => Some(prompt.as_str()),
None => self.initial_messages.as_ref().and_then(|messages| {
Expand Down
6 changes: 3 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ impl ConversationDatabase {
&self,
) -> Result<(), ApplicationError> {
if let Some((conversation, messages)) =
self.fetch_conversation(None, None)?
self.fetch_conversation(None, None).await?
{
display_conversation_with_messages(&conversation, &messages);
} else {
Expand All @@ -21,7 +21,7 @@ impl ConversationDatabase {
&self,
limit: usize,
) -> Result<(), ApplicationError> {
let conversations = self.fetch_conversation_list(limit)?;
let conversations = self.fetch_conversation_list(limit).await?;
for conversation in conversations {
println!(
"ID: {}, Name: {}, Updated: {}",
Expand All @@ -42,7 +42,7 @@ impl ConversationDatabase {
})?);

if let Some((conversation, messages)) =
self.fetch_conversation(Some(conversation_id), None)?
self.fetch_conversation(Some(conversation_id), None).await?
{
display_conversation_with_messages(&conversation, &messages);
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::*;

impl<'a> ConversationDbHandler<'a> {
pub fn permanent_delete_conversation(
impl ConversationDbHandler {
pub async fn permanent_delete_conversation(
&mut self,
conversation_id: Option<ConversationId>,
) -> Result<(), SqliteError> {
let target_conversation_id = conversation_id.or(self.conversation_id);

if let Some(id) = target_conversation_id {
let mut db = self.db.lock().unwrap();
let mut db = self.db.lock().await;
let result = db.process_queue_with_result(|tx| {
// Delete all attachments for the conversation
tx.execute(
Expand Down
Loading

0 comments on commit 97c1345

Please sign in to comment.