Skip to content

Commit

Permalink
initial OpenAI, not completed yet
Browse files Browse the repository at this point in the history
  • Loading branch information
aprxi committed Jul 1, 2024
1 parent b9813a0 commit a5b070e
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 2 deletions.
1 change: 0 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/chat/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ pub async fn http_post(
"application/json".to_string(),
)])
};

let payload_bytes = Bytes::from(payload.into_bytes());
tokio::spawn(async move {
match http_client
Expand Down
27 changes: 26 additions & 1 deletion lumni/src/apps/builtin/llm/prompt/src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod bedrock;
mod openai;
mod endpoints;
mod llama;
mod llm;
Expand All @@ -13,6 +14,7 @@ pub use llm::LLMDefinition;
use lumni::api::error::ApplicationError;
pub use lumni::HttpClient;
pub use ollama::Ollama;
pub use openai::OpenAI;
use tokio::sync::{mpsc, oneshot};

pub use super::chat::{
Expand All @@ -24,12 +26,13 @@ pub use super::defaults::*;
pub use super::model::{ModelFormatter, ModelFormatterTrait, PromptRole};
use crate::external as lumni;

pub const SUPPORTED_MODEL_ENDPOINTS: [&str; 3] = ["llama", "ollama", "bedrock"];
pub const SUPPORTED_MODEL_ENDPOINTS: [&str; 4] = ["llama", "ollama", "bedrock", "openai"];

pub enum ModelServer {
Llama(Llama),
Ollama(Ollama),
Bedrock(Bedrock),
OpenAI(OpenAI),
}

impl ModelServer {
Expand All @@ -48,6 +51,11 @@ impl ModelServer {
ApplicationError::ServerConfigurationError(e.to_string())
})?))
}
"openai" => {
Ok(ModelServer::OpenAI(OpenAI::new().map_err(|e| {
ApplicationError::ServerConfigurationError(e.to_string())
})?))
}
_ => Err(ApplicationError::NotImplemented(format!(
"{}. Supported server types: {:?}",
s, SUPPORTED_MODEL_ENDPOINTS
Expand Down Expand Up @@ -79,6 +87,11 @@ impl ServerTrait for ModelServer {
.initialize_with_model(model, prompt_instruction)
.await
}
ModelServer::OpenAI(openai) => {
openai
.initialize_with_model(model, prompt_instruction)
.await
}
}
}

Expand All @@ -90,6 +103,7 @@ impl ServerTrait for ModelServer {
ModelServer::Llama(llama) => llama.process_response(response),
ModelServer::Ollama(ollama) => ollama.process_response(response),
ModelServer::Bedrock(bedrock) => bedrock.process_response(response),
ModelServer::OpenAI(openai) => openai.process_response(response),
}
}

Expand All @@ -107,6 +121,9 @@ impl ServerTrait for ModelServer {
ModelServer::Bedrock(bedrock) => {
bedrock.get_context_size(prompt_instruction).await
}
ModelServer::OpenAI(openai) => {
openai.get_context_size(prompt_instruction).await
}
}
}

Expand All @@ -118,6 +135,7 @@ impl ServerTrait for ModelServer {
ModelServer::Llama(llama) => llama.tokenizer(content).await,
ModelServer::Ollama(ollama) => ollama.tokenizer(content).await,
ModelServer::Bedrock(bedrock) => bedrock.tokenizer(content).await,
ModelServer::OpenAI(openai) => openai.tokenizer(content).await,
}
}

Expand All @@ -144,6 +162,11 @@ impl ServerTrait for ModelServer {
.completion(exchanges, prompt_instruction, tx, cancel_rx)
.await
}
ModelServer::OpenAI(openai) => {
openai
.completion(exchanges, prompt_instruction, tx, cancel_rx)
.await
}
}
}

Expand All @@ -154,6 +177,7 @@ impl ServerTrait for ModelServer {
ModelServer::Llama(llama) => llama.list_models().await,
ModelServer::Ollama(ollama) => ollama.list_models().await,
ModelServer::Bedrock(bedrock) => bedrock.list_models().await,
ModelServer::OpenAI(openai) => openai.list_models().await,
}
}

Expand All @@ -162,6 +186,7 @@ impl ServerTrait for ModelServer {
ModelServer::Llama(llama) => llama.get_model(),
ModelServer::Ollama(ollama) => ollama.get_model(),
ModelServer::Bedrock(bedrock) => bedrock.get_model(),
ModelServer::OpenAI(openai) => openai.get_model(),
}
}
}
Expand Down
25 changes: 25 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/server/openai/credentials.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::env;

use lumni::api::error::ApplicationError;
pub use crate::external as lumni;


#[derive(Clone)]
pub struct OpenAICredentials {
api_key: String,
}

impl OpenAICredentials {
pub fn from_env() -> Result<OpenAICredentials, ApplicationError> {
let api_key = env::var("OPENAI_API_KEY").map_err(|_| {
ApplicationError::InvalidCredentials(
"OPENAI_API_KEY not found in environment".to_string(),
)
})?;
Ok(OpenAICredentials { api_key })
}

pub fn get_api_key(&self) -> &str {
&self.api_key
}
}
16 changes: 16 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/server/openai/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use lumni::{HttpClientError, HttpClientErrorHandler, HttpClientResponse};

pub use crate::external as lumni;

pub struct OpenAIErrorHandler;

impl HttpClientErrorHandler for OpenAIErrorHandler {
fn handle_error(
&self,
response: HttpClientResponse,
canonical_reason: String,
) -> HttpClientError {
// Fallback if no special handling is needed
HttpClientError::HttpError(response.status_code(), canonical_reason)
}
}
171 changes: 171 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/server/openai/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
mod error;
mod request;
mod response;
mod credentials;

use std::collections::HashMap;
use std::error::Error;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use error::OpenAIErrorHandler;
use lumni::api::error::ApplicationError;
use lumni::HttpClient;
use tokio::sync::{mpsc, oneshot};
use url::Url;

use super::{
http_post, ChatExchange, ChatHistory, ChatMessage, Endpoints,
LLMDefinition, PromptInstruction, ServerTrait,
};
use credentials::OpenAICredentials;
use request::OpenAIRequestPayload;
use response::OpenAIResponsePayload;

pub use crate::external as lumni;

pub struct OpenAI {
http_client: HttpClient,
endpoints: Endpoints,
model: Option<LLMDefinition>,
}

const OPENAI_COMPLETION_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";


impl OpenAI {
pub fn new() -> Result<Self, Box<dyn Error>> {
let endpoints = Endpoints::new()
.set_completion(Url::parse(OPENAI_COMPLETION_ENDPOINT)?);

Ok(OpenAI {
http_client: HttpClient::new()
.with_error_handler(Arc::new(OpenAIErrorHandler)),
endpoints,
model: None,
})
}

fn completion_api_payload(
&self,
model: &LLMDefinition,
exchanges: &Vec<ChatExchange>,
system_prompt: Option<&str>,
) -> Result<String, serde_json::Error> {
let messages: Vec<ChatMessage> =
ChatHistory::exchanges_to_messages(
exchanges,
system_prompt,
&|role| self.get_role_name(role),
);

let openai_request_payload = OpenAIRequestPayload {
model: model.get_name().to_string(),
messages,
stream: true,
frequency_penalty: None,
stop: None,
temperature: Some(0.7),
top_p: None,
max_tokens: None,
presence_penalty: None,
logprobs: None,
best_of: None,
};
openai_request_payload.to_json()
}
}

#[async_trait]
impl ServerTrait for OpenAI {
async fn initialize_with_model(
&mut self,
model: LLMDefinition,
_prompt_instruction: &PromptInstruction,
) -> Result<(), ApplicationError> {
self.model = Some(model);
Ok(())
}

fn get_model(&self) -> Option<&LLMDefinition> {
self.model.as_ref()
}

fn process_response(
&self,
response_bytes: Bytes,
) -> (Option<String>, bool, Option<usize>) {
// TODO: OpenAI sents back split responses, which we need to concatenate first
match OpenAIResponsePayload::extract_content(response_bytes) {
Ok(chat) => {
let choices = chat.choices;
if choices.is_empty() {
return (None, false, None);
}
let chat_message = &choices[0];
let delta = &chat_message.delta;
let stop = if let Some(_) = chat_message.finish_reason {
true // if finish_reason is present, then always stop
} else {
false
};
let stop = true;
(delta.content.clone(), stop, None)
}
Err(e) => {
(Some(format!("Failed to parse JSON: {}", e)), true, None)
}
}
}

async fn completion(
&self,
exchanges: &Vec<ChatExchange>,
prompt_instruction: &PromptInstruction,
tx: Option<mpsc::Sender<Bytes>>,
cancel_rx: Option<oneshot::Receiver<()>>,
) -> Result<(), ApplicationError> {
let model = self.get_selected_model()?;
let system_prompt = prompt_instruction.get_instruction();

let completion_endpoint = self.endpoints.get_completion_endpoint()?;
let data_payload = self
.completion_api_payload(model, exchanges, Some(system_prompt))
.map_err(|e| {
ApplicationError::InvalidUserConfiguration(e.to_string())
})?;

let credentials = OpenAICredentials::from_env()?;

let mut headers = HashMap::new();
headers.insert(
"Content-Type".to_string(),
"application/json".to_string(),
);
headers.insert(
"Authorization".to_string(),
format!("Bearer {}", credentials.get_api_key()),
);

http_post(
completion_endpoint,
self.http_client.clone(),
tx,
data_payload,
Some(headers),
cancel_rx,
)
.await;
Ok(())
}

async fn list_models(
&self,
) -> Result<Vec<LLMDefinition>, ApplicationError> {
let model = LLMDefinition::new(
"gpt-3.5-turbo".to_string(),
);
Ok(vec![model])
}
}
33 changes: 33 additions & 0 deletions lumni/src/apps/builtin/llm/prompt/src/server/openai/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use serde::Serialize;

use super::ChatMessage;


#[derive(Debug, Serialize)]
pub struct OpenAIRequestPayload {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>, // up to 4 stop sequences
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
}

impl OpenAIRequestPayload {
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(&self)
}
}
Loading

0 comments on commit a5b070e

Please sign in to comment.