Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Trying to backward through the graph a second time when setting opt_mode to fisher_diag #22

Open
a1trl9 opened this issue Oct 19, 2021 · 0 comments

Comments

@a1trl9
Copy link

a1trl9 commented Oct 19, 2021

Hi Yuhang,

Thank you for open sourcing this project.

As noted in the paper that diagonal fisher information matrix is applied to replace the pre-activation Hessian, we tried to set opt_mode to fisher_diag instead of mse for reconstruction. However, a runtime error is thrown:

File "xxxx/quant/data_utils.py", line 184, in __call__
    loss.backward()
  File "xxxx/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "xxxx/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

It seems occuring during backward to save grad:

handle = self.layer.register_backward_hook(self.data_saver)
        with torch.enable_grad():
            try:
                self.model.zero_grad()
                inputs = model_input.to(self.device)
                self.model.set_quant_state(False, False)
                out_fp = self.model(inputs)
                quantize_model_till(self.model, self.layer, self.act_quant)
                out_q = self.model(inputs)
                loss = F.kl_div(F.log_softmax(out_q, dim=1), F.softmax(out_fp, dim=1), reduction='batchmean')
                # here....
                loss.backward()
            except StopForwardException:
                pass

As indicated by the error, first backward succeeds but second fails.

We tried to create a very simple network for reproducing and the error keeps showing:

class DummyNet(nn.Module):
  def __init__(self):
      super(DummyNet, self).__init__()
      self.conv1 = nn.Conv2d(3, 32, 3, 3)
      self.conv2 = nn.Conv2d(32, 32, 3, 3)
      self.conv3 = nn.Conv2d(32, 1, 3, 3)

  def forward(self, x):
      x = self.conv1(x)
      x = self.conv2(x)
      x = self.conv3(x)
      output = F.log_softmax(x, dim=0)
      return output

recon_model function is the same as that in the main_imagenet file:

def recon_model(model: nn.Module):
        """
        Block reconstruction. For the first and last layers, we can only apply layer reconstruction.
        """
        for name, module in model.named_children():
            if isinstance(module, QuantModule):
                if module.ignore_reconstruction is True:
                    print('Ignore reconstruction of layer {}'.format(name))
                    continue
                else:
                    layer_reconstruction(qnn, module, **kwargs)
            elif isinstance(module, BaseQuantBlock):
                if module.ignore_reconstruction is True:
                    print('Ignore reconstruction of block {}'.format(name))
                    continue
                else:
                    print('Reconstruction for block {}'.format(name))
                    block_reconstruction(qnn, module, **kwargs)
            else:
                recon_model(module)

We are not quite sure why PyTorch complains here as backward only calls once in a batch... But we also noticed that after calling save_grad_data, grad would be cached for later loss calculation:

# in block_reconstruction
err = loss_func(out_quant, cur_out, cur_grad)

Is intermediate grad still available at this point since backward has already been called? In our case, even we workaround for the first error inside save_grad_data, here we would get a same one (i. e. backward twice)

Environment

Ubuntu 16.04 / Python 3.6.8 / PyTorch 1.7.1 / CUDA 10.1

Any advice would be appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant