Skip to content

Commit

Permalink
Ignore incorrect type annotations related to jax dtypes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571144007
  • Loading branch information
Jake VanderPlas authored and learned_optimization authors committed Oct 5, 2023
1 parent 463ab9a commit 60c128d
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 16 deletions.
16 changes: 10 additions & 6 deletions learned_optimization/learned_optimizers/nn_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,24 @@ def init(
lambda x: jnp.tile(x, [n_states] + [1] * len(x.shape[1:])),
theta["lstm_init_state"])

return NNAdamState(
return NNAdamState( # pytype: disable=wrong-arg-types # jnp-type
params=params,
rolling_features=rolling.init(params),
iteration=jnp.asarray(0, dtype=jnp.int32),
state=model_state,
lstm_hidden_state=lstm_hidden_state,
per_layer_lr=jax.tree_util.tree_map(lambda x: theta["per_layer_lr"],
params),
per_layer_lr=jax.tree_util.tree_map(
lambda x: theta["per_layer_lr"], params
),
per_layer_beta1=jax.tree_util.tree_map(
lambda x: theta["per_layer_beta1"], params),
lambda x: theta["per_layer_beta1"], params
),
per_layer_beta2=jax.tree_util.tree_map(
lambda x: theta["per_layer_beta2"], params),
lambda x: theta["per_layer_beta2"], params
),
per_layer_epsilon=jax.tree_util.tree_map(
lambda x: theta["per_layer_epsilon"], params),
lambda x: theta["per_layer_epsilon"], params
),
)

def lstm_features_for_tensor(self, p: jnp.ndarray, g: jnp.ndarray,
Expand Down
12 changes: 7 additions & 5 deletions learned_optimization/optimizers/optax_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,13 @@ def __init__(self,

# SM3 doesn't support scalars, so we have to reshape the params and grads.

def init(self,
params: Any,
model_state: Optional[Any] = None,
num_steps: Optional[int] = None,
key: chex.PRNGKey = None) -> SM3OptState:
def init( # type: ignore # jnp-type
self,
params: Any,
model_state: Optional[Any] = None,
num_steps: Optional[int] = None,
key: chex.PRNGKey = None,
) -> SM3OptState:
should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test
params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape)
out = super().init(params, model_state, num_steps, key)
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/outer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def metrics_and_info_from_gradients(
max_stale = current_step - onp.min(steps)
metrics["max_staleness"] = max_stale

return metrics, worker_ids, applied_inner_steps
return metrics, worker_ids, applied_inner_steps # pytype: disable=bad-return-type # jnp-type


def maybe_resample_gradient_estimators(
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/outer_trainers/full_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def single_vec_batch(theta, state, key_data):

es_grad = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), vec_es_grad)

return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad
return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad # pytype: disable=bad-return-type # jnp-type


@gin.configurable
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/outer_trainers/truncated_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def flat_first(x):

es_grad = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), vec_es_grad)

return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad, p_ys, delta_loss
return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad, p_ys, delta_loss # pytype: disable=bad-return-type # jnp-type


@gin.configurable
Expand Down
11 changes: 9 additions & 2 deletions learned_optimization/outer_trainers/truncated_pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,15 @@ def _switch_one_accum(a, b):
pos_loss = jnp.sum(p_ys.loss * p_ys.mask, axis=0) / jnp.sum(p_ys.mask, axis=0)
neg_loss = jnp.sum(n_ys.loss * n_ys.mask, axis=0) / jnp.sum(n_ys.mask, axis=0)

return jnp.mean(
(pos_loss + neg_loss) / 2.0), es_grad, new_accumulator, p_ys, delta_losses
return (
jnp.mean( # pytype: disable=bad-return-type # jnp-type
(pos_loss + neg_loss) / 2.0
),
es_grad,
new_accumulator,
p_ys,
delta_losses,
)


@gin.configurable
Expand Down

0 comments on commit 60c128d

Please sign in to comment.