Skip to content

Commit

Permalink
fix: Fix normalize embedding, Mistral's client API, changes BM25 defa… (
Browse files Browse the repository at this point in the history
#31)

…ult worker num.

Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored Aug 16, 2024
1 parent 9fe20bf commit 66a3d9a
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 9 deletions.
8 changes: 5 additions & 3 deletions milvus_model/dense/instructor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import List, Optional
import struct
from collections import defaultdict
import numpy as np

from milvus_model.base import BaseEmbeddingFunction
Expand Down Expand Up @@ -36,7 +34,11 @@ def __call__(self, texts: List[str]) -> List[np.array]:

def _encode(self, texts: List[str]) -> List[np.array]:
embs = self.model.encode(
texts, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True,
texts,
batch_size=self.batch_size,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings
)
return list(embs)

Expand Down
2 changes: 1 addition & 1 deletion milvus_model/dense/mistralai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
else:
self.api_key = api_key
self.model_name = model_name
self.client = MistralClient(api_key=api_key)
self.client = Mistral(api_key=api_key)
self._encode_config = {"model": model_name, **kwargs}

def encode_queries(self, queries: List[str]) -> List[np.array]:
Expand Down
6 changes: 5 additions & 1 deletion milvus_model/dense/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def __call__(self, texts: List[str]) -> List[np.array]:

def _encode(self, texts: List[str]) -> List[np.array]:
embs = self.model.encode(
texts, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True,
texts,
batch_size=self.batch_size,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings
)
return list(embs)

Expand Down
6 changes: 2 additions & 4 deletions milvus_model/sparse/bm25/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import math
from collections import defaultdict
from multiprocessing import Pool, cpu_count
from multiprocessing import Pool
from pathlib import Path
from typing import Dict, List, Optional

Expand All @@ -44,7 +44,7 @@ def __init__(
k1: float = 1.5,
b: float = 0.75,
epsilon: float = 0.25,
num_workers: Optional[int] = None,
num_workers: int = 1,
):
if analyzer is None:
analyzer = build_default_analyzer(language="en")
Expand All @@ -55,8 +55,6 @@ def __init__(
self.k1 = k1
self.b = b
self.epsilon = epsilon
if num_workers is None:
self.num_workers = cpu_count()
self.num_workers = num_workers

if analyzer and corpus is not None:
Expand Down
1 change: 1 addition & 0 deletions milvus_model/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def import_sentence_transformers():
_check_library("sentence_transformers", package="sentence-transformers")

def import_FlagEmbedding():
_check_library("peft", package="peft")
_check_library("FlagEmbedding", package="FlagEmbedding>=1.2.2")

def import_nltk():
Expand Down

0 comments on commit 66a3d9a

Please sign in to comment.