From 87bfae15bafafe0a83b517cc2b29980ca2a8765c Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Wed, 18 Dec 2024 17:58:08 +0100 Subject: [PATCH] wip: design draft --- fastembed/image/image_embedding.py | 3 +- .../late_interaction_multimodal/colpali.py | 266 ++++++++++++++++++ .../late_interaction_multimodal_embedding.py | 123 ++++++++ ...e_interaction_multimodal_embedding_base.py | 65 +++++ .../onnx_multimodal_model.py | 237 ++++++++++++++++ 5 files changed, 692 insertions(+), 2 deletions(-) create mode 100644 fastembed/late_interaction_multimodal/colpali.py create mode 100644 fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py create mode 100644 fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py create mode 100644 fastembed/late_interaction_multimodal/onnx_multimodal_model.py diff --git a/fastembed/image/image_embedding.py b/fastembed/image/image_embedding.py index aa4c91b4..23d39a3e 100644 --- a/fastembed/image/image_embedding.py +++ b/fastembed/image/image_embedding.py @@ -80,8 +80,7 @@ def embed( **kwargs, ) -> Iterable[np.ndarray]: """ - Encode a list of documents into list of embeddings. - We use mean pooling with attention so that the model can handle variable-length inputs. + Encode a list of images into list of embeddings. Args: images: Iterator of image paths or single image path to embed diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py new file mode 100644 index 00000000..d3508194 --- /dev/null +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -0,0 +1,266 @@ +from typing import Any, Iterable, Optional, Sequence, Type, Union + +import numpy as np +from tokenizers import Encoding + +from fastembed.common import OnnxProvider, ImageInput +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import define_cache_dir +from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( + LateInteractionMultimodalEmbeddingBase, +) +from fastembed.late_interaction_multimodal.onnx_multimodal_model import ( + OnnxMultimodalModel, + TextEmbeddingWorker, + ImageEmbeddingWorker, +) + + +supported_colbert_models = [ + { + "model": "colpali", + "dim": ..., + "description": "Late interaction model", + "license": "mit", + "size_in_GB": 6.06, + "sources": { + "hf": "colpali", + }, + "model_file": "model.onnx", + }, +] + + +class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[np.ndarray]): + DOCUMENT_MARKER_TOKEN_ID = 2 + QUERY_PREFIX = "Query: " + BOS_TOKEN = "" + PAD_TOKEN = "" + QUERY_MARKER_TOKEN_ID = [2, 9413] + IMAGE_PLACEHOLDER_SIZE = (3, 448, 448) + EMPTY_TEXT_PLACEHOLDER = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) + EVEN_ATTENTION_MASK = np.array([1] * 1030) + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + **kwargs, + ): + """ + Args: + model_name (str): The name of the model to use. + cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. + providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use. + Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None. + cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers` + Defaults to False. + device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in + workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None. + lazy_load (bool, optional): Whether to load the model during class initialization or on demand. + Should be set to True when using multiple-gpu and parallel encoding. Defaults to False. + device_id (Optional[int], optional): The device id to use for loading the model in the worker process. + + Raises: + ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. + """ + + super().__init__(model_name, cache_dir, threads, **kwargs) + self.providers = providers + self.lazy_load = lazy_load + + # List of device ids, that can be used for data parallel processing in workers + self.device_ids = device_ids + self.cuda = cuda + + # This device_id will be used if we need to load model in current process + if device_id is not None: + self.device_id = device_id + elif self.device_ids is not None: + self.device_id = self.device_ids[0] + else: + self.device_id = None + + self.model_description = self._get_model_description(model_name) + self.cache_dir = define_cache_dir(cache_dir) + + self._model_dir = self.download_model( + self.model_description, self.cache_dir, local_files_only=self._local_files_only + ) + self.mask_token_id = None + self.pad_token_id = None + self.skip_list = set() + + if not self.lazy_load: + self.load_onnx_model() + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_colbert_models + + def load_onnx_model(self) -> None: + self._load_onnx_model( + model_dir=self._model_dir, + model_file=self.model_description["model_file"], + threads=self.threads, + providers=self.providers, + cuda=self.cuda, + device_id=self.device_id, + ) + + def _post_process_onnx_image_output( + self, + output: OnnxOutputContext, + ) -> Iterable[np.ndarray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[np.ndarray]: Post-processed output as NumPy arrays. + """ + return output.model_output.astype(np.float32) + + def _post_process_onnx_text_output( + self, + output: OnnxOutputContext, + ) -> Iterable[np.ndarray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[np.ndarray]: Post-processed output as NumPy arrays. + """ + return output.model_output.astype(np.float32) + + def tokenize(self, documents: list[str], **_) -> list[Encoding]: + texts_query: list[str] = [] + + for query in documents: + query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 + query += "\n" + + texts_query.append(query) + encoded = self.tokenizer.encode_batch(documents) + return encoded + + def _preprocess_onnx_text_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + onnx_input["input_ids"] = np.array( + [self.QUERY_MARKER_TOKEN_ID + input_ids[2:] for input_ids in onnx_input["input_ids"]] + ) + return onnx_input + + def embed_text( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self._embed_documents( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + documents=documents, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + + def embed_images( + self, + images: ImageInput, + batch_size: int = 16, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of images into list of embeddings. + + Args: + images: Iterator of image paths or single image path to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self._embed_images( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + images=images, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + + @classmethod + def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker]: + return ColPaliTextEmbeddingWorker + + @classmethod + def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker]: + return ColPaliImageEmbeddingWorker + + +class ColPaliTextEmbeddingWorker(TextEmbeddingWorker): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: + return ColPali( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) + + +class ColPaliImageEmbeddingWorker(ImageEmbeddingWorker): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: + return ColPali( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py new file mode 100644 index 00000000..3d35c52f --- /dev/null +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -0,0 +1,123 @@ +from typing import Any, Iterable, Optional, Sequence, Type, Union + +import numpy as np + +from fastembed.common import OnnxProvider, ImageInput +from fastembed.late_interaction_multimodal.colpali import ColPali + +from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( + LateInteractionMultimodalEmbeddingBase, +) + + +class LateInteractionMultimodalEmbedding(LateInteractionMultimodalEmbeddingBase): + EMBEDDINGS_REGISTRY: list[Type[LateInteractionMultimodalEmbeddingBase]] = [ColPali] + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """ + Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + + Example: + ``` + [ + { + "model": "colpali", + "dim": ..., + "description": "Late interaction model", + "license": "mit", + "size_in_GB": 6.06, + "sources": { + "hf": "colpali", + }, + "model_file": "model.onnx", + }, + ] + ``` + """ + result = [] + for embedding in cls.EMBEDDINGS_REGISTRY: + result.extend(embedding.list_supported_models()) + return result + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + **kwargs, + ): + super().__init__(model_name, cache_dir, threads, **kwargs) + for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: + supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() + if any(model_name.lower() == model["model"].lower() for model in supported_models): + self.model = EMBEDDING_MODEL_TYPE( + model_name, + cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + **kwargs, + ) + return + + raise ValueError( + f"Model {model_name} is not supported in LateInteractionMultimodalEmbedding." + "Please check the supported models using `LateInteractionMultimodalEmbedding.list_supported_models()`" + ) + + def embed_text( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self.model.embed_text(documents, batch_size, parallel, **kwargs) + + def embed_image( + self, + images: ImageInput, + batch_size: int = 16, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + We use mean pooling with attention so that the model can handle variable-length inputs. + + Args: + images: Iterator of image paths or single image path to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self.model.embed_image(images, batch_size, parallel, **kwargs) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py new file mode 100644 index 00000000..cc1a929b --- /dev/null +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -0,0 +1,65 @@ +from typing import Iterable, Optional, Union + +import numpy as np + +from fastembed.common import ImageInput +from fastembed.common.model_management import ModelManagement + + +class LateInteractionMultimodalEmbeddingBase(ModelManagement): + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs, + ): + self.model_name = model_name + self.cache_dir = cache_dir + self.threads = threads + self._local_files_only = kwargs.pop("local_files_only", False) + + def embed_text( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Embeds a list of documents into a list of embeddings. + + Args: + documents (Iterable[str]): The list of texts to embed. + batch_size (int) - ... + parallel (Optional[int]) - ... + **kwargs: Additional keyword argument to pass to the embed method. + + Yields: + Iterable[np.ndarray]: The embeddings. + """ + raise NotImplementedError() + + def embed_image( + self, + images: ImageInput, + batch_size: int = 16, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + We use mean pooling with attention so that the model can handle variable-length inputs. + + Args: + images: Iterator of image paths or single image path to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + raise NotImplementedError() diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py new file mode 100644 index 00000000..0557a92e --- /dev/null +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -0,0 +1,237 @@ +import contextlib +import os +from multiprocessing import get_all_start_methods +from pathlib import Path +from typing import Any, Iterable, Optional, Sequence, Type, Union + +import numpy as np +from PIL import Image +from tokenizers import Encoding + +from fastembed.common import OnnxProvider, ImageInput +from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T +from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor +from fastembed.common.utils import iter_batch +from fastembed.parallel_processor import ParallelWorkerPool + + +class OnnxMultimodalModel(OnnxModel[T]): + ONNX_OUTPUT_NAMES: Optional[list[str]] = None + + def __init__(self) -> None: + super().__init__() + self.tokenizer = None + self.processor = None + self.special_token_to_id = {} + + def _preprocess_onnx_text_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + """ + Preprocess the onnx input. + """ + return onnx_input + + def _preprocess_onnx_image_input( + self, onnx_input: dict[str, np.ndarray], **kwargs + ) -> dict[str, np.ndarray]: + """ + Preprocess the onnx input. + """ + return onnx_input + + @classmethod + def _get_text_worker_class(cls) -> Type["TextEmbeddingWorker"]: + raise NotImplementedError("Subclasses must implement this method") + + @classmethod + def _get_image_worker_class(cls) -> Type["ImageEmbeddingWorker"]: + raise NotImplementedError("Subclasses must implement this method") + + def _post_process_onnx_image_output(self, output: OnnxOutputContext) -> Iterable[T]: + raise NotImplementedError("Subclasses must implement this method") + + def _post_process_onnx_text_output(self, output: OnnxOutputContext) -> Iterable[T]: + raise NotImplementedError("Subclasses must implement this method") + + def _load_onnx_model( + self, + model_dir: Path, + model_file: str, + threads: Optional[int], + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_id: Optional[int] = None, + ) -> None: + super()._load_onnx_model( + model_dir=model_dir, + model_file=model_file, + threads=threads, + providers=providers, + cuda=cuda, + device_id=device_id, + ) + self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) + self.processor = load_preprocessor(model_dir=model_dir) + + def load_onnx_model(self) -> None: + raise NotImplementedError("Subclasses must implement this method") + + def tokenize(self, documents: list[str], **kwargs) -> list[Encoding]: + return self.tokenizer.encode_batch(documents) + + def onnx_embed_text( + self, + documents: list[str], + **kwargs, + ) -> OnnxOutputContext: + encoded = self.tokenize(documents, **kwargs) + input_ids = np.array([e.ids for e in encoded]) + attention_mask = np.array([e.attention_mask for e in encoded]) + input_names = {node.name for node in self.model.get_inputs()} + onnx_input = { + "input_ids": np.array(input_ids, dtype=np.int64), + } + if "attention_mask" in input_names: + onnx_input["attention_mask"] = np.array(attention_mask, dtype=np.int64) + if "token_type_ids" in input_names: + onnx_input["token_type_ids"] = np.array( + [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 + ) + + onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs) + + model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) + return OnnxOutputContext( + model_output=model_output[0], + attention_mask=onnx_input.get("attention_mask", attention_mask), + input_ids=onnx_input.get("input_ids", input_ids), + ) + + def _embed_documents( + self, + model_name: str, + cache_dir: str, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + **kwargs, + ) -> Iterable[T]: + is_small = False + + if isinstance(documents, str): + documents = [documents] + is_small = True + + if isinstance(documents, list): + if len(documents) < batch_size: + is_small = True + + if parallel is None or is_small: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + for batch in iter_batch(documents, batch_size): + yield from self._post_process_onnx_text_output(self.onnx_embed_text(batch)) + else: + if parallel == 0: + parallel = os.cpu_count() + + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" + params = { + "model_name": model_name, + "cache_dir": cache_dir, + "providers": providers, + **kwargs, + } + + pool = ParallelWorkerPool( + num_workers=parallel or 1, + worker=self._get_text_worker_class(), + cuda=cuda, + device_ids=device_ids, + start_method=start_method, + ) + for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): + yield from self._post_process_onnx_text_output(batch) + + def _build_onnx_image_input(self, encoded: np.ndarray) -> dict[str, np.ndarray]: + return {node.name: encoded for node in self.model.get_inputs()} + + def onnx_embed_image(self, images: list[ImageInput], **kwargs) -> OnnxOutputContext: + with contextlib.ExitStack(): + image_files = [ + Image.open(image) if not isinstance(image, Image.Image) else image + for image in images + ] + encoded = self.processor(image_files) + onnx_input = self._build_onnx_image_input(encoded) + onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs) + model_output = self.model.run(None, onnx_input) + embeddings = model_output[0].reshape(len(images), -1) + return OnnxOutputContext(model_output=embeddings) + + def _embed_images( + self, + model_name: str, + cache_dir: str, + images: ImageInput, + batch_size: int = 256, + parallel: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + **kwargs, + ) -> Iterable[T]: + is_small = False + + if isinstance(images, (str, Path, Image.Image)): + images = [images] + is_small = True + + if isinstance(images, list) and len(images) < batch_size: + is_small = True + + if parallel is None or is_small: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + for batch in iter_batch(images, batch_size): + yield from self._post_process_onnx_image_output(self.onnx_embed_image(batch)) + else: + if parallel == 0: + parallel = os.cpu_count() + + start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" + params = { + "model_name": model_name, + "cache_dir": cache_dir, + "providers": providers, + **kwargs, + } + + pool = ParallelWorkerPool( + num_workers=parallel or 1, + worker=self._get_image_worker_class(), + cuda=cuda, + device_ids=device_ids, + start_method=start_method, + ) + for batch in pool.ordered_map(iter_batch(images, batch_size), **params): + yield from self._post_process_onnx_image_output(batch) + + +class TextEmbeddingWorker(EmbeddingWorker): + def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: + for idx, batch in items: + onnx_output = self.model.onnx_embed_text(batch) + yield idx, onnx_output + + +class ImageEmbeddingWorker(EmbeddingWorker): + def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: + for idx, batch in items: + embeddings = self.model.onnx_embed_image(batch) + yield idx, embeddings