-
Notifications
You must be signed in to change notification settings - Fork 332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explain predictions of Keras image classifiers (Grad-CAM) #315
Conversation
…ications preprocess_input. Fix using callable to find target layer
…ain_prediction via approximate attention over area
… be a Layer instance)
Co-Authored-By: Mikhail Korobov <[email protected]>
eli5/sklearn/text.py
Outdated
@@ -13,10 +13,11 @@ | |||
|
|||
|
|||
def get_weighted_spans(doc, vec, feature_weights): | |||
# type: (Any, Any, FeatureWeights) -> Optional[WeightedSpans] | |||
# type: (Any, Any, Union[FeatureWeights, None]) -> Optional[WeightedSpans] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems this function requires FeatureWeights to be not None, so maybe it makes sense to keep type signature the same, but move assert to the caller code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the assert to add_weighted_spans
. Good one. Strange that https://github.com/TeamHG-Memex/eli5/search?q=get_weighted_spans did not show add_weighted_spans
's call to get_weighted_spans
.
eli5/keras/gradcam.py
Outdated
|
||
|
||
def _get_target_prediction(targets, estimator): | ||
# type: (Union[None, list], Model) -> K.variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional[List]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I was thinking whether to use this. My thinking was that targets
is not actually optional in the parameter list so it might be confusing. But I see that the rest of the library uses Optional (i.e. for functions that could return None) so I will change it.
eli5/keras/explain_prediction.py
Outdated
a valid keras layer name, layer index, or an instance of a Keras layer. | ||
|
||
If None, a suitable layer is attempted to be retrieved. | ||
See :func:`eli5.keras._search_layer_backwards` for details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if we're documenting a function, it makes sense to make it public - or just document the behavior without mentioning a function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Forgot that I made the function private.
eli5/keras/explain_prediction.py
Outdated
An input image as a tensor to ``estimator``, | ||
from which prediction will be done and explained. | ||
|
||
For example a ``numpy.ndarray``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are there other supported data types, why "for example"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will explicitly mention that numpy arrays are required.
I think it is possible to use other input types (https://github.com/keras-team/keras/blob/ed07472bc5fc985982db355135d37059a1f887a9/keras/engine/training.py#L1315), i.e. tensorflow tensor. However, I haven't tested with other types and I think I have some numpy dependencies in my code. Adding more input types could be a separate GitHub issue?
Co-Authored-By: Mikhail Korobov <[email protected]>
…nto keras-gradcam-img
Currently the parameter |
I think both options are fine, to me a constant looks a bit better than a string, and I think it's fine to use a PIL constant here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@teabolt sorry for a long review - there is just one minor thing which is preventing the merge now, updating the some docs after some attributes were moved to TargetExplanation
:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks @teabolt 👍
@@ -45,6 +45,9 @@ following machine learning frameworks and packages: | |||
* :ref:`library-sklearn-crfsuite`. ELI5 allows to check weights of | |||
sklearn_crfsuite.CRF models. | |||
|
|||
* :ref:`library-keras` - explain predictions of image classifiers | |||
via Grad-CAM visualizations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for a late comment: could you please copy overview.rst changes to README file in the repo root?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do that in #329. Thanks
This PR adds explanations for Keras models that are used to classify images. Specifically we implement Grad-CAM.
For example, the following piece of code:
produces this explanation for the class 'motor scooter':
The following features are added:
keras
package withexplain_prediction_keras()
andformatters.image
module withformat_as_image()
(requires matplotlib and Pillow)..image
attribute tobase.Explanation
and.heatmap
tobase.TargetExplanation
.ipython.show_prediction()
dispatch to an image display function for image explanations.TODO items before this PR is finalized:
CoveragePass CIMypy type annotationsDocs (formatting, tutorial)Integration and unit tests