Skip to content

Commit

Permalink
fix implicit methods
Browse files Browse the repository at this point in the history
  • Loading branch information
sungyubkim committed Dec 9, 2019
1 parent c87132d commit 1577e17
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
7 changes: 2 additions & 5 deletions gbml/imaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ def __init__(self, args):
super().__init__(args)
self._init_net()
self._init_opt()
self.lamb = 2.0
self.n_cg = 3
self.lamb = 100.0
self.n_cg = 1
return None

@torch.enable_grad()
def inner_loop(self, fmodel, diffopt, train_input, train_target):

train_logit = fmodel(train_input)
inner_loss = F.cross_entropy(train_logit, train_target)
inner_loss += (self.lamb/2.) * ((torch.nn.utils.parameters_to_vector(self.network.parameters())-torch.nn.utils.parameters_to_vector(self.network.parameters()).detach())**2).sum()
diffopt.step(inner_loss)

return None
Expand All @@ -41,8 +40,6 @@ def cg(self, in_grad, outer_grad, params):
beta = (r_new @ r_new)/(r @ r)
p = r_new + beta * p
r = r_new.clone().detach()
# print(alpha, beta ,r @ r);input()
# print('end')
return self.vec_to_grad(x)

def vec_to_grad(self, vec):
Expand Down
5 changes: 3 additions & 2 deletions gbml/neumann.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def __init__(self, args):
super().__init__(args)
self._init_net()
self._init_opt()
self.n_series = 3
self.lamb = 100.0
self.n_series = 1
return None

@torch.enable_grad()
Expand Down Expand Up @@ -48,7 +49,7 @@ def vec_to_grad(self, vec):
def hv_prod(self, in_grad, x, params):
hv = torch.autograd.grad(in_grad, params, retain_graph=True, grad_outputs=x)
hv = torch.nn.utils.parameters_to_vector(hv)
hv = (-1.*self.args.inner_lr) * hv # scale for regularization
hv = (-1./self.lamb) * hv # scaling for convergence
return hv.detach()

def outer_loop(self, batch, is_train):
Expand Down

0 comments on commit 1577e17

Please sign in to comment.