-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[alpha] Improvements to ModelWrapper and better QA/Classification imp…
…lementation (#8) * Add barebone unit test for jury format conversion for references * Add more tests for jury format conversion * Fix test function name * Add misc datasets loader module See `evalem.misc.datasets`. We have `datasets.get_squad_v2(...)` function. * Add device parameter to DefaultQAWrapper * Update __init__ files with respective module imports * Add metrics for computing semantics similarity Now we have `metrics.semantics.SemanticMetric`. There are 2 implementation for now: - `metrics.semantics.BertScore` - `metrics.semantics.BartScore` * Add preprocessing and post-processing mechanism to ModelWrapper We make use of 2 kwargs to any model wrapper: - `inputs_preprocessor` (maps inputs to a specific format, defaults to identity) - `predictions_postprocessor` (maps model outputs to a specific format, defaults to identity) Also `models.HFPipelineWrapperForQuestionAnswering` is created. `models.DefaultQAModelWrapper` is deprecated. * Update code documentation for HFPipelineWrapper docstring * Implement HF pipeline wrapper/evaluator for text classification See `models.defaults.TextClassificationHFPipelineWrapper`. Also improve the concstruction of hf pipeline object in existing wrapper. `evaluators.basics.TextClassificationEvaluator` is also added. * Add per_instance_score flag to BertScore metric This flag is used to return precision/recall/f1 score per prediction instance. * Default to bert-base-uncased for BertScore * Bugfix tokenizer parameter in QuestionAnsweringHFPipelineWrapper and TextClassificationHFPipelineWrapper Previously, tokenizer was set to some defaults. However, that is incorrect. We want tokenizer to be the one for which provided model was trained on. So, now `tokenizer` is set to None by default. * Change BertScore's per_instance_score behaviour to compute mean
- Loading branch information
Showing
9 changed files
with
284 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
# flake8: noqa | ||
from ._base import Evaluator | ||
from .basics import QAEvaluator, TextClassificationEvaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# flake8: noqa | ||
from ._base import HFLMWrapper, HFPipelineWrapper, ModelWrapper | ||
from .defaults import DefaultQAModelWrapper | ||
from .defaults import ( | ||
DefaultQAModelWrapper, | ||
QuestionAnsweringHFPipelineWrapper, | ||
TextClassificationHFPipelineWrapper, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.