Skip to content

Commit

Permalink
feat: Add JinaAI embedding and reranker models (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored May 9, 2024
1 parent 8202e8f commit f7d43b2
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 0 deletions.
2 changes: 2 additions & 0 deletions milvus_model/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .openai import OpenAIEmbeddingFunction
from .sentence_transformer import SentenceTransformerEmbeddingFunction
from .voyageai import VoyageEmbeddingFunction
from .jinaai import JinaEmbeddingFunction

__all__ = [
"OpenAIEmbeddingFunction",
"SentenceTransformerEmbeddingFunction",
"VoyageEmbeddingFunction",
"JinaEmbeddingFunction",
]
61 changes: 61 additions & 0 deletions milvus_model/dense/jinaai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import requests
from typing import List, Optional

import numpy as np

from milvus_model.base import BaseEmbeddingFunction

API_URL = "https://api.jina.ai/v1/embeddings"


class JinaEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self, model_name: str = "jina-embeddings-v2-base-en", api_key: Optional[str] = None, **kwargs):
if api_key is None:
if 'JINAAI_API_KEY' in os.environ and os.environ['JINAAI_API_KEY']:
self.api_key = os.environ['JINAAI_API_KEY']
else:
raise ValueError(
f"Did not find api_key, please add an environment variable"
f" `JINAAI_API_KEY` which contains it, or pass"
f" `api_key` as a named parameter."
)
else:
self.api_key = api_key
self.model_name = model_name
self._session = requests.Session()
self._session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)
self.model_name = model_name
self._dim = None

@property
def dim(self):
if self._dim is None:
self._dim = self._call_jina_api([""])[0].shape[0]
return self._dim

def encode_queries(self, queries: List[str]) -> List[np.array]:
return self._call_jina_api(queries)

def encode_documents(self, documents: List[str]) -> List[np.array]:
return self._call_jina_api(documents)

def __call__(self, texts: List[str]) -> List[np.array]:
return self._call_jina_api(texts)

def _call_jina_api(self, texts: List[str]):
resp = self._session.post( # type: ignore
API_URL,
json={"input": texts, "model": self.model_name},
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])

embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore

return [np.array(result["embedding"]) for result in sorted_embeddings]
2 changes: 2 additions & 0 deletions milvus_model/reranker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .cohere import CohereRerankFunction
from .cross_encoder import CrossEncoderRerankFunction
from .voyageai import VoyageRerankFunction
from .jinaai import JinaRerankFunction

__all__ = [
"CohereRerankFunction",
"BGERerankFunction",
"VoyageRerankFunction",
"CrossEncoderRerankFunction",
"JinaRerankFunction",
]
50 changes: 50 additions & 0 deletions milvus_model/reranker/jinaai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import requests
from typing import List, Optional

from milvus_model.base import BaseRerankFunction, RerankResult

API_URL = "https://api.jina.ai/v1/rerank"


class JinaRerankFunction(BaseRerankFunction):
def __init__(self, model_name: str = "jina-reranker-v1-base-en", api_key: Optional[str] = None):
if api_key is None:
if 'JINAAI_API_KEY' in os.environ and os.environ['JINAAI_API_KEY']:
self.api_key = os.environ['JINAAI_API_KEY']
else:
raise ValueError(
f"Did not find api_key, please add an environment variable"
f" `JINAAI_API_KEY` which contains it, or pass"
f" `api_key` as a named parameter."
)
else:
self.api_key = api_key
self.model_name = model_name
self._session = requests.Session()
self._session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)
self.model_name = model_name

def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
resp = self._session.post( # type: ignore
API_URL,
json={
"query": query,
"documents": documents,
"model": self.model_name,
"top_n": top_k,
},
).json()
if "results" not in resp:
raise RuntimeError(resp["detail"])

results = []
for res in resp["results"]:
results.append(
RerankResult(
text=res['document']['text'], score=res['relevance_score'], index=res['index']
)
)
return results

0 comments on commit f7d43b2

Please sign in to comment.