diff --git a/README.md b/README.md index 649cc92..8136745 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ $ export OPENAI_API_KEY=sk-xxxxxxx ### Create client ```rust -let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); +let client = OpenAIClient::builder().with_api_key(api_key).build()?; ``` ### Create request @@ -57,7 +58,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = ChatCompletionRequest::new( GPT4_O.to_string(), diff --git a/examples/assistant.rs b/examples/assistant.rs index 133922e..2e212b8 100644 --- a/examples/assistant.rs +++ b/examples/assistant.rs @@ -9,7 +9,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let mut tools = HashMap::new(); tools.insert("type".to_string(), "code_interpreter".to_string()); diff --git a/examples/audio_speech.rs b/examples/audio_speech.rs index c6541d1..9e3af08 100644 --- a/examples/audio_speech.rs +++ b/examples/audio_speech.rs @@ -4,7 +4,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = AudioSpeechRequest::new( TTS_1.to_string(), diff --git a/examples/audio_transcriptions.rs b/examples/audio_transcriptions.rs index 5a495c8..b2c7ced 100644 --- a/examples/audio_transcriptions.rs +++ b/examples/audio_transcriptions.rs @@ -4,7 +4,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = AudioTranscriptionRequest::new( "examples/data/problem.mp3".to_string(), diff --git a/examples/audio_translations.rs b/examples/audio_translations.rs index 89bf87c..13b57e0 100644 --- a/examples/audio_translations.rs +++ b/examples/audio_translations.rs @@ -4,7 +4,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = AudioTranslationRequest::new( "examples/data/problem_cn.mp3".to_string(), diff --git a/examples/batch.rs b/examples/batch.rs index 0924d4f..f14617e 100644 --- a/examples/batch.rs +++ b/examples/batch.rs @@ -9,7 +9,8 @@ use std::str; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = FileUploadRequest::new( "examples/data/batch_request.json".to_string(), diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index 95a6e7d..7a5791c 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -5,7 +5,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = ChatCompletionRequest::new( GPT4_O_MINI.to_string(), diff --git a/examples/completion.rs b/examples/completion.rs index e0fab80..138c1fe 100644 --- a/examples/completion.rs +++ b/examples/completion.rs @@ -4,7 +4,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = CompletionRequest::new( completion::GPT3_TEXT_DAVINCI_003.to_string(), diff --git a/examples/data/problem.mp3 b/examples/data/problem.mp3 index 9fce89a..4f524d9 100644 Binary files a/examples/data/problem.mp3 and b/examples/data/problem.mp3 differ diff --git a/examples/embedding.rs b/examples/embedding.rs index f93a153..8615bdb 100644 --- a/examples/embedding.rs +++ b/examples/embedding.rs @@ -5,7 +5,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let mut req = EmbeddingRequest::new( TEXT_EMBEDDING_3_SMALL.to_string(), diff --git a/examples/function_call.rs b/examples/function_call.rs index 95d263f..1858465 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -17,7 +17,8 @@ fn get_coin_price(coin: &str) -> f64 { #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let mut properties = HashMap::new(); properties.insert( diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index 613cd4e..1afc91a 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -17,7 +17,8 @@ fn get_coin_price(coin: &str) -> f64 { #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let mut properties = HashMap::new(); properties.insert( diff --git a/examples/vision.rs b/examples/vision.rs index cbf9e95..6a92c43 100644 --- a/examples/vision.rs +++ b/examples/vision.rs @@ -5,7 +5,8 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { - let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = ChatCompletionRequest::new( GPT4_O.to_string(), diff --git a/src/v1/api.rs b/src/v1/api.rs index d530ecf..c1b5592 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -38,11 +38,13 @@ use crate::v1::run::{ use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject}; use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::multipart::{Form, Part}; use reqwest::{Client, Method, Response}; use serde::Serialize; use serde_json::Value; +use std::error::Error; use std::fs::{create_dir_all, File}; use std::io::Read; use std::io::Write; @@ -50,61 +52,84 @@ use std::path::Path; const API_URL_V1: &str = "https://api.openai.com/v1"; +#[derive(Default)] +pub struct OpenAIClientBuilder { + api_endpoint: Option, + api_key: Option, + organization: Option, + proxy: Option, + timeout: Option, + headers: Option, +} + pub struct OpenAIClient { - pub api_endpoint: String, - pub api_key: String, - pub organization: Option, - pub proxy: Option, - pub timeout: Option, + api_endpoint: String, + api_key: String, + organization: Option, + proxy: Option, + timeout: Option, + headers: Option, } -impl OpenAIClient { - pub fn new(api_key: String) -> Self { - let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned()); - Self::new_with_endpoint(endpoint, api_key) +impl OpenAIClientBuilder { + pub fn new() -> Self { + Self::default() } - pub fn new_with_endpoint(api_endpoint: String, api_key: String) -> Self { - Self { - api_endpoint, - api_key, - organization: None, - proxy: None, - timeout: None, - } + pub fn with_api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self } - pub fn new_with_organization(api_key: String, organization: String) -> Self { - let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned()); - Self { - api_endpoint: endpoint, - api_key, - organization: Some(organization), - proxy: None, - timeout: None, - } + pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { + self.api_endpoint = Some(endpoint.into()); + self } - pub fn new_with_proxy(api_key: String, proxy: String) -> Self { - let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned()); - Self { - api_endpoint: endpoint, - api_key, - organization: None, - proxy: Some(proxy), - timeout: None, - } + pub fn with_organization(mut self, organization: impl Into) -> Self { + self.organization = Some(organization.into()); + self + } + + pub fn with_proxy(mut self, proxy: impl Into) -> Self { + self.proxy = Some(proxy.into()); + self + } + + pub fn with_timeout(mut self, timeout: u64) -> Self { + self.timeout = Some(timeout); + self } - pub fn new_with_timeout(api_key: String, timeout: u64) -> Self { - let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned()); - Self { - api_endpoint: endpoint, + pub fn with_header(mut self, key: impl Into, value: impl Into) -> Self { + let headers = self.headers.get_or_insert_with(HeaderMap::new); + headers.insert( + HeaderName::from_bytes(key.into().as_bytes()).expect("Invalid header name"), + HeaderValue::from_str(&value.into()).expect("Invalid header value"), + ); + self + } + + pub fn build(self) -> Result> { + let api_key = self.api_key.ok_or("API key is required")?; + let api_endpoint = self.api_endpoint.unwrap_or_else(|| { + std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned()) + }); + + Ok(OpenAIClient { + api_endpoint, api_key, - organization: None, - proxy: None, - timeout: Some(timeout), - } + organization: self.organization, + proxy: self.proxy, + timeout: self.timeout, + headers: self.headers, + }) + } +} + +impl OpenAIClient { + pub fn builder() -> OpenAIClientBuilder { + OpenAIClientBuilder::new() } async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder { @@ -127,13 +152,18 @@ impl OpenAIClient { let mut request = client .request(method, url) - // .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", self.api_key)); if let Some(organization) = &self.organization { request = request.header("openai-organization", organization); } + if let Some(headers) = &self.headers { + for (key, value) in headers { + request = request.header(key, value); + } + } + if Self::is_beta(path) { request = request.header("OpenAI-Beta", "assistants=v2"); }