Skip to content

Commit

Permalink
reorder schema-files, expand on LLMModel
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 20, 2024
1 parent dca9516 commit bf4ae7c
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 120 deletions.
Original file line number Diff line number Diff line change
@@ -1,103 +1,14 @@
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use super::PromptRole;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelIdentifier(pub String);

impl ModelIdentifier {
pub fn new(provider: &str, name: &str) -> Self {
ModelIdentifier(format!("{}::{}", provider, name))
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelServerName(pub String);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ConversationId(pub i64);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MessageId(pub i64);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AttachmentId(pub i64);


#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Model {
pub identifier: ModelIdentifier,
pub info: Option<serde_json::Value>,
pub config: Option<serde_json::Value>,
pub context_window_size: Option<i64>,
pub input_token_limit: Option<i64>,
}

impl Model {
pub fn new(identifier: ModelIdentifier) -> Self {
Model {
identifier,
info: None,
config: None,
context_window_size: None,
input_token_limit: None,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Conversation {
pub id: ConversationId,
pub name: String,
pub info: serde_json::Value,
pub model_identifier: ModelIdentifier,
pub model_server: ModelServerName,
pub parent_conversation_id: Option<ConversationId>,
pub fork_message_id: Option<MessageId>, // New field
pub completion_options: Option<serde_json::Value>,
pub created_at: i64,
pub updated_at: i64,
pub is_deleted: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: MessageId,
pub conversation_id: ConversationId,
pub role: PromptRole,
pub message_type: String,
pub content: String,
pub has_attachments: bool,
pub token_length: Option<i64>,
pub previous_message_id: Option<MessageId>,
pub created_at: i64,
pub is_deleted: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AttachmentData {
Uri(String),
Data(Vec<u8>),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Attachment {
pub attachment_id: AttachmentId,
pub message_id: MessageId,
pub conversation_id: ConversationId,
pub data: AttachmentData, // file_uri or file_data
pub file_type: String,
pub metadata: Option<serde_json::Value>,
pub created_at: i64,
pub is_deleted: bool,
}
use super::{ModelIdentifier, LLMModel, PromptRole,
ConversationId, Message,
MessageId, AttachmentId, Attachment
};

#[derive(Debug)]
pub struct ConversationCache {
conversation_id: ConversationId,
models: HashMap<ModelIdentifier, Model>,
models: HashMap<ModelIdentifier, LLMModel>,
messages: Vec<Message>, // messages have to be ordered
attachments: HashMap<AttachmentId, Attachment>,
message_attachments: HashMap<MessageId, Vec<AttachmentId>>,
Expand Down Expand Up @@ -130,7 +41,7 @@ impl ConversationCache {
AttachmentId(self.attachments.len() as i64)
}

pub fn add_model(&mut self, model: Model) {
pub fn add_model(&mut self, model: LLMModel) {
self.models.insert(model.identifier.clone(), model);
}

Expand Down
78 changes: 78 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/conversation/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use serde::{Deserialize, Serialize};

mod model;
mod cache;

pub use model::LLMModel;
pub use cache::ConversationCache;

use super::PromptRole;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelIdentifier(pub String);

impl ModelIdentifier {
pub fn new(provider: &str, name: &str) -> Self {
ModelIdentifier(format!("{}::{}", provider, name))
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelServerName(pub String);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ConversationId(pub i64);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MessageId(pub i64);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AttachmentId(pub i64);


#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Conversation {
pub id: ConversationId,
pub name: String,
pub info: serde_json::Value,
pub model_identifier: ModelIdentifier,
pub model_server: ModelServerName,
pub parent_conversation_id: Option<ConversationId>,
pub fork_message_id: Option<MessageId>, // New field
pub completion_options: Option<serde_json::Value>,
pub created_at: i64,
pub updated_at: i64,
pub is_deleted: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: MessageId,
pub conversation_id: ConversationId,
pub role: PromptRole,
pub message_type: String,
pub content: String,
pub has_attachments: bool,
pub token_length: Option<i64>,
pub previous_message_id: Option<MessageId>,
pub created_at: i64,
pub is_deleted: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AttachmentData {
Uri(String),
Data(Vec<u8>),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Attachment {
pub attachment_id: AttachmentId,
pub message_id: MessageId,
pub conversation_id: ConversationId,
pub data: AttachmentData, // file_uri or file_data
pub file_type: String,
pub metadata: Option<serde_json::Value>,
pub created_at: i64,
pub is_deleted: bool,
}
115 changes: 115 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/chat/conversation/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use serde::{Deserialize, Serialize};
use super::ModelIdentifier;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMModel {
pub identifier: ModelIdentifier,
pub info: Option<serde_json::Value>,
pub config: Option<serde_json::Value>,
pub context_window_size: Option<i64>,
pub input_token_limit: Option<i64>,
}

impl LLMModel {
pub fn new(identifier: ModelIdentifier) -> Self {
LLMModel {
identifier,
info: None,
config: None,
context_window_size: None,
input_token_limit: None,
}
}

pub fn identifier(&self) -> &ModelIdentifier {
&self.identifier
}

pub fn info(&self) -> Option<&serde_json::Value> {
self.info.as_ref()
}

pub fn config(&self) -> Option<&serde_json::Value> {
self.config.as_ref()
}

pub fn context_window_size(&self) -> Option<i64> {
self.context_window_size
}

pub fn input_token_limit(&self) -> Option<i64> {
self.input_token_limit
}

pub fn set_info(&mut self, info: serde_json::Value) -> &mut Self {
self.info = Some(info);
self
}

pub fn set_config(&mut self, config: serde_json::Value) -> &mut Self {
self.config = Some(config);
self
}

pub fn set_context_window_size(&mut self, size: i64) -> &mut Self {
self.context_window_size = Some(size);
self
}

pub fn set_input_token_limit(&mut self, limit: i64) -> &mut Self {
self.input_token_limit = Some(limit);
self
}

pub fn set_config_value(&mut self, key: &str, value: serde_json::Value) -> &mut Self {
if let Some(config) = self.config.as_mut() {
if let serde_json::Value::Object(map) = config {
map.insert(key.to_string(), value);
}
} else {
let mut map = serde_json::Map::new();
map.insert(key.to_string(), value);
self.config = Some(serde_json::Value::Object(map));
}
self
}

pub fn get_config_value(&self, key: &str) -> Option<&serde_json::Value> {
self.config.as_ref().and_then(|config| {
if let serde_json::Value::Object(map) = config {
map.get(key)
} else {
None
}
})
}

pub fn set_size(&mut self, size: usize) -> &mut Self {
// model size in bytes
self.set_config_value("size", serde_json::Value::Number(size.into()))
}

pub fn get_size(&self) -> Option<usize> {
self.get_config_value("size")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
}

pub fn set_family(&mut self, family: &str) -> &mut Self {
self.set_config_value("family", serde_json::Value::String(family.to_string()))
}

pub fn get_family(&self) -> Option<&str> {
self.get_config_value("family")
.and_then(|v| v.as_str())
}

pub fn set_description(&mut self, description: &str) -> &mut Self {
self.set_config_value("description", serde_json::Value::String(description.to_string()))
}

pub fn get_description(&self) -> Option<&str> {
self.get_config_value("description")
.and_then(|v| v.as_str())
}
}
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/display.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use lumni::api::error::ApplicationError;

use super::schema::{Conversation, ConversationId, Message};
use super::conversation::{Conversation, ConversationId, Message};
use super::ConversationDatabaseStore;
pub use crate::external as lumni;

Expand Down
5 changes: 1 addition & 4 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
mod connector;
mod display;
mod reader;
mod schema;
mod store;

pub use reader::ConversationReader;
pub use schema::{ConversationCache, ConversationId, Model, ModelIdentifier, ModelServerName, Message, MessageId};
pub use super::conversation;
pub use store::ConversationDatabaseStore;

pub use super::PromptRole;
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/db/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex};
use rusqlite::{params, Error as SqliteError, OptionalExtension};

use super::connector::DatabaseConnector;
use super::schema::ConversationId;
use super::conversation::ConversationId;

pub struct ConversationReader<'a> {
conversation_id: ConversationId,
Expand Down
6 changes: 3 additions & 3 deletions lumni/src/apps/builtin/llm/prompt/src/chat/db/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use rusqlite::{params, Error as SqliteError, OptionalExtension};

use super::connector::DatabaseConnector;
use super::reader::ConversationReader;
use super::schema::{
use super::conversation::{
Attachment, AttachmentData, AttachmentId, Conversation, ConversationId,
Message, MessageId, Model, ModelIdentifier, ModelServerName,
Message, MessageId, LLMModel, ModelIdentifier, ModelServerName,
};

pub struct ConversationDatabaseStore {
Expand All @@ -34,7 +34,7 @@ impl ConversationDatabaseStore {
parent_id: Option<ConversationId>,
fork_message_id: Option<MessageId>,
completion_options: Option<serde_json::Value>,
model: Model,
model: LLMModel,
model_server: ModelServerName,
) -> Result<ConversationId, SqliteError> {
let mut db = self.db.lock().unwrap();
Expand Down
11 changes: 6 additions & 5 deletions lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use lumni::api::error::ApplicationError;

use super::db::{
ConversationCache, ConversationDatabaseStore, ConversationId,
Message, MessageId, Model, ModelIdentifier, ModelServerName,
use super::db::ConversationDatabaseStore;
use super::conversation::{
ConversationCache, ConversationId,
Message, MessageId, LLMModel, ModelIdentifier, ModelServerName,
};
use super::prompt::Prompt;
use super::{ChatCompletionOptions, ChatMessage, PromptRole, PERSONAS};
Expand Down Expand Up @@ -35,7 +36,7 @@ impl PromptInstruction {
None => serde_json::to_value(ChatCompletionOptions::default())?,
};
// Create a new Conversation in the database
let model = Model::new(
let model = LLMModel::new(
ModelIdentifier::new("foo-provider", "bar-model"),
);

Expand Down Expand Up @@ -141,7 +142,7 @@ impl PromptInstruction {
) -> Result<(), ApplicationError> {
// reset by creating a new conversation
// TODO: clone previous conversation settings
let model = Model::new(
let model = LLMModel::new(
ModelIdentifier::new("foo-provider", "bar-model"),
);

Expand Down
Loading

0 comments on commit bf4ae7c

Please sign in to comment.