-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Addition of unit tests for metrics (#9)
* Add barebone unit test for jury format conversion for references * Add more tests for jury format conversion * Fix test function name * Add misc datasets loader module See `evalem.misc.datasets`. We have `datasets.get_squad_v2(...)` function. * Add device parameter to DefaultQAWrapper * Update __init__ files with respective module imports * Add metrics for computing semantics similarity Now we have `metrics.semantics.SemanticMetric`. There are 2 implementation for now: - `metrics.semantics.BertScore` - `metrics.semantics.BartScore` * Add preprocessing and post-processing mechanism to ModelWrapper We make use of 2 kwargs to any model wrapper: - `inputs_preprocessor` (maps inputs to a specific format, defaults to identity) - `predictions_postprocessor` (maps model outputs to a specific format, defaults to identity) Also `models.HFPipelineWrapperForQuestionAnswering` is created. `models.DefaultQAModelWrapper` is deprecated. * Update code documentation for HFPipelineWrapper docstring * Implement HF pipeline wrapper/evaluator for text classification See `models.defaults.TextClassificationHFPipelineWrapper`. Also improve the concstruction of hf pipeline object in existing wrapper. `evaluators.basics.TextClassificationEvaluator` is also added. * Add per_instance_score flag to BertScore metric This flag is used to return precision/recall/f1 score per prediction instance. * Default to bert-base-uncased for BertScore * Bugfix tokenizer parameter in QuestionAnsweringHFPipelineWrapper and TextClassificationHFPipelineWrapper Previously, tokenizer was set to some defaults. However, that is incorrect. We want tokenizer to be the one for which provided model was trained on. So, now `tokenizer` is set to None by default. * Add preliminary test cases for metrics * Improve metric tests by caching metric result at class-level scope * Ignore flake8 linting error for all test cases
- Loading branch information
Showing
7 changed files
with
147 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# flake8: noqa | ||
#!/usr/bin/env python3 | ||
|
||
import pytest | ||
|
||
from .fixtures import predictions, references | ||
|
||
|
||
class BaseMetricTest: | ||
_metric_cls = None | ||
_key = None | ||
|
||
@pytest.fixture(autouse=True, scope="class") | ||
def metric_result(self, predictions, references): | ||
""" | ||
This is used to cache the computation per class | ||
""" | ||
return self._metric_cls()(predictions=predictions, references=references) | ||
|
||
def test_predictions_references_len(self, predictions, references): | ||
""" | ||
Test the length of input predictions and references | ||
""" | ||
assert len(predictions) == len(references) | ||
|
||
def test_metric_return_type(self, metric_result): | ||
""" | ||
Check if return type of each metric is a dictionary | ||
""" | ||
assert isinstance(metric_result, dict) | ||
|
||
def test_metric_return_keys(self, metric_result): | ||
assert self._key in metric_result | ||
assert "score" in metric_result[self._key] | ||
|
||
|
||
def main(): | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# flake8: noqa | ||
#!/usr/bin/env python3 | ||
|
||
from typing import List | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture(autouse=True, scope="module") | ||
def references() -> List[str]: | ||
return ["A", "B", "C", "D", "A"] | ||
|
||
|
||
@pytest.fixture(autouse=True, scope="module") | ||
def predictions() -> List[str]: | ||
return ["A", "B", "C", "D", "B"] | ||
|
||
|
||
def main(): | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# flake8: noqa | ||
#!/usr/bin/env python3 | ||
|
||
import numpy as np | ||
|
||
from evalem.metrics import ( | ||
AccuracyMetric, | ||
ConfusionMatrix, | ||
F1Metric, | ||
PrecisionMetric, | ||
RecallMetric, | ||
) | ||
|
||
from ._base import BaseMetricTest, predictions, references | ||
|
||
|
||
class TestAccuracyMetric(BaseMetricTest): | ||
_metric_cls = AccuracyMetric | ||
_key = "accuracy" | ||
|
||
def test_metric_score(self, metric_result): | ||
assert metric_result[self._key]["score"] >= 0 | ||
|
||
|
||
class TestF1Metric(BaseMetricTest): | ||
_metric_cls = F1Metric | ||
_key = "f1" | ||
|
||
def test_metric_score(self, metric_result): | ||
assert metric_result[self._key]["score"] >= 0 | ||
|
||
|
||
class TestPrecisionMetric(BaseMetricTest): | ||
_metric_cls = PrecisionMetric | ||
_key = "precision" | ||
|
||
def test_metric_score(self, metric_result): | ||
assert metric_result[self._key]["score"] >= 0 | ||
|
||
|
||
class TestRecallMetric(BaseMetricTest): | ||
_metric_cls = RecallMetric | ||
_key = "recall" | ||
|
||
def test_metric_score(self, metric_result): | ||
assert metric_result[self._key]["score"] >= 0 | ||
|
||
|
||
class TestConfusionMatrix(BaseMetricTest): | ||
_metric_cls = ConfusionMatrix | ||
_key = "confusion_matrix" | ||
|
||
def test_metric_return_keys(self, metric_result): | ||
assert self._key in metric_result | ||
|
||
def test_metric_score(self, metric_result): | ||
assert isinstance(metric_result[self._key], np.ndarray) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# flake8: noqa | ||
#!/usr/bin/env python3 | ||
|
||
from evalem.metrics import BartScore, BertScore | ||
|
||
from ._base import BaseMetricTest, predictions, references | ||
|
||
|
||
class TestBertScore(BaseMetricTest): | ||
_metric_cls = BertScore | ||
_key = "bertscore" | ||
|
||
def test_metric_score(self, metric_result): | ||
assert -1 <= metric_result[self._key]["score"] <= 1 | ||
|
||
|
||
class TestBartScore(BaseMetricTest): | ||
_metric_cls = BartScore | ||
_key = "bartscore" | ||
|
||
def test_metric_score(self, metric_result): | ||
assert -10 <= metric_result[self._key]["score"] <= 10 |
File renamed without changes.