From a61a37cb2975f14cd3738957ed8b6630d9b69c9d Mon Sep 17 00:00:00 2001 From: Tony Wu <28306721+tonywu71@users.noreply.github.com> Date: Tue, 17 Sep 2024 09:40:23 +0200 Subject: [PATCH] feat: add guardrails and instructions in tests --- tests/all.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/all.py b/tests/all.py index 9873a46..85ad522 100644 --- a/tests/all.py +++ b/tests/all.py @@ -1,3 +1,5 @@ +from pathlib import Path + from colpali_engine.utils.torch_utils import get_torch_device from byaldi import RAGMultiModalModel @@ -5,6 +7,9 @@ device = get_torch_device("auto") print(f"Using device: {device}") +path_document_1 = Path("docs/attention.pdf") +path_document_2 = Path("docs/attention_copy.pdf") + def test_single_pdf(): print("Testing single PDF indexing and retrieval...") @@ -12,6 +17,11 @@ def test_single_pdf(): # Initialize the model model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device=device) + if not Path("docs/attention.pdf").is_file(): + raise FileNotFoundError( + f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_1}." + ) + # Index a single PDF model.index( input_path="docs/attention.pdf", @@ -56,6 +66,15 @@ def test_multi_document(): # Initialize the model model = RAGMultiModalModel.from_pretrained("vidore/colpali") + if not Path("docs/attention.pdf").is_file(): + raise FileNotFoundError( + f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_1}." + ) + if not Path("docs/attention_copy.pdf").is_file(): + raise FileNotFoundError( + f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_2}." + ) + # Index a directory of documents model.index( input_path="docs/", @@ -137,6 +156,15 @@ def test_add_to_index(): if __name__ == "__main__": + print("Starting tests...") + + print("/n/n----------------- Single PDF test -----------------n") test_single_pdf() + + print("/n/n----------------- Multi document test -----------------n") test_multi_document() + + print("/n/n----------------- Add to index test -----------------n") test_add_to_index() + + print("\nAll tests completed.")