Skip to content

Commit

Permalink
feat: Add CohereEmbeddingFunction. (#23)
Browse files Browse the repository at this point in the history
[fix: Normalize the output of
OnnxEmbeddingFunction.](627abb9)
[feat: Add
CohereEmbeddingFunction.](47f1505)

---------

Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored May 27, 2024
1 parent 3ee078f commit 46e8fbb
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 3 deletions.
5 changes: 5 additions & 0 deletions milvus_model/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"VoyageEmbeddingFunction",
"JinaEmbeddingFunction",
"OnnxEmbeddingFunction",
"CohereEmbeddingFunction"
]

from milvus_model.utils.lazy_import import LazyImport
Expand All @@ -13,6 +14,7 @@
sentence_transformer = LazyImport("sentence_transformer", globals(), "milvus_model.dense.sentence_transformer")
voyageai = LazyImport("voyageai", globals(), "milvus_model.dense.voyageai")
onnx = LazyImport("onnx", globals(), "milvus_model.dense.onnx")
cohere = LazyImport("cohere", globals(), "milvus_model.dense.cohere")

def JinaEmbeddingFunction(*args, **kwargs):
return jinaai.JinaEmbeddingFunction(*args, **kwargs)
Expand All @@ -28,3 +30,6 @@ def VoyageEmbeddingFunction(*args, **kwargs):

def OnnxEmbeddingFunction(*args, **kwargs):
return onnx.OnnxEmbeddingFunction(*args, **kwargs)

def CohereEmbeddingFunction(*args, **kwargs):
return cohere.CohereEmbeddingFunction(*args, **kwargs)
77 changes: 77 additions & 0 deletions milvus_model/dense/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import List, Optional
import struct
from collections import defaultdict
import numpy as np

from milvus_model.base import BaseEmbeddingFunction
from milvus_model.utils import import_cohere

import_cohere()
import cohere

class CohereEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self,
model_name: str = "embed-english-light-v3.0",
api_key: Optional[str] = None,
input_type: str = "search_document",
embedding_types: Optional[List[str]] = None,
truncate: Optional[str] = None,
**kwargs):
self.model_name = model_name
self.input_type = input_type
self.embedding_types = embedding_types
self.truncate = truncate

if isinstance(embedding_types, list):
if len(embedding_types) > 1:
raise ValueError("Only one embedding type can be specified using current PyMilvus model library.")
elif embedding_types[0] == "int8" or embedding_types[0] == "uint8":
raise ValueError("Currently int8 or uint8 is not supported with PyMilvus model library.")
else:
pass

self.client = cohere.Client(api_key, **kwargs)
self._cohereai_model_meta_info = defaultdict(dict)
self._cohereai_model_meta_info["embed-english-v3.0"]["dim"] = 1024
self._cohereai_model_meta_info["embed-english-light-v3.0"]["dim"] = 384
self._cohereai_model_meta_info["embed-english-v2.0"]["dim"] = 4096
self._cohereai_model_meta_info["embed-english-light-v2.0"]["dim"] = 1024
self._cohereai_model_meta_info["embed-multilingual-v3.0"]["dim"] = 1024
self._cohereai_model_meta_info["embed-multilingual-light-v3.0"]["dim"] = 384
self._cohereai_model_meta_info["embed-multilingual-v2.0"]["dim"] = 768

def _call_cohere_api(self, texts: List[str], input_type: str) -> List[np.array]:
embeddings = self.client.embed(
texts=texts,
model=self.model_name,
input_type=input_type,
embedding_types=self.embedding_types,
truncate=self.truncate
).embeddings
if self.embedding_types is None:
results = [np.array(data) for data in embeddings]
else:
results = getattr(embeddings, self.embedding_types[0], None)
if self.embedding_types[0] == "binary":
results = [struct.pack('b' * len(int8_vector), *int8_vector) for int8_vector in results]
elif self.embedding_types[0] == "ubinary":
results = [struct.pack('B' * len(uint8_vector), *uint8_vector) for uint8_vector in results]
elif self.embedding_types[0] == "float":
results = [np.array(result, dtype=np.float32) for result in results]
else:
pass
return results

def encode_documents(self, documents: List[str]) -> List[np.array]:
return self._call_cohere_api(documents, input_type="search_document")

def encode_queries(self, queries: List[str]) -> List[np.array]:
return self._call_cohere_api(queries, input_type="search_query")

def __call__(self, texts: List[str]) -> List[np.array]:
return self._call_cohere_api(texts, self.input_type)

@property
def dim(self):
return self._cohereai_model_meta_info[self.model_name]["dim"]

5 changes: 3 additions & 2 deletions milvus_model/sparse/bm25/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import requests
from scipy.sparse import csr_array, vstack
import numpy as np

from milvus_model.base import BaseEmbeddingFunction
from milvus_model.sparse.bm25.tokenizers import Analyzer, build_default_analyzer
Expand Down Expand Up @@ -134,7 +135,7 @@ def _encode_query(self, query: str) -> csr_array:
values.append(self.idf[term][0])
rows.append(0)
cols.append(self.idf[term][1])
return csr_array((values, (rows, cols)), shape=(1, len(self.idf)))
return csr_array((values, (rows, cols)), shape=(1, len(self.idf))).astype(np.float32)

def _encode_document(self, doc: str) -> csr_array:
terms = self.analyzer(doc)
Expand All @@ -156,7 +157,7 @@ def _encode_document(self, doc: str) -> csr_array:
rows.append(0)
cols.append(self.idf[term][1])
values.append(value)
return csr_array((values, (rows, cols)), shape=(1, len(self.idf)))
return csr_array((values, (rows, cols)), shape=(1, len(self.idf))).astype(np.float32)

def encode_queries(self, queries: List[str]) -> csr_array:
sparse_embs = [self._encode_query(query) for query in queries]
Expand Down
5 changes: 5 additions & 0 deletions milvus_model/sparse/bm25/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def apply(self, text: str):

@register_class("StandardTokenizer")
class StandardTokenizer:
def __init__(self):
try:
word_tokenize("this is a simple test.")
except LookupError:
nltk.download("punkt")
def tokenize(self, text: str):
return word_tokenize(text)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"transformers >= 4.36.0",
"onnxruntime",
"scipy >= 1.10.0",
"protobuf==3.20.2",
"protobuf",
"numpy"
]

Expand Down

0 comments on commit 46e8fbb

Please sign in to comment.