diff --git a/learned_optimization/learned_optimizers/nn_adam.py b/learned_optimization/learned_optimizers/nn_adam.py index d374d08..aeb0eaa 100644 --- a/learned_optimization/learned_optimizers/nn_adam.py +++ b/learned_optimization/learned_optimizers/nn_adam.py @@ -269,20 +269,24 @@ def init( lambda x: jnp.tile(x, [n_states] + [1] * len(x.shape[1:])), theta["lstm_init_state"]) - return NNAdamState( + return NNAdamState( # pytype: disable=wrong-arg-types # jnp-type params=params, rolling_features=rolling.init(params), iteration=jnp.asarray(0, dtype=jnp.int32), state=model_state, lstm_hidden_state=lstm_hidden_state, - per_layer_lr=jax.tree_util.tree_map(lambda x: theta["per_layer_lr"], - params), + per_layer_lr=jax.tree_util.tree_map( + lambda x: theta["per_layer_lr"], params + ), per_layer_beta1=jax.tree_util.tree_map( - lambda x: theta["per_layer_beta1"], params), + lambda x: theta["per_layer_beta1"], params + ), per_layer_beta2=jax.tree_util.tree_map( - lambda x: theta["per_layer_beta2"], params), + lambda x: theta["per_layer_beta2"], params + ), per_layer_epsilon=jax.tree_util.tree_map( - lambda x: theta["per_layer_epsilon"], params), + lambda x: theta["per_layer_epsilon"], params + ), ) def lstm_features_for_tensor(self, p: jnp.ndarray, g: jnp.ndarray, diff --git a/learned_optimization/optimizers/optax_opts.py b/learned_optimization/optimizers/optax_opts.py index e8c1edd..1f90ae6 100644 --- a/learned_optimization/optimizers/optax_opts.py +++ b/learned_optimization/optimizers/optax_opts.py @@ -384,11 +384,13 @@ def __init__(self, # SM3 doesn't support scalars, so we have to reshape the params and grads. - def init(self, - params: Any, - model_state: Optional[Any] = None, - num_steps: Optional[int] = None, - key: chex.PRNGKey = None) -> SM3OptState: + def init( # type: ignore # jnp-type + self, + params: Any, + model_state: Optional[Any] = None, + num_steps: Optional[int] = None, + key: chex.PRNGKey = None, + ) -> SM3OptState: should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape) out = super().init(params, model_state, num_steps, key) diff --git a/learned_optimization/outer_train.py b/learned_optimization/outer_train.py index 119f402..d28586d 100644 --- a/learned_optimization/outer_train.py +++ b/learned_optimization/outer_train.py @@ -235,7 +235,7 @@ def metrics_and_info_from_gradients( max_stale = current_step - onp.min(steps) metrics["max_staleness"] = max_stale - return metrics, worker_ids, applied_inner_steps + return metrics, worker_ids, applied_inner_steps # pytype: disable=bad-return-type # jnp-type def maybe_resample_gradient_estimators( diff --git a/learned_optimization/outer_trainers/full_es.py b/learned_optimization/outer_trainers/full_es.py index bd0b64d..2a0df60 100644 --- a/learned_optimization/outer_trainers/full_es.py +++ b/learned_optimization/outer_trainers/full_es.py @@ -308,7 +308,7 @@ def single_vec_batch(theta, state, key_data): es_grad = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), vec_es_grad) - return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad + return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad # pytype: disable=bad-return-type # jnp-type @gin.configurable diff --git a/learned_optimization/outer_trainers/truncated_es.py b/learned_optimization/outer_trainers/truncated_es.py index 8154472..e00804c 100644 --- a/learned_optimization/outer_trainers/truncated_es.py +++ b/learned_optimization/outer_trainers/truncated_es.py @@ -85,7 +85,7 @@ def flat_first(x): es_grad = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), vec_es_grad) - return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad, p_ys, delta_loss + return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad, p_ys, delta_loss # pytype: disable=bad-return-type # jnp-type @gin.configurable diff --git a/learned_optimization/outer_trainers/truncated_pes.py b/learned_optimization/outer_trainers/truncated_pes.py index 1572ce9..da94a1e 100644 --- a/learned_optimization/outer_trainers/truncated_pes.py +++ b/learned_optimization/outer_trainers/truncated_pes.py @@ -130,8 +130,15 @@ def _switch_one_accum(a, b): pos_loss = jnp.sum(p_ys.loss * p_ys.mask, axis=0) / jnp.sum(p_ys.mask, axis=0) neg_loss = jnp.sum(n_ys.loss * n_ys.mask, axis=0) / jnp.sum(n_ys.mask, axis=0) - return jnp.mean( - (pos_loss + neg_loss) / 2.0), es_grad, new_accumulator, p_ys, delta_losses + return ( + jnp.mean( # pytype: disable=bad-return-type # jnp-type + (pos_loss + neg_loss) / 2.0 + ), + es_grad, + new_accumulator, + p_ys, + delta_losses, + ) @gin.configurable