Skip to content

Commit

Permalink
code cleanups, prep refactor easy converstation switch
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Sep 13, 2024
1 parent c2f2549 commit 77bbea6
Show file tree
Hide file tree
Showing 28 changed files with 403 additions and 871 deletions.
4 changes: 2 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ use ratatui::Terminal;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
use tokio::signal;
use tokio::sync::Mutex;
use tokio::time::{timeout, Duration};
use tokio::time::Duration;

use super::chat::db::ConversationDatabase;
use super::chat::{
prompt_app, App, ChatEvent, PromptInstruction, PromptInstructionBuilder,
prompt_app, App, PromptInstruction, PromptInstructionBuilder,
ThreadedChatSession,
};
use super::cli::{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use lumni::api::error::ApplicationError;

use super::db::{
ConversationCache, ConversationDbHandler, ConversationId, Message,
MessageId, ModelSpec, Timestamp, WorkspaceId,
MessageId, ModelSpec, Timestamp, Workspace,
};
use super::prepare::NewConversation;
use super::{
Expand Down Expand Up @@ -33,14 +33,14 @@ impl PromptInstruction {
None => ChatCompletionOptions::default(),
};

let workspace_id: Option<WorkspaceId> = None;
let workspace: Option<Workspace> = None;
let conversation_id = if let Some(ref model) = new_conversation.model {
Some(
db_handler
.new_conversation(
"New Conversation",
new_conversation.parent.as_ref().map(|p| p.id),
workspace_id,
workspace,
new_conversation
.parent
.as_ref()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ pub use prepare::NewConversation;

pub use super::db;
use super::{
ChatCompletionOptions, ChatMessage, ColorScheme, ModelBackend, PromptError,
PromptRole, SimpleString, TextLine, TextSegment,
ChatCompletionOptions, ChatMessage, ColorScheme, PromptError, PromptRole,
SimpleString, TextLine, TextSegment,
};

#[derive(Debug, Clone, PartialEq, Copy)]
Expand Down
123 changes: 104 additions & 19 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/conversations/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,76 @@ use super::*;

#[allow(dead_code)]
impl ConversationDbHandler {
pub async fn fetch_conversation(
&self,
conversation_id: ConversationId,
) -> Result<Option<Conversation>, DatabaseOperationError> {
let query = "SELECT c.id, c.name, c.info, c.completion_options,
c.model_identifier,
c.workspace_id, w.name AS workspace_name, w.path AS \
workspace_path,
c.parent_conversation_id, c.fork_message_id,
c.created_at, c.updated_at,
c.is_deleted, c.is_pinned, c.status, c.message_count, \
c.total_tokens
FROM conversations c
LEFT JOIN workspaces w ON c.workspace_id = w.id
WHERE c.id = ?1 AND c.status != 'deleted'";
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
let result: Result<Option<Conversation>, SqliteError> = {
let mut stmt = tx.prepare(query)?;
let mut rows = stmt.query([conversation_id.0])?;
if let Some(row) = rows.next()? {
Ok(Some(Conversation {
id: ConversationId(row.get(0)?),
name: row.get(1)?,
info: serde_json::from_str(&row.get::<_, String>(2)?)
.unwrap_or_default(),
completion_options: row
.get::<_, Option<String>>(3)?
.map(|s| {
serde_json::from_str(&s).unwrap_or_default()
}),
model_identifier: ModelIdentifier(row.get(4)?),
workspace: row.get::<_, Option<i64>>(5)?.and_then(
|id| {
let name =
row.get::<_, Option<String>>(6).ok()?;
let path =
row.get::<_, Option<String>>(7).ok()?;
Some(Workspace {
id: WorkspaceId(id),
name: name?,
directory_path: path.map(PathBuf::from),
})
},
),
parent_conversation_id: row
.get::<_, Option<i64>>(8)?
.map(ConversationId),
fork_message_id: row
.get::<_, Option<i64>>(9)?
.map(MessageId),
created_at: row.get(10)?,
updated_at: row.get(11)?,
is_deleted: row.get::<_, i64>(12)? != 0,
is_pinned: row.get::<_, i64>(13)? != 0,
status: ConversationStatus::from_str(
&row.get::<_, String>(14)?,
)
.unwrap_or(ConversationStatus::Active),
message_count: row.get(15)?,
total_tokens: row.get(16)?,
}))
} else {
Ok(None)
}
};
result.map_err(DatabaseOperationError::from)
})
}

pub async fn fetch_completion_options(
&self,
) -> Result<JsonValue, DatabaseOperationError> {
Expand Down Expand Up @@ -225,13 +295,18 @@ impl ConversationDbHandler {
limit: usize,
) -> Result<Vec<Conversation>, DatabaseOperationError> {
let query = format!(
"SELECT id, name, info, completion_options, model_identifier,
workspace_id, parent_conversation_id, fork_message_id, \
created_at, updated_at,
is_deleted, is_pinned, status, message_count, total_tokens
FROM conversations
WHERE is_deleted = FALSE
ORDER BY is_pinned DESC, updated_at DESC
"SELECT c.id, c.name, c.info, c.completion_options, \
c.model_identifier,
c.workspace_id, w.name AS workspace_name, w.path AS \
workspace_path,
c.parent_conversation_id, c.fork_message_id,
c.created_at, c.updated_at,
c.is_deleted, c.is_pinned, c.status, c.message_count, \
c.total_tokens
FROM conversations c
LEFT JOIN workspaces w ON c.workspace_id = w.id
WHERE c.is_deleted = FALSE
ORDER BY c.is_pinned DESC, c.updated_at DESC
LIMIT {}",
limit
);
Expand All @@ -251,25 +326,35 @@ impl ConversationDbHandler {
serde_json::from_str(&s).unwrap_or_default()
}),
model_identifier: ModelIdentifier(row.get(4)?),
workspace_id: row
.get::<_, Option<i64>>(5)?
.map(WorkspaceId),
workspace: row.get::<_, Option<i64>>(5)?.and_then(
|id| {
let name =
row.get::<_, Option<String>>(6).ok()?;
let path =
row.get::<_, Option<String>>(7).ok()?;
Some(Workspace {
id: WorkspaceId(id),
name: name?,
directory_path: path.map(PathBuf::from),
})
},
),
parent_conversation_id: row
.get::<_, Option<i64>>(6)?
.get::<_, Option<i64>>(8)?
.map(ConversationId),
fork_message_id: row
.get::<_, Option<i64>>(7)?
.get::<_, Option<i64>>(9)?
.map(MessageId),
created_at: row.get(8)?,
updated_at: row.get(9)?,
is_deleted: row.get::<_, i64>(10)? != 0,
is_pinned: row.get::<_, i64>(11)? != 0,
created_at: row.get(10)?,
updated_at: row.get(11)?,
is_deleted: row.get::<_, i64>(12)? != 0,
is_pinned: row.get::<_, i64>(13)? != 0,
status: ConversationStatus::from_str(
&row.get::<_, String>(12)?,
&row.get::<_, String>(14)?,
)
.unwrap_or(ConversationStatus::Active),
message_count: row.get(13)?,
total_tokens: row.get(14)?,
message_count: row.get(15)?,
total_tokens: row.get(16)?,
})
})?;
rows.collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ impl ConversationDbHandler {
&mut self,
name: &str,
parent_id: Option<ConversationId>,
workspace_id: Option<WorkspaceId>,
workspace: Option<Workspace>,
fork_message_id: Option<MessageId>,
completion_options: Option<serde_json::Value>,
model: &ModelSpec,
Expand All @@ -14,7 +14,7 @@ impl ConversationDbHandler {
let mut db = self.db.lock().await;
db.process_queue_with_result(|tx| {
let result: Result<ConversationId, SqliteError> = {
// Ensure the model exists
// Ensure the model exists (unchanged)
let exists: bool = tx
.query_row(
"SELECT 1 FROM models WHERE identifier = ?",
Expand Down Expand Up @@ -53,7 +53,7 @@ impl ConversationDbHandler {
name: name.to_string(),
info: serde_json::Value::Null,
model_identifier: model.identifier.clone(),
workspace_id,
workspace,
parent_conversation_id: parent_id,
fork_message_id,
completion_options,
Expand All @@ -66,6 +66,22 @@ impl ConversationDbHandler {
status: ConversationStatus::Active,
};

// Insert workspace if it doesn't exist
if let Some(workspace) = &conversation.workspace {
tx.execute(
"INSERT OR IGNORE INTO workspaces (id, name, path) \
VALUES (?, ?, ?)",
params![
workspace.id.0,
workspace.name,
workspace
.directory_path
.as_ref()
.map(|p| p.to_string_lossy().to_string()),
],
)?;
}

tx.execute(
"INSERT INTO conversations (
name, info, model_identifier, workspace_id, \
Expand All @@ -81,7 +97,7 @@ impl ConversationDbHandler {
serde_json::to_string(&conversation.info)
.unwrap_or_default(),
conversation.model_identifier.0,
conversation.workspace_id.map(|id| id.0),
conversation.workspace.as_ref().map(|w| w.id.0),
conversation.parent_conversation_id.map(|id| id.0),
conversation.fork_message_id.map(|id| id.0),
conversation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod fetch;
mod insert;
mod update;

use std::path::PathBuf;
use std::sync::Arc;

use lumni::api::error::ApplicationError;
Expand All @@ -15,7 +16,7 @@ use super::encryption::EncryptionHandler;
use super::{
Attachment, AttachmentData, AttachmentId, Conversation, ConversationId,
ConversationStatus, Message, MessageId, ModelIdentifier, ModelSpec,
Timestamp, WorkspaceId,
Timestamp, Workspace, WorkspaceId,
};
use crate::external as lumni;

Expand Down
7 changes: 4 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ impl ConversationDatabase {
&self,
) -> Result<(), ApplicationError> {
if let Some((conversation, messages)) =
self.fetch_conversation(None, None).await?
self.fetch_conversation_with_messages(None, None).await?
{
display_conversation_with_messages(&conversation, &messages);
} else {
Expand Down Expand Up @@ -41,8 +41,9 @@ impl ConversationDatabase {
))
})?);

if let Some((conversation, messages)) =
self.fetch_conversation(Some(conversation_id), None).await?
if let Some((conversation, messages)) = self
.fetch_conversation_with_messages(Some(conversation_id), None)
.await?
{
display_conversation_with_messages(&conversation, &messages);
} else {
Expand Down
10 changes: 9 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::error::Error;
use std::fmt;
use std::path::PathBuf;

mod connector;
mod conversations;
Expand Down Expand Up @@ -77,13 +78,20 @@ impl ConversationStatus {
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Workspace {
pub id: WorkspaceId,
pub name: String,
pub directory_path: Option<PathBuf>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Conversation {
pub id: ConversationId,
pub name: String,
pub info: serde_json::Value,
pub model_identifier: ModelIdentifier,
pub workspace_id: Option<WorkspaceId>,
pub workspace: Option<Workspace>, // this used to be a WorkspaceId
pub parent_conversation_id: Option<ConversationId>,
pub fork_message_id: Option<MessageId>, // New field
pub completion_options: Option<serde_json::Value>,
Expand Down
Loading

0 comments on commit 77bbea6

Please sign in to comment.