Skip to content

Commit

Permalink
Rename RerankerFunction to RerankFunction and optimize splade code. (#10
Browse files Browse the repository at this point in the history
)

Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored Apr 1, 2024
1 parent 56a379b commit 4446a2b
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 157 deletions.
3 changes: 2 additions & 1 deletion _version_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def custom_version(version: ScmVersion) -> str:
# We're in a release/maintenance branch, next is a patch/rc/beta bump:
return version.format_next_version(guess_next_version, fmt=fmt)
# We're in a development branch, next is a minor bump:
return version.format_next_version(guess_next_simple_semver, retain=SEMVER_MINOR, fmt=fmt)
#return version.format_next_version(guess_next_simple_semver, retain=SEMVER_MINOR, fmt=fmt)
return version.format_next_version(guess_next_version, fmt=fmt)


def scm_version() -> str:
Expand Down
2 changes: 1 addition & 1 deletion milvus_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def encode_queries(self, queries: List[str]):
""" """


class BaseRerankerFunction:
class BaseRerankFunction:
@abstractmethod
def __call__(self, query: str, documents: List[str], top_k: int):
""" """
Expand Down
16 changes: 8 additions & 8 deletions milvus_model/reranker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .cohere import CohereRerankerFunction
from .cross_encoder import CrossEncoderRerankerFunction
from .flagreranker import FlagRerankerFunction
from .voyageai import VoyageRerankerFunction
from .bgereranker import BGERerankFunction
from .cohere import CohereRerankFunction
from .cross_encoder import CrossEncoderRerankFunction
from .voyageai import VoyageRerankFunction

__all__ = [
"CohereRerankerFunction",
"FlagRerankerFunction",
"VoyageRerankerFunction",
"CrossEncoderRerankerFunction",
"CohereRerankFunction",
"BGERerankFunction",
"VoyageRerankFunction",
"CrossEncoderRerankFunction",
]
4 changes: 2 additions & 2 deletions milvus_model/reranker/cohere.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import List, Optional

from milvus_model.base import BaseRerankerFunction, RerankResult
from milvus_model.base import BaseRerankFunction, RerankResult

try:
import cohere
except ImportError:
cohere = None


class CohereRerankerFunction(BaseRerankerFunction):
class CohereRerankFunction(BaseRerankFunction):
def __init__(self, model_name: str = "rerank-english-v2.0", api_key: Optional[str] = None):
if cohere is None:
error_message = "cohere is not installed."
Expand Down
4 changes: 2 additions & 2 deletions milvus_model/reranker/cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any, List

from milvus_model.base import BaseRerankerFunction, RerankResult
from milvus_model.base import BaseRerankFunction, RerankResult

try:
import sentence_transformers
except ImportError:
sentence_transformers = None


class CrossEncoderRerankerFunction(BaseRerankerFunction):
class CrossEncoderRerankFunction(BaseRerankFunction):
def __init__(
self,
model_name: str = "",
Expand Down
110 changes: 0 additions & 110 deletions milvus_model/reranker/flagreranker.py

This file was deleted.

4 changes: 2 additions & 2 deletions milvus_model/reranker/voyageai.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import List, Optional

from milvus_model.base import BaseRerankerFunction, RerankResult
from milvus_model.base import BaseRerankFunction, RerankResult

try:
import voyageai
except ImportError:
voyageai = None


class VoyageRerankerFunction(BaseRerankerFunction):
class VoyageRerankFunction(BaseRerankFunction):
def __init__(self, model_name: str = "rerank-lite-1", api_key: Optional[str] = None):
if voyageai is None:
error_message = "voyageai is not installed."
Expand Down
58 changes: 28 additions & 30 deletions milvus_model/sparse/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import logging
from typing import Dict, List, Optional

import numpy as np
import torch
from scipy.sparse import csr_array, vstack

Expand Down Expand Up @@ -145,17 +144,14 @@ def forward(self, texts: List[str], k_tokens: int) -> csr_array:
logits = self._encode(texts=batch_texts)
activations = self._get_activation(logits=logits)
if k_tokens is None:
nonzero_indices = [
torch.nonzero(activations["sparse_activations"][i]).t()[0]
for i in range(len(batch_texts))
]
nonzero_indices = torch.nonzero(activations["sparse_activations"])
activations["activations"] = nonzero_indices
else:
activations = self._update_activations(**activations, k_tokens=k_tokens)
batch_csr = self._convert_to_csr_array(activations)
sparse_embs.extend(batch_csr)

return vstack(sparse_embs)
return vstack(sparse_embs).tocsr()

def _get_activation(self, logits: torch.Tensor) -> Dict[str, torch.Tensor]:
return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)}
Expand All @@ -170,6 +166,16 @@ def _update_activations(self, sparse_activations: torch.Tensor, k_tokens: int) -
device=self.device,
).scatter_(dim=1, index=activations.long(), value=1)

activations = torch.cat(
(
torch.arange(activations.shape[0], device=activations.device)
.repeat_interleave(activations.shape[1])
.reshape(-1, 1),
activations.reshape((-1, 1)),
),
dim=1,
)

return {
"activations": activations,
"sparse_activations": sparse_activations,
Expand All @@ -182,27 +188,19 @@ def _filter_activations(
return activations

def _convert_to_csr_array(self, activations: Dict):
csr_array_list = []

if activations["sparse_activations"].shape[0] != len(activations["activations"]):
error_msg = (
"The shape of 'sparse_activations' does not match the length of 'activations'"
)
raise ValueError(error_msg)

for i, column_indices in enumerate(activations["activations"]):
values = (
torch.gather(activations["sparse_activations"][i], 0, column_indices)
.cpu()
.detach()
.numpy()
)
row_indices = np.zeros(len(activations["activations"][i]))
col_indices = activations["activations"][i].cpu().detach().numpy()
csr_array_list.append(
csr_array(
(values.flatten(), (row_indices, col_indices)),
shape=(1, activations["sparse_activations"].shape[1]),
)
)
return csr_array_list

values = (
activations["sparse_activations"][
activations["activations"][:, 0], activations["activations"][:, 1]
]
.cpu()
.detach()
.numpy()
)

row_indices = activations["activations"][:, 0].cpu().detach().numpy()
col_indices = activations["activations"][:, 1].cpu().detach().numpy()
return csr_array(
(values.flatten(), (row_indices, col_indices)),
shape=activations["sparse_activations"].shape,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"sentence-transformers",
"FlagEmbedding >= 1.2.2",
"nltk",
"transformers >= 4.33.0",
"transformers >= 4.36.0",
"jieba",
"konlpy",
"mecab-python3",
Expand Down

0 comments on commit 4446a2b

Please sign in to comment.