Skip to content

Commit

Permalink
fix(ver): consider torch version for function_call import
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnec committed Jan 4, 2024
1 parent aea18cd commit 8992ffa
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions torchimize/optimizer/gna_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
from torch.optim.optimizer import Optimizer
from typing import List

from distutils.version import LooseVersion
if LooseVersion(torch.__version__) > LooseVersion('2.0.0'):
from torch.func import functional_call
else:
try:
from torch.nn.utils.stateless import functional_call
except ImportError:
from torch.nn.utils._stateless import functional_call



class GNA(Optimizer):
r"""Implements Gauss-Newton.
Expand Down Expand Up @@ -62,7 +72,7 @@ def step(self, x: torch.Tensor, closure=None):
if self.hessian_approx:
# vectorized jacobian (https://github.com/pytorch/pytorch/issues/49171)
def func(*params: torch.Tensor):
out = torch.func.functional_call(self._model, {n: p for n, p in zip(keys, params)}, x)
out = functional_call(self._model, {n: p for n, p in zip(keys, params)}, x)
return out
self._j_list: tuple[torch.Tensor] = torch.autograd.functional.jacobian(func, values, create_graph=False) # NxCxBxCxHxW
# create hessian approximation
Expand All @@ -77,7 +87,7 @@ def func(*params: torch.Tensor):
else:
# vectorized hessian (https://github.com/pytorch/pytorch/issues/49171)
def func(*params: torch.Tensor):
out: torch.Tensor = torch.func.functional_call(self._model, {n: p for n, p in zip(keys, params)}, x)
out: torch.Tensor = functional_call(self._model, {n: p for n, p in zip(keys, params)}, x)
return out.square().sum()
self._h_list: tuple[torch.Tensor] = torch.autograd.functional.hessian(func, tuple(self._model.parameters()), create_graph=False)
self._h_list = [self._h_list[i][i] for i in range(len(self._h_list))] # filter j-th element
Expand Down

0 comments on commit 8992ffa

Please sign in to comment.