Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems about make_visualizations.py #2

Open
Ljiaqii opened this issue Dec 24, 2021 · 2 comments
Open

Problems about make_visualizations.py #2

Ljiaqii opened this issue Dec 24, 2021 · 2 comments

Comments

@Ljiaqii
Copy link

Ljiaqii commented Dec 24, 2021

Hi! Thank you for sharing your great work! When I was studying make_ visualizations. py, I met some problems.
You define activations_ hook on line 21 of the program, but line 22 self.gradients = grads is not really initialized. So I found 36 lines of self activations_ hook will not call the function of line 21, which leads to errors in subsequent programs.
Therefore, do you have a good way to solve this problem or can you provide the latest version of the code?
Thank you very much and look forward to your reply! Best regards!

@teddykoker
Copy link
Owner

Hi!

This is the latest version of my code for doing GradCAM. activations_hook() should be called when .backword() is called on the output predictions (see docs).

If you could provide a stack trace of the error that would be very helpful!

@Ljiaqii
Copy link
Author

Ljiaqii commented Dec 26, 2021

Hi! Thank you for your quick reply! About the problem of .backward () is proved that I carelessly omitted it, I'm very sorry about it and I have corrected it.
Moreover, I intend to realize 3D image classification visualization based on your code (the cube of my program is a three-dimensional gray image with the size of 32 × 32× 32(z,y,x)).
But now I don't quite understand a few lines of code. Could you please tell me how to modify it?

Here are the codes :

def grad_cam_cube(model, test_data_loader):
model = GradCAMResnet(model)
model.eval()

for batch_idx, (inputs, targets, feat) in enumerate(test_data_loader):
    inputs, targets = inputs.cuda(), targets.cuda()    #  batchsize = 1, input [1,1,32,32,32]

    inputs, targets = Variable(inputs, requires_grad=False), Variable(targets)
    outputs = model(inputs)     #outputs:tensor([[ 4.2578,-4.4285]], device='cuda:0', grad_fn=<AddmmBackward>)
    outputs1 = outputs[0, 0]     #outputs1:tensor(4.2578, device='cuda:0', grad_fn=<SelectBackward>)
    outputs2 = outputs[0, 1]     #outputs2:tensor(-4.4285, device='cuda:0', grad_fn=<SelectBackward>)
    outputs[0,0].backward()      # otherwise outputs[0,1].backward()
    gradients, activations = model.gradients, model.activations  # gradients [1,2048,1,1,1]    # What does 2048 mean?
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
    for i in range(activations.shape[1]):
        activations[:, i, :, :] *= pooled_gradients[i]

    heatmap = torch.mean(activations, dim=1).squeeze()     #heatmap: tensor(0.0021, device='cuda:0')
    heatmap = torch.relu(heatmap)
    heatmap /= torch.max(heatmap)     #heatmap={Tensor} tensor(0.0021, device='cuda:0'  #Does the heatmap need to return 32 × 32 × 32?

The following code is my definition of GradCAMResnet:

class GradCAMResnet(torch.nn.Module):
def init(self, model):
super().init()
self.model = model

def activations_hook(self, grads):
    self.gradients = grads

def _forward_impl(self, x):
    # See note [TorchScript super()]
    x = self.model.conv1(x)  
    x = self.model.bn1(x)
    x = self.model.relu(x)
    x = self.model.maxpool(x)

    x = self.model.layer1(x)
    x = self.model.layer2(x)
    x = self.model.layer3(x)
    x = self.model.layer4(x)
   
    #I added the following two lines of code.
    self.activations = x.detach()  
    x.register_hook(self.activations_hook)  
    x = self.model.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.model.fc(x)

    return x

def forward(self, x):
    return self._forward_impl(x)

Thank you for your help! Looking forward to your reply!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants