Skip to content

Commit

Permalink
Support to import data set (#182)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG authored Apr 12, 2023
1 parent e67c2a2 commit e21ca1f
Show file tree
Hide file tree
Showing 17 changed files with 196 additions and 125 deletions.
23 changes: 4 additions & 19 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,8 @@ create_conda_env:
remove_conda_env:
@bash ./scripts/manage_conda_env.sh remove

docs_build:
cd docs && poetry run make html

docs_clean:
cd docs && poetry run make clean

docs_linkcheck:
poetry run linkchecker docs/_build/html/index.html

PYTHON_FILES=.
lint: PYTHON_FILES=.
lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | grep -E '\.py$$')

lint lint_diff:
poetry run mypy $(PYTHON_FILES)
poetry run black $(PYTHON_FILES) --check
poetry run ruff .

pylint_check:
pylint --rcfile=pylint.conf --output-format=colorized gptcache
pylint --rcfile=pylint.conf --output-format=colorized gptcache

pytest:
pytest tests/
8 changes: 2 additions & 6 deletions examples/benchmark/benchmark_sqlite_faiss_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,8 @@ def range(self):

if not has_data:
print('insert data')
id_origin = {}
for pair in mock_data:
question = pair['origin']
answer = pair['id']
id_origin[answer] = question
cache.data_manager.save(question, answer, cache.embedding_func(question))
questions, answers = zip(*((pair['origin'], pair['id']) for pair in mock_data))
cache.import_data(questions=questions, answers=answers)
print('end insert data')

all_time = 0.0
Expand Down
21 changes: 19 additions & 2 deletions gptcache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import atexit
import os
import time
from typing import List, Any, Optional

import openai

from gptcache.embedding.string import to_embeddings as string_embedding
Expand Down Expand Up @@ -55,7 +57,9 @@ def __init__(
similarity_threshold=0.8,
):
if similarity_threshold < 0 or similarity_threshold > 1:
raise CacheError("Invalid the similarity threshold param, reasonable range: 0-1")
raise CacheError(
"Invalid the similarity threshold param, reasonable range: 0-1"
)
self.log_time_func = log_time_func
self.similarity_threshold = similarity_threshold

Expand Down Expand Up @@ -134,7 +138,7 @@ def __init__(self):
self.cache_enable_func = None
self.pre_embedding_func = None
self.embedding_func = None
self.data_manager = None
self.data_manager: Optional[DataManager] = None
self.post_process_messages_func = None
self.config = Config()
self.report = Report()
Expand Down Expand Up @@ -179,6 +183,19 @@ def close():
except Exception as e: # pylint: disable=W0703
print(e)

def import_data(self, questions: List[Any], answers: List[Any]) -> None:
""" Import data to GPTCache
:param questions: preprocessed question Data
:param answers: list of answers to questions
:return: None
"""
self.data_manager.import_data(
questions=questions,
answers=answers,
embedding_datas=[self.embedding_func(question) for question in questions],
)

@staticmethod
def set_openai_key():
openai.api_key = os.getenv("OPENAI_API_KEY")
Expand Down
64 changes: 53 additions & 11 deletions gptcache/manager/data_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import abstractmethod, ABCMeta
import pickle
from typing import List, Any

import cachetools
import numpy as np

from gptcache.utils.error import CacheError
from gptcache.manager.scalar_data.base import CacheStorage
from gptcache.manager.vector_data.base import VectorBase, ClearStrategy
from gptcache.utils.error import CacheError, NotFoundStrategyError, ParamError
from gptcache.manager.scalar_data.base import CacheStorage, CacheData
from gptcache.manager.vector_data.base import VectorBase, ClearStrategy, VectorData
from gptcache.manager.eviction import EvictionManager


Expand All @@ -16,6 +18,12 @@ class DataManager(metaclass=ABCMeta):
def save(self, question, answer, embedding_data, **kwargs):
pass

@abstractmethod
def import_data(
self, questions: List[Any], answers: List[Any], embedding_datas: List[Any]
):
pass

# should return the tuple, (question, answer)
@abstractmethod
def get_scalar_data(self, res_data, **kwargs):
Expand Down Expand Up @@ -49,12 +57,20 @@ def init(self):
return
except PermissionError:
raise CacheError( # pylint: disable=W0707
f"You don't have permission to access this file <${self.data_path}>."
f"You don't have permission to access this file <{self.data_path}>."
)

def save(self, question, answer, embedding_data, **kwargs):
self.data[embedding_data] = (question, answer)

def import_data(
self, questions: List[Any], answers: List[Any], embedding_datas: List[Any]
):
if len(questions) != len(answers) or len(questions) != len(embedding_datas):
raise ParamError("Make sure that all parameters have the same length")
for i, embedding_data in enumerate(embedding_datas):
self.data[embedding_data] = (questions[i], answers[i])

def get_scalar_data(self, res_data, **kwargs):
return res_data

Expand Down Expand Up @@ -139,13 +155,39 @@ def save(self, question, answer, embedding_data, **kwargs):

if self.cur_size >= self.max_size:
self._clear()
embedding_data = normalize(embedding_data)
if self.v.clear_strategy() == ClearStrategy.DELETE:
key = self.s.insert(question, answer)
elif self.v.clear_strategy() == ClearStrategy.REBUILD:
key = self.s.insert(question, answer, embedding_data.astype("float32"))
self.v.add(key, embedding_data)
self.cur_size += 1

self.import_data([question], [answer], [embedding_data])

def import_data(
self, questions: List[Any], answers: List[Any], embedding_datas: List[Any]
):
if len(questions) != len(answers) or len(questions) != len(embedding_datas):
raise ParamError("Make sure that all parameters have the same length")
cache_datas = []
embedding_datas = [
normalize(embedding_data) for embedding_data in embedding_datas
]
for i, embedding_data in enumerate(embedding_datas):
if self.v.clear_strategy() == ClearStrategy.DELETE:
cache_datas.append(CacheData(question=questions[i], answer=answers[i]))
elif self.v.clear_strategy() == ClearStrategy.REBUILD:
cache_datas.append(
CacheData(
question=questions[i],
answer=answers[i],
embedding_data=embedding_data.astype("float32"),
)
)
else:
raise NotFoundStrategyError(self.v.clear_strategy())
ids = self.s.batch_insert(cache_datas)
self.v.mul_add(
[
VectorData(id=ids[i], data=embedding_data)
for i, embedding_data in enumerate(embedding_datas)
]
)
self.cur_size += len(questions)

def get_scalar_data(self, res_data, **kwargs):
return self.s.get_data_by_id(res_data[1])
Expand Down
11 changes: 10 additions & 1 deletion gptcache/manager/scalar_data/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Optional, Any, List

import numpy as np


@dataclass
class CacheData:
question: Any
answer: Any
embedding_data: Optional[np.ndarray] = None


class CacheStorage(metaclass=ABCMeta):
"""
BaseStorage for scalar data.
Expand All @@ -13,7 +22,7 @@ def create(self):
pass

@abstractmethod
def insert(self, data, reply, embedding_data: np.ndarray = None):
def batch_insert(self, datas: List[CacheData]):
pass

@abstractmethod
Expand Down
45 changes: 26 additions & 19 deletions gptcache/manager/scalar_data/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
from typing import List

from datetime import datetime

from gptcache.utils import import_sqlalchemy
from gptcache.manager.scalar_data.base import CacheStorage
from gptcache.manager.scalar_data.base import CacheStorage, CacheData

import_sqlalchemy()

Expand All @@ -29,8 +30,8 @@ class CacheTable(Base):
__table_args__ = {"extend_existing": True}

id = Column(Integer, primary_key=True, autoincrement=True)
data = Column(String(1000), nullable=False)
reply = Column(String(1000), nullable=False)
question = Column(String(1000), nullable=False)
answer = Column(String(1000), nullable=False)
create_on = Column(DateTime, default=datetime.now)
last_access = Column(DateTime, default=datetime.now)
embedding_data = Column(LargeBinary, nullable=True)
Expand All @@ -48,8 +49,8 @@ class CacheTableSequence(Base):
id = Column(
Integer, Sequence("id_seq", start=1), primary_key=True, autoincrement=True
)
data = Column(String(1000), nullable=False)
reply = Column(String(1000), nullable=False)
question = Column(String(1000), nullable=False)
answer = Column(String(1000), nullable=False)
create_on = Column(DateTime, default=datetime.now)
last_access = Column(DateTime, default=datetime.now)
embedding_data = Column(LargeBinary, nullable=True)
Expand All @@ -68,10 +69,10 @@ class SQLDataBase(CacheStorage):
"""

def __init__(
self,
db_type: str = "sqlite",
url: str = "sqlite:///./sqlite.db",
table_name: str = "gptcache",
self,
db_type: str = "sqlite",
url: str = "sqlite:///./sqlite.db",
table_name: str = "gptcache",
):
self._url = url
self._model = get_model(table_name, db_type)
Expand All @@ -83,19 +84,25 @@ def __init__(
def create(self):
self._model.__table__.create(bind=self._engine, checkfirst=True)

def insert(self, data, reply, embedding_data: np.ndarray = None):
if embedding_data is None:
model_obj = self._model(data=data, reply=reply)
else:
embedding_data = embedding_data.tobytes()
model_obj = self._model(data=data, reply=reply, embedding_data=embedding_data)
self._session.add(model_obj)
def batch_insert(self, datas: List[CacheData]):
model_objs = []
for data in datas:
model_obj = self._model(
question=data.question,
answer=data.answer,
embedding_data=data.embedding_data.tobytes()
if data.embedding_data is not None
else None,
)
model_objs.append(model_obj)

self._session.add_all(model_objs)
self._session.commit()
return model_obj.id
return [model_obj.id for model_obj in model_objs]

def get_data_by_id(self, key):
res = (
self._session.query(self._model.data, self._model.reply)
self._session.query(self._model.question, self._model.answer)
.filter(self._model.id == key)
.filter(self._model.state == 0)
.first()
Expand Down
16 changes: 13 additions & 3 deletions gptcache/manager/vector_data/base.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import List

import numpy as np


class ClearStrategy(Enum):
REBUILD = 0
DELETE = 1


@dataclass
class VectorData:
id: int
data: np.ndarray


class VectorBase(ABC):
"""VectorBase: base vector store interface"""

@abstractmethod
def add(self, key: str, data: "ndarray"):
def mul_add(self, datas: List[VectorData]):
pass

@abstractmethod
def search(self, data: "ndarray"):
def search(self, data: np.ndarray):
pass

@abstractmethod
def clear_strategy(self):
pass

def rebuild(self) -> bool:
def rebuild(self, all_data, keys) -> bool:
raise NotImplementedError

def delete(self, ids) -> bool:
Expand Down
9 changes: 6 additions & 3 deletions gptcache/manager/vector_data/chroma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from gptcache.manager.vector_data.base import VectorBase, ClearStrategy
from typing import List

from gptcache.manager.vector_data.base import VectorBase, ClearStrategy, VectorData
from gptcache.utils import import_chromadb

import_chromadb()
Expand Down Expand Up @@ -30,8 +32,9 @@ def __init__(
self._persist_directory = persist_directory
self._collection = self._client.get_or_create_collection(name=collection_name)

def add(self, key, data):
self._collection.add(embeddings=[data], ids=[key])
def mul_add(self, datas: List[VectorData]):
data_array, id_array = map(list, zip(*((data.data, str(data.id)) for data in datas)))
self._collection.add(embeddings=data_array, ids=id_array)

def search(self, data):
if self._collection.count() == 0:
Expand Down
Loading

0 comments on commit e21ca1f

Please sign in to comment.