diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..8c09f1a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + metrics: For testing only the metrics. + models: for testing only the models. diff --git a/tests/metrics/_base.py b/tests/metrics/_base.py index a51a0c8..089abdc 100755 --- a/tests/metrics/_base.py +++ b/tests/metrics/_base.py @@ -6,6 +6,7 @@ from .fixtures import predictions, references +@pytest.mark.metrics class BaseMetricTest: _metric_cls = None _key = None diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/fixtures.py b/tests/models/fixtures.py new file mode 100755 index 0000000..869aba3 --- /dev/null +++ b/tests/models/fixtures.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +import json +from pathlib import Path + +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def squad_v2_data(): + path = Path(__file__).parent.parent.joinpath("data/squad_v2.json") + data = {} + with open(path, "r") as f: + data = json.load(f) + return data + + +@pytest.fixture(autouse=True, scope="module") +def imdb_data(): + path = Path(__file__).parent.parent.joinpath("data/imdb.json") + data = {} + with open(path, "r") as f: + data = json.load(f) + return data + + +def main(): + pass + + +if __name__ == "__main__": + main() diff --git a/tests/models/test_classification.py b/tests/models/test_classification.py new file mode 100755 index 0000000..5451af7 --- /dev/null +++ b/tests/models/test_classification.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +from pprint import pprint +from typing import Iterable + +import pytest + +from evalem.models import TextClassificationHFPipelineWrapper +from evalem.structures import PredictionDTO + +from .fixtures import imdb_data + + +@pytest.mark.models +class TestDefaultTextClassificationWrapper: + @pytest.fixture(autouse=True, scope="class") + def inputs(self, imdb_data): + return imdb_data.get("inputs", []) + + @pytest.fixture(autouse=True, scope="class") + def model(self): + yield TextClassificationHFPipelineWrapper(hf_params=dict(truncation=True)) + + @pytest.fixture(autouse=True, scope="class") + def references(self, imdb_data): + return imdb_data.get("references", []) + + @pytest.fixture(autouse=True, scope="class") + def predictions(self, model, imdb_data): + return model(imdb_data["inputs"]) + + def test_predictions_format(self, predictions): + assert isinstance(predictions, Iterable) + assert isinstance(predictions[0], PredictionDTO) + + def test_predictions_len(self, predictions, references): + pprint(f"Predictions | {predictions}") + assert len(predictions) == len(references) + + +def main(): + pass + + +if __name__ == "__main__": + main() diff --git a/tests/models/test_qa.py b/tests/models/test_qa.py new file mode 100755 index 0000000..4e46aee --- /dev/null +++ b/tests/models/test_qa.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +from typing import Iterable + +import pytest + +from evalem.models import QuestionAnsweringHFPipelineWrapper +from evalem.structures import QAPredictionDTO + +from .fixtures import squad_v2_data + + +@pytest.mark.models +class TestDefaultQAWrapper: + @pytest.fixture(autouse=True, scope="class") + def inputs(self, squad_v2_data): + return squad_v2_data.get("inputs", []) + + @pytest.fixture(autouse=True, scope="class") + def model(self): + yield QuestionAnsweringHFPipelineWrapper( + # model="distilbert-base-cased-distilled-squad" + ) + + @pytest.fixture(autouse=True, scope="class") + def references(self, squad_v2_data): + return squad_v2_data.get("references", []) + + @pytest.fixture(autouse=True, scope="class") + def predictions(self, model, squad_v2_data): + return model(squad_v2_data["inputs"]) + + def test_predictions_format(self, predictions): + assert isinstance(predictions, Iterable) + assert isinstance(predictions[0], QAPredictionDTO) + + def test_predictions_len(self, predictions, references): + print(predictions) + assert len(predictions) == len(references) + + +def main(): + pass + + +if __name__ == "__main__": + main()