Skip to content

Commit

Permalink
added support for weaviate vector databse (#493)
Browse files Browse the repository at this point in the history
* added support for weaviate vector databse

Signed-off-by: pranaychandekar <[email protected]>

* added support for in local db for weaviate vector store

Signed-off-by: pranaychandekar <[email protected]>

* added unit test case for weaviate vector store

Signed-off-by: pranaychandekar <[email protected]>

* resolved unit test case error for weaviate vector store

Signed-off-by: pranaychandekar <[email protected]>

* increased code coverage
resolved pylint issues

pylint: disabled C0413

Signed-off-by: pranaychandekar <[email protected]>

---------

Signed-off-by: pranaychandekar <[email protected]>
  • Loading branch information
pranaychandekar authored Jul 22, 2023
1 parent f60c303 commit 07db497
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 2 deletions.
5 changes: 3 additions & 2 deletions examples/data_manager/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def run():
'milvus',
'chromadb',
'docarray',
'redis'
'redis',
'weaviate',
]
for vector_store in vector_stores:
cache_base = CacheBase('sqlite')
Expand All @@ -40,4 +41,4 @@ def run():


if __name__ == '__main__':
run()
run()
28 changes: 28 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

COLLECTION_NAME = "gptcache"

WEAVIATE_TIMEOUT_CONFIG = (10, 60)
WEAVIATE_STARTUP_PERIOD = 5


# pylint: disable=import-outside-toplevel
class VectorBase:
Expand Down Expand Up @@ -257,6 +260,31 @@ def get(name, **kwargs):
flush_interval_sec=flush_interval_sec,
index_params=index_params,
)
elif name == "weaviate":
from gptcache.manager.vector_data.weaviate import Weaviate

url = kwargs.get("url", None)
auth_client_secret = kwargs.get("auth_client_secret", None)
timeout_config = kwargs.get("timeout_config", WEAVIATE_TIMEOUT_CONFIG)
proxies = kwargs.get("proxies", None)
trust_env = kwargs.get("trust_env", False)
additional_headers = kwargs.get("additional_headers", None)
startup_period = kwargs.get("startup_period", WEAVIATE_STARTUP_PERIOD)
embedded_options = kwargs.get("embedded_options", None)
additional_config = kwargs.get("additional_config", None)

vector_base = Weaviate(
url=url,
auth_client_secret=auth_client_secret,
timeout_config=timeout_config,
proxies=proxies,
trust_env=trust_env,
additional_headers=additional_headers,
startup_period=startup_period,
embedded_options=embedded_options,
additional_config=additional_config,
top_k=top_k,
)
else:
raise NotFoundError("vector store", name)
return vector_base
182 changes: 182 additions & 0 deletions gptcache/manager/vector_data/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import List, Optional, Tuple, Union
import numpy as np

from gptcache.utils import import_weaviate
from gptcache.utils.log import gptcache_log
from gptcache.manager.vector_data.base import VectorBase, VectorData

import_weaviate()

from weaviate import Client
from weaviate.auth import AuthCredentials
from weaviate.config import Config
from weaviate.embedded import EmbeddedOptions
from weaviate.types import NUMBERS


class Weaviate(VectorBase):
"""
vector store: Weaviate
"""

TIMEOUT_TYPE = Union[Tuple[NUMBERS, NUMBERS], NUMBERS]

def __init__(
self,
url: Optional[str] = None,
auth_client_secret: Optional[AuthCredentials] = None,
timeout_config: TIMEOUT_TYPE = (10, 60),
proxies: Union[dict, str, None] = None,
trust_env: bool = False,
additional_headers: Optional[dict] = None,
startup_period: Optional[int] = 5,
embedded_options: Optional[EmbeddedOptions] = None,
additional_config: Optional[Config] = None,
top_k: Optional[int] = 1,
) -> None:

if url is None and embedded_options is None:
embedded_options = EmbeddedOptions()

self.client = Client(
url=url,
auth_client_secret=auth_client_secret,
timeout_config=timeout_config,
proxies=proxies,
trust_env=trust_env,
additional_headers=additional_headers,
startup_period=startup_period,
embedded_options=embedded_options,
additional_config=additional_config,
)

self._create_class()
self.top_k = top_k

def _create_class(self):
class_schema = self._get_default_class_schema()

self.class_name = class_schema.get("class")

if self.client.schema.exists(self.class_name):
gptcache_log.warning(
"The %s collection already exists, and it will be used directly.",
self.class_name,
)
else:
self.client.schema.create_class(class_schema)

@staticmethod
def _get_default_class_schema() -> dict:
return {
"class": "GPTCache",
"description": "LLM response cache",
"properties": [
{
"name": "data_id",
"dataType": ["int"],
"description": "The data-id generated by GPTCache for vectors.",
}
],
"vectorIndexConfig": {"distance": "cosine"},
}

def mul_add(self, datas: List[VectorData]):
with self.client.batch(batch_size=100, dynamic=True) as batch:
for data in datas:
properties = {
"data_id": data.id,
}

batch.add_data_object(
data_object=properties, class_name=self.class_name, vector=data.data
)

def search(self, data: np.ndarray, top_k: int = -1):
if top_k == -1:
top_k = self.top_k

result = (
self.client.query.get(class_name=self.class_name, properties=["data_id"])
.with_near_vector(content={"vector": data})
.with_additional(["distance"])
.with_limit(top_k)
.do()
)

return list(
map(
lambda x: (x["_additional"]["distance"], x["data_id"]),
result["data"]["Get"][self.class_name],
)
)

def _get_uuids(self, data_ids):
uuid_list = []

for data_id in data_ids:
res = (
self.client.query.get(
class_name=self.class_name, properties=["data_id"]
)
.with_where(
{"path": ["data_id"], "operator": "Equal", "valueInt": data_id}
)
.with_additional(["id"])
.do()
)

uuid_list.append(
res["data"]["Get"][self.class_name][0]["_additional"]["id"]
)

return uuid_list

def delete(self, ids):
uuids = self._get_uuids(ids)

for uuid in uuids:
self.client.data_object.delete(class_name=self.class_name, uuid=uuid)

def rebuild(self, ids=None):
return

def flush(self):
self.client.batch.flush()

def close(self):
self.flush()

def get_embeddings(self, data_id: int):
results = (
self.client.query.get(class_name=self.class_name, properties=["data_id"])
.with_where(
{
"path": ["data_id"],
"operator": "Equal",
"valueInt": data_id,
}
)
.with_additional(["vector"])
.with_limit(1)
.do()
)

results = results["data"]["Get"][self.class_name]

if len(results) < 1:
return None

vec_emb = np.asarray(results[0]["_additional"]["vector"], dtype="float32")
return vec_emb

def update_embeddings(self, data_id: int, emb: np.ndarray):
self.delete([data_id])

properties = {
"data_id": data_id,
}

self.client.data_object.create(
data_object=properties, class_name=self.class_name, vector=emb
)
5 changes: 5 additions & 0 deletions gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"import_fastapi",
"import_redis",
"import_qdrant",
"import_weaviate",
]

import importlib.util
Expand Down Expand Up @@ -262,3 +263,7 @@ def import_redis():

def import_starlette():
_check_library("starlette")


def import_weaviate():
_check_library("weaviate-client")
1 change: 1 addition & 0 deletions pylint.conf
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ disable=abstract-method,
zip-builtin-not-iterating,
missing-module-docstring,
super-init-not-called,
wrong-import-position


[REPORTS]
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/manager/test_weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest
import numpy as np

from gptcache.manager.vector_data import VectorBase
from gptcache.manager.vector_data.base import VectorData


class TestWeaviateDB(unittest.TestCase):
def test_normal(self):
size = 1000
dim = 512
top_k = 10

db = VectorBase(
"weaviate",
top_k=top_k
)

db._create_class()
data = np.random.randn(size, dim).astype(np.float32)
db.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
self.assertEqual(len(db.search(data[0])), top_k)
db.mul_add([VectorData(id=size, data=data[0])])
ret = db.search(data[0])
self.assertIn(ret[0][1], [0, size])
self.assertIn(ret[1][1], [0, size])
db.delete([0, 1, 2, 3, 4, 5, size])
ret = db.search(data[0])
self.assertNotIn(ret[0][1], [0, size])
db.rebuild()
db.update_embeddings(6, data[7])
emb = db.get_embeddings(6)
self.assertEqual(emb.tolist(), data[7].tolist())
emb = db.get_embeddings(0)
self.assertIsNone(emb)
db.close()

0 comments on commit 07db497

Please sign in to comment.