From 3d532d344460c972142b06dbd71ceed81035aab3 Mon Sep 17 00:00:00 2001 From: Dan Nelson <41842184+danwritecode@users.noreply.github.com> Date: Mon, 11 Nov 2024 22:21:17 -0500 Subject: [PATCH] refactor!: Async Implementation (#13) BREAKING --- .github/workflows/test.yaml | 24 +-- Cargo.lock | 40 ++++- Cargo.toml | 9 +- README.md | 18 +-- src/lib.rs | 66 ++++---- src/v1/api.rs | 118 -------------- src/v2/api.rs | 107 +++++++++++++ src/{v1 => v2}/client.rs | 105 +++++++------ src/{v1 => v2}/collection.rs | 228 ++++++++++++++++------------ src/{v1 => v2}/commons.rs | 0 src/{v1 => v2}/embeddings/bert.rs | 10 +- src/{v1 => v2}/embeddings/mod.rs | 7 +- src/{v1 => v2}/embeddings/openai.rs | 64 ++++---- src/{v1 => v2}/mod.rs | 0 14 files changed, 443 insertions(+), 353 deletions(-) delete mode 100644 src/v1/api.rs create mode 100644 src/v2/api.rs rename src/{v1 => v2}/client.rs (71%) rename src/{v1 => v2}/collection.rs (87%) rename src/{v1 => v2}/commons.rs (100%) rename src/{v1 => v2}/embeddings/bert.rs (87%) rename src/{v1 => v2}/embeddings/mod.rs (63%) rename src/{v1 => v2}/embeddings/openai.rs (63%) rename src/{v1 => v2}/mod.rs (100%) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index cdc8e21..6c2ee96 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,20 +1,26 @@ name: "Cargo Tests" on: - pull_request: - types: - - opened - - edited - - synchronize - - reopened + pull_request: + types: + - opened + - edited + - synchronize + - reopened env: CARGO_TERM_COLOR: always + jobs: test: runs-on: ubuntu-latest + + services: + chroma: + image: 'chromadb/chroma:0.5.0' + ports: + - '8000:8000' + steps: - uses: actions/checkout@v3 - - name: ChromaDB - uses: CakeCrusher/chroma@v1.0.3 - name: Run tests - run: cargo test \ No newline at end of file + run: cargo test diff --git a/Cargo.lock b/Cargo.lock index 0ff296b..ca3ce4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,6 +54,17 @@ version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +[[package]] +name = "async-trait" +version = "0.1.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -193,14 +204,17 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chromadb" -version = "0.4.4" +version = "0.5.0" dependencies = [ "anyhow", + "async-trait", "base64 0.22.0", "minreq", + "reqwest", "rust-bert", "serde", "serde_json", + "tokio", ] [[package]] @@ -1071,9 +1085,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -1086,9 +1100,9 @@ checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" [[package]] name = "quote" -version = "1.0.33" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -1542,9 +1556,9 @@ checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" -version = "2.0.29" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -1657,9 +1671,21 @@ dependencies = [ "num_cpus", "pin-project-lite", "socket2 0.5.3", + "tokio-macros", "windows-sys 0.48.0", ] +[[package]] +name = "tokio-macros" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-native-tls" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index da4aaa5..aca1567 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chromadb" -authors = ["Anush008"] +authors = ["Anush008", "danwritecode"] description = "A Rust client library for the ChromaDB vector database." edition = "2021" license = "MIT" @@ -30,6 +30,13 @@ version = "0.21" features = ["download-libtorch"] optional = true +[dependencies] +async-trait = "0.1.83" +reqwest = { version = "0.11", features = ["json"] } + +[dev-dependencies] +tokio = { version = "1.0", features = ["rt", "macros"] } + [features] bert = ["dep:rust-bert"] openai = [] diff --git a/README.md b/README.md index 6110092..bbfa591 100644 --- a/README.md +++ b/README.md @@ -39,8 +39,8 @@ The library reference can be found [here](https://docs.rs/chromadb). #### You can connect to ChromaDB by instantiating a [ChromaClient](https://docs.rs/chromadb/latest/chromadb/v1/client/struct.ChromaClient.html) ```rust -use chromadb::v1::ChromaClient; -use chromadb::v1::collection::{ChromaCollection, GetQuery, GetResult, CollectionEntries}; +use chromadb::v2::ChromaClient; +use chromadb::v2::collection::{ChromaCollection, GetQuery, GetResult, CollectionEntries}; use serde_json::json; // With default ChromaClientOptions @@ -55,7 +55,7 @@ let client: ChromaClient = ChromaClient::new(ChromaClientOptions { url: " anyhow::Result<()> { @@ -16,7 +16,14 @@ //! let client: ChromaClient = ChromaClient::new(Default::default()); //! //! // With custom ChromaClientOptions -//! let client: ChromaClient = ChromaClient::new(ChromaClientOptions { url: "".into() }); +//! let auth = ChromaAuthMethod::TokenAuth { +//! token: "".to_string(), +//! header: chromadb::v2::client::ChromaTokenHeader::Authorization +//! }; +//! let client: ChromaClient = ChromaClient::new(ChromaClientOptions { +//! url: "".into(), +//! auth +//! }); //! //! # Ok(()) //! # } @@ -26,12 +33,12 @@ //! ### Collection Queries //! //! ``` -//!# use chromadb::v1::ChromaClient; -//!# use chromadb::v1::collection::{ChromaCollection, GetResult, CollectionEntries, GetOptions}; +//!# use chromadb::v2::ChromaClient; +//!# use chromadb::v2::collection::{ChromaCollection, GetResult, CollectionEntries, GetOptions}; //!# use serde_json::json; -//!# fn doc_client_create_collection(client: &ChromaClient) -> anyhow::Result<()> { +//!# async fn doc_client_create_collection(client: &ChromaClient) -> anyhow::Result<()> { //! // Get or create a collection with the given name and no metadata. -//! let collection: ChromaCollection = client.get_or_create_collection("my_collection", None)?; +//! let collection: ChromaCollection = client.get_or_create_collection("my_collection", None).await?; //! //! // Get the UUID of the collection //! let collection_uuid = collection.id(); @@ -48,7 +55,7 @@ //! ]) //! }; //! -//! let result: bool = collection.upsert(collection_entries, None)?; +//! let result = collection.upsert(collection_entries, None).await?; //! //! // Create a filter object to filter by document content. //! let where_document = json!({ @@ -67,19 +74,19 @@ //! include: Some(vec!["documents".into(),"embeddings".into()]) //! }; //! -//! let get_result: GetResult = collection.get(get_query)?; +//! let get_result: GetResult = collection.get(get_query).await?; //! println!("Get result: {:?}", get_result); //!# Ok(()) //!# } //! ``` -//!Find more information about on the available filters and options in the [get()](crate::v1::ChromaCollection::get) documentation. +//!Find more information about on the available filters and options in the [get()](crate::v2::ChromaCollection::get) documentation. //! //! //! ### Perform a similarity search. //! ``` -//!# use chromadb::v1::collection::{ChromaCollection, QueryResult, QueryOptions}; +//!# use chromadb::v2::collection::{ChromaCollection, QueryResult, QueryOptions}; //!# use serde_json::json; -//!# fn doc_query_collection(collection: &ChromaCollection) -> anyhow::Result<()> { +//!# async fn doc_query_collection(collection: &ChromaCollection) -> anyhow::Result<()> { //! //Instantiate QueryOptions to perform a similarity search on the collection //! //Alternatively, an embedding_function can also be provided with query_texts to perform the search //! let query = QueryOptions { @@ -91,7 +98,7 @@ //! include: None, //! }; //! -//! let query_result: QueryResult = collection.query(query, None)?; +//! let query_result: QueryResult = collection.query(query, None).await?; //! println!("Query result: {:?}", query_result); //!# Ok(()) //!# } @@ -103,12 +110,13 @@ //! To use [OpenAI](https://platform.openai.com/docs/guides/embeddings) embeddings, enable the `openai` feature in your Cargo.toml. //! //! ```ignore -//!# use chromadb::v1::ChromaClient; -//!# use chromadb::v1::collection::{ChromaCollection, GetResult, CollectionEntries, GetOptions}; -//!# use chromadb::v1::embeddings::openai::OpenAIEmbeddings; +//!# use chromadb::v2::ChromaClient; +//!# use chromadb::v2::collection::{ChromaCollection, GetResult, CollectionEntries, GetOptions}; +//!# use chromadb::v2::embeddings::openai::OpenAIEmbeddings; //!# use serde_json::json; -//!# fn doc_client_create_collection(client: &ChromaClient) -> anyhow::Result<()> { -//! let collection: ChromaCollection = client.get_or_create_collection("openai_collection", None)?; +//!# async fn doc_client_create_collection(client: &ChromaClient) -> anyhow::Result<()> { +//! let collection: ChromaCollection = client.get_or_create_collection("openai_collection", +//! None).await?; //! //! let collection_entries = CollectionEntries { //! ids: vec!["demo-id-1", "demo-id-2"], @@ -121,7 +129,7 @@ //! //! // Use OpenAI embeddings //! let openai_embeddings = OpenAIEmbeddings::new(Default::default()); -//! collection.upsert(collection_entries, Some(Box::new(openai_embeddings)))?; +//! collection.upsert(collection_entries, Some(Box::new(openai_embeddings))).await?; //! Ok(()) //!# } //! ``` @@ -129,12 +137,13 @@ //! To use [SBERT](https://docs.rs/crate/rust-bert/latest) embeddings, enable the `bert` feature in your Cargo.toml. //! //! ```ignore -//!# use chromadb::v1::ChromaClient; -//!# use chromadb::v1::collection::{ChromaCollection, GetResult, CollectionEntries, GetOptions}; +//!# use chromadb::v2::ChromaClient; +//!# use chromadb::v2::collection::{ChromaCollection, GetResult, CollectionEntries, GetOptions}; //!# use serde_json::json; -//!# use chromadb::v1::embeddings::bert::{SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType}; -//!# fn doc_client_create_collection(client: &ChromaClient) -> anyhow::Result<()> { -//! let collection: ChromaCollection = client.get_or_create_collection("sbert_collection", None)?; +//!# use chromadb::v2::embeddings::bert::{SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType}; +//!# async fn doc_client_create_collection(client: &ChromaClient) -> anyhow::Result<()> { +//! let collection: ChromaCollection = client.get_or_create_collection("sbert_collection", +//! None).await?; //! //! let collection_entries = CollectionEntries { //! ids: vec!["demo-id-1", "demo-id-2"], @@ -150,8 +159,9 @@ //! SentenceEmbeddingsModelType::AllMiniLmL6V2 //! ).create_model()?; //! -//! collection.upsert(collection_entries, Some(Box::new(sbert_embeddings)))?; +//! collection.upsert(collection_entries, Some(Box::new(sbert_embeddings))).await?; //!# Ok(()) //!# } //! ``` -pub mod v1; + +pub mod v2; diff --git a/src/v1/api.rs b/src/v1/api.rs deleted file mode 100644 index 2048171..0000000 --- a/src/v1/api.rs +++ /dev/null @@ -1,118 +0,0 @@ -use super::commons::Result; -use base64::prelude::*; -use minreq::Response; -use serde_json::Value; - -/// Which header to send the token if using `ChromaAuthMethod::TokenAuth`. -#[derive(Clone, Debug)] -pub enum ChromaTokenHeader { - /// Authorization: Bearer - Authorization, - /// X-Chroma-Token - XChromaToken, -} - -/// Authentication options, currently only supported in server/client mode. -#[derive(Clone, Debug)] -pub enum ChromaAuthMethod { - /// No authentication - None, - - /// Basic authentication: RFC 7617 - BasicAuth { username: String, password: String }, - - /// Static token authentication: RFC 6750 - TokenAuth { - token: String, - header: ChromaTokenHeader, - }, -} - -impl Default for ChromaAuthMethod { - fn default() -> Self { - Self::None - } -} - -#[derive(Clone, Default, Debug)] -pub(super) struct APIClientV1 { - pub(super) api_endpoint: String, - pub(super) auth_method: ChromaAuthMethod, -} - -impl APIClientV1 { - pub fn new(endpoint: String, auth_method: ChromaAuthMethod) -> Self { - Self { - api_endpoint: format!("{}/api/v1", endpoint), - auth_method, - } - } - - pub fn post(&self, path: &str, json_body: Option) -> Result { - self.send_request("POST", path, json_body) - } - - pub fn get(&self, path: &str) -> Result { - self.send_request("GET", path, None) - } - - pub fn put(&self, path: &str, json_body: Option) -> Result { - self.send_request("PUT", path, json_body) - } - - pub fn delete(&self, path: &str) -> Result { - self.send_request("DELETE", path, None) - } - - fn send_request(&self, method: &str, path: &str, json_body: Option) -> Result { - let url = format!( - "{api_endpoint}{path}", - api_endpoint = self.api_endpoint, - path = path - ); - - let request = match method { - "POST" => minreq::post(url), - "PUT" => minreq::put(url), - "DELETE" => minreq::delete(url), - _ => minreq::get(url), - }; - - let request = if let Some(body) = json_body { - request - .with_header("Content-Type", "application/json") - .with_json(&body)? - } else { - request - }; - - let request = match &self.auth_method { - ChromaAuthMethod::None => request, - ChromaAuthMethod::BasicAuth { username, password } => { - let credentials = BASE64_STANDARD.encode(format!("{username}:{password}")); - request.with_header("Authorization", format!("Basic {credentials}")) - } - ChromaAuthMethod::TokenAuth { - token, - header: token_header, - } => match token_header { - ChromaTokenHeader::Authorization => { - request.with_header("Authorization", format!("Bearer {token}")) - } - ChromaTokenHeader::XChromaToken => request.with_header("X-Chroma-Token", token), - }, - }; - - let res = request.send()?; - - match res.status_code { - 200..=299 => Ok(res), - _ => anyhow::bail!( - "{} {}: {}", - res.status_code, - res.reason_phrase, - res.as_str().unwrap() - ), - } - } -} diff --git a/src/v2/api.rs b/src/v2/api.rs new file mode 100644 index 0000000..dd36787 --- /dev/null +++ b/src/v2/api.rs @@ -0,0 +1,107 @@ +use super::commons::Result; +use base64::prelude::*; +use reqwest::{Client, Response, Method}; +use serde_json::Value; + +#[derive(Clone, Debug)] +pub enum ChromaTokenHeader { + Authorization, + XChromaToken, +} + +#[derive(Clone, Debug)] +pub enum ChromaAuthMethod { + None, + BasicAuth { username: String, password: String }, + TokenAuth { + token: String, + header: ChromaTokenHeader, + }, +} + +impl Default for ChromaAuthMethod { + fn default() -> Self { + Self::None + } +} + +#[derive(Clone, Default, Debug)] +pub(super) struct APIClientAsync { + pub(super) api_endpoint: String, + pub(super) auth_method: ChromaAuthMethod, + client: Client, +} + +impl APIClientAsync { + pub fn new(endpoint: String, auth_method: ChromaAuthMethod) -> Self { + Self { + api_endpoint: format!("{}/api/v1", endpoint), + auth_method, + client: Client::new(), + } + } + + pub async fn post(&self, path: &str, json_body: Option) -> Result { + self.send_request(Method::POST, path, json_body).await + } + + pub async fn get(&self, path: &str) -> Result { + self.send_request(Method::GET, path, None).await + } + + pub async fn put(&self, path: &str, json_body: Option) -> Result { + self.send_request(Method::PUT, path, json_body).await + } + + pub async fn delete(&self, path: &str) -> Result { + self.send_request(Method::DELETE, path, None).await + } + + async fn send_request(&self, method: Method, path: &str, json_body: Option) -> Result { + let url = format!("{}{}", self.api_endpoint, path); + + let mut request = self.client + .request(method, &url); + + // Add auth headers if needed + match &self.auth_method { + ChromaAuthMethod::None => {}, + ChromaAuthMethod::BasicAuth { username, password } => { + let credentials = BASE64_STANDARD.encode(format!("{username}:{password}")); + request = request.header("Authorization", format!("Basic {credentials}")); + } + ChromaAuthMethod::TokenAuth { token, header } => { + match header { + ChromaTokenHeader::Authorization => { + request = request.header("Authorization", format!("Bearer {token}")); + } + ChromaTokenHeader::XChromaToken => { + request = request.header("X-Chroma-Token", token); + } + } + } + } + + // Add JSON body if present + if let Some(body) = json_body { + request = request + .header("Content-Type", "application/json") + .json(&body); + } + + let response = request.send().await?; + let status = response.status(); + + if status.is_success() { + Ok(response) + } else { + let error_text = response.text().await?; + anyhow::bail!( + "{} {}: {}", + status.as_u16(), + status.canonical_reason().unwrap_or("Unknown"), + error_text + ) + } + } +} diff --git a/src/v1/client.rs b/src/v2/client.rs similarity index 71% rename from src/v1/client.rs rename to src/v2/client.rs index 00f84df..7d9d9dc 100644 --- a/src/v1/client.rs +++ b/src/v2/client.rs @@ -2,7 +2,7 @@ use std::sync::Arc; pub use super::api::{ChromaAuthMethod, ChromaTokenHeader}; use super::{ - api::APIClientV1, + api::APIClientAsync, commons::{Metadata, Result}, ChromaCollection, }; @@ -14,7 +14,7 @@ const DEFAULT_ENDPOINT: &str = "http://localhost:8000"; // A client representation for interacting with ChromaDB. pub struct ChromaClient { - api: Arc, + api: Arc, } /// The options for instantiating ChromaClient. @@ -35,7 +35,7 @@ impl ChromaClient { }; ChromaClient { - api: Arc::new(APIClientV1::new(endpoint, auth)), + api: Arc::new(APIClientAsync::new(endpoint, auth)), } } @@ -51,7 +51,7 @@ impl ChromaClient { /// /// * If the collection already exists and get_or_create is false /// * If the collection name is invalid - pub fn create_collection( + pub async fn create_collection( &self, name: &str, metadata: Option, @@ -62,8 +62,8 @@ impl ChromaClient { "metadata": metadata, "get_or_create": get_or_create, }); - let response = self.api.post("/collections", Some(request_body))?; - let mut collection = response.json::()?; + let response = self.api.post("/collections", Some(request_body)).await?; + let mut collection = response.json::().await?; collection.api = self.api.clone(); Ok(collection) } @@ -78,18 +78,18 @@ impl ChromaClient { /// # Errors /// /// * If the collection name is invalid - pub fn get_or_create_collection( + pub async fn get_or_create_collection( &self, name: &str, metadata: Option, ) -> Result { - self.create_collection(name, metadata, true) + self.create_collection(name, metadata, true).await } /// List all collections - pub fn list_collections(&self) -> Result> { - let response = self.api.get("/collections")?; - let collections = response.json::>()?; + pub async fn list_collections(&self) -> Result> { + let response = self.api.get("/collections").await?; + let collections = response.json::>().await?; let collections = collections .into_iter() .map(|mut collection| { @@ -110,9 +110,9 @@ impl ChromaClient { /// /// * If the collection name is invalid /// * If the collection does not exist - pub fn get_collection(&self, name: &str) -> Result { - let response = self.api.get(&format!("/collections/{}", name))?; - let mut collection = response.json::()?; + pub async fn get_collection(&self, name: &str) -> Result { + let response = self.api.get(&format!("/collections/{}", name)).await?; + let mut collection = response.json::().await?; collection.api = self.api.clone(); Ok(collection) } @@ -127,29 +127,29 @@ impl ChromaClient { /// /// * If the collection name is invalid /// * If the collection does not exist - pub fn delete_collection(&self, name: &str) -> Result<()> { - self.api.delete(&format!("/collections/{}", name))?; + pub async fn delete_collection(&self, name: &str) -> Result<()> { + self.api.delete(&format!("/collections/{}", name)).await?; Ok(()) } /// Resets the database. This will delete all collections and entries. - pub fn reset(&self) -> Result { - let respones = self.api.post("/reset", None)?; - let result = respones.json::()?; + pub async fn reset(&self) -> Result { + let respones = self.api.post("/reset", None).await?; + let result = respones.json::().await?; Ok(result) } /// The version of Chroma - pub fn version(&self) -> Result { - let response = self.api.get("/version")?; - let version = response.json::()?; + pub async fn version(&self) -> Result { + let response = self.api.get("/version").await?; + let version = response.json::().await?; Ok(version) } /// Get the current time in nanoseconds since epoch. Used to check if the server is alive. - pub fn heartbeat(&self) -> Result { - let response = self.api.get("/heartbeat")?; - let json = response.json::()?; + pub async fn heartbeat(&self) -> Result { + let response = self.api.get("/heartbeat").await?; + let json = response.json::().await?; Ok(json.heartbeat) } } @@ -163,74 +163,85 @@ struct HeartbeatResponse { #[cfg(test)] mod tests { use super::*; + use tokio; + const TEST_COLLECTION: &str = "8-recipies-for-octopus"; - #[test] - fn test_heartbeat() { + #[tokio::test] + async fn test_heartbeat() { let client: ChromaClient = ChromaClient::new(Default::default()); - let heartbeat = client.heartbeat().unwrap(); + let heartbeat = client.heartbeat().await.unwrap(); assert!(heartbeat > 0); } - #[test] - fn test_version() { + #[tokio::test] + async fn test_version() { let client: ChromaClient = ChromaClient::new(Default::default()); - let version = client.version().unwrap(); + let version = client.version().await.unwrap(); assert_eq!(version.split('.').count(), 3); } - #[test] - fn test_reset() { + #[tokio::test] + async fn test_reset() { let client: ChromaClient = ChromaClient::new(Default::default()); - let result = client.reset(); + let result = client.reset().await; assert!(result.is_err_and(|e| e .to_string() .contains("Resetting is not allowed by this configuration"))); } - #[test] - fn test_create_collection() { + #[tokio::test] + async fn test_create_collection() { let client: ChromaClient = ChromaClient::new(Default::default()); let result = client .create_collection(TEST_COLLECTION, None, true) + .await .unwrap(); assert_eq!(result.name(), TEST_COLLECTION); } - #[test] - fn test_get_collection() { + #[tokio::test] + async fn test_get_collection() { let client: ChromaClient = ChromaClient::new(Default::default()); + + const GET_TEST_COLLECTION: &str = "100-recipes-for-octopus"; + + client + .create_collection(GET_TEST_COLLECTION, None, true) + .await + .unwrap(); - let collection = client.get_collection(TEST_COLLECTION).unwrap(); - assert_eq!(collection.name(), TEST_COLLECTION); + let collection = client.get_collection(GET_TEST_COLLECTION).await.unwrap(); + assert_eq!(collection.name(), GET_TEST_COLLECTION); } - #[test] - fn test_list_collection() { + #[tokio::test] + async fn test_list_collection() { let client: ChromaClient = ChromaClient::new(Default::default()); - let result = client.list_collections().unwrap(); + let result = client.list_collections().await.unwrap(); assert!(result.len() > 0); } - #[test] - fn test_delete_collection() { + #[tokio::test] + async fn test_delete_collection() { let client: ChromaClient = ChromaClient::new(Default::default()); const DELETE_TEST_COLLECTION: &str = "6-recipies-for-octopus"; client .get_or_create_collection(DELETE_TEST_COLLECTION, None) + .await .unwrap(); - let collection = client.delete_collection(DELETE_TEST_COLLECTION); + let collection = client.delete_collection(DELETE_TEST_COLLECTION).await; assert!(collection.is_ok()); - let collection = client.delete_collection(DELETE_TEST_COLLECTION); + let collection = client.delete_collection(DELETE_TEST_COLLECTION).await; assert!(collection.is_err()); } } diff --git a/src/v1/collection.rs b/src/v2/collection.rs similarity index 87% rename from src/v1/collection.rs rename to src/v2/collection.rs index 657c70d..41bcc2f 100644 --- a/src/v1/collection.rs +++ b/src/v2/collection.rs @@ -4,7 +4,7 @@ use serde_json::{json, Value}; use std::{collections::HashSet, sync::Arc, vec}; use super::{ - api::APIClientV1, + api::APIClientAsync, commons::{Documents, Embedding, Embeddings, Metadata, Metadatas, Result}, embeddings::EmbeddingFunction, }; @@ -13,7 +13,7 @@ use super::{ #[derive(Deserialize, Debug)] pub struct ChromaCollection { #[serde(skip)] - pub(super) api: Arc, + pub(super) api: Arc, pub(super) id: String, pub(super) metadata: Option, pub(super) name: String, @@ -36,10 +36,10 @@ impl ChromaCollection { } /// The total number of embeddings added to the database. - pub fn count(&self) -> Result { + pub async fn count(&self) -> Result { let path = format!("/collections/{}/count", self.id); - let response = self.api.get(&path)?; - let count = response.json::()?; + let response = self.api.get(&path).await?; + let count = response.json::().await?; Ok(count) } @@ -53,13 +53,13 @@ impl ChromaCollection { /// # Errors /// /// * If the collection name is invalid - pub fn modify(&self, name: Option<&str>, metadata: Option<&Metadata>) -> Result<()> { + pub async fn modify(&self, name: Option<&str>, metadata: Option<&Metadata>) -> Result<()> { let json_body = json!({ "new_name": name, "new_metadata": metadata, }); let path = format!("/collections/{}", self.id); - self.api.put(&path, Some(json_body))?; + self.api.put(&path, Some(json_body)).await?; Ok(()) } @@ -82,12 +82,12 @@ impl ChromaCollection { /// * If you provide an embedding function and don't provide documents /// * If you provide both embeddings and embedding_function /// - pub fn add( + pub async fn add<'a>( &self, - collection_entries: CollectionEntries, + collection_entries: CollectionEntries<'a>, embedding_function: Option>, ) -> Result { - let collection_entries = validate(true, collection_entries, embedding_function)?; + let collection_entries = validate(true, collection_entries, embedding_function).await?; let CollectionEntries { ids, @@ -104,8 +104,8 @@ impl ChromaCollection { }); let path = format!("/collections/{}/add", self.id); - let response = self.api.post(&path, Some(json_body))?; - let response = response.json::()?; + let response = self.api.post(&path, Some(json_body)).await?; + let response = response.json::().await?; Ok(response) } @@ -129,12 +129,12 @@ impl ChromaCollection { /// * If you provide an embedding function and don't provide documents /// * If you provide both embeddings and embedding_function /// - pub fn upsert( + pub async fn upsert<'a>( &self, - collection_entries: CollectionEntries, + collection_entries: CollectionEntries<'a>, embedding_function: Option>, ) -> Result { - let collection_entries = validate(true, collection_entries, embedding_function)?; + let collection_entries = validate(true, collection_entries, embedding_function).await?; let CollectionEntries { ids, @@ -151,8 +151,8 @@ impl ChromaCollection { }); let path = format!("/collections/{}/upsert", self.id); - let response = self.api.post(&path, Some(json_body))?; - let response = response.json::()?; + let response = self.api.post(&path, Some(json_body)).await?; + let response = response.json::().await?; Ok(response) } @@ -168,7 +168,7 @@ impl ChromaCollection { /// * `where_document` - Used to filter by the documents. E.g. {"$contains": "hello"}. See for more information on document content filters. Optional. /// * `include` - A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`. Ids are always included. Defaults to `["metadatas", "documents"]`. Optional. /// - pub fn get(&self, get_options: GetOptions) -> Result { + pub async fn get(&self, get_options: GetOptions) -> Result { let GetOptions { ids, where_metadata, @@ -192,8 +192,8 @@ impl ChromaCollection { .retain(|_, v| !v.is_null()); let path = format!("/collections/{}/get", self.id); - let response = self.api.post(&path, Some(json_body))?; - let get_result = response.json::()?; + let response = self.api.post(&path, Some(json_body)).await?; + let get_result = response.json::().await?; Ok(get_result) } @@ -215,12 +215,12 @@ impl ChromaCollection { /// * If you provide an embedding function and don't provide documents /// * If you provide both embeddings and embedding_function /// - pub fn update( + pub async fn update<'a>( &self, - collection_entries: CollectionEntries, + collection_entries: CollectionEntries<'a>, embedding_function: Option>, - ) -> Result { - let collection_entries = validate(false, collection_entries, embedding_function)?; + ) -> Result<()> { + let collection_entries = validate(false, collection_entries, embedding_function).await?; let CollectionEntries { ids, @@ -237,10 +237,14 @@ impl ChromaCollection { }); let path = format!("/collections/{}/update", self.id); - let response = self.api.post(&path, Some(json_body))?; - let response = response.json::()?; + let response = self.api.post(&path, Some(json_body)).await?; - Ok(response) + match response.error_for_status() { + Ok(_) => Ok(()), + Err(e) => { + Err(e.into()) + } + } } ///Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts. @@ -261,9 +265,9 @@ impl ChromaCollection { /// * If you provide both query_embeddings and query_texts /// * If you provide query_texts and don't provide an embedding function when embeddings is None /// - pub fn query( + pub async fn query<'a>( &self, - query_options: QueryOptions, + query_options: QueryOptions<'a>, embedding_function: Option>, ) -> Result { let QueryOptions { @@ -284,7 +288,7 @@ impl ChromaCollection { query_embeddings = Some( embedding_function .unwrap() - .embed(query_texts.as_ref().unwrap())?, + .embed(query_texts.as_ref().unwrap()).await?, ); }; @@ -302,8 +306,8 @@ impl ChromaCollection { .retain(|_, v| !v.is_null()); let path = format!("/collections/{}/query", self.id); - let response = self.api.post(&path, Some(json_body))?; - let query_result = response.json::()?; + let response = self.api.post(&path, Some(json_body)).await?; + let query_result = response.json::().await?; Ok(query_result) } @@ -313,7 +317,7 @@ impl ChromaCollection { /// /// * `limit` - The number of entries to return. /// - pub fn peek(&self, limit: usize) -> Result { + pub async fn peek(&self, limit: usize) -> Result { let get_query = GetOptions { ids: vec![], where_metadata: None, @@ -322,7 +326,7 @@ impl ChromaCollection { where_document: None, include: None, }; - self.get(get_query) + self.get(get_query).await } /// Delete the embeddings based on ids and/or a where filter. Deletes all the entries if None are provided @@ -333,12 +337,12 @@ impl ChromaCollection { /// * `where_metadata` - Used to filter deletion by metadata. E.g. {"$and": ["color" : "red", "price": {"$gte": 4.20}]}. Optional. /// * `where_document` - Used to filter the deletion by the document content. E.g. {$contains: "some text"}. Optional.. Optional. /// - pub fn delete( + pub async fn delete( &self, ids: Option>, where_metadata: Option, where_document: Option, - ) -> Result> { + ) -> Result<()> { let json_body = json!({ "ids": ids, "where": where_metadata, @@ -346,10 +350,14 @@ impl ChromaCollection { }); let path = format!("/collections/{}/delete", self.id); - let response = self.api.post(&path, Some(json_body))?; - let response = response.json::>()?; + let response = self.api.post(&path, Some(json_body)).await?; - Ok(response) + match response.error_for_status() { + Ok(_) => Ok(()), + Err(e) => { + Err(e.into()) + } + } } } @@ -398,9 +406,9 @@ pub struct CollectionEntries<'a> { pub embeddings: Option, } -fn validate( +async fn validate<'a>( require_embeddings_or_documents: bool, - collection_entries: CollectionEntries, + collection_entries: CollectionEntries<'a>, embedding_function: Option>, ) -> Result { let CollectionEntries { @@ -427,7 +435,7 @@ fn validate( embeddings = Some( embedding_function .unwrap() - .embed(documents.as_ref().unwrap())?, + .embed(documents.as_ref().unwrap()).await?, ); } @@ -467,7 +475,7 @@ fn validate( mod tests { use serde_json::json; - use crate::v1::{ + use crate::v2::{ collection::{CollectionEntries, GetOptions, QueryOptions}, embeddings::MockEmbeddingProvider, ChromaClient, @@ -475,17 +483,19 @@ mod tests { const TEST_COLLECTION: &str = "21-recipies-for-octopus"; - #[test] - fn test_modify_collection() { + #[tokio::test] + async fn test_modify_collection() { let client = ChromaClient::new(Default::default()); let collection = client .get_or_create_collection(TEST_COLLECTION, None) + .await .unwrap(); //Test for setting invalid collection name. Should fail. assert!(collection .modify(Some("new name for test collection"), None) + .await .is_err()); //Test for setting new metadata. Should pass. @@ -500,36 +510,17 @@ mod tests { .unwrap() ) ) + .await .is_ok()); } - #[test] - fn test_get_from_collection() { - let client = ChromaClient::new(Default::default()); - - let collection = client - .get_or_create_collection(TEST_COLLECTION, None) - .unwrap(); - assert!(collection.count().is_ok()); - - let get_query = GetOptions { - ids: vec![], - where_metadata: None, - limit: None, - offset: None, - where_document: None, - include: None, - }; - let get_result = collection.get(get_query).unwrap(); - assert_eq!(get_result.ids.len(), collection.count().unwrap()); - } - - #[test] - fn test_add_to_collection() { + #[tokio::test] + async fn test_add_to_collection() { let client = ChromaClient::new(Default::default()); let collection = client .get_or_create_collection(TEST_COLLECTION, None) + .await .unwrap(); let invalid_collection_entries = CollectionEntries { @@ -544,7 +535,7 @@ mod tests { Some(Box::new(MockEmbeddingProvider)), ); assert!( - response.is_err(), + response.await.is_err(), "Embeddings and documents cannot both be None" ); @@ -562,7 +553,7 @@ mod tests { Some(Box::new(MockEmbeddingProvider)), ); assert!( - response.is_err(), + response.await.is_err(), "IDs, embeddings, metadatas, and documents must all be the same length" ); @@ -580,7 +571,7 @@ mod tests { Some(Box::new(MockEmbeddingProvider)), ); assert!( - response.is_ok(), + response.await.is_ok(), "IDs, embeddings, metadatas, and documents must all be the same length" ); @@ -597,7 +588,7 @@ mod tests { invalid_collection_entries, Some(Box::new(MockEmbeddingProvider)), ); - assert!(response.is_err(), "Empty IDs not allowed"); + assert!(response.await.is_err(), "Empty IDs not allowed"); let invalid_collection_entries = CollectionEntries { ids: vec!["test".into(), "test".into()], @@ -610,7 +601,7 @@ mod tests { }; let response = collection.add(invalid_collection_entries, None); assert!( - response.is_err(), + response.await.is_err(), "Expected IDs to be unique. Duplicates not allowed" ); @@ -625,7 +616,7 @@ mod tests { }; let response = collection.add(collection_entries, None); assert!( - response.is_err(), + response.await.is_err(), "embedding_function cannot be None if documents are provided and embeddings are None" ); @@ -640,17 +631,18 @@ mod tests { }; let response = collection.add(collection_entries, Some(Box::new(MockEmbeddingProvider))); assert!( - response.is_ok(), + response.await.is_ok(), "Embeddings are computed by the embedding_function if embeddings are None and documents are provided" ); } - #[test] - fn test_upsert_collection() { + #[tokio::test] + async fn test_upsert_collection() { let client = ChromaClient::new(Default::default()); let collection = client .get_or_create_collection(TEST_COLLECTION, None) + .await .unwrap(); let invalid_collection_entries = CollectionEntries { @@ -665,7 +657,7 @@ mod tests { Some(Box::new(MockEmbeddingProvider)), ); assert!( - response.is_err(), + response.await.is_err(), "Embeddings and documents cannot both be None" ); @@ -683,7 +675,7 @@ mod tests { Some(Box::new(MockEmbeddingProvider)), ); assert!( - response.is_err(), + response.await.is_err(), "IDs, embeddings, metadatas, and documents must all be the same length" ); @@ -701,7 +693,7 @@ mod tests { Some(Box::new(MockEmbeddingProvider)), ); assert!( - response.is_ok(), + response.await.is_ok(), "IDs, embeddings, metadatas, and documents must all be the same length" ); @@ -718,7 +710,7 @@ mod tests { invalid_collection_entries, Some(Box::new(MockEmbeddingProvider)), ); - assert!(response.is_err(), "Empty IDs not allowed"); + assert!(response.await.is_err(), "Empty IDs not allowed"); let invalid_collection_entries = CollectionEntries { ids: vec!["test".into(), "test".into()], @@ -731,7 +723,7 @@ mod tests { }; let response = collection.upsert(invalid_collection_entries, None); assert!( - response.is_err(), + response.await.is_err(), "Expected IDs to be unique. Duplicates not allowed" ); @@ -746,7 +738,7 @@ mod tests { }; let response = collection.upsert(collection_entries, None); assert!( - response.is_err(), + response.await.is_err(), "embedding_function cannot be None if documents are provided and embeddings are None" ); @@ -761,17 +753,18 @@ mod tests { }; let response = collection.upsert(collection_entries, Some(Box::new(MockEmbeddingProvider))); assert!( - response.is_ok(), + response.await.is_ok(), "Embeddings are computed by the embedding_function if embeddings are None and documents are provided" ); } - #[test] - fn test_update_collection() { + #[tokio::test] + async fn test_update_collection() { let client = ChromaClient::new(Default::default()); let collection = client .get_or_create_collection(TEST_COLLECTION, None) + .await .unwrap(); let valid_collection_entries = CollectionEntries { @@ -784,7 +777,10 @@ mod tests { let response = collection.update( valid_collection_entries, Some(Box::new(MockEmbeddingProvider)), - ); + ).await; + + println!("{:?}", response); + assert!( response.is_ok(), "Embeddings and documents can both be None" @@ -804,7 +800,7 @@ mod tests { Some(Box::new(MockEmbeddingProvider)), ); assert!( - response.is_err(), + response.await.is_err(), "IDs, embeddings, metadatas, and documents must all be the same length" ); @@ -822,7 +818,7 @@ mod tests { Some(Box::new(MockEmbeddingProvider)), ); assert!( - response.is_ok(), + response.await.is_ok(), "IDs, embeddings, metadatas, and documents must all be the same length" ); @@ -839,7 +835,7 @@ mod tests { invalid_collection_entries, Some(Box::new(MockEmbeddingProvider)), ); - assert!(response.is_err(), "Empty IDs not allowed"); + assert!(response.await.is_err(), "Empty IDs not allowed"); let invalid_collection_entries = CollectionEntries { ids: vec!["test".into(), "test".into()], @@ -852,7 +848,7 @@ mod tests { }; let response = collection.update(invalid_collection_entries, None); assert!( - response.is_err(), + response.await.is_err(), "Expected IDs to be unique. Duplicates not allowed" ); @@ -867,7 +863,7 @@ mod tests { }; let response = collection.update(collection_entries, None); assert!( - response.is_err(), + response.await.is_err(), "embedding_function cannot be None if documents are provided and embeddings are None" ); @@ -882,19 +878,20 @@ mod tests { }; let response = collection.update(collection_entries, Some(Box::new(MockEmbeddingProvider))); assert!( - response.is_ok(), + response.await.is_ok(), "Embeddings are computed by the embedding_function if embeddings are None and documents are provided" ); } - #[test] - fn test_query_collection() { + #[tokio::test] + async fn test_query_collection() { let client = ChromaClient::new(Default::default()); let collection = client .get_or_create_collection(TEST_COLLECTION, None) + .await .unwrap(); - assert!(collection.count().is_ok()); + assert!(collection.count().await.is_ok()); let query = QueryOptions { query_texts: None, @@ -906,7 +903,7 @@ mod tests { }; let query_result = collection.query(query, None); assert!( - query_result.is_err(), + query_result.await.is_err(), "query_texts and query_embeddings cannot both be None" ); @@ -923,7 +920,7 @@ mod tests { }; let query_result = collection.query(query, Some(Box::new(MockEmbeddingProvider))); assert!( - query_result.is_ok(), + query_result.await.is_ok(), "query_embeddings will be computed from query_texts if embedding_function is provided" ); @@ -940,7 +937,7 @@ mod tests { }; let query_result = collection.query(query, Some(Box::new(MockEmbeddingProvider))); assert!( - query_result.is_err(), + query_result.await.is_err(), "Both query_embeddings and query_texts cannot be provided" ); @@ -954,8 +951,39 @@ mod tests { }; let query_result = collection.query(query, None); assert!( - query_result.is_ok(), + query_result.await.is_ok(), "Use provided query_embeddings if embedding_function is None" ); } + + #[tokio::test] + async fn test_delete_from_collection() { + let client = ChromaClient::new(Default::default()); + + let collection = client + .get_or_create_collection(TEST_COLLECTION, None) + .await + .unwrap(); + + let valid_collection_entries = CollectionEntries { + ids: vec!["123ABC".into()], + metadatas: None, + documents: Some(vec![ + "Document content 1".into(), + ]), + embeddings: None, + }; + + let response = collection.add( + valid_collection_entries, + Some(Box::new(MockEmbeddingProvider)), + ); + assert!(response.await.is_ok()); + + let response = collection.delete(Some(vec!["123ABC"]), None, None).await; + + assert!( + response.is_ok(), + ); + } } diff --git a/src/v1/commons.rs b/src/v2/commons.rs similarity index 100% rename from src/v1/commons.rs rename to src/v2/commons.rs diff --git a/src/v1/embeddings/bert.rs b/src/v2/embeddings/bert.rs similarity index 87% rename from src/v1/embeddings/bert.rs rename to src/v2/embeddings/bert.rs index a42acf4..7c6d5a1 100644 --- a/src/v1/embeddings/bert.rs +++ b/src/v2/embeddings/bert.rs @@ -10,16 +10,17 @@ impl EmbeddingFunction for SentenceEmbeddingsModel { #[cfg(test)] mod tests { - use crate::v1::collection::CollectionEntries; - use crate::v1::ChromaClient; + use crate::v2::collection::CollectionEntries; + use crate::v2::ChromaClient; use super::*; - #[test] - fn test_sbert_embeddings() { + #[tokio::test] + async fn test_sbert_embeddings() { let client = ChromaClient::new(Default::default()); let collection = client .get_or_create_collection("sbert-test-collection", None) + .await .unwrap(); let sbert_embedding = @@ -42,6 +43,7 @@ mod tests { collection .upsert(collection_entries, Some(Box::new(sbert_embedding))) + .await .unwrap(); } } diff --git a/src/v1/embeddings/mod.rs b/src/v2/embeddings/mod.rs similarity index 63% rename from src/v1/embeddings/mod.rs rename to src/v2/embeddings/mod.rs index 3dd6167..7eadbf4 100644 --- a/src/v1/embeddings/mod.rs +++ b/src/v2/embeddings/mod.rs @@ -1,5 +1,6 @@ use super::commons::Embedding; use anyhow::Result; +use async_trait::async_trait; #[cfg(feature = "bert")] pub mod bert; @@ -7,15 +8,17 @@ pub mod bert; #[cfg(feature = "openai")] pub mod openai; +#[async_trait] pub trait EmbeddingFunction { - fn embed(&self, docs: &[&str]) -> Result>; + async fn embed(&self, docs: &[&str]) -> Result>; } #[derive(Clone)] pub(super) struct MockEmbeddingProvider; +#[async_trait] impl EmbeddingFunction for MockEmbeddingProvider { - fn embed(&self, docs: &[&str]) -> Result> { + async fn embed(&self, docs: &[&str]) -> Result> { Ok(docs.iter().map(|_| vec![0.0_f32; 768]).collect()) } } diff --git a/src/v1/embeddings/openai.rs b/src/v2/embeddings/openai.rs similarity index 63% rename from src/v1/embeddings/openai.rs rename to src/v2/embeddings/openai.rs index d6203b0..de775cf 100644 --- a/src/v1/embeddings/openai.rs +++ b/src/v2/embeddings/openai.rs @@ -1,10 +1,12 @@ +use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use serde_json::Value; use super::EmbeddingFunction; -use crate::v1::commons::Embedding; +use crate::v2::commons::Embedding; const OPENAI_EMBEDDINGS_ENDPOINT: &str = "https://api.openai.com/v1/embeddings"; -const OPENAI_EMBEDDINGS_MODEL: &str = "text-embedding-ada-002"; +const OPENAI_EMBEDDINGS_MODEL: &str = "text-embedding-3-small"; #[derive(Debug, Deserialize)] struct EmbeddingData { @@ -24,10 +26,10 @@ struct EmbeddingResponse { /// Represents the OpenAI Embeddings provider pub struct OpenAIEmbeddings { - config: OpenAIConfig, + config: OpenAIConfig } -/// Defaults to the "text-embedding-ada-002" model +/// Defaults to the "text-embedding-3-small" model /// The API key can be set in the OPENAI_API_KEY environment variable pub struct OpenAIConfig { pub api_endpoint: String, @@ -50,53 +52,58 @@ impl OpenAIEmbeddings { Self { config } } - fn post(&self, json_body: T) -> anyhow::Result { - let res = minreq::post(&self.config.api_endpoint) - .with_header("Content-Type", "application/json") - .with_header("Authorization", format!("Bearer {}", self.config.api_key)) - .with_json(&json_body)? - .send()?; - - match res.status_code { - 200..=299 => Ok(res), - _ => anyhow::bail!( - "{} {}: {}", - res.status_code, - res.reason_phrase, - res.as_str().unwrap() - ), + async fn post(&self, json_body: T) -> anyhow::Result { + let client = reqwest::Client::new(); + let res = client.post(&self.config.api_endpoint) + .body("the exact body that is sent") + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.config.api_key)) + .json(&json_body) + .send() + .await?; + + match res.error_for_status() { + Ok(res) => { + Ok(res.json().await?) + }, + Err(e) => { + Err(e.into()) + } } } } +#[async_trait] impl EmbeddingFunction for OpenAIEmbeddings { - fn embed(&self, docs: &[&str]) -> anyhow::Result> { + async fn embed(&self, docs: &[&str]) -> anyhow::Result> { let mut embeddings = Vec::new(); - docs.iter().for_each(|doc| { + for doc in docs { let req = EmbeddingRequest { model: &self.config.model, input: &doc, }; - let res = self.post(req).unwrap(); - let body = res.json::().unwrap(); + let res = self.post(req).await?; + let body = serde_json::from_value::(res)?; embeddings.push(body.data[0].embedding.clone()); - }); + } + Ok(embeddings) } } #[cfg(test)] mod tests { - use crate::v1::collection::CollectionEntries; + use crate::v2::collection::CollectionEntries; use super::*; - use crate::v1::ChromaClient; + use crate::v2::ChromaClient; - #[test] - fn test_openai_embeddings() { + #[tokio::test] + async fn test_openai_embeddings() { let client = ChromaClient::new(Default::default()); let collection = client .get_or_create_collection("open-ai-test-collection", None) + .await .unwrap(); let openai_embeddings = OpenAIEmbeddings::new(Default::default()); @@ -115,6 +122,7 @@ mod tests { collection .upsert(collection_entries, Some(Box::new(openai_embeddings))) + .await .unwrap(); } } diff --git a/src/v1/mod.rs b/src/v2/mod.rs similarity index 100% rename from src/v1/mod.rs rename to src/v2/mod.rs