Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain predictions of Keras image classifiers (Grad-CAM) #315

Merged
merged 165 commits into from
Aug 5, 2019

Conversation

teabolt
Copy link
Contributor

@teabolt teabolt commented Jun 10, 2019

This PR adds explanations for Keras models that are used to classify images. Specifically we implement Grad-CAM.

For example, the following piece of code:

import keras
import numpy as np
import eli5

# load model
xception = keras.applications.xception.Xception(include_top=True, weights='imagenet', classes=1000)

# load image
im = keras.preprocessing.image.load_img('../eli5_examples/motorcycle.jpg', target_size=(299, 299))
doc = keras.preprocessing.image.img_to_array(im)
doc = np.expand_dims(doc, axis=0)
keras.applications.xception.preprocess_input(doc)

# explain
eli5.show_prediction(xception, doc)

produces this explanation for the class 'motor scooter':
motorcycle_pr_2

The following features are added:

  • Add keras package with explain_prediction_keras() and formatters.image module with format_as_image() (requires matplotlib and Pillow).
  • Add .image attribute to base.Explanation and .heatmap to base.TargetExplanation.
  • Make ipython.show_prediction() dispatch to an image display function for image explanations.

TODO items before this PR is finalized:

  • Resolve reviews
  • Coverage
  • Pass CI
  • Mypy type annotations
  • Docs (formatting, tutorial)
  • Integration and unit tests

teabolt added 30 commits June 1, 2019 14:16
…ications preprocess_input. Fix using callable to find target layer
…ain_prediction via approximate attention over area
@@ -13,10 +13,11 @@


def get_weighted_spans(doc, vec, feature_weights):
# type: (Any, Any, FeatureWeights) -> Optional[WeightedSpans]
# type: (Any, Any, Union[FeatureWeights, None]) -> Optional[WeightedSpans]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this function requires FeatureWeights to be not None, so maybe it makes sense to keep type signature the same, but move assert to the caller code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the assert to add_weighted_spans. Good one. Strange that https://github.com/TeamHG-Memex/eli5/search?q=get_weighted_spans did not show add_weighted_spans's call to get_weighted_spans.

913d415



def _get_target_prediction(targets, estimator):
# type: (Union[None, list], Model) -> K.variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional[List]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I was thinking whether to use this. My thinking was that targets is not actually optional in the parameter list so it might be confusing. But I see that the rest of the library uses Optional (i.e. for functions that could return None) so I will change it.

e8a34f1

eli5/keras/gradcam.py Outdated Show resolved Hide resolved
a valid keras layer name, layer index, or an instance of a Keras layer.

If None, a suitable layer is attempted to be retrieved.
See :func:`eli5.keras._search_layer_backwards` for details.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we're documenting a function, it makes sense to make it public - or just document the behavior without mentioning a function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Forgot that I made the function private.

bcaf7ca

An input image as a tensor to ``estimator``,
from which prediction will be done and explained.

For example a ``numpy.ndarray``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there other supported data types, why "for example"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will explicitly mention that numpy arrays are required.

I think it is possible to use other input types (https://github.com/keras-team/keras/blob/ed07472bc5fc985982db355135d37059a1f887a9/keras/engine/training.py#L1315), i.e. tensorflow tensor. However, I haven't tested with other types and I think I have some numpy dependencies in my code. Adding more input types could be a separate GitHub issue?

9d2d22a

@teabolt
Copy link
Contributor Author

teabolt commented Jul 12, 2019

Currently the parameter resampling_filter (previously called interpolation) of eli5.format_as_image() takes an integer from https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters, i.e. the user passes something like resampling_filter=PIL.Image.BOX. It could be clearer to take the filter as a string, i.e. let the user say resampling_filter="BOX"?

@lopuhin
Copy link
Contributor

lopuhin commented Jul 29, 2019

resampling_filter=PIL.Image.BOX. It could be clearer to take the filter as a string, i.e. let the user say resampling_filter="BOX"?

I think both options are fine, to me a constant looks a bit better than a string, and I think it's fine to use a PIL constant here.

Copy link
Contributor

@lopuhin lopuhin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@teabolt sorry for a long review - there is just one minor thing which is preventing the merge now, updating the some docs after some attributes were moved to TargetExplanation:

eli5/keras/explain_prediction.py Outdated Show resolved Hide resolved
docs/source/libraries/keras.rst Show resolved Hide resolved
eli5/formatters/image.py Show resolved Hide resolved
eli5/formatters/image.py Outdated Show resolved Hide resolved
Copy link
Contributor

@lopuhin lopuhin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks @teabolt 👍

@lopuhin
Copy link
Contributor

lopuhin commented Aug 5, 2019

Thanks for a great new feature @teabolt , and thanks for review @kmike , merging 🎉

@@ -45,6 +45,9 @@ following machine learning frameworks and packages:
* :ref:`library-sklearn-crfsuite`. ELI5 allows to check weights of
sklearn_crfsuite.CRF models.

* :ref:`library-keras` - explain predictions of image classifiers
via Grad-CAM visualizations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for a late comment: could you please copy overview.rst changes to README file in the repo root?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do that in #329. Thanks

@kmike
Copy link
Contributor

kmike commented Aug 5, 2019

Thanks @teabolt and @lopuhin, great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants