Skip to content

Commit

Permalink
Merge pull request #14 from AlignmentResearch/aga/lstm-var
Browse files Browse the repository at this point in the history
Tweak the ConvLSTM architecture until it works
  • Loading branch information
rhaps0dy authored May 28, 2024
2 parents 8d00215 + e032380 commit 51ccefb
Show file tree
Hide file tree
Showing 27 changed files with 2,192 additions and 375 deletions.
16 changes: 11 additions & 5 deletions cleanba/cleanba_impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,16 @@ def train(args: Args, *, writer: Optional[WandbWriter] = None):

policy, _, agent_params = args.net.init_params(envs, agent_params_subkey)

agent_state = TrainState.create(
apply_fn=None,
params=agent_params,
tx=make_optimizer(args, agent_params, total_updates=runtime_info.num_updates),
)
if args.load_path is None:
agent_state = TrainState.create(
apply_fn=None,
params=agent_params,
tx=make_optimizer(args, agent_params, total_updates=runtime_info.num_updates),
)
else:
old_args, agent_state = load_train_state(args.load_path)
print(f"Loaded TrainState from {args.load_path}. Here are the differences from `args` to the loaded args:")
farconf.config_diff(farconf.to_dict(args), farconf.to_dict(old_args))

multi_device_update = jax.pmap(
jax.jit(
Expand Down Expand Up @@ -716,6 +721,7 @@ def load_train_state(dir: Path) -> tuple[Args, TrainState]:
with open(dir / "model", "rb") as f:
train_state = flax.serialization.from_bytes(target_state, f.read())
assert isinstance(train_state, TrainState)
train_state = unreplicate(train_state)
return args, train_state


Expand Down
36 changes: 29 additions & 7 deletions cleanba/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
from dataclasses import field
from pathlib import Path
from typing import List
from typing import List, Optional

from cleanba.convlstm import ConvConfig, ConvLSTMConfig
from cleanba.convlstm import ConvConfig, ConvLSTMCellConfig, ConvLSTMConfig
from cleanba.environments import AtariEnv, EnvConfig, EnvpoolBoxobanConfig, random_seed
from cleanba.evaluate import EvalConfig
from cleanba.impala_loss import (
Expand Down Expand Up @@ -65,6 +65,8 @@ class Args:
distributed: bool = False # whether to use `jax.distributed`
concurrency: bool = True # whether to run the actor and learner concurrently

load_path: Optional[Path] = None # Where to load the initial training state from


def sokoban_resnet() -> Args:
CACHE_PATH = Path("/opt/sokoban_cache")
Expand Down Expand Up @@ -133,7 +135,7 @@ def sokoban_resnet() -> Args:
)


def sokoban_drc(num_layers: int, num_repeats: int) -> Args:
def sokoban_drc(n_recurrent: int, num_repeats: int) -> Args:
CACHE_PATH = Path("/opt/sokoban_cache")
return Args(
train_env=EnvpoolBoxobanConfig(
Expand Down Expand Up @@ -170,15 +172,35 @@ def sokoban_drc(num_layers: int, num_repeats: int) -> Args:
),
),
log_frequency=10,
sync_frequency=int(4e9),
net=ConvLSTMConfig(
embed=[ConvConfig(32, (4, 4), (1, 1), "SAME", True)] * 2,
recurrent=[ConvConfig(32, (3, 3), (1, 1), "SAME", True)] * num_layers,
recurrent=ConvLSTMCellConfig(
ConvConfig(32, (3, 3), (1, 1), "SAME", True), pool_and_inject="horizontal", fence_pad="same"
),
n_recurrent=n_recurrent,
mlp_hiddens=(256,),
repeats_per_step=num_repeats,
pool_and_inject=True,
add_one_to_forget=True,
),
loss=ImpalaLossConfig(
vtrace_lambda=0.97,
weight_l2_coef=1.5625e-07,
gamma=0.97,
logit_l2_coef=1.5625e-05,
),
actor_update_cutoff=100000000000000000000,
sync_frequency=100000000000000000000,
num_minibatches=8,
rmsprop_eps=1.5625e-07,
local_num_envs=256,
total_timesteps=80117760,
base_run_dir=Path("/training/cleanba"),
learning_rate=0.0004,
eval_frequency=978,
optimizer="adam",
base_fan_in=1,
anneal_lr=True,
max_grad_norm=0.015,
num_actor_threads=1,
)


Expand Down
Loading

0 comments on commit 51ccefb

Please sign in to comment.