Skip to content

Commit

Permalink
add status and tags to conversation, fix read converstation
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 24, 2024
1 parent 27ec9d5 commit cb0ec8a
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 156 deletions.
51 changes: 48 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::error::Error;
use std::fmt;

mod connector;
mod display;
mod helpers;
Expand Down Expand Up @@ -36,6 +39,24 @@ pub struct MessageId(pub i64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AttachmentId(pub i64);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConversationStatus {
Active,
Archived,
Deleted,
}

impl ConversationStatus {
pub fn from_str(s: &str) -> Result<Self, ConversionError> {
match s {
"active" => Ok(ConversationStatus::Active),
"archived" => Ok(ConversationStatus::Archived),
"deleted" => Ok(ConversationStatus::Deleted),
_ => Err(ConversionError::new("status", s)),
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Conversation {
pub id: ConversationId,
Expand All @@ -48,6 +69,7 @@ pub struct Conversation {
pub created_at: i64,
pub updated_at: i64,
pub is_deleted: bool,
pub status: ConversationStatus,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -61,9 +83,9 @@ pub struct Message {
pub token_length: Option<i64>,
pub previous_message_id: Option<MessageId>,
pub created_at: i64,
pub vote: i64, // New field
pub include_in_prompt: bool, // New field
pub is_hidden: bool, // New field
pub vote: i64,
pub include_in_prompt: bool,
pub is_hidden: bool,
pub is_deleted: bool,
}

Expand All @@ -84,3 +106,26 @@ pub struct Attachment {
pub created_at: i64,
pub is_deleted: bool,
}

#[derive(Debug)]
pub struct ConversionError {
field: String,
value: String,
}

impl ConversionError {
pub fn new<T: fmt::Display>(field: &str, value: T) -> Self {
ConversionError {
field: field.to_string(),
value: value.to_string(),
}
}
}

impl fmt::Display for ConversionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Invalid {} value: '{}'", self.field, self.value)
}
}

impl Error for ConversionError {}
39 changes: 37 additions & 2 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use rusqlite::{params, Error as SqliteError, OptionalExtension};

use super::connector::DatabaseConnector;
use super::{
Attachment, AttachmentData, AttachmentId, ConversationId, Message,
MessageId, ModelIdentifier, ModelSpec,
Attachment, AttachmentData, AttachmentId, ConversationId,
ConversationStatus, Message, MessageId, ModelIdentifier, ModelSpec,
};

pub struct ConversationReader<'a> {
Expand Down Expand Up @@ -76,6 +76,41 @@ impl<'a> ConversationReader<'a> {
})
}

pub fn get_conversation_status(
&self,
) -> Result<ConversationStatus, SqliteError> {
let query = "SELECT status FROM conversations WHERE id = ?";
let mut db = self.db.lock().unwrap();
db.process_queue_with_result(|tx| {
tx.query_row(query, params![self.conversation_id.0], |row| {
let status: String = row.get(0)?;
ConversationStatus::from_str(&status).map_err(|e| {
SqliteError::FromSqlConversionFailure(
0,
rusqlite::types::Type::Text,
Box::new(e),
)
})
})
})
}

pub fn get_conversation_tags(&self) -> Result<Vec<String>, SqliteError> {
let query = "
SELECT t.name
FROM tags t
JOIN conversation_tags ct ON t.id = ct.tag_id
WHERE ct.conversation_id = ?
ORDER BY t.name
";
let mut db = self.db.lock().unwrap();
db.process_queue_with_result(|tx| {
tx.prepare(query)?
.query_map(params![self.conversation_id.0], |row| row.get(0))?
.collect()
})
}

pub fn get_model_spec(&self) -> Result<ModelSpec, SqliteError> {
let query = "
SELECT m.identifier, m.info, m.config, m.context_window_size, \
Expand Down
25 changes: 20 additions & 5 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ CREATE TABLE conversations (
message_count INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
is_deleted BOOLEAN DEFAULT FALSE,
status TEXT CHECK(status IN ('active', 'archived', 'deleted')) DEFAULT 'active',
FOREIGN KEY (parent_conversation_id) REFERENCES conversations(id),
FOREIGN KEY (model_identifier) REFERENCES models(identifier),
FOREIGN KEY (fork_message_id) REFERENCES messages(id)
FOREIGN KEY (fork_message_id) REFERENCES messages(id),
CONSTRAINT check_message_count CHECK (message_count >= 0),
CONSTRAINT check_total_tokens CHECK (total_tokens >= 0)
);

CREATE TABLE messages (
Expand All @@ -41,12 +44,13 @@ CREATE TABLE messages (
token_length INTEGER,
previous_message_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
vote INTEGER DEFAULT 0, -- ADDED
include_in_prompt BOOLEAN DEFAULT TRUE, -- ADDED
is_hidden BOOLEAN DEFAULT FALSE, -- ADDED
vote INTEGER DEFAULT 0,
include_in_prompt BOOLEAN DEFAULT TRUE,
is_hidden BOOLEAN DEFAULT FALSE,
is_deleted BOOLEAN DEFAULT FALSE,
FOREIGN KEY (conversation_id) REFERENCES conversations(id),
FOREIGN KEY (previous_message_id) REFERENCES messages(id)
FOREIGN KEY (previous_message_id) REFERENCES messages(id),
CONSTRAINT check_token_length CHECK (token_length >= 0)
);

CREATE TABLE attachments (
Expand All @@ -64,6 +68,17 @@ CREATE TABLE attachments (
CHECK ((file_uri IS NULL) != (file_data IS NULL))
);

CREATE TABLE tags (
id INTEGER PRIMARY KEY,
name TEXT UNIQUE NOT NULL
);

CREATE TABLE conversation_tags (
conversation_id INTEGER REFERENCES conversations(id),
tag_id INTEGER REFERENCES tags(id),
PRIMARY KEY (conversation_id, tag_id)
);

CREATE INDEX idx_parent_conversation ON conversations(parent_conversation_id);
CREATE INDEX idx_conversation_model_identifier ON conversations(model_identifier);
CREATE INDEX idx_attachment_message ON attachments(message_id);
Expand Down
Loading

0 comments on commit cb0ec8a

Please sign in to comment.