Skip to content

Commit

Permalink
Merge branch 'keras-gradcam-text' of github.com:teabolt/eli5 into pyt…
Browse files Browse the repository at this point in the history
…orch-gradcam
  • Loading branch information
teabolt committed Jun 6, 2020
2 parents 9c8d7c6 + ec0f51c commit 02d1ea3
Show file tree
Hide file tree
Showing 54 changed files with 464 additions and 345 deletions.
12 changes: 12 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
Changelog
=========

0.10.1 (2019-08-29)
-------------------

* Don't include typing dependency on Python 3.5+
to fix installation on Python 3.7

0.10.0 (2019-08-21)
-------------------

* Keras image classifiers: explaining predictions with Grad-CAM
(GSoC-2019 project by @teabolt).

0.9.0 (2019-07-05)
------------------

Expand Down
7 changes: 5 additions & 2 deletions eli5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import

__version__ = '0.9.0'
__version__ = '0.10.1'

from .formatters import (
format_as_html,
Expand Down Expand Up @@ -96,6 +96,7 @@
except ImportError:
# keras is not available
pass
<<<<<<< HEAD


try:
Expand All @@ -104,4 +105,6 @@
)
except ImportError:
# pytorch is not available
pass
pass
=======
>>>>>>> ec0f51c60aaf360327ca18e3e0cdae2222cec6bf
4 changes: 2 additions & 2 deletions eli5/_feature_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
Union, Callable, Pattern
)

import numpy as np # type: ignore
import scipy.sparse as sp # type: ignore
import numpy as np
import scipy.sparse as sp


class FeatureNames(Sized, Iterable):
Expand Down
2 changes: 1 addition & 1 deletion eli5/_feature_weights.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import

import numpy as np # type: ignore
import numpy as np

from eli5.base import FeatureWeights, FeatureWeight
from .utils import argsort_k_largest_positive, argsort_k_smallest, mask
Expand Down
2 changes: 1 addition & 1 deletion eli5/_graphviz.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
import graphviz # type: ignore
import graphviz


def is_supported():
Expand Down
2 changes: 1 addition & 1 deletion eli5/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from typing import Any, List, Tuple, Union, Optional

import numpy as np # type: ignore
import numpy as np

from .base_utils import attrs
from .formatters.features import FormattedFeatureName
Expand Down
4 changes: 2 additions & 2 deletions eli5/base_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import inspect

import attr # type: ignore
import attr

try:
from functools import singledispatch # type: ignore
from functools import singledispatch
except ImportError:
from singledispatch import singledispatch # type: ignore

Expand Down
4 changes: 2 additions & 2 deletions eli5/catboost.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import absolute_import, division

import numpy as np # type: ignore
import catboost # type: ignore
import numpy as np
import catboost

from eli5.explain import explain_weights
from eli5._feature_importances import get_feature_importance_explanation
Expand Down
2 changes: 1 addition & 1 deletion eli5/formatters/as_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional
import warnings

import pandas as pd # type: ignore
import pandas as pd

import eli5
from eli5.base import (
Expand Down
4 changes: 2 additions & 2 deletions eli5/formatters/as_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import six

import attr # type: ignore
import numpy as np # type: ignore
import attr
import numpy as np

from .features import FormattedFeatureName

Expand Down
4 changes: 2 additions & 2 deletions eli5/formatters/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from itertools import groupby
from typing import List, Optional, Tuple

import numpy as np # type: ignore
from jinja2 import Environment, PackageLoader # type: ignore
import numpy as np
from jinja2 import Environment, PackageLoader

from eli5 import _graphviz
from eli5.base import (Explanation, TargetExplanation, FeatureWeights,
Expand Down
19 changes: 7 additions & 12 deletions eli5/formatters/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
from __future__ import absolute_import
from typing import Union, Optional, Callable

import numpy as np # type: ignore
from PIL import Image # type: ignore
import matplotlib.cm # type: ignore
import numpy as np
from PIL import Image
import matplotlib.cm

from eli5.base import Explanation
from eli5.nn.gradcam import (
_validate_heatmap,
)


def format_as_image(expl, # type: Explanation
Expand Down Expand Up @@ -287,14 +290,6 @@ def _validate_image(image):
'Got: {}'.format(image))


def _validate_heatmap(heatmap):
# type: (np.ndarray) -> None
"""Check that ``heatmap`` has the right type."""
if not isinstance(heatmap, np.ndarray):
raise TypeError('heatmap must be a numpy.ndarray instance. '
'Got: {}'.format(heatmap))


def _needs_normalization(heatmap):
# type: (np.ndarray) -> bool
"""Return whether ``heatmap`` values are in the interval [0, 1]."""
Expand All @@ -311,4 +306,4 @@ def _normalize_heatmap(h, epsilon=1e-07):
# https://datascience.stackexchange.com/questions/5885/how-to-scale-an-array-of-signed-integers-to-range-from-0-to-1
# add eps to avoid division by zero in case heatmap is all 0's
# this also means that lmap max will be slightly less than the 'true' max
return (h - h.min()) / (h.max() - h.min() + epsilon)
return (h - h.min()) / (h.max() - h.min() + epsilon)
7 changes: 4 additions & 3 deletions eli5/formatters/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
from __future__ import absolute_import
from itertools import chain
import six
from tabulate import tabulate
from typing import List, Optional, Iterator

from eli5.base import Explanation, FeatureImportances
from . import fields
from .features import FormattedFeatureName
from .utils import (
format_signed, format_value, format_weight, has_any_values_for_weights,
replace_spaces, should_highlight_spaces, tabulate)
replace_spaces, should_highlight_spaces)
from .utils import tabulate as eli5_tabulate
from .trees import tree2text


Expand Down Expand Up @@ -153,7 +155,6 @@ def _decision_tree_lines(explanation):

def _transition_features_lines(explanation):
# type: (Explanation) -> List[str]
from tabulate import tabulate # type: ignore
tf = explanation.transition_features
assert tf is not None
return [
Expand Down Expand Up @@ -203,7 +204,7 @@ def _targets_lines(explanation, # type: Explanation

w = target.feature_weights
assert w is not None
table = tabulate(
table = eli5_tabulate(
[table_line(fw) for fw in chain(w.pos, reversed(w.neg))],
header=table_header,
col_align=col_align,
Expand Down
2 changes: 1 addition & 1 deletion eli5/formatters/text_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import Counter
from typing import List, Optional

import numpy as np # type: ignore
import numpy as np

from eli5.base import TargetExplanation, WeightedSpans, DocWeightedSpans
from eli5.base_utils import attrs
Expand Down
2 changes: 2 additions & 0 deletions eli5/formatters/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def p(*args):
else:
assert node.left is not None
assert node.right is not None
assert node.threshold is not None

feat_name = node.feature_name

if depth > 0:
Expand Down
7 changes: 3 additions & 4 deletions eli5/formatters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from itertools import chain
import re
import six
from numbers import Real
from typing import Any, Union, List, Dict, Callable, Match, Optional

import numpy as np # type: ignore
import numpy as np

from eli5.base import Explanation
from .features import FormattedFeatureName
Expand Down Expand Up @@ -143,12 +142,12 @@ def tabulate(data, # type: List[List[Any]]


def format_weight(value):
# type: (Real) -> str
# type: (float) -> str
return '{:+.3f}'.format(value)


def format_value(value):
# type: (Optional[Real]) -> str
# type: (Optional[float]) -> str
if value is None:
return ''
elif np.isnan(value):
Expand Down
4 changes: 2 additions & 2 deletions eli5/ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from typing import Any, Dict, Tuple
import warnings

from IPython.display import HTML, Image # type: ignore
from IPython.display import HTML, Image

from .explain import explain_weights, explain_prediction
from .formatters import format_as_html, fields
try:
from .formatters.image import format_as_image
except ImportError as e:
# missing dependencies
format_as_image = e # type: ignore
format_as_image = e # type: ignore


FORMAT_KWARGS = {'include_styles', 'force_weights',
Expand Down
Loading

0 comments on commit 02d1ea3

Please sign in to comment.