-
Notifications
You must be signed in to change notification settings - Fork 510
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added support for weaviate vector databse (#493)
* added support for weaviate vector databse Signed-off-by: pranaychandekar <[email protected]> * added support for in local db for weaviate vector store Signed-off-by: pranaychandekar <[email protected]> * added unit test case for weaviate vector store Signed-off-by: pranaychandekar <[email protected]> * resolved unit test case error for weaviate vector store Signed-off-by: pranaychandekar <[email protected]> * increased code coverage resolved pylint issues pylint: disabled C0413 Signed-off-by: pranaychandekar <[email protected]> --------- Signed-off-by: pranaychandekar <[email protected]>
- Loading branch information
1 parent
f60c303
commit 07db497
Showing
6 changed files
with
255 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
from typing import List, Optional, Tuple, Union | ||
import numpy as np | ||
|
||
from gptcache.utils import import_weaviate | ||
from gptcache.utils.log import gptcache_log | ||
from gptcache.manager.vector_data.base import VectorBase, VectorData | ||
|
||
import_weaviate() | ||
|
||
from weaviate import Client | ||
from weaviate.auth import AuthCredentials | ||
from weaviate.config import Config | ||
from weaviate.embedded import EmbeddedOptions | ||
from weaviate.types import NUMBERS | ||
|
||
|
||
class Weaviate(VectorBase): | ||
""" | ||
vector store: Weaviate | ||
""" | ||
|
||
TIMEOUT_TYPE = Union[Tuple[NUMBERS, NUMBERS], NUMBERS] | ||
|
||
def __init__( | ||
self, | ||
url: Optional[str] = None, | ||
auth_client_secret: Optional[AuthCredentials] = None, | ||
timeout_config: TIMEOUT_TYPE = (10, 60), | ||
proxies: Union[dict, str, None] = None, | ||
trust_env: bool = False, | ||
additional_headers: Optional[dict] = None, | ||
startup_period: Optional[int] = 5, | ||
embedded_options: Optional[EmbeddedOptions] = None, | ||
additional_config: Optional[Config] = None, | ||
top_k: Optional[int] = 1, | ||
) -> None: | ||
|
||
if url is None and embedded_options is None: | ||
embedded_options = EmbeddedOptions() | ||
|
||
self.client = Client( | ||
url=url, | ||
auth_client_secret=auth_client_secret, | ||
timeout_config=timeout_config, | ||
proxies=proxies, | ||
trust_env=trust_env, | ||
additional_headers=additional_headers, | ||
startup_period=startup_period, | ||
embedded_options=embedded_options, | ||
additional_config=additional_config, | ||
) | ||
|
||
self._create_class() | ||
self.top_k = top_k | ||
|
||
def _create_class(self): | ||
class_schema = self._get_default_class_schema() | ||
|
||
self.class_name = class_schema.get("class") | ||
|
||
if self.client.schema.exists(self.class_name): | ||
gptcache_log.warning( | ||
"The %s collection already exists, and it will be used directly.", | ||
self.class_name, | ||
) | ||
else: | ||
self.client.schema.create_class(class_schema) | ||
|
||
@staticmethod | ||
def _get_default_class_schema() -> dict: | ||
return { | ||
"class": "GPTCache", | ||
"description": "LLM response cache", | ||
"properties": [ | ||
{ | ||
"name": "data_id", | ||
"dataType": ["int"], | ||
"description": "The data-id generated by GPTCache for vectors.", | ||
} | ||
], | ||
"vectorIndexConfig": {"distance": "cosine"}, | ||
} | ||
|
||
def mul_add(self, datas: List[VectorData]): | ||
with self.client.batch(batch_size=100, dynamic=True) as batch: | ||
for data in datas: | ||
properties = { | ||
"data_id": data.id, | ||
} | ||
|
||
batch.add_data_object( | ||
data_object=properties, class_name=self.class_name, vector=data.data | ||
) | ||
|
||
def search(self, data: np.ndarray, top_k: int = -1): | ||
if top_k == -1: | ||
top_k = self.top_k | ||
|
||
result = ( | ||
self.client.query.get(class_name=self.class_name, properties=["data_id"]) | ||
.with_near_vector(content={"vector": data}) | ||
.with_additional(["distance"]) | ||
.with_limit(top_k) | ||
.do() | ||
) | ||
|
||
return list( | ||
map( | ||
lambda x: (x["_additional"]["distance"], x["data_id"]), | ||
result["data"]["Get"][self.class_name], | ||
) | ||
) | ||
|
||
def _get_uuids(self, data_ids): | ||
uuid_list = [] | ||
|
||
for data_id in data_ids: | ||
res = ( | ||
self.client.query.get( | ||
class_name=self.class_name, properties=["data_id"] | ||
) | ||
.with_where( | ||
{"path": ["data_id"], "operator": "Equal", "valueInt": data_id} | ||
) | ||
.with_additional(["id"]) | ||
.do() | ||
) | ||
|
||
uuid_list.append( | ||
res["data"]["Get"][self.class_name][0]["_additional"]["id"] | ||
) | ||
|
||
return uuid_list | ||
|
||
def delete(self, ids): | ||
uuids = self._get_uuids(ids) | ||
|
||
for uuid in uuids: | ||
self.client.data_object.delete(class_name=self.class_name, uuid=uuid) | ||
|
||
def rebuild(self, ids=None): | ||
return | ||
|
||
def flush(self): | ||
self.client.batch.flush() | ||
|
||
def close(self): | ||
self.flush() | ||
|
||
def get_embeddings(self, data_id: int): | ||
results = ( | ||
self.client.query.get(class_name=self.class_name, properties=["data_id"]) | ||
.with_where( | ||
{ | ||
"path": ["data_id"], | ||
"operator": "Equal", | ||
"valueInt": data_id, | ||
} | ||
) | ||
.with_additional(["vector"]) | ||
.with_limit(1) | ||
.do() | ||
) | ||
|
||
results = results["data"]["Get"][self.class_name] | ||
|
||
if len(results) < 1: | ||
return None | ||
|
||
vec_emb = np.asarray(results[0]["_additional"]["vector"], dtype="float32") | ||
return vec_emb | ||
|
||
def update_embeddings(self, data_id: int, emb: np.ndarray): | ||
self.delete([data_id]) | ||
|
||
properties = { | ||
"data_id": data_id, | ||
} | ||
|
||
self.client.data_object.create( | ||
data_object=properties, class_name=self.class_name, vector=emb | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import unittest | ||
import numpy as np | ||
|
||
from gptcache.manager.vector_data import VectorBase | ||
from gptcache.manager.vector_data.base import VectorData | ||
|
||
|
||
class TestWeaviateDB(unittest.TestCase): | ||
def test_normal(self): | ||
size = 1000 | ||
dim = 512 | ||
top_k = 10 | ||
|
||
db = VectorBase( | ||
"weaviate", | ||
top_k=top_k | ||
) | ||
|
||
db._create_class() | ||
data = np.random.randn(size, dim).astype(np.float32) | ||
db.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))]) | ||
self.assertEqual(len(db.search(data[0])), top_k) | ||
db.mul_add([VectorData(id=size, data=data[0])]) | ||
ret = db.search(data[0]) | ||
self.assertIn(ret[0][1], [0, size]) | ||
self.assertIn(ret[1][1], [0, size]) | ||
db.delete([0, 1, 2, 3, 4, 5, size]) | ||
ret = db.search(data[0]) | ||
self.assertNotIn(ret[0][1], [0, size]) | ||
db.rebuild() | ||
db.update_embeddings(6, data[7]) | ||
emb = db.get_embeddings(6) | ||
self.assertEqual(emb.tolist(), data[7].tolist()) | ||
emb = db.get_embeddings(0) | ||
self.assertIsNone(emb) | ||
db.close() |