Skip to content

Commit

Permalink
add ChatCompletionOptions to PrompInstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 22, 2024
1 parent eaecbde commit de8f76a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lumni/src/apps/builtin/llm/prompt/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use super::chat::{
AssistantManager, ChatSession, ConversationDatabaseStore, NewConversation,
PromptInstruction,
};
use super::server::{ModelServer, ModelServerName, ServerManager, ServerTrait};
use super::server::{ModelServer, ModelServerName, ServerTrait};
use super::session::{AppSession, TabSession};
use super::tui::{
ColorScheme, CommandLineAction, ConversationEvent, ConversationReader,
Expand Down
20 changes: 15 additions & 5 deletions lumni/src/apps/builtin/llm/prompt/src/chat/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,31 +77,32 @@ pub struct PromptInstruction {
cache: ConversationCache,
model: Option<ModelSpec>,
conversation_id: Option<ConversationId>,
completion_options: ChatCompletionOptions,
}

impl PromptInstruction {
pub fn new(
new_conversation: NewConversation,
db_conn: &ConversationDatabaseStore,
) -> Result<Self, ApplicationError> {

let mut completion_options = match new_conversation.options {
Some(opts) => {
let mut options = ChatCompletionOptions::default();
options.update(opts)?;
serde_json::to_value(options)?
options
}
None => serde_json::to_value(ChatCompletionOptions::default())?,
None => ChatCompletionOptions::default(),
};
// Update model_server in completion_options
completion_options["model_server"] =
serde_json::to_value(new_conversation.server.0)?;
completion_options.model_server = Some(new_conversation.server);

let conversation_id = if let Some(ref model) = new_conversation.model {
Some(db_conn.new_conversation(
"New Conversation",
new_conversation.parent.as_ref().map(|p| p.id),
new_conversation.parent.as_ref().map(|p| p.fork_message_id),
Some(completion_options),
Some(serde_json::to_value(&completion_options)?),
model,
)?)
} else {
Expand All @@ -112,6 +113,8 @@ impl PromptInstruction {
cache: ConversationCache::new(),
model: new_conversation.model,
conversation_id,
completion_options,

};

if let Some(conversation_id) = prompt_instruction.conversation_id {
Expand Down Expand Up @@ -161,10 +164,17 @@ impl PromptInstruction {
.get_model_spec()
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

let completion_options = reader
.get_completion_options()
.map_err(|e| ApplicationError::DatabaseError(e.to_string()))?;

let completion_options: ChatCompletionOptions = serde_json::from_value(completion_options)?;

let mut prompt_instruction = PromptInstruction {
cache: ConversationCache::new(),
model: Some(model_spec),
conversation_id: Some(conversation_id),
completion_options,
};

prompt_instruction
Expand Down
18 changes: 9 additions & 9 deletions lumni/src/apps/builtin/llm/prompt/src/chat/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@ use super::{ModelServerName, DEFAULT_N_PREDICT, DEFAULT_TEMPERATURE};
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatCompletionOptions {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
n_keep: Option<usize>,
pub n_keep: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
n_predict: Option<u32>,
pub n_predict: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
cache_prompt: Option<bool>,
pub cache_prompt: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
model_server: Option<ModelServerName>,
pub model_server: Option<ModelServerName>,
}

impl Default for ChatCompletionOptions {
Expand Down

0 comments on commit de8f76a

Please sign in to comment.