From e6b25c3bdcca6c78494ad4dd411c1acec909860e Mon Sep 17 00:00:00 2001 From: Nish Date: Mon, 20 Mar 2023 15:01:33 -0500 Subject: [PATCH] [alpha] Addition of test suites for QA and text classification model wrapper (#10) * Add barebone unit test for jury format conversion for references * Add more tests for jury format conversion * Fix test function name * Add misc datasets loader module See `evalem.misc.datasets`. We have `datasets.get_squad_v2(...)` function. * Add device parameter to DefaultQAWrapper * Update __init__ files with respective module imports * Add metrics for computing semantics similarity Now we have `metrics.semantics.SemanticMetric`. There are 2 implementation for now: - `metrics.semantics.BertScore` - `metrics.semantics.BartScore` * Add preprocessing and post-processing mechanism to ModelWrapper We make use of 2 kwargs to any model wrapper: - `inputs_preprocessor` (maps inputs to a specific format, defaults to identity) - `predictions_postprocessor` (maps model outputs to a specific format, defaults to identity) Also `models.HFPipelineWrapperForQuestionAnswering` is created. `models.DefaultQAModelWrapper` is deprecated. * Update code documentation for HFPipelineWrapper docstring * Implement HF pipeline wrapper/evaluator for text classification See `models.defaults.TextClassificationHFPipelineWrapper`. Also improve the concstruction of hf pipeline object in existing wrapper. `evaluators.basics.TextClassificationEvaluator` is also added. * Add per_instance_score flag to BertScore metric This flag is used to return precision/recall/f1 score per prediction instance. * Default to bert-base-uncased for BertScore * Bugfix tokenizer parameter in QuestionAnsweringHFPipelineWrapper and TextClassificationHFPipelineWrapper Previously, tokenizer was set to some defaults. However, that is incorrect. We want tokenizer to be the one for which provided model was trained on. So, now `tokenizer` is set to None by default. * Add preliminary test cases for metrics * Improve metric tests by caching metric result at class-level scope * Ignore flake8 linting error for all test cases * Add initial basic test for default QA model wrapper * Add default test suite for text classification All the model test suite are maked as `models` and can be run as `pytest -m models -v` --- pytest.ini | 4 +++ tests/metrics/_base.py | 1 + tests/models/__init__.py | 0 tests/models/fixtures.py | 32 ++++++++++++++++++++ tests/models/test_classification.py | 46 ++++++++++++++++++++++++++++ tests/models/test_qa.py | 47 +++++++++++++++++++++++++++++ 6 files changed, 130 insertions(+) create mode 100644 pytest.ini create mode 100644 tests/models/__init__.py create mode 100755 tests/models/fixtures.py create mode 100755 tests/models/test_classification.py create mode 100755 tests/models/test_qa.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..8c09f1a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + metrics: For testing only the metrics. + models: for testing only the models. diff --git a/tests/metrics/_base.py b/tests/metrics/_base.py index a51a0c8..089abdc 100755 --- a/tests/metrics/_base.py +++ b/tests/metrics/_base.py @@ -6,6 +6,7 @@ from .fixtures import predictions, references +@pytest.mark.metrics class BaseMetricTest: _metric_cls = None _key = None diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/fixtures.py b/tests/models/fixtures.py new file mode 100755 index 0000000..869aba3 --- /dev/null +++ b/tests/models/fixtures.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +import json +from pathlib import Path + +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def squad_v2_data(): + path = Path(__file__).parent.parent.joinpath("data/squad_v2.json") + data = {} + with open(path, "r") as f: + data = json.load(f) + return data + + +@pytest.fixture(autouse=True, scope="module") +def imdb_data(): + path = Path(__file__).parent.parent.joinpath("data/imdb.json") + data = {} + with open(path, "r") as f: + data = json.load(f) + return data + + +def main(): + pass + + +if __name__ == "__main__": + main() diff --git a/tests/models/test_classification.py b/tests/models/test_classification.py new file mode 100755 index 0000000..5451af7 --- /dev/null +++ b/tests/models/test_classification.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +from pprint import pprint +from typing import Iterable + +import pytest + +from evalem.models import TextClassificationHFPipelineWrapper +from evalem.structures import PredictionDTO + +from .fixtures import imdb_data + + +@pytest.mark.models +class TestDefaultTextClassificationWrapper: + @pytest.fixture(autouse=True, scope="class") + def inputs(self, imdb_data): + return imdb_data.get("inputs", []) + + @pytest.fixture(autouse=True, scope="class") + def model(self): + yield TextClassificationHFPipelineWrapper(hf_params=dict(truncation=True)) + + @pytest.fixture(autouse=True, scope="class") + def references(self, imdb_data): + return imdb_data.get("references", []) + + @pytest.fixture(autouse=True, scope="class") + def predictions(self, model, imdb_data): + return model(imdb_data["inputs"]) + + def test_predictions_format(self, predictions): + assert isinstance(predictions, Iterable) + assert isinstance(predictions[0], PredictionDTO) + + def test_predictions_len(self, predictions, references): + pprint(f"Predictions | {predictions}") + assert len(predictions) == len(references) + + +def main(): + pass + + +if __name__ == "__main__": + main() diff --git a/tests/models/test_qa.py b/tests/models/test_qa.py new file mode 100755 index 0000000..4e46aee --- /dev/null +++ b/tests/models/test_qa.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +from typing import Iterable + +import pytest + +from evalem.models import QuestionAnsweringHFPipelineWrapper +from evalem.structures import QAPredictionDTO + +from .fixtures import squad_v2_data + + +@pytest.mark.models +class TestDefaultQAWrapper: + @pytest.fixture(autouse=True, scope="class") + def inputs(self, squad_v2_data): + return squad_v2_data.get("inputs", []) + + @pytest.fixture(autouse=True, scope="class") + def model(self): + yield QuestionAnsweringHFPipelineWrapper( + # model="distilbert-base-cased-distilled-squad" + ) + + @pytest.fixture(autouse=True, scope="class") + def references(self, squad_v2_data): + return squad_v2_data.get("references", []) + + @pytest.fixture(autouse=True, scope="class") + def predictions(self, model, squad_v2_data): + return model(squad_v2_data["inputs"]) + + def test_predictions_format(self, predictions): + assert isinstance(predictions, Iterable) + assert isinstance(predictions[0], QAPredictionDTO) + + def test_predictions_len(self, predictions, references): + print(predictions) + assert len(predictions) == len(references) + + +def main(): + pass + + +if __name__ == "__main__": + main()