Skip to content

Commit

Permalink
Fix: Correct parameter name in _encode method for BGEM3FlagModel (#45)
Browse files Browse the repository at this point in the history
# Fix: BGEM3EmbeddingFunction encode() parameter mismatch

## Issue
When using the `BGEM3EmbeddingFunction` class, users encounter the
following error:
```
M3Embedder.encode() missing 1 required positional argument: 'queries'
```

This occurs because the `_encode()` method is using `sentences` as the
parameter name when calling `self.model.encode()`, but the underlying
`BGEM3FlagModel` expects the parameter to be named `queries`.

## Solution
Changed the parameter name in the `_encode()` method from `sentences` to
`queries` to match the expected parameter name of the underlying model:

```python
# Before
output = self.model.encode(sentences=texts, **self._encode_config)

# After
output = self.model.encode(queries=texts, **self._encode_config)
```

## Testing
- Tested the fix by creating an instance of `BGEM3EmbeddingFunction` and
successfully encoding both documents and queries
- Verified that all return types (dense, sparse, colbert_vecs) work as
expected
- Confirmed the error no longer occurs
- Created a demonstration notebook showing the working implementation:
[Google Colab
Notebook](https://colab.research.google.com/drive/198XBcwA5RWCjvUsqOvTPku2Qk73XWcl7?usp=sharing)

## Additional Context
This fix aligns with the
[FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding)
implementation which uses `queries` as the parameter name in its encode
method.

## Related Issues
Resolves the error reported in various contexts where users attempt to
use the BGE-M3 model with Milvus.

Co-authored-by: flukethitipong <[email protected]>
  • Loading branch information
thititj and flukethitipong authored Nov 19, 2024
1 parent 0e05e5e commit f35dce6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions milvus_model/hybrid/bge_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def dim(self) -> Dict:
}

def _encode(self, texts: List[str]) -> Dict:
output = self.model.encode(sentences=texts, **self._encode_config)
# Change 'sentences' to 'queries' to match the expected parameter
output = self.model.encode(queries=texts, **self._encode_config)
results = {}
if self._encode_config["return_dense"] is True:
results["dense"] = list(output["dense_vecs"])
Expand All @@ -93,11 +94,12 @@ def _encode(self, texts: List[str]) -> Dict:
row_indices = [0] * len(indices)
csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim))
results["sparse"].append(csr)
results["sparse"] = stack_sparse_embeddings(results["sparse"]).tocsr()
results["sparse"] = stack_sparse_embeddings(results["sparse"]).tocsr()
if self._encode_config["return_colbert_vecs"] is True:
results["colbert_vecs"] = output["colbert_vecs"]
return results


def encode_queries(self, queries: List[str]) -> Dict:
return self._encode(queries)

Expand Down

0 comments on commit f35dce6

Please sign in to comment.