You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As noted in the paper that diagonal fisher information matrix is applied to replace the pre-activation Hessian, we tried to set opt_mode to fisher_diag instead of mse for reconstruction. However, a runtime error is thrown:
File "xxxx/quant/data_utils.py", line 184, in __call__
loss.backward()
File "xxxx/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "xxxx/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
recon_model function is the same as that in the main_imagenet file:
defrecon_model(model: nn.Module):
""" Block reconstruction. For the first and last layers, we can only apply layer reconstruction. """forname, moduleinmodel.named_children():
ifisinstance(module, QuantModule):
ifmodule.ignore_reconstructionisTrue:
print('Ignore reconstruction of layer {}'.format(name))
continueelse:
layer_reconstruction(qnn, module, **kwargs)
elifisinstance(module, BaseQuantBlock):
ifmodule.ignore_reconstructionisTrue:
print('Ignore reconstruction of block {}'.format(name))
continueelse:
print('Reconstruction for block {}'.format(name))
block_reconstruction(qnn, module, **kwargs)
else:
recon_model(module)
We are not quite sure why PyTorch complains here as backward only calls once in a batch... But we also noticed that after calling save_grad_data, grad would be cached for later loss calculation:
# in block_reconstructionerr=loss_func(out_quant, cur_out, cur_grad)
Is intermediate grad still available at this point since backward has already been called? In our case, even we workaround for the first error inside save_grad_data, here we would get a same one (i. e. backward twice)
Hi Yuhang,
Thank you for open sourcing this project.
As noted in the paper that diagonal fisher information matrix is applied to replace the pre-activation Hessian, we tried to set
opt_mode
tofisher_diag
instead ofmse
for reconstruction. However, a runtime error is thrown:It seems occuring during backward to save grad:
As indicated by the error, first backward succeeds but second fails.
We tried to create a very simple network for reproducing and the error keeps showing:
recon_model
function is the same as that in themain_imagenet
file:We are not quite sure why PyTorch complains here as
backward
only calls once in a batch... But we also noticed that after callingsave_grad_data
, grad would be cached for later loss calculation:Is intermediate grad still available at this point since backward has already been called? In our case, even we workaround for the first error inside
save_grad_data
, here we would get a same one (i. e. backward twice)Environment
Ubuntu 16.04 / Python 3.6.8 / PyTorch 1.7.1 / CUDA 10.1
Any advice would be appreciated.
The text was updated successfully, but these errors were encountered: