diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ff2d195..2c55b5b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,8 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: debug-statements - - id: name-tests-test + # - id: name-tests-test + # args: [--pytest-test-first] - id: requirements-txt-fixer - repo: https://github.com/asottile/setup-cfg-fmt rev: v2.2.0 diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/metrics/_base.py b/tests/metrics/_base.py new file mode 100755 index 0000000..a51a0c8 --- /dev/null +++ b/tests/metrics/_base.py @@ -0,0 +1,42 @@ +# flake8: noqa +#!/usr/bin/env python3 + +import pytest + +from .fixtures import predictions, references + + +class BaseMetricTest: + _metric_cls = None + _key = None + + @pytest.fixture(autouse=True, scope="class") + def metric_result(self, predictions, references): + """ + This is used to cache the computation per class + """ + return self._metric_cls()(predictions=predictions, references=references) + + def test_predictions_references_len(self, predictions, references): + """ + Test the length of input predictions and references + """ + assert len(predictions) == len(references) + + def test_metric_return_type(self, metric_result): + """ + Check if return type of each metric is a dictionary + """ + assert isinstance(metric_result, dict) + + def test_metric_return_keys(self, metric_result): + assert self._key in metric_result + assert "score" in metric_result[self._key] + + +def main(): + pass + + +if __name__ == "__main__": + main() diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py new file mode 100755 index 0000000..d7cd440 --- /dev/null +++ b/tests/metrics/fixtures.py @@ -0,0 +1,24 @@ +# flake8: noqa +#!/usr/bin/env python3 + +from typing import List + +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def references() -> List[str]: + return ["A", "B", "C", "D", "A"] + + +@pytest.fixture(autouse=True, scope="module") +def predictions() -> List[str]: + return ["A", "B", "C", "D", "B"] + + +def main(): + pass + + +if __name__ == "__main__": + main() diff --git a/tests/metrics/test_basics.py b/tests/metrics/test_basics.py new file mode 100755 index 0000000..64e1ca1 --- /dev/null +++ b/tests/metrics/test_basics.py @@ -0,0 +1,57 @@ +# flake8: noqa +#!/usr/bin/env python3 + +import numpy as np + +from evalem.metrics import ( + AccuracyMetric, + ConfusionMatrix, + F1Metric, + PrecisionMetric, + RecallMetric, +) + +from ._base import BaseMetricTest, predictions, references + + +class TestAccuracyMetric(BaseMetricTest): + _metric_cls = AccuracyMetric + _key = "accuracy" + + def test_metric_score(self, metric_result): + assert metric_result[self._key]["score"] >= 0 + + +class TestF1Metric(BaseMetricTest): + _metric_cls = F1Metric + _key = "f1" + + def test_metric_score(self, metric_result): + assert metric_result[self._key]["score"] >= 0 + + +class TestPrecisionMetric(BaseMetricTest): + _metric_cls = PrecisionMetric + _key = "precision" + + def test_metric_score(self, metric_result): + assert metric_result[self._key]["score"] >= 0 + + +class TestRecallMetric(BaseMetricTest): + _metric_cls = RecallMetric + _key = "recall" + + def test_metric_score(self, metric_result): + assert metric_result[self._key]["score"] >= 0 + + +class TestConfusionMatrix(BaseMetricTest): + _metric_cls = ConfusionMatrix + _key = "confusion_matrix" + + def test_metric_return_keys(self, metric_result): + assert self._key in metric_result + + def test_metric_score(self, metric_result): + assert isinstance(metric_result[self._key], np.ndarray) diff --git a/tests/metrics/test_semantics.py b/tests/metrics/test_semantics.py new file mode 100755 index 0000000..96a6db1 --- /dev/null +++ b/tests/metrics/test_semantics.py @@ -0,0 +1,22 @@ +# flake8: noqa +#!/usr/bin/env python3 + +from evalem.metrics import BartScore, BertScore + +from ._base import BaseMetricTest, predictions, references + + +class TestBertScore(BaseMetricTest): + _metric_cls = BertScore + _key = "bertscore" + + def test_metric_score(self, metric_result): + assert -1 <= metric_result[self._key]["score"] <= 1 + + +class TestBartScore(BaseMetricTest): + _metric_cls = BartScore + _key = "bartscore" + + def test_metric_score(self, metric_result): + assert -10 <= metric_result[self._key]["score"] <= 10 diff --git a/tests/jury_format_test.py b/tests/test_jury_format.py similarity index 100% rename from tests/jury_format_test.py rename to tests/test_jury_format.py