Skip to content

Commit

Permalink
[alpha] Addition of test suites for QA and text classification model …
Browse files Browse the repository at this point in the history
…wrapper (#10)

* Add barebone unit test for jury format conversion for references

* Add more tests for jury format conversion

* Fix test function name

* Add misc datasets loader module

See `evalem.misc.datasets`.
We have `datasets.get_squad_v2(...)` function.

* Add device parameter to DefaultQAWrapper

* Update __init__ files with respective module imports

* Add metrics for computing semantics similarity

Now we have `metrics.semantics.SemanticMetric`.
There are 2 implementation for now:
- `metrics.semantics.BertScore`
- `metrics.semantics.BartScore`

* Add preprocessing and post-processing mechanism to ModelWrapper

We make use of 2 kwargs to any model wrapper:
- `inputs_preprocessor` (maps inputs to a specific format, defaults to
  identity)
- `predictions_postprocessor` (maps model outputs to a specific format,
  defaults to identity)

Also `models.HFPipelineWrapperForQuestionAnswering` is created.
`models.DefaultQAModelWrapper` is deprecated.

* Update code documentation for HFPipelineWrapper docstring

* Implement HF pipeline wrapper/evaluator for text classification

See `models.defaults.TextClassificationHFPipelineWrapper`.
Also improve the concstruction of hf pipeline object in existing
wrapper.

`evaluators.basics.TextClassificationEvaluator` is also added.

* Add per_instance_score flag to BertScore metric

This flag is used to return precision/recall/f1 score per prediction
instance.

* Default to bert-base-uncased for BertScore

* Bugfix tokenizer parameter in QuestionAnsweringHFPipelineWrapper and
TextClassificationHFPipelineWrapper

Previously, tokenizer was set to some defaults. However, that is
incorrect. We want tokenizer to be the one for which provided model was
trained on. So, now `tokenizer` is set to None by default.

* Add preliminary test cases for metrics

* Improve metric tests by caching metric result at class-level scope

* Ignore flake8 linting error for all test cases

* Add initial basic test for default QA model wrapper

* Add default test suite for text classification

All the model test suite are maked as `models` and can be run as
`pytest -m models -v`
  • Loading branch information
NISH1001 authored Mar 20, 2023
1 parent f0bdd66 commit e6b25c3
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
markers =
metrics: For testing only the metrics.
models: for testing only the models.
1 change: 1 addition & 0 deletions tests/metrics/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .fixtures import predictions, references


@pytest.mark.metrics
class BaseMetricTest:
_metric_cls = None
_key = None
Expand Down
Empty file added tests/models/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions tests/models/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3

import json
from pathlib import Path

import pytest


@pytest.fixture(autouse=True, scope="module")
def squad_v2_data():
path = Path(__file__).parent.parent.joinpath("data/squad_v2.json")
data = {}
with open(path, "r") as f:
data = json.load(f)
return data


@pytest.fixture(autouse=True, scope="module")
def imdb_data():
path = Path(__file__).parent.parent.joinpath("data/imdb.json")
data = {}
with open(path, "r") as f:
data = json.load(f)
return data


def main():
pass


if __name__ == "__main__":
main()
46 changes: 46 additions & 0 deletions tests/models/test_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3

from pprint import pprint
from typing import Iterable

import pytest

from evalem.models import TextClassificationHFPipelineWrapper
from evalem.structures import PredictionDTO

from .fixtures import imdb_data


@pytest.mark.models
class TestDefaultTextClassificationWrapper:
@pytest.fixture(autouse=True, scope="class")
def inputs(self, imdb_data):
return imdb_data.get("inputs", [])

@pytest.fixture(autouse=True, scope="class")
def model(self):
yield TextClassificationHFPipelineWrapper(hf_params=dict(truncation=True))

@pytest.fixture(autouse=True, scope="class")
def references(self, imdb_data):
return imdb_data.get("references", [])

@pytest.fixture(autouse=True, scope="class")
def predictions(self, model, imdb_data):
return model(imdb_data["inputs"])

def test_predictions_format(self, predictions):
assert isinstance(predictions, Iterable)
assert isinstance(predictions[0], PredictionDTO)

def test_predictions_len(self, predictions, references):
pprint(f"Predictions | {predictions}")
assert len(predictions) == len(references)


def main():
pass


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions tests/models/test_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3

from typing import Iterable

import pytest

from evalem.models import QuestionAnsweringHFPipelineWrapper
from evalem.structures import QAPredictionDTO

from .fixtures import squad_v2_data


@pytest.mark.models
class TestDefaultQAWrapper:
@pytest.fixture(autouse=True, scope="class")
def inputs(self, squad_v2_data):
return squad_v2_data.get("inputs", [])

@pytest.fixture(autouse=True, scope="class")
def model(self):
yield QuestionAnsweringHFPipelineWrapper(
# model="distilbert-base-cased-distilled-squad"
)

@pytest.fixture(autouse=True, scope="class")
def references(self, squad_v2_data):
return squad_v2_data.get("references", [])

@pytest.fixture(autouse=True, scope="class")
def predictions(self, model, squad_v2_data):
return model(squad_v2_data["inputs"])

def test_predictions_format(self, predictions):
assert isinstance(predictions, Iterable)
assert isinstance(predictions[0], QAPredictionDTO)

def test_predictions_len(self, predictions, references):
print(predictions)
assert len(predictions) == len(references)


def main():
pass


if __name__ == "__main__":
main()

0 comments on commit e6b25c3

Please sign in to comment.