Skip to content

Commit

Permalink
Add validation for public function arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
teabolt committed Oct 13, 2019
1 parent 6cfd526 commit c94eb52
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 67 deletions.
13 changes: 4 additions & 9 deletions eli5/formatters/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import matplotlib.cm # type: ignore

from eli5.base import Explanation
from eli5.nn.gradcam import (
_validate_heatmap,
)


def format_as_image(expl, # type: Explanation
Expand Down Expand Up @@ -287,14 +290,6 @@ def _validate_image(image):
'Got: {}'.format(image))


def _validate_heatmap(heatmap):
# type: (np.ndarray) -> None
"""Check that ``heatmap`` has the right type."""
if not isinstance(heatmap, np.ndarray):
raise TypeError('heatmap must be a numpy.ndarray instance. '
'Got: {}'.format(heatmap))


def _needs_normalization(heatmap):
# type: (np.ndarray) -> bool
"""Return whether ``heatmap`` values are in the interval [0, 1]."""
Expand All @@ -311,4 +306,4 @@ def _normalize_heatmap(h, epsilon=1e-07):
# https://datascience.stackexchange.com/questions/5885/how-to-scale-an-array-of-signed-integers-to-range-from-0-to-1
# add eps to avoid division by zero in case heatmap is all 0's
# this also means that lmap max will be slightly less than the 'true' max
return (h - h.min()) / (h.max() - h.min() + epsilon)
return (h - h.min()) / (h.max() - h.min() + epsilon)
2 changes: 1 addition & 1 deletion eli5/keras/explain_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def _validate_params(model, # type: Model
doc, # type: np.ndarray
):
# type: (...) -> None
"""Helper for validating all explanation function parameters."""
"""Helper for validating explanation function parameters."""
_validate_model(model)
_validate_doc(doc)

Expand Down
11 changes: 10 additions & 1 deletion eli5/nn/gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,13 @@ def _validate_classification_target(target, output_shape):
if not (0 <= target < output_nodes):
raise ValueError('Prediction target index is '
'outside the required range [0, {}). ',
'Got {}'.format(output_nodes, target))
'Got {}'.format(output_nodes, target))


def _validate_heatmap(heatmap):
# type: (np.ndarray) -> None
"""Utility function to check that the ``heatmap``
argument has the right type."""
if not isinstance(heatmap, np.ndarray):
raise TypeError('heatmap must be a numpy.ndarray instance. '
'Got: "{}" (type "{}").'.format(heatmap, type(heatmap)))
91 changes: 66 additions & 25 deletions eli5/nn/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
WeightedSpans,
DocWeightedSpans,
)
from eli5.nn.gradcam import (
_validate_heatmap,
)


def gradcam_spans(heatmap, # type: np.ndarray
Expand All @@ -30,9 +33,16 @@ def gradcam_spans(heatmap, # type: np.ndarray
**Should be rank 1 (no batch dimension).**
:raises TypeError: if ``heatmap`` is wrong type.
tokens : numpy.ndarray or list
Tokens that will be highlighted using weights from ``heatmap``.
:raises TypeError: if ``tokens`` is wrong type.
:raises ValueError: if ``tokens`` contents are unexpected.
doc: numpy.ndarray
Original input to the network, from which ``heatmap`` was created.
Expand All @@ -53,9 +63,14 @@ def gradcam_spans(heatmap, # type: np.ndarray
``tokens`` and ``heatmap`` optionally cut from padding.
A :class:`eli5.base.WeightedSpans` object with a weight for each token.
"""
# FIXME: might want to do this when formatting the explanation?
# We call this before returning the explanation, NOT when formatting the explanation
# Because WeightedSpans, etc are attributes of a returned explanation
# TODO: might want to add validation for heatmap and other arguments?
_validate_tokens(doc, tokens)
_validate_tokens(tokens)
_validate_tokens_value(tokens, doc)
if isinstance(tokens, list):
# convert to a common data type
tokens = np.array(tokens)

length = len(tokens)
heatmap = resize_1d(heatmap, length, interpolation_kind=interpolation_kind)
Expand All @@ -64,14 +79,16 @@ def gradcam_spans(heatmap, # type: np.ndarray
if pad_value is not None or pad_token is not None:
# remove padding
pad_indices = _find_padding(pad_value=pad_value, pad_token=pad_token, doc=doc, tokens=tokens)
# If pad_value is not the actual padding value, behaviour is unknown
# If passed padding argument is not the actual padding token/value, behaviour is unknown
tokens, heatmap = _trim_padding(pad_indices, tokens, heatmap)

document = _construct_document(tokens)
spans = _build_spans(tokens, heatmap, document)
weighted_spans = WeightedSpans([
DocWeightedSpans(document, spans=spans)
]) # why list? - for each vectorized - don't need multiple vectorizers?
# multiple highlights? - could do positive and negative expl?
])
# why do we have a list of WeightedSpans? One for each vectorizer?
# But we do not use multiple vectorizers?
return tokens, heatmap, weighted_spans


Expand All @@ -89,6 +106,9 @@ def resize_1d(heatmap, length, interpolation_kind='linear'):
heatmap : numpy.ndarray
Heatmap to be resized.
:raises TypeError: if ``heatmap`` is wrong type.
length : int
Required width.
Expand All @@ -104,6 +124,8 @@ def resize_1d(heatmap, length, interpolation_kind='linear'):
heatmap : numpy.ndarray
The heatmap resized.
"""
_validate_heatmap(heatmap)
_validate_length(length)
if len(heatmap.shape) == 1 and heatmap.shape[0] == 1:
# single weight, no batch
heatmap = heatmap.repeat(length)
Expand Down Expand Up @@ -146,7 +168,7 @@ def _build_spans(tokens, # type: Union[np.ndarray, list]

def _construct_document(tokens):
# type: (Union[list, np.ndarray]) -> str
"""Create a document string by joining ``tokens``."""
"""Create a document string by joining ``tokens`` sequence."""
if _is_character_tokenization(tokens):
sep = ''
else:
Expand All @@ -156,10 +178,7 @@ def _construct_document(tokens):

def _is_character_tokenization(tokens):
# type: (Union[list, np.ndarray]) -> bool
"""
Check whether tokenization is character-level
(returns True) or word-level (returns False).
"""
"""Check whether tokenization is character-level (True) or word-level (False)."""
return any(' ' in t for t in tokens)


Expand All @@ -180,27 +199,27 @@ def _find_padding(pad_value=None, # type: Union[int, float]
else:
raise TypeError('Pass "doc" and "pad_value", '
'or "tokens" and "pad_token".')
# TODO: warn if indices is empty - passed wrong padding char/value?


def _find_padding_values(pad_value, doc):
# type: (Union[int, float], np.ndarray) -> np.ndarray
if not isinstance(pad_value, (int, float)):
raise TypeError('"pad_value" must be int or float. Got "{}"'.format(type(pad_value)))
_validate_doc(doc)
values, indices = np.where(doc == pad_value)
return indices


def _find_padding_tokens(pad_token, tokens):
# type: (str, Union[list, np.ndarray]) -> np.ndarray
# type: (str, np.ndarray) -> np.ndarray
if not isinstance(pad_token, str):
raise TypeError('"pad_token" must be str. Got "{}"'.format(type(pad_token)))
indices = [idx for idx, token in enumerate(tokens) if token == pad_token]
return np.array(indices)
indices = np.where(tokens == pad_token)
return indices


def _trim_padding(pad_indices, # type: np.ndarray
tokens, # type: Union[list, np.ndarray]
tokens, # type: np.ndarray
heatmap, # type: np.ndarray
):
# type: (...) -> Tuple[Union[list, np.ndarray], np.ndarray]
Expand All @@ -217,37 +236,59 @@ def _trim_padding(pad_indices, # type: np.ndarray
return tokens, heatmap


def _validate_doc(doc):
if not isinstance(doc, np.ndarray):
raise TypeError('"doc" must be an instance of numpy.ndarray. '
'Got "{}" (type "{}")'.format(doc, type(doc)))


def _validate_length(length):
if not isinstance(length, int):
raise TypeError('"length" must be an integer. Got "{}" '
'(type "{}")'.format(length, type(length)))
if length < 0:
raise ValueError('"length" must be a non-negative integer. '
'Got "{}"'.format(length))


# TODO:
# docs for raises in here
# coverage tests for new validation


# FIXME: break this function up
def _validate_tokens(doc, tokens):
# type: (np.ndarray, Union[np.ndarray, list]) -> None
def _validate_tokens(tokens):
# type: (Union[np.ndarray, list]) -> None
"""Check that ``tokens`` contains correct items and matches ``doc``."""
if not isinstance(tokens, (list, np.ndarray)):
# wrong type
raise TypeError('"tokens" must be list or numpy.ndarray. '
'Got "{}".'.format(tokens))

batch_size, doc_len = doc.shape[0], doc.shape[1]
if len(tokens) == 0:
# empty list
raise ValueError('"tokens" is empty: {}'.format(tokens))


def _validate_tokens_value(tokens, doc):
# type: (Union[np.ndarray, list], np.ndarray) -> None
doc_batch, doc_len = doc.shape[0], doc.shape[1]
an_entry = tokens[0]
if isinstance(an_entry, str):
# no batch
if batch_size != 1:
if doc_batch != 1:
# doc is batched but tokens is not
raise ValueError('If passing "tokens" without batch dimension, '
'"doc" must have batch size = 1.'
'Got "doc" with batch size = %d.' % batch_size)
'Got "doc" with batch size = %d.' % doc_batch)
tokens_len = len(tokens)
elif isinstance(an_entry, (list, np.ndarray)):
# batched
tokens_batch_size = len(tokens)
if tokens_batch_size != batch_size:
tokens_batch = len(tokens)
if tokens_batch != doc_batch:
# batch lengths do not match
raise ValueError('"tokens" must have same number of samples '
'as in doc batch. Got: "tokens" samples: %d, '
'doc samples: %d' % (tokens_batch_size, batch_size))
'doc samples: %d' % (tokens_batch, doc_batch))

a_token = an_entry[0]
if not isinstance(a_token, str):
Expand All @@ -260,7 +301,7 @@ def _validate_tokens(doc, tokens):
it = iter(tokens)
the_len = len(next(it))
if not all(len(l) == the_len for l in it):
raise ValueError('"tokens" samples do not have the same length.')
raise ValueError('"tokens" samples do not all have the same length.')
tokens_len = the_len
else:
raise TypeError('"tokens" must be an array of strings, '
Expand Down
7 changes: 0 additions & 7 deletions tests/test_formatters_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_cap_alpha,
_overlay_heatmap,
_validate_image,
_validate_heatmap,
)
from .utils_image import assert_pixel_by_pixel_equal
import eli5
Expand Down Expand Up @@ -153,12 +152,6 @@ def test_validate_image():
_validate_image(np.zeros((2, 2, 4,)))


def test_validate_heatmap():
with pytest.raises(TypeError):
# heatmap must be a numpy array, not a Pillow image
_validate_heatmap(PIL.Image.new('L', (2, 2,)))


def test_format_as_image_notransparency(catdog_rgba):
# heatmap with full transparency
expl = Explanation('mock',
Expand Down
7 changes: 3 additions & 4 deletions tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
_autoget_layer_text,
)
from eli5.keras.gradcam import (
_autoget_target_prediction,
_calc_gradient,
)

Expand Down Expand Up @@ -188,12 +187,12 @@ def test_calc_gradient(differentiable_model):

def test_calc_gradient_nondifferentiable(nondifferentiable_model):
with pytest.raises(ValueError):
grads = _calc_gradient(nondifferentiable_model.output,
[nondifferentiable_model.input])
_calc_gradient(nondifferentiable_model.output,
[nondifferentiable_model.input])



# TODO: test_autoget_target_prediction with multiple maximum values, etc
# TODO: test chossing multiple target from multiple maximum values, etc


def test_import():
Expand Down
13 changes: 12 additions & 1 deletion tests/test_nn_gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import pytest
import numpy as np
PIL = pytest.importorskip('PIL')

from eli5.nn.gradcam import (
gradcam_heatmap,
_validate_targets,
_validate_classification_target,
_validate_heatmap,
)


Expand Down Expand Up @@ -70,4 +72,13 @@ def test_validate_classification_target():
_validate_classification_target(2, (1, 2,))
with pytest.raises(ValueError):
# one less
_validate_classification_target(-1, (1, 1,))
_validate_classification_target(-1, (1, 1,))


def test_validate_heatmap():
with pytest.raises(TypeError):
# heatmap must be a numpy array, not a Pillow image
_validate_heatmap(PIL.Image.new('L', (2, 2,)))
with pytest.raises(TypeError):
# heatmap must not be a Python list
_validate_heatmap([2, 3])
Loading

0 comments on commit c94eb52

Please sign in to comment.