Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the Latent Shift attribution method #1024

Open
wants to merge 72 commits into
base: master
Choose a base branch
from

Conversation

ieee8023
Copy link

(making a new PR to fix the CLA issue)
RE: #694
I got ahead of myself with the code here but I will use the PR as a design doc.

Background

The Latent Shift method is an approach to explain neural networks by generating a counterfactual. What differentiates this approach from others is that it is modular and tries to be as simple as possible. To generate the counterfactual this method uses an autoencoder to restrict the possible adversarial examples to remain in the data space by adjusting the latent space of the autoencoder using dy/dz instead of dy/dx in order to change the classifier's prediction.

Proposed Captum API Design

The proposed design is an attribution method which takes an autoencoder as well as the model as input. Here is the proposed interface.

# Load classifier and autoencoder
model = classifiers.FaceAttribute()
ae = autoencoders.Transformer(weights="faceshq")

# Load image
input = torch.randn(1, 3, 1024, 1024)

# Defining Latent Shift module
attr = captum.attr.LatentShift(model, ae)

# Computes counterfactual for class 3.
output = attr.attribute(input, target=3)

This example corresponds to tutorial code and models provided in this repo: https://github.com/ieee8023/latentshift and available as a colab notebook https://colab.research.google.com/github/ieee8023/latentshift/blob/main/example.ipynb.

The call to attr.attribute returns a dictionary with multiple aspects of the search that can be used later. The basic output being a heatmap in output['heatmap'] and a sequence of images in output['generated_images'].

The generated_images can be stitched together into a video using a provided attr.generate_video(output, "filename") function that will take the images and use ffmpeg to combine them into a video. It will also add the output probability of the model in the upper left hand corner to make it easier to interpret.

To generate the attribution and heatmaps there are many parameters which have defaults set. The search that is performed is an iterative heuristic based search and will likely need to be tuned when new autoencoders and models are used. The search determines the correct lambda values to generate the shift. The search starts by stepping by search_step_size in the negative direction while trying to determine if the output of the classifier has changed by search_pred_diff or when the change in the predict in stops going down. In order to avoid artifacts if the shift is too large or in the wrong direction an extra stop conditions is added search_max_pixel_diff if the change in the image is too large. To avoid the search from taking too long a search_max_steps will prevent the search from going on endlessly.

Also here is a side by side comparison of this method to other methods in captum for the target of pointy_nose
Screen Shot 2022-09-11 at 13 09 25

Copy link
Contributor

@aobo-y aobo-y left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for our delayed review. This looks great. Thanks for your contribution!
I have left some comments about the high level design choices. Will continue the review in the following days.

Curious if you have checked captum.robust. I feel LatentShift also fit well there, as it generates adversarial inputs. Do you consider LatentShift more as an attribution method?

captum/attr/_core/latent_shift.py Show resolved Hide resolved

params["heatmap"] = heatmap

return params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we treat LatentShift as an attribution method, the single most important return should be this heatmap, i.e., attr in Captum's wording. To align with our convention, that should be the single return of attribute().

But of course, for LatentShift, it is also valuable to have the generated frames. Maybe we can consider the design used in current IntegratedGradient https://captum.ai/api/integrated_gradients.html where a flag return_convergence_delta is used to indicate if we need a 2nd return

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored it so it will return a single batched result of heatmaps. If a return_dicts argument is set to true it will return a list of dicts with the details of the computation (needed for videos). What do you think?

captum/attr/_core/latent_shift.py Outdated Show resolved Hide resolved
@aobo-y aobo-y self-assigned this Mar 1, 2023
@ieee8023
Copy link
Author

captum.robust

Ya I agree it also fits there. But I would vote for it to be an attribution method so it is easier for people to use if for that purpose.

@ieee8023
Copy link
Author

ieee8023 commented Apr 4, 2023

@aobo-y Hey just pinging you that I think I resolved the last round of comments on the code!

@ieee8023
Copy link
Author

@aobo-y I just updated a demo notebook so it works with the new interface and relocation of the video utils: https://colab.research.google.com/github/ieee8023/latentshift/blob/main/example.ipynb

It seems the main branch has broken tests so I can't make them pass at the moment.

@ieee8023
Copy link
Author

@aobo-y tests are passing now!

@facebook-github-bot
Copy link
Contributor

@aobo-y has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants