Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the Latent Shift attribution method #1024

Open
wants to merge 72 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
32b38e0
Add Latent Shift
ieee8023 Sep 6, 2022
0e57fe4
video
ieee8023 Sep 6, 2022
7c5025c
align text
ieee8023 Sep 6, 2022
3a89340
cleanup
ieee8023 Sep 11, 2022
d0c833a
clean up docs
ieee8023 Sep 11, 2022
4dc25bf
add support for colab version
ieee8023 Sep 11, 2022
12e78d3
cleanup
ieee8023 Sep 11, 2022
2cf44ba
add more docs
ieee8023 Sep 11, 2022
0a74565
Merge branch 'master' into master
ieee8023 Sep 11, 2022
0ceb34e
cleanup format and add test
ieee8023 Sep 11, 2022
9043790
more cleanup
ieee8023 Sep 11, 2022
2963ae5
cleanup and add more docs
ieee8023 Sep 11, 2022
d4320d2
fix flake8 errors
ieee8023 Sep 11, 2022
a039159
fixing flake8 for real
ieee8023 Sep 11, 2022
907d2d7
fix format and add opion to limit printing
ieee8023 Sep 11, 2022
01bb3b2
fix type error
ieee8023 Sep 11, 2022
222128e
flake8
ieee8023 Sep 11, 2022
4c588b9
autopep8
ieee8023 Sep 12, 2022
c1dd756
make mypy happy
ieee8023 Sep 12, 2022
77963d8
ufmt format
ieee8023 Sep 12, 2022
b0a08d7
I really think flake8 will pass now
ieee8023 Sep 12, 2022
8432ff3
match reference to other references
ieee8023 Sep 12, 2022
5597155
small change to kick off tests again
ieee8023 Sep 14, 2022
5407086
Merge branch 'master' into master
ieee8023 Sep 15, 2022
9b5272e
Merge branch 'master' into master
ieee8023 Sep 21, 2022
7a19759
Merge branch 'master' into master
ieee8023 Oct 9, 2022
7d64a75
Merge branch 'master' into master
ieee8023 Oct 19, 2022
e245cde
Merge branch 'master' into master
ieee8023 Oct 22, 2022
7245048
Merge branch 'master' into master
ieee8023 Nov 4, 2022
03fe557
Merge branch 'master' into master
ieee8023 Nov 8, 2022
298c9e8
add options for extra loops and the cmap value
ieee8023 Nov 20, 2022
04f16ca
Merge branch 'master' into master
ieee8023 Nov 20, 2022
3d6f842
fix flake8
ieee8023 Nov 21, 2022
17bc3af
Add Latent Shift
ieee8023 Sep 6, 2022
7615dcc
video
ieee8023 Sep 6, 2022
558b429
align text
ieee8023 Sep 6, 2022
430888e
cleanup
ieee8023 Sep 11, 2022
bf434a9
clean up docs
ieee8023 Sep 11, 2022
96e8b42
add support for colab version
ieee8023 Sep 11, 2022
34f48f6
cleanup
ieee8023 Sep 11, 2022
554db30
add more docs
ieee8023 Sep 11, 2022
cefc673
cleanup format and add test
ieee8023 Sep 11, 2022
42c2c36
more cleanup
ieee8023 Sep 11, 2022
77c574b
cleanup and add more docs
ieee8023 Sep 11, 2022
8aa3fec
fix flake8 errors
ieee8023 Sep 11, 2022
90ffd8e
fixing flake8 for real
ieee8023 Sep 11, 2022
67a576c
fix format and add opion to limit printing
ieee8023 Sep 11, 2022
2a9cab7
fix type error
ieee8023 Sep 11, 2022
a0f156a
flake8
ieee8023 Sep 11, 2022
435bee8
autopep8
ieee8023 Sep 12, 2022
2f618ff
make mypy happy
ieee8023 Sep 12, 2022
3f9bbdd
ufmt format
ieee8023 Sep 12, 2022
cec2237
I really think flake8 will pass now
ieee8023 Sep 12, 2022
b387097
match reference to other references
ieee8023 Sep 12, 2022
29951a0
small change to kick off tests again
ieee8023 Sep 14, 2022
653a67a
add options for extra loops and the cmap value
ieee8023 Nov 20, 2022
2292efa
fix flake8
ieee8023 Nov 21, 2022
390fee0
Merge branch 'master' of github.com:ieee8023/captum
ieee8023 Jan 12, 2023
74af5c8
refactor image writing
ieee8023 Feb 19, 2023
e9196ed
refactor for batches and just returning heatmaps
ieee8023 Mar 28, 2023
8a24e9b
pep8
ieee8023 Mar 28, 2023
f873da3
ufmt
ieee8023 Mar 28, 2023
92e93a7
format errors
ieee8023 Mar 28, 2023
8f9d8a2
fix typing
ieee8023 Mar 28, 2023
7779497
reduce string length
ieee8023 Mar 28, 2023
cd2c2f5
remove usage of torchvision in tests
ieee8023 Mar 29, 2023
6855d55
Merge branch 'master' into master
ieee8023 Apr 24, 2023
10a43c0
add sigmoid param
ieee8023 May 5, 2023
1d8e9dd
Merge branch 'master' of github.com:ieee8023/captum
ieee8023 May 5, 2023
c0dfdfb
Merge branch 'master' into master
ieee8023 May 24, 2023
4eb4c6d
Merge branch 'master' into master
ieee8023 Jun 10, 2023
d63a803
Merge branch 'master' into master
ieee8023 Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions captum/attr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from captum.attr._core.input_x_gradient import InputXGradient # noqa
from captum.attr._core.integrated_gradients import IntegratedGradients # noqa
from captum.attr._core.kernel_shap import KernelShap # noqa
from captum.attr._core.latent_shift import LatentShift # noqa
from captum.attr._core.layer.grad_cam import LayerGradCam # noqa
from captum.attr._core.layer.internal_influence import InternalInfluence # noqa
from captum.attr._core.layer.layer_activation import LayerActivation # noqa
Expand Down Expand Up @@ -142,4 +143,5 @@
"Max",
"Sum",
"Count",
"LatentShift",
]
267 changes: 267 additions & 0 deletions captum/attr/_core/latent_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
#!/usr/bin/env python3

from typing import Any, Callable, Dict, List, Tuple, Union

import numpy as np
import torch
from captum.attr._utils.attribution import GradientAttribution
from captum.log import log_usage
from torch import Tensor


class LatentShift(GradientAttribution):
r"""An implementation of the Latent Shift method to generate
counterfactual explanations. This method uses an autoencoder to restrict
the possible adversarial examples to remain in the data space by
adjusting the latent space of the autoencoder using dy/dz instead of
dy/dx in order to change the classifier's prediction.

This class implements a search strategy to determine the lambda needed to
change the prediction of the classifier by a specific amount as well as
the code to generate a video and construct a heatmap representing the
image changes for viewing as an image.

More details regarding the latent shift method can be found in the
original paper:
https://arxiv.org/abs/2102.09475
And the original code repository:
https://github.com/mlmed/gifsplanation
"""

def __init__(self, forward_func: Callable, autoencoder) -> None:
r"""
Args:
forward_func (callable): The forward function of the model or
any modification of it
autoencoder: An object with an encode and decode function which
maintains a gradient tape.
"""
GradientAttribution.__init__(self, forward_func)
self.ae = autoencoder

# check if ae has encode and decode
assert hasattr(self.ae, "encode")
assert hasattr(self.ae, "decode")

@log_usage()
def attribute(
self,
inputs: Tensor,
target: int,
fix_range: Union[Tuple, None] = None,
search_pred_diff: float = 0.8,
search_step_size: float = 10.0,
search_max_steps: int = 3000,
search_max_pixel_diff_pct: float = 0.05,
lambda_sweep_steps: int = 10,
heatmap_method: str = "int",
apply_sigmoid: bool = True,
verbose: bool = True,
return_dicts: bool = False,
) -> Union[Tensor, List[Dict[str, Any]]]:
r"""
This method performs a search in order to determine the correct lambda
values to generate the shift. The search starts by stepping by
`search_step_size` in the negative direction while trying to determine
if the output of the classifier has changed by `search_pred_diff` or
when the change in the predict in stops going down. In order to avoid
artifacts if the shift is too large or in the wrong direction an extra
stop conditions is added `search_max_pixel_diff` if the change in the
image is too large. To avoid the search from taking too long a
`search_max_steps` will prevent the search from going on endlessly.


Args:

inputs (tensor): Input for which the counterfactual is computed.
target (int): Output indices for which dydz is computed (for
classification cases, this is usually the target class).
fix_range (tuple): Overrides searching and directly specifies the
lambda range to use. e.g. [-100,0].
search_pred_diff (float): The desired change in the classifiers
prediction. For example if the classifer predicts 0.9
and pred_diff=0.8 the search will try to generate a
counterfactual where the prediction is 0.1.
search_step_size (float): When searching for the right lambda to use
this will be the initial step size. This is similar to
a learning rate. Smaller values avoid jumping over the
ideal lambda but the search may take a long time.
search_max_steps (int): The max steps to take when doing the search.
Sometimes steps make a tiny improvement and can go on
forever. This just bounds the time and gives up the
search.
search_max_pixel_diff_pct (float): When searching, stop if the pixel
difference is larger than this amount. This will
prevent large artifacts being introduced into the
image. |img0 - imgx| > |img0|*pct
lambda_sweep_steps (int): How many frames to generate for the video.
heatmap_method: Default: 'int'. Possible methods: 'int': Average
per frame differences. 'mean' : Average difference
between 0 and other lambda frames. 'mm': Difference
between first and last frames. 'max': Max difference
from lambda 0 frame
apply_sigmoid: Default: True. Apply a sigmoid to the output of the
model. Set to false to work with regression models or
if the model already applies a sigmoid.
verbose: True to print debug text
return_dicts (bool): Return a list of dicts containing information
from each image processed. Default False

Returns:
attributions or (if return_dict=True) a list of dicts containing the
follow keys:
generated_images: A list of images generated at each step along
the dydz vector from the smallest lambda to the largest. By
default the smallest lambda represents the counterfactual
image and the largest lambda is 0 (representing no change).
lambdas: A list of the lambda values for each generated image.
preds: A list of the predictions of the model for each generated
image.
heatmap: A heatmap indicating the pixels which change in the
video sequence of images.


Example::

>>> # Load classifier and autoencoder
>>> model = classifiers.FaceAttribute()
>>> ae = autoencoders.VQGAN(weights="faceshq")
>>>
>>> # Load image
>>> x = torch.randn(1, 3, 1024, 1024)
ieee8023 marked this conversation as resolved.
Show resolved Hide resolved
>>>
>>> # Defining Latent Shift module
>>> attr = captum.attr.LatentShift(model, ae)
>>>
>>> # Computes counterfactual for class 3.
>>> output = attr.attribute(x, target=3)

"""

assert lambda_sweep_steps > 1, "lambda_sweep_steps must be at least 2"

results = []
# cheap batching
for idx in range(inputs.shape[0]):
inp = inputs[idx].unsqueeze(0)
z = self.ae.encode(inp).detach()
z.requires_grad = True
x_lambda0 = self.ae.decode(z)
pred = self.forward_func(x_lambda0)[:, target]
if apply_sigmoid:
pred = torch.sigmoid(pred)
dzdxp = torch.autograd.grad(pred, z)[0]

# Cache so we can reuse at sweep stage
cache = {}

def compute_shift(lambdax):
"""Compute the shift for a specific lambda"""
if lambdax not in cache:
x_lambdax = self.ae.decode(z + dzdxp * lambdax).detach()
pred1 = self.forward_func(x_lambdax)[:, target]
if apply_sigmoid:
pred1 = torch.sigmoid(pred1)
pred1 = pred1.detach().cpu().numpy()
cache[lambdax] = x_lambdax, pred1
return cache[lambdax]

_, initial_pred = compute_shift(0)

if fix_range:
lbound, rbound = fix_range
else:
# Left range
lbound = 0
last_pred = initial_pred
pixel_sum = x_lambda0.abs().sum() # Used for pixel diff
while True:
x_lambdax, cur_pred = compute_shift(lbound)
pixel_diff = torch.abs(x_lambda0 - x_lambdax).sum().detach().cpu()
if verbose:
toprint = [
f"Shift: {lbound}",
f"Pred: {float(cur_pred)}",
f"pixel_diff: {float(pixel_diff)}",
f"sum*diff_pct: {pixel_sum * search_max_pixel_diff_pct}",
]
print(", ".join(toprint))

# If we stop decreasing the prediction
if last_pred < cur_pred:
break
# If the prediction becomes very low
if cur_pred < 0.05:
break
# If we have decreased the prediction by pred_diff
if initial_pred - search_pred_diff > cur_pred:
break
# If we are moving in the latent space too much
if lbound <= -search_max_steps:
break
# If we move too far we will distort the image
if pixel_diff > (pixel_sum * search_max_pixel_diff_pct):
break

last_pred = cur_pred
lbound = lbound - search_step_size + lbound // 10

# Right range search not implemented
rbound = 0

if verbose:
print("Selected bounds: ", lbound, rbound)

# Sweep over the range of lambda values to create a sequence
lambdas = np.linspace(lbound, rbound, lambda_sweep_steps)
assert lambda_sweep_steps == len(
lambdas
), "Inconsistent number of lambda steps"

if verbose:
print("Lambdas to compute: ", lambdas)

preds = []
generated_images = []

for lam in lambdas:
x_lambdax, pred = compute_shift(lam)
generated_images.append(x_lambdax.cpu().numpy()[0])
preds.append(float(pred))

params = {}
params["generated_images"] = np.array(generated_images)
params["lambdas"] = lambdas
params["preds"] = preds

x_lambda0 = x_lambda0.detach().cpu().numpy()
if heatmap_method == "max":
# Max difference from lambda 0 frame
heatmap = np.max(np.abs(x_lambda0 - generated_images), 0)

elif heatmap_method == "mean":
# Average difference between 0 and other lambda frames
heatmap = np.mean(np.abs(x_lambda0 - generated_images), 0)

elif heatmap_method == "mm":
# Difference between first and last frames
heatmap = np.abs(generated_images[0] - generated_images[-1])

elif heatmap_method == "int":
# Average per frame differences
image_changes = []
for i in range(len(generated_images) - 1):
image_changes.append(
np.abs(generated_images[i] - generated_images[i + 1])
)
heatmap = np.mean(image_changes, 0)
else:
raise Exception("Unknown heatmap_method for 2d image")

params["heatmap"] = heatmap
results.append(params)

if return_dicts:
return results
else:
return torch.tensor([result["heatmap"] for result in results])
Loading