Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak the ConvLSTM architecture until it works #14

Merged
merged 42 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4717422
Implement LSTM
rhaps0dy May 2, 2024
f677f9d
LSTM reference test
rhaps0dy May 3, 2024
9a9d233
Tests say it's fine.
rhaps0dy May 4, 2024
94be8b4
Fix ConvLSTM config
rhaps0dy May 4, 2024
c39576e
Check that the CLI parses without launching it
rhaps0dy May 4, 2024
0cd099b
Can we straightforwardly compensate for LSTM low variance?
rhaps0dy May 4, 2024
2a54374
Incorporate the vision shortcut
rhaps0dy May 4, 2024
155a9df
Hackily set up the devbox
rhaps0dy May 4, 2024
47176e3
Remove multiplier
rhaps0dy May 4, 2024
64e701e
Try and make the ConvLSTM more balanced in variance
rhaps0dy May 4, 2024
aa179e9
Use horizontal connections for pool-and-inject
rhaps0dy May 4, 2024
a11fd26
pool_and_inject_horizontal
rhaps0dy May 4, 2024
1cc9423
Squash missing `head_scale` bug
rhaps0dy May 5, 2024
9b4b8dd
Get rid of generics and work around https://github.com/NiklasRosenste…
rhaps0dy May 5, 2024
765482e
Try more things
rhaps0dy May 5, 2024
31cc510
Change notebook params
rhaps0dy May 5, 2024
1436bfd
Torch initialization maybe
rhaps0dy May 5, 2024
1f45d9e
Go back to separated convs
rhaps0dy May 5, 2024
760f29f
More architectural variation for the ConvLSTM
rhaps0dy May 5, 2024
d4ae3c3
Fencepad
rhaps0dy May 5, 2024
d3587f3
Copy hyperparams from the best lstm runs
rhaps0dy May 5, 2024
0fd2af6
Config reflects actual practice
rhaps0dy May 5, 2024
a5326b8
Wild variations experiment
rhaps0dy May 5, 2024
cfd7831
Run good one for longer
rhaps0dy May 5, 2024
e9dc974
Store RMS gradient of every parameter
rhaps0dy May 5, 2024
bda7fd3
Have to commit new variations
rhaps0dy May 5, 2024
0ec2666
Try updating the actor again
rhaps0dy May 5, 2024
497430b
What if it's a residual LSTM
rhaps0dy May 5, 2024
ac9b919
Try again
rhaps0dy May 5, 2024
2803e78
Tweak the residual stream
rhaps0dy May 5, 2024
56756b5
Load_path
rhaps0dy May 5, 2024
578f999
Continue the old run hopefully
rhaps0dy May 6, 2024
014e7d3
Load_path maybe loads now
rhaps0dy May 6, 2024
39b8361
don't pin commit
rhaps0dy May 6, 2024
d76a7b0
Unreplicate loss before proceeding
rhaps0dy May 6, 2024
568a40a
Even more tweaks
rhaps0dy May 6, 2024
fb95771
Unreplicate train_state
rhaps0dy May 6, 2024
f79b5cb
Reduce randomness in planning evaluations
rhaps0dy May 6, 2024
a10f956
Test that the new evaluation is well-behaved
rhaps0dy May 6, 2024
3cd1daf
Get more planners
rhaps0dy May 6, 2024
bc2bb2e
Delete empty file
rhaps0dy May 9, 2024
e032380
Fix version for farconf
rhaps0dy May 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions cleanba/cleanba_impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,16 @@ def train(args: Args, *, writer: Optional[WandbWriter] = None):

policy, _, agent_params = args.net.init_params(envs, agent_params_subkey)

agent_state = TrainState.create(
apply_fn=None,
params=agent_params,
tx=make_optimizer(args, agent_params, total_updates=runtime_info.num_updates),
)
if args.load_path is None:
agent_state = TrainState.create(
apply_fn=None,
params=agent_params,
tx=make_optimizer(args, agent_params, total_updates=runtime_info.num_updates),
)
else:
old_args, agent_state = load_train_state(args.load_path)
print(f"Loaded TrainState from {args.load_path}. Here are the differences from `args` to the loaded args:")
farconf.config_diff(farconf.to_dict(args), farconf.to_dict(old_args))

multi_device_update = jax.pmap(
jax.jit(
Expand Down Expand Up @@ -716,6 +721,7 @@ def load_train_state(dir: Path) -> tuple[Args, TrainState]:
with open(dir / "model", "rb") as f:
train_state = flax.serialization.from_bytes(target_state, f.read())
assert isinstance(train_state, TrainState)
train_state = unreplicate(train_state)
return args, train_state


Expand Down
36 changes: 29 additions & 7 deletions cleanba/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
from dataclasses import field
from pathlib import Path
from typing import List
from typing import List, Optional

from cleanba.convlstm import ConvConfig, ConvLSTMConfig
from cleanba.convlstm import ConvConfig, ConvLSTMCellConfig, ConvLSTMConfig
from cleanba.environments import AtariEnv, EnvConfig, EnvpoolBoxobanConfig, random_seed
from cleanba.evaluate import EvalConfig
from cleanba.impala_loss import (
Expand Down Expand Up @@ -65,6 +65,8 @@ class Args:
distributed: bool = False # whether to use `jax.distributed`
concurrency: bool = True # whether to run the actor and learner concurrently

load_path: Optional[Path] = None # Where to load the initial training state from


def sokoban_resnet() -> Args:
CACHE_PATH = Path("/opt/sokoban_cache")
Expand Down Expand Up @@ -133,7 +135,7 @@ def sokoban_resnet() -> Args:
)


def sokoban_drc(num_layers: int, num_repeats: int) -> Args:
def sokoban_drc(n_recurrent: int, num_repeats: int) -> Args:
CACHE_PATH = Path("/opt/sokoban_cache")
return Args(
train_env=EnvpoolBoxobanConfig(
Expand Down Expand Up @@ -170,15 +172,35 @@ def sokoban_drc(num_layers: int, num_repeats: int) -> Args:
),
),
log_frequency=10,
sync_frequency=int(4e9),
net=ConvLSTMConfig(
embed=[ConvConfig(32, (4, 4), (1, 1), "SAME", True)] * 2,
recurrent=[ConvConfig(32, (3, 3), (1, 1), "SAME", True)] * num_layers,
recurrent=ConvLSTMCellConfig(
ConvConfig(32, (3, 3), (1, 1), "SAME", True), pool_and_inject="horizontal", fence_pad="same"
),
n_recurrent=n_recurrent,
mlp_hiddens=(256,),
repeats_per_step=num_repeats,
pool_and_inject=True,
add_one_to_forget=True,
),
loss=ImpalaLossConfig(
vtrace_lambda=0.97,
weight_l2_coef=1.5625e-07,
gamma=0.97,
logit_l2_coef=1.5625e-05,
),
actor_update_cutoff=100000000000000000000,
sync_frequency=100000000000000000000,
num_minibatches=8,
rmsprop_eps=1.5625e-07,
local_num_envs=256,
total_timesteps=80117760,
base_run_dir=Path("/training/cleanba"),
learning_rate=0.0004,
eval_frequency=978,
optimizer="adam",
base_fan_in=1,
anneal_lr=True,
max_grad_norm=0.015,
num_actor_threads=1,
)


Expand Down
Loading