Skip to content

Commit

Permalink
Addition of unit tests for metrics (#9)
Browse files Browse the repository at this point in the history
* Add barebone unit test for jury format conversion for references

* Add more tests for jury format conversion

* Fix test function name

* Add misc datasets loader module

See `evalem.misc.datasets`.
We have `datasets.get_squad_v2(...)` function.

* Add device parameter to DefaultQAWrapper

* Update __init__ files with respective module imports

* Add metrics for computing semantics similarity

Now we have `metrics.semantics.SemanticMetric`.
There are 2 implementation for now:
- `metrics.semantics.BertScore`
- `metrics.semantics.BartScore`

* Add preprocessing and post-processing mechanism to ModelWrapper

We make use of 2 kwargs to any model wrapper:
- `inputs_preprocessor` (maps inputs to a specific format, defaults to
  identity)
- `predictions_postprocessor` (maps model outputs to a specific format,
  defaults to identity)

Also `models.HFPipelineWrapperForQuestionAnswering` is created.
`models.DefaultQAModelWrapper` is deprecated.

* Update code documentation for HFPipelineWrapper docstring

* Implement HF pipeline wrapper/evaluator for text classification

See `models.defaults.TextClassificationHFPipelineWrapper`.
Also improve the concstruction of hf pipeline object in existing
wrapper.

`evaluators.basics.TextClassificationEvaluator` is also added.

* Add per_instance_score flag to BertScore metric

This flag is used to return precision/recall/f1 score per prediction
instance.

* Default to bert-base-uncased for BertScore

* Bugfix tokenizer parameter in QuestionAnsweringHFPipelineWrapper and
TextClassificationHFPipelineWrapper

Previously, tokenizer was set to some defaults. However, that is
incorrect. We want tokenizer to be the one for which provided model was
trained on. So, now `tokenizer` is set to None by default.

* Add preliminary test cases for metrics

* Improve metric tests by caching metric result at class-level scope

* Ignore flake8 linting error for all test cases
  • Loading branch information
NISH1001 authored Mar 20, 2023
1 parent 6e3c4a6 commit f0bdd66
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: debug-statements
- id: name-tests-test
# - id: name-tests-test
# args: [--pytest-test-first]
- id: requirements-txt-fixer
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v2.2.0
Expand Down
Empty file added tests/metrics/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions tests/metrics/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# flake8: noqa
#!/usr/bin/env python3

import pytest

from .fixtures import predictions, references


class BaseMetricTest:
_metric_cls = None
_key = None

@pytest.fixture(autouse=True, scope="class")
def metric_result(self, predictions, references):
"""
This is used to cache the computation per class
"""
return self._metric_cls()(predictions=predictions, references=references)

def test_predictions_references_len(self, predictions, references):
"""
Test the length of input predictions and references
"""
assert len(predictions) == len(references)

def test_metric_return_type(self, metric_result):
"""
Check if return type of each metric is a dictionary
"""
assert isinstance(metric_result, dict)

def test_metric_return_keys(self, metric_result):
assert self._key in metric_result
assert "score" in metric_result[self._key]


def main():
pass


if __name__ == "__main__":
main()
24 changes: 24 additions & 0 deletions tests/metrics/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# flake8: noqa
#!/usr/bin/env python3

from typing import List

import pytest


@pytest.fixture(autouse=True, scope="module")
def references() -> List[str]:
return ["A", "B", "C", "D", "A"]


@pytest.fixture(autouse=True, scope="module")
def predictions() -> List[str]:
return ["A", "B", "C", "D", "B"]


def main():
pass


if __name__ == "__main__":
main()
57 changes: 57 additions & 0 deletions tests/metrics/test_basics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# flake8: noqa
#!/usr/bin/env python3

import numpy as np

from evalem.metrics import (
AccuracyMetric,
ConfusionMatrix,
F1Metric,
PrecisionMetric,
RecallMetric,
)

from ._base import BaseMetricTest, predictions, references


class TestAccuracyMetric(BaseMetricTest):
_metric_cls = AccuracyMetric
_key = "accuracy"

def test_metric_score(self, metric_result):
assert metric_result[self._key]["score"] >= 0


class TestF1Metric(BaseMetricTest):
_metric_cls = F1Metric
_key = "f1"

def test_metric_score(self, metric_result):
assert metric_result[self._key]["score"] >= 0


class TestPrecisionMetric(BaseMetricTest):
_metric_cls = PrecisionMetric
_key = "precision"

def test_metric_score(self, metric_result):
assert metric_result[self._key]["score"] >= 0


class TestRecallMetric(BaseMetricTest):
_metric_cls = RecallMetric
_key = "recall"

def test_metric_score(self, metric_result):
assert metric_result[self._key]["score"] >= 0


class TestConfusionMatrix(BaseMetricTest):
_metric_cls = ConfusionMatrix
_key = "confusion_matrix"

def test_metric_return_keys(self, metric_result):
assert self._key in metric_result

def test_metric_score(self, metric_result):
assert isinstance(metric_result[self._key], np.ndarray)
22 changes: 22 additions & 0 deletions tests/metrics/test_semantics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# flake8: noqa
#!/usr/bin/env python3

from evalem.metrics import BartScore, BertScore

from ._base import BaseMetricTest, predictions, references


class TestBertScore(BaseMetricTest):
_metric_cls = BertScore
_key = "bertscore"

def test_metric_score(self, metric_result):
assert -1 <= metric_result[self._key]["score"] <= 1


class TestBartScore(BaseMetricTest):
_metric_cls = BartScore
_key = "bartscore"

def test_metric_score(self, metric_result):
assert -10 <= metric_result[self._key]["score"] <= 10
File renamed without changes.

0 comments on commit f0bdd66

Please sign in to comment.