diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..0392032 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,9 @@ +name: Ruff +on: pull_request +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: "Linting & Flaking" + uses: chartboost/ruff-action@v1 diff --git a/byaldi/RAGModel.py b/byaldi/RAGModel.py index 58f99dd..0b87bd8 100644 --- a/byaldi/RAGModel.py +++ b/byaldi/RAGModel.py @@ -1,6 +1,5 @@ from pathlib import Path -from typing import Any, List, Optional, Union, Dict -from uuid import uuid4 +from typing import Dict, List, Optional, Union from PIL import Image @@ -165,4 +164,4 @@ def search( return self.model.search(query, k, return_base64_results) def get_doc_ids_to_file_names(self): - return self.model.get_doc_ids_to_file_names() \ No newline at end of file + return self.model.get_doc_ids_to_file_names() diff --git a/byaldi/__init__.py b/byaldi/__init__.py index 0f7a779..3b8a6cf 100644 --- a/byaldi/__init__.py +++ b/byaldi/__init__.py @@ -1,5 +1,6 @@ -from .RAGModel import RAGMultiModalModel from importlib.metadata import version +from .RAGModel import RAGMultiModalModel + __version__ = version("Byaldi") __all__ = ["RAGMultiModalModel"] diff --git a/byaldi/colpali.py b/byaldi/colpali.py index 183068f..cc08fef 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -1,24 +1,27 @@ import os -import srsly -import torch import shutil -from typing import Optional, Union, List, Dict + +# Import version directly from the package metadata +from importlib.metadata import version from pathlib import Path -from PIL import Image -from pdf2image import convert_from_path -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import AutoProcessor +from typing import Dict, List, Optional, Union + +import srsly +import torch from colpali_engine.models.paligemma_colbert_architecture import ColPali from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator from colpali_engine.utils.colpali_processing_utils import ( process_images, process_queries, ) +from pdf2image import convert_from_path +from PIL import Image +from transformers import AutoProcessor + from byaldi.objects import Result + from .utils import capture_print -# Import version directly from the package metadata -from importlib.metadata import version + VERSION = version("Byaldi") @@ -43,12 +46,20 @@ def __init__( ) if verbose > 0: - print(f"Verbosity is set to {verbose} ({'active' if verbose == 1 else 'loud'}). Pass verbose=0 to make quieter.") + print( + f"Verbosity is set to {verbose} ({'active' if verbose == 1 else 'loud'}). Pass verbose=0 to make quieter." + ) self.pretrained_model_name_or_path = pretrained_model_name_or_path self.model_name = self.pretrained_model_name_or_path self.n_gpu = torch.cuda.device_count() if n_gpu == -1 else n_gpu - device = device or "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + device = ( + device or "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) self.index_name = index_name self.verbose = verbose self.load_from_index = load_from_index @@ -71,13 +82,11 @@ def __init__( # token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), # ) - # if verbose > 0: # print("Loading adapter...") # print("Adapter name: ", self.pretrained_model_name_or_path) # self.model.load_adapter(self.pretrained_model_name_or_path) - self.model = ColPali.from_pretrained( self.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, @@ -113,7 +122,8 @@ def __init__( if self.full_document_collection: collection_path = index_path / "collection" json_files = sorted( - collection_path.glob("*.json.gz"), key=lambda x: int(x.stem.split('.')[0]) + collection_path.glob("*.json.gz"), + key=lambda x: int(x.stem.split(".")[0]), ) for json_file in json_files: @@ -132,20 +142,35 @@ def __init__( ) embeddings_path = index_path / "embeddings" - embedding_files = sorted(embeddings_path.glob("embeddings_*.pt"), key=lambda x: int(x.stem.split('_')[1])) + embedding_files = sorted( + embeddings_path.glob("embeddings_*.pt"), + key=lambda x: int(x.stem.split("_")[1]), + ) self.indexed_embeddings = [] for file in embedding_files: self.indexed_embeddings.extend(torch.load(file)) - self.embed_id_to_doc_id = srsly.read_gzip_json(index_path / "embed_id_to_doc_id.json.gz") + self.embed_id_to_doc_id = srsly.read_gzip_json( + index_path / "embed_id_to_doc_id.json.gz" + ) # Restore keys to integers - self.embed_id_to_doc_id = {int(k): v for k, v in self.embed_id_to_doc_id.items()} - self.highest_doc_id = max(int(entry["doc_id"]) for entry in self.embed_id_to_doc_id.values()) - self.doc_ids = set(int(entry["doc_id"]) for entry in self.embed_id_to_doc_id.values()) + self.embed_id_to_doc_id = { + int(k): v for k, v in self.embed_id_to_doc_id.items() + } + self.highest_doc_id = max( + int(entry["doc_id"]) for entry in self.embed_id_to_doc_id.values() + ) + self.doc_ids = set( + int(entry["doc_id"]) for entry in self.embed_id_to_doc_id.values() + ) try: # We don't want this error out with indexes created prior to 0.0.2 - self.doc_ids_to_file_names = srsly.read_gzip_json(index_path / "doc_ids_to_file_names.json.gz") - self.doc_ids_to_file_names = {int(k): v for k, v in self.doc_ids_to_file_names.items()} + self.doc_ids_to_file_names = srsly.read_gzip_json( + index_path / "doc_ids_to_file_names.json.gz" + ) + self.doc_ids_to_file_names = { + int(k): v for k, v in self.doc_ids_to_file_names.items() + } except FileNotFoundError: pass @@ -154,7 +179,9 @@ def __init__( if metadata_path.exists(): self.doc_id_to_metadata = srsly.read_gzip_json(metadata_path) # Convert metadata keys to integers - self.doc_id_to_metadata = {int(k): v for k, v in self.doc_id_to_metadata.items()} + self.doc_id_to_metadata = { + int(k): v for k, v in self.doc_id_to_metadata.items() + } else: self.doc_id_to_metadata = {} @@ -188,7 +215,7 @@ def from_index( index_root: str = ".byaldi", **kwargs, ): - index_path = Path(index_root) / Path(index_path) + index_path = Path(index_root) / Path(index_path) index_config = srsly.read_gzip_json(index_path / "index_config.json.gz") instance = cls( @@ -217,7 +244,7 @@ def _export_index(self): num_embeddings = len(self.indexed_embeddings) chunk_size = 500 for i in range(0, num_embeddings, chunk_size): - chunk = self.indexed_embeddings[i:i+chunk_size] + chunk = self.indexed_embeddings[i : i + chunk_size] torch.save(chunk, embeddings_path / f"embeddings_{i}.pt") # Save index config @@ -225,7 +252,9 @@ def _export_index(self): "model_name": self.model_name, "full_document_collection": self.full_document_collection, "highest_doc_id": self.highest_doc_id, - "resize_stored_images": True if self.max_image_width and self.max_image_height else False, + "resize_stored_images": True + if self.max_image_width and self.max_image_height + else False, "max_image_width": self.max_image_width, "max_image_height": self.max_image_height, "library_version": VERSION, @@ -233,10 +262,14 @@ def _export_index(self): srsly.write_gzip_json(index_path / "index_config.json.gz", index_config) # Save embed_id_to_doc_id mapping - srsly.write_gzip_json(index_path / "embed_id_to_doc_id.json.gz", self.embed_id_to_doc_id) + srsly.write_gzip_json( + index_path / "embed_id_to_doc_id.json.gz", self.embed_id_to_doc_id + ) # Save doc_ids_to_file_names - srsly.write_gzip_json(index_path / "doc_ids_to_file_names.json.gz", self.doc_ids_to_file_names) + srsly.write_gzip_json( + index_path / "doc_ids_to_file_names.json.gz", self.doc_ids_to_file_names + ) # Save metadata srsly.write_gzip_json(index_path / "metadata.json.gz", self.doc_id_to_metadata) @@ -251,7 +284,7 @@ def _export_index(self): if self.verbose > 0: print(f"Index exported to {index_path}") - + def index( self, input_path: Union[str, Path], @@ -279,18 +312,22 @@ def index( raise ValueError("index_name must be specified to create a new index.") if store_collection_with_index: self.full_document_collection = True - + index_path = Path(self.index_root) / Path(index_name) if index_path.exists(): if overwrite is False: - raise ValueError(f"An index named {index_name} already exists.", - "Use overwrite=True to delete the existing index and build a new one.", - "Exiting indexing without doing anything...") + raise ValueError( + f"An index named {index_name} already exists.", + "Use overwrite=True to delete the existing index and build a new one.", + "Exiting indexing without doing anything...", + ) return None else: - print(f"overwrite is on. Deleting existing index {index_name} to build a new one.") + print( + f"overwrite is on. Deleting existing index {index_name} to build a new one." + ) shutil.rmtree(index_path) - + self.index_name = index_name self.max_image_width = max_image_width self.max_image_height = max_image_height @@ -313,19 +350,31 @@ def index( print(f"Indexing file: {item}") doc_id = doc_ids[i] if doc_ids else self.highest_doc_id + 1 doc_metadata = metadata[doc_id] if metadata else None - self.add_to_index(item, store_collection_with_index, doc_id=doc_id, metadata=doc_metadata) + self.add_to_index( + item, + store_collection_with_index, + doc_id=doc_id, + metadata=doc_metadata, + ) self.doc_ids_to_file_names[doc_id] = str(item) else: if metadata is not None and len(metadata) != 1: - raise ValueError("For a single document, metadata should be a list with one dictionary") + raise ValueError( + "For a single document, metadata should be a list with one dictionary" + ) doc_id = doc_ids[0] if doc_ids else self.highest_doc_id + 1 doc_metadata = metadata[0] if metadata else None - self.add_to_index(input_path, store_collection_with_index, doc_id=doc_id, metadata=doc_metadata) + self.add_to_index( + input_path, + store_collection_with_index, + doc_id=doc_id, + metadata=doc_metadata, + ) self.doc_ids_to_file_names[doc_id] = str(input_path) self._export_index() return self.doc_ids_to_file_names - + def add_to_index( self, input_item: Union[str, Path, Image.Image, List[Union[str, Path, Image.Image]]], @@ -334,22 +383,34 @@ def add_to_index( metadata: Optional[List[Dict[str, Union[str, int]]]] = None, ) -> Dict[int, str]: if self.index_name is None: - raise ValueError("No index loaded. Use index() to create or load an index first.") + raise ValueError( + "No index loaded. Use index() to create or load an index first." + ) if not hasattr(self, "highest_doc_id"): self.highest_doc_id = -1 # Convert single inputs to lists for uniform processing if isinstance(input_item, (str, Path)) and Path(input_item).is_dir(): input_items = list(Path(input_item).iterdir()) else: - input_items = [input_item] if not isinstance(input_item, list) else input_item - - doc_ids = [doc_id] if isinstance(doc_id, int) else (doc_id if doc_id is not None else None) + input_items = ( + [input_item] if not isinstance(input_item, list) else input_item + ) + + doc_ids = ( + [doc_id] + if isinstance(doc_id, int) + else (doc_id if doc_id is not None else None) + ) # Validate input lengths if doc_ids and len(doc_ids) != len(input_items): - raise ValueError(f"Number of doc_ids ({len(doc_ids)}) does not match number of input items ({len(input_items)})") + raise ValueError( + f"Number of doc_ids ({len(doc_ids)}) does not match number of input items ({len(input_items)})" + ) if metadata and len(metadata) != len(input_items): - raise ValueError(f"Number of metadata entries ({len(metadata)}) does not match number of input items ({len(input_items)})") + raise ValueError( + f"Number of metadata entries ({len(metadata)}) does not match number of input items ({len(input_items)})" + ) # Process each input item for i, item in enumerate(input_items): @@ -357,19 +418,33 @@ def add_to_index( current_metadata = metadata[i] if metadata else None if current_doc_id in self.doc_ids: - raise ValueError(f"Document ID {current_doc_id} already exists in the index") + raise ValueError( + f"Document ID {current_doc_id} already exists in the index" + ) self.highest_doc_id = max(self.highest_doc_id, current_doc_id) if isinstance(item, (str, Path)): item_path = Path(item) if item_path.is_dir(): - self._process_directory(item_path, store_collection_with_index, current_doc_id, current_metadata) + self._process_directory( + item_path, + store_collection_with_index, + current_doc_id, + current_metadata, + ) else: - self._process_and_add_to_index(item_path, store_collection_with_index, current_doc_id, current_metadata) + self._process_and_add_to_index( + item_path, + store_collection_with_index, + current_doc_id, + current_metadata, + ) self.doc_ids_to_file_names[current_doc_id] = str(item_path) elif isinstance(item, Image.Image): - self._process_and_add_to_index(item, store_collection_with_index, current_doc_id, current_metadata) + self._process_and_add_to_index( + item, store_collection_with_index, current_doc_id, current_metadata + ) self.doc_ids_to_file_names[current_doc_id] = "In-memory Image" else: raise ValueError(f"Unsupported input type: {type(item)}") @@ -377,11 +452,19 @@ def add_to_index( self._export_index() return self.doc_ids_to_file_names - def _process_directory(self, directory: Path, store_collection_with_index: bool, base_doc_id: int, metadata: Optional[Dict[str, Union[str, int]]]): + def _process_directory( + self, + directory: Path, + store_collection_with_index: bool, + base_doc_id: int, + metadata: Optional[Dict[str, Union[str, int]]], + ): for i, item in enumerate(directory.iterdir()): print(f"Indexing file: {item}") current_doc_id = base_doc_id + i - self._process_and_add_to_index(item, store_collection_with_index, current_doc_id, metadata) + self._process_and_add_to_index( + item, store_collection_with_index, current_doc_id, metadata + ) self.doc_ids_to_file_names[current_doc_id] = str(item) def _process_and_add_to_index( @@ -396,14 +479,24 @@ def _process_and_add_to_index( if item.suffix.lower() == ".pdf": images = convert_from_path(item) for i, image in enumerate(images): - self._add_to_index(image, store_collection_with_index, doc_id, page_id=i + 1, metadata=metadata) + self._add_to_index( + image, + store_collection_with_index, + doc_id, + page_id=i + 1, + metadata=metadata, + ) elif item.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]: image = Image.open(item) - self._add_to_index(image, store_collection_with_index, doc_id, metadata=metadata) + self._add_to_index( + image, store_collection_with_index, doc_id, metadata=metadata + ) else: raise ValueError(f"Unsupported input type: {item.suffix}") elif isinstance(item, Image.Image): - self._add_to_index(item, store_collection_with_index, doc_id, metadata=metadata) + self._add_to_index( + item, store_collection_with_index, doc_id, metadata=metadata + ) else: raise ValueError(f"Unsupported input type: {type(item)}") @@ -415,8 +508,13 @@ def _add_to_index( page_id: int = 1, metadata: Optional[Dict[str, Union[str, int]]] = None, ): - if any(entry["doc_id"] == doc_id and entry["page_id"] == page_id for entry in self.embed_id_to_doc_id.values()): - raise ValueError(f"Document ID {doc_id} with page ID {page_id} already exists in the index") + if any( + entry["doc_id"] == doc_id and entry["page_id"] == page_id + for entry in self.embed_id_to_doc_id.values() + ): + raise ValueError( + f"Document ID {doc_id} with page ID {page_id} already exists in the index" + ) processed_image = process_images(self.processor, [image]) @@ -431,7 +529,10 @@ def _add_to_index( self.embed_id_to_doc_id[embed_id] = {"doc_id": doc_id, "page_id": int(page_id)} # Update highest_doc_id - self.highest_doc_id = max(self.highest_doc_id, int(doc_id) if isinstance(doc_id, int) else self.highest_doc_id) + self.highest_doc_id = max( + self.highest_doc_id, + int(doc_id) if isinstance(doc_id, int) else self.highest_doc_id, + ) if store_collection_with_index: import base64 @@ -451,9 +552,11 @@ def _add_to_index( new_height = self.max_image_height new_width = int(new_height * aspect_ratio) if self.verbose > 2: - print(f"Resizing image to {new_width}x{new_height}" , - f"(aspect ratio {aspect_ratio:.2f}, original size {img_width}x{img_height}," - f"compression {new_width/img_width * new_height/img_height:.2f})") + print( + f"Resizing image to {new_width}x{new_height}", + f"(aspect ratio {aspect_ratio:.2f}, original size {img_width}x{img_height}," + f"compression {new_width/img_width * new_height/img_height:.2f})", + ) image = image.resize((new_width, new_height), Image.LANCZOS) buffered = io.BytesIO() @@ -531,12 +634,14 @@ def search( return results[0] if isinstance(query, str) else results - def encode_image(self, input_data: Union[str, Image.Image, List[Union[str, Image.Image]]]) -> torch.Tensor: + def encode_image( + self, input_data: Union[str, Image.Image, List[Union[str, Image.Image]]] + ) -> torch.Tensor: """ Compute embeddings for one or more images, PDFs, folders, or image files. Args: - input_data (Union[str, Image.Image, List[Union[str, Image.Image]]]): + input_data (Union[str, Image.Image, List[Union[str, Image.Image]]]): A single image, PDF path, folder path, image file path, or a list of these. Returns: @@ -553,13 +658,17 @@ def encode_image(self, input_data: Union[str, Image.Image, List[Union[str, Image if os.path.isdir(item): # Process folder for file in os.listdir(item): - if file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')): + if file.lower().endswith( + (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif") + ): images.append(Image.open(os.path.join(item, file))) - elif item.lower().endswith('.pdf'): + elif item.lower().endswith(".pdf"): # Process PDF pdf_images = convert_from_path(item) images.extend(pdf_images) - elif item.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')): + elif item.lower().endswith( + (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif") + ): # Process image file images.append(Image.open(item)) else: @@ -579,7 +688,7 @@ def encode_query(self, query: Union[str, List[str]]) -> torch.Tensor: Compute embeddings for one or more text queries. Args: - query (Union[str, List[str]]): + query (Union[str, List[str]]): A single text query or a list of text queries. Returns: diff --git a/byaldi/objects.py b/byaldi/objects.py index 37ca7d1..3e1f9cd 100644 --- a/byaldi/objects.py +++ b/byaldi/objects.py @@ -2,7 +2,14 @@ class Result: - def __init__(self, doc_id: str, page_num: int, score: float, metadata: Optional[dict] = None, base64: Optional[str] = None): + def __init__( + self, + doc_id: str, + page_num: int, + score: float, + metadata: Optional[dict] = None, + base64: Optional[str] = None, + ): self.doc_id = doc_id self.page_num = page_num self.score = score diff --git a/byaldi/utils.py b/byaldi/utils.py index 3d50cfb..52e37b0 100644 --- a/byaldi/utils.py +++ b/byaldi/utils.py @@ -1,5 +1,6 @@ -from io import StringIO import sys +from io import StringIO + def capture_print(func): def wrapper(*args, **kwargs): @@ -10,4 +11,5 @@ def wrapper(*args, **kwargs): finally: sys.stdout = original_stdout return result - return wrapper \ No newline at end of file + + return wrapper diff --git a/pyproject.toml b/pyproject.toml index 1ef5af0..b3b16b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,51 +1,52 @@ [build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" +requires = ["setuptools"] +build-backend = "setuptools.build_meta" [tool.setuptools] -packages = [ - "byaldi" -] +packages = ["byaldi"] [project] -name = "Byaldi" -version = "0.0.2post2" -description = "Use late-interaction multi-modal models such as ColPALI in just a few lines of code." +name = "Byaldi" +version = "0.0.4" +description = "Use late-interaction multi-modal models such as ColPali in just a few lines of code." readme = "README.md" -requires-python = ">=3.8" -license = {file = "LICENSE"} -keywords = ["reranking", "retrieval", "rag", "nlp", "colpali", "colbert", "multi-modal"] -authors = [ - {name = "Ben Clavié", email = "bc@answer.ai" } +requires-python = ">=3.9" +license = { file = "LICENSE" } +keywords = [ + "reranking", + "retrieval", + "rag", + "nlp", + "colpali", + "colbert", + "multi-modal", ] +authors = [{ name = "Ben Clavié", email = "bc@answer.ai" }] maintainers = [ - {name = "Ben Clavié", email = "bc@answer.ai" } + { name = "Ben Clavié", email = "bc@answer.ai" }, + { name = "Tony Wu", email = "tony.wu@illuin.tech" }, ] dependencies = [ -"transformers", -"torch", -"ninja", -"pdf2image", -"srsly", -"colpali-engine==0.2.2", -"mteb==1.6.35", + "colpali-engine==0.2.2", + "ml-dtypes", + "mteb==1.6.35", + "ninja", + "pdf2image", + "srsly", + "torch", + "transformers", ] - [project.optional-dependencies] -server = [ - "uvicorn", - "fastapi" -] +dev = ["pytest>=7.4.0", "ruff>=0.1.9"] +server = ["uvicorn", "fastapi"] [project.urls] "Homepage" = "https://github.com/answerdotai/byaldi" [tool.pytest.ini_options] -filterwarnings = [ - "ignore::Warning" -] +filterwarnings = ["ignore::Warning"] [tool.ruff] # Exclude a variety of commonly ignored directories. @@ -83,23 +84,17 @@ target-version = "py39" [tool.ruff.lint] select = [ - # bugbear rules - "B", - "I", - # remove unused imports - "F401", - # bare except statements - "E722", - # unused arguments - "ARG", -] -ignore = [ - "B006", - "B018", + # bugbear rules + "B", + "I", + # remove unused imports + "F401", + # bare except statements + "E722", + # unused arguments + "ARG", ] +ignore = ["B006", "B018"] -unfixable = [ - "T201", - "T203", -] +unfixable = ["T201", "T203"] ignore-init-module-imports = true diff --git a/tests/all.py b/tests/all.py index 4c6dcc6..6369eee 100644 --- a/tests/all.py +++ b/tests/all.py @@ -1,120 +1,137 @@ -import os -from pathlib import Path from byaldi import RAGMultiModalModel + def test_single_pdf(): print("Testing single PDF indexing and retrieval...") - + # Initialize the model model = RAGMultiModalModel.from_pretrained("vidore/colpali") - + # Index a single PDF model.index( input_path="docs/attention.pdf", index_name="attention_index", store_collection_with_index=True, - overwrite=True + overwrite=True, ) - + # Test retrieval queries = [ "How does the positional encoding thing work?", - "what's the BLEU score of this new strange method?" + "what's the BLEU score of this new strange method?", ] - + for query in queries: results = model.search(query, k=3) - + print(f"\nQuery: {query}") for result in results: - print(f"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}") - + print( + f"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}" + ) + # Check if the expected page (6 for positional encoding) is in the top results if "positional encoding" in query.lower(): - assert any(r.page_num == 6 for r in results), "Expected page 6 for positional encoding query" - + assert any( + r.page_num == 6 for r in results + ), "Expected page 6 for positional encoding query" + # Check if the expected pages (8 and 9 for BLEU score) are in the top results if "bleu score" in query.lower(): - assert any(r.page_num in [8, 9] for r in results), "Expected pages 8 or 9 for BLEU score query" - + assert any( + r.page_num in [8, 9] for r in results + ), "Expected pages 8 or 9 for BLEU score query" + print("Single PDF test completed.") + def test_multi_document(): print("\nTesting multi-document indexing and retrieval...") - + # Initialize the model model = RAGMultiModalModel.from_pretrained("vidore/colpali") - + # Index a directory of documents model.index( input_path="docs/", index_name="multi_doc_index", store_collection_with_index=True, - overwrite=True + overwrite=True, ) - + # Test retrieval queries = [ "How does the positional encoding thing work?", - "what's the BLEU score of this new strange method?" + "what's the BLEU score of this new strange method?", ] - + for query in queries: results = model.search(query, k=5) - + print(f"\nQuery: {query}") for result in results: - print(f"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}") - + print( + f"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}" + ) + # Check if the expected page (6 for positional encoding) is in the top results if "positional encoding" in query.lower(): - assert any(r.page_num == 6 for r in results), "Expected page 6 for positional encoding query" - + assert any( + r.page_num == 6 for r in results + ), "Expected page 6 for positional encoding query" + # Check if the expected pages (8 and 9 for BLEU score) are in the top results if "bleu score" in query.lower(): - assert any(r.page_num in [8, 9] for r in results), "Expected pages 8 or 9 for BLEU score query" - + assert any( + r.page_num in [8, 9] for r in results + ), "Expected pages 8 or 9 for BLEU score query" + print("Multi-document test completed.") + def test_add_to_index(): print("\nTesting adding to an existing index...") - + # Load the existing index model = RAGMultiModalModel.from_index("multi_doc_index") - + # Add a new document to the index model.add_to_index( input_item="docs/", store_collection_with_index=True, doc_id=[1002, 1003], - metadata=[{"author": "John Doe", "year": 2023}] * 2 + metadata=[{"author": "John Doe", "year": 2023}] * 2, ) - + # Test retrieval with the updated index - queries = [ - "what's the BLEU score of this new strange method?" - ] - + queries = ["what's the BLEU score of this new strange method?"] + for query in queries: results = model.search(query, k=3) - + print(f"\nQuery: {query}") for result in results: - print(f"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}") + print( + f"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}" + ) print(f"Metadata: {result.metadata}") - + # Check if the expected page (6 for positional encoding) is in the top results if "positional encoding" in query.lower(): - assert any(r.page_num == 6 for r in results), "Expected page 6 for positional encoding query" - + assert any( + r.page_num == 6 for r in results + ), "Expected page 6 for positional encoding query" + # Check if the expected pages (8 and 9 for BLEU score) are in the top results if "bleu score" in query.lower(): - assert any(r.page_num in [8, 9] for r in results), "Expected pages 8 or 9 for BLEU score query" - + assert any( + r.page_num in [8, 9] for r in results + ), "Expected pages 8 or 9 for BLEU score query" + print("Add to index test completed.") if __name__ == "__main__": test_single_pdf() test_multi_document() - test_add_to_index() \ No newline at end of file + test_add_to_index()