From f35dce6dbbb42510f5e323da60809ae18b46ad67 Mon Sep 17 00:00:00 2001 From: Thitipong Jantawee <97588968+thititj@users.noreply.github.com> Date: Tue, 19 Nov 2024 13:46:37 +0700 Subject: [PATCH] Fix: Correct parameter name in _encode method for BGEM3FlagModel (#45) # Fix: BGEM3EmbeddingFunction encode() parameter mismatch ## Issue When using the `BGEM3EmbeddingFunction` class, users encounter the following error: ``` M3Embedder.encode() missing 1 required positional argument: 'queries' ``` This occurs because the `_encode()` method is using `sentences` as the parameter name when calling `self.model.encode()`, but the underlying `BGEM3FlagModel` expects the parameter to be named `queries`. ## Solution Changed the parameter name in the `_encode()` method from `sentences` to `queries` to match the expected parameter name of the underlying model: ```python # Before output = self.model.encode(sentences=texts, **self._encode_config) # After output = self.model.encode(queries=texts, **self._encode_config) ``` ## Testing - Tested the fix by creating an instance of `BGEM3EmbeddingFunction` and successfully encoding both documents and queries - Verified that all return types (dense, sparse, colbert_vecs) work as expected - Confirmed the error no longer occurs - Created a demonstration notebook showing the working implementation: [Google Colab Notebook](https://colab.research.google.com/drive/198XBcwA5RWCjvUsqOvTPku2Qk73XWcl7?usp=sharing) ## Additional Context This fix aligns with the [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) implementation which uses `queries` as the parameter name in its encode method. ## Related Issues Resolves the error reported in various contexts where users attempt to use the BGE-M3 model with Milvus. Co-authored-by: flukethitipong --- milvus_model/hybrid/bge_m3.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/milvus_model/hybrid/bge_m3.py b/milvus_model/hybrid/bge_m3.py index 6d4f6d1..b59b688 100644 --- a/milvus_model/hybrid/bge_m3.py +++ b/milvus_model/hybrid/bge_m3.py @@ -80,7 +80,8 @@ def dim(self) -> Dict: } def _encode(self, texts: List[str]) -> Dict: - output = self.model.encode(sentences=texts, **self._encode_config) + # Change 'sentences' to 'queries' to match the expected parameter + output = self.model.encode(queries=texts, **self._encode_config) results = {} if self._encode_config["return_dense"] is True: results["dense"] = list(output["dense_vecs"]) @@ -93,11 +94,12 @@ def _encode(self, texts: List[str]) -> Dict: row_indices = [0] * len(indices) csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim)) results["sparse"].append(csr) - results["sparse"] = stack_sparse_embeddings(results["sparse"]).tocsr() + results["sparse"] = stack_sparse_embeddings(results["sparse"]).tocsr() if self._encode_config["return_colbert_vecs"] is True: results["colbert_vecs"] = output["colbert_vecs"] return results + def encode_queries(self, queries: List[str]) -> Dict: return self._encode(queries)