Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 593217203
  • Loading branch information
learned_optimization authors committed Dec 23, 2023
1 parent a7322af commit 4bcaeb0
Showing 1 changed file with 54 additions and 10 deletions.
64 changes: 54 additions & 10 deletions learned_optimization/research/univ_nfn/learned_opt/learned_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,34 @@ def __call__(self, inp_features):
return self.mod(inp_features, self.perm_spec.unfreeze())


class HybridMLPNFN(nn.Module):
"""MLP + NFN Lopt."""

in_channels: int
hidden_channels: int
out_channels: int
num_layers: int
perm_spec: Any
ptwise_init: bool = False

def setup(self):
out_channels, hidden_channels = self.out_channels, self.hidden_channels

self.mlp = MLPForOpt(hidden_channels, hidden_channels, self.num_layers - 1)

def make_layer(out_chan, in_chan):
if self.ptwise_init:
return universal_layers.PointwiseInitNFLinear(out_chan, in_chan)
else:
return universal_layers.NFLinear(out_chan, in_chan, w_init='lecun')

self.final = make_layer(out_channels, hidden_channels)

def __call__(self, inp_features):
features = universal_layers.nf_relu(self.mlp(inp_features))
return self.final(features, self.perm_spec.unfreeze())


class SGDControl(lopt_base.LearnedOptimizer):
"""SGD where per-parameter learning rates are controlled by a network."""

Expand Down Expand Up @@ -457,7 +485,13 @@ class ResidualOptNFN(ResidualOpt):
"""NFN learning a residual on base optimizer."""

def __init__(
self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False, pos_emb=False
self,
task,
step_mult=0.1,
out_mult=1e-4,
ptwise_init=False,
pos_emb=False,
hybrid=False,
):
example_params = task.init(jax.random.PRNGKey(0))
if 'conv2_d' in example_params:
Expand All @@ -468,15 +502,25 @@ def __init__(
perm_spec = make_hk_transformer_perm_spec(example_params)
else:
perm_spec = make_hk_perm_spec(example_params)
network = UnivNFNForOpt(
in_channels=19,
hidden_channels=32,
out_channels=1,
num_layers=4,
perm_spec=perm_spec,
ptwise_init=ptwise_init,
pos_emb=pos_emb,
)
if hybrid:
assert not pos_emb
network = HybridMLPNFN(
in_channels=19,
hidden_channels=32,
out_channels=1,
num_layers=4,
perm_spec=perm_spec,
)
else:
network = UnivNFNForOpt(
in_channels=19,
hidden_channels=32,
out_channels=1,
num_layers=4,
perm_spec=perm_spec,
ptwise_init=ptwise_init,
pos_emb=pos_emb,
)
super().__init__(
network, example_params, step_mult=step_mult, out_mult=out_mult
)
Expand Down

0 comments on commit 4bcaeb0

Please sign in to comment.