From 66a3d9a25aa6c2017ab654e31df4bbf45b1d745a Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 16 Aug 2024 18:54:52 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20Fix=20normalize=20embedding,=20Mistral's?= =?UTF-8?q?=20client=20API,=20changes=20BM25=20defa=E2=80=A6=20(#31)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ult worker num. Signed-off-by: wxywb --- milvus_model/dense/instructor.py | 8 +++++--- milvus_model/dense/mistralai.py | 2 +- milvus_model/dense/sentence_transformer.py | 6 +++++- milvus_model/sparse/bm25/bm25.py | 6 ++---- milvus_model/utils/__init__.py | 1 + 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/milvus_model/dense/instructor.py b/milvus_model/dense/instructor.py index 6aebf19..ec99f78 100644 --- a/milvus_model/dense/instructor.py +++ b/milvus_model/dense/instructor.py @@ -1,6 +1,4 @@ from typing import List, Optional -import struct -from collections import defaultdict import numpy as np from milvus_model.base import BaseEmbeddingFunction @@ -36,7 +34,11 @@ def __call__(self, texts: List[str]) -> List[np.array]: def _encode(self, texts: List[str]) -> List[np.array]: embs = self.model.encode( - texts, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True, + texts, + batch_size=self.batch_size, + show_progress_bar=False, + convert_to_numpy=True, + normalize_embeddings=self.normalize_embeddings ) return list(embs) diff --git a/milvus_model/dense/mistralai.py b/milvus_model/dense/mistralai.py index d438e90..19e1689 100644 --- a/milvus_model/dense/mistralai.py +++ b/milvus_model/dense/mistralai.py @@ -32,7 +32,7 @@ def __init__( else: self.api_key = api_key self.model_name = model_name - self.client = MistralClient(api_key=api_key) + self.client = Mistral(api_key=api_key) self._encode_config = {"model": model_name, **kwargs} def encode_queries(self, queries: List[str]) -> List[np.array]: diff --git a/milvus_model/dense/sentence_transformer.py b/milvus_model/dense/sentence_transformer.py index 08cbb69..b7c8ba0 100644 --- a/milvus_model/dense/sentence_transformer.py +++ b/milvus_model/dense/sentence_transformer.py @@ -33,7 +33,11 @@ def __call__(self, texts: List[str]) -> List[np.array]: def _encode(self, texts: List[str]) -> List[np.array]: embs = self.model.encode( - texts, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True, + texts, + batch_size=self.batch_size, + show_progress_bar=False, + convert_to_numpy=True, + normalize_embeddings=self.normalize_embeddings ) return list(embs) diff --git a/milvus_model/sparse/bm25/bm25.py b/milvus_model/sparse/bm25/bm25.py index f073658..0b0a2ea 100644 --- a/milvus_model/sparse/bm25/bm25.py +++ b/milvus_model/sparse/bm25/bm25.py @@ -18,7 +18,7 @@ import logging import math from collections import defaultdict -from multiprocessing import Pool, cpu_count +from multiprocessing import Pool from pathlib import Path from typing import Dict, List, Optional @@ -44,7 +44,7 @@ def __init__( k1: float = 1.5, b: float = 0.75, epsilon: float = 0.25, - num_workers: Optional[int] = None, + num_workers: int = 1, ): if analyzer is None: analyzer = build_default_analyzer(language="en") @@ -55,8 +55,6 @@ def __init__( self.k1 = k1 self.b = b self.epsilon = epsilon - if num_workers is None: - self.num_workers = cpu_count() self.num_workers = num_workers if analyzer and corpus is not None: diff --git a/milvus_model/utils/__init__.py b/milvus_model/utils/__init__.py index ff6dde8..6587f9d 100644 --- a/milvus_model/utils/__init__.py +++ b/milvus_model/utils/__init__.py @@ -31,6 +31,7 @@ def import_sentence_transformers(): _check_library("sentence_transformers", package="sentence-transformers") def import_FlagEmbedding(): + _check_library("peft", package="peft") _check_library("FlagEmbedding", package="FlagEmbedding>=1.2.2") def import_nltk():