Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634095268
  • Loading branch information
Jake VanderPlas authored and pax authors committed May 15, 2024
1 parent d903d68 commit 3f4cbb4
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 141 deletions.
2 changes: 1 addition & 1 deletion praxis/layers/adapters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_residual_adapter_tf_equivalent(self):
theta = initial_vars['params'].copy()
theta['layer_norm'] = theta['norm']
del theta['norm']
theta = jax.tree_map(np.array, theta)
theta = jax.tree.map(np.array, theta)
theta = py_utils.NestedMap.FromNestedDict(theta)
theta.down_w = tf.convert_to_tensor(theta.down_w)
theta.up_w = tf.convert_to_tensor(theta.up_w)
Expand Down
4 changes: 2 additions & 2 deletions praxis/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2247,12 +2247,12 @@ def _vmap_on_broadcast_prefixes(

# Wraps fn with slicing on args_to_slice and broadcast_args_to_slice.
def _sliced_fn(layer, args, args_to_slice, broadcast_args_to_slice, states):
sliced = jax.tree_map(
sliced = jax.tree.map(
lambda x, d: self._slice_decode_chunk(x, chunk_id, d),
args_to_slice,
args_time_dims,
)
broadcast_sliced = jax.tree_map(
broadcast_sliced = jax.tree.map(
lambda x, d: self._slice_decode_chunk(x, chunk_id, d),
broadcast_args_to_slice,
broadcast_args_time_dims,
Expand Down
4 changes: 2 additions & 2 deletions praxis/layers/convolutions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_causal_conv2d_layer(self):

prng_key = jax.random.PRNGKey(seed=123)
initial_vars = conv_layer.init(prng_key, inputs)
initial_vars = jax.tree_map(jnp.ones_like, initial_vars)
initial_vars = jax.tree.map(jnp.ones_like, initial_vars)

# Test odd length sequence.
output = conv_layer.apply(initial_vars, inputs)
Expand All @@ -156,7 +156,7 @@ def test_causal_conv2d_layer(self):

prng_key = jax.random.PRNGKey(seed=123)
initial_vars = conv_layer.init(prng_key, inputs)
initial_vars = jax.tree_map(jnp.ones_like, initial_vars)
initial_vars = jax.tree.map(jnp.ones_like, initial_vars)

output = conv_layer.apply(initial_vars, inputs)
np_output = np.array(
Expand Down
6 changes: 3 additions & 3 deletions praxis/layers/flax_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ def test_mix_layer(self):
def assert_learnable(x):
assert not x.collections

jax.tree_map(assert_learnable, init_var_meta['params'])
jax.tree.map(assert_learnable, init_var_meta['params'])

def assert_non_learnable(x):
assert WeightHParamsCollection.NON_TRAINABLE in x.collections
assert WeightHParamsCollection.REQUIRES_MEAN_SYNC in x.collections

jax.tree_map(assert_non_learnable, init_var_meta['batch_stats'])
jax.tree_map(assert_non_learnable, init_var_meta['non_trainable'])
jax.tree.map(assert_non_learnable, init_var_meta['batch_stats'])
jax.tree.map(assert_non_learnable, init_var_meta['non_trainable'])
init_vars = test_layer.init(prng_key, input_x)
_ = test_layer.apply(init_vars, input_x, mutable=True)
_ = test_layer.apply(
Expand Down
20 changes: 10 additions & 10 deletions praxis/layers/frnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def reset_mask(


def _sum_aux_loss(tree):
return jax.tree_map(jnp.sum, tree)
return jax.tree.map(jnp.sum, tree)


class FRnn(base_layer.BaseLayer):
Expand Down Expand Up @@ -143,7 +143,7 @@ def __call__(
state: Final state.
"""
# Make a copy of the input structure to avoid side-effect.
inputs = jax.tree_map(lambda x: x, inputs)
inputs = jax.tree.map(lambda x: x, inputs)
assert hasattr(inputs, 'act')
assert hasattr(inputs, 'padding')
assert isinstance(self.cell, rnn_cell.BaseRnnCell)
Expand All @@ -159,7 +159,7 @@ def __call__(
inputs.reset_mask = jnp.ones_like(inputs.padding, dtype=self.fprop_dtype)

if self.reverse:
inputs = jax.tree_map(lambda x: jnp.flip(x, axis=[1]), inputs)
inputs = jax.tree.map(lambda x: jnp.flip(x, axis=[1]), inputs)

if not state0:
batch_size = inputs.padding.shape[0]
Expand All @@ -176,7 +176,7 @@ def body_fn(sub, state0, inputs):
if self.is_initializing():
# inputs has shape [b, t, dim] or [b, t, 1]
# sliced_inputs has shape [b, dim] or [b, 1].
sliced_inputs = jax.tree_map(lambda x: x[:, 1], inputs)
sliced_inputs = jax.tree.map(lambda x: x[:, 1], inputs)
_ = body_fn(self.cell, state0, sliced_inputs)

# NON_TRAINABLE variables are carried over from one iteration to another.
Expand Down Expand Up @@ -248,7 +248,7 @@ def init_states(self, batch_size: int) -> list[NestedMap]:
def extend_step(
self, inputs: NestedMap, state: list[NestedMap]
) -> tuple[list[NestedMap], JTensor]:
inputs = jax.tree_map(lambda x: x, inputs)
inputs = jax.tree.map(lambda x: x, inputs)
new_states = []
for i in range(self.num_layers):
new_state, act_i = self.frnn[i].extend_step(inputs, state[i])
Expand All @@ -275,7 +275,7 @@ def __call__(
act: A tensor of [batch, time, dims]. The output.
state: Final state.
"""
inputs = jax.tree_map(lambda x: x, inputs)
inputs = jax.tree.map(lambda x: x, inputs)

if not state0:
batch_size = inputs.padding.shape[0]
Expand Down Expand Up @@ -375,7 +375,7 @@ def __call__(
state: Final state - a list of NestedMap of fwd and bwd states.
"""
# This is to create a copy.
inputs = jax.tree_map(lambda x: x, inputs)
inputs = jax.tree.map(lambda x: x, inputs)

if not state0:
batch_size = inputs.padding.shape[0]
Expand Down Expand Up @@ -428,7 +428,7 @@ def __call__(
state: Final state.
"""
# Make a copy of the input structure to avoid side-effect.
inputs = jax.tree_map(lambda x: x, inputs)
inputs = jax.tree.map(lambda x: x, inputs)
assert hasattr(inputs, 'act')
assert hasattr(inputs, 'padding')
assert isinstance(self.cell, rnn_cell.BaseRnnCell)
Expand All @@ -444,7 +444,7 @@ def __call__(
inputs.reset_mask = jnp.ones_like(inputs.padding, dtype=self.fprop_dtype)

if self.reverse:
inputs = jax.tree_map(lambda x: jnp.flip(x, axis=[1]), inputs)
inputs = jax.tree.map(lambda x: jnp.flip(x, axis=[1]), inputs)

if not state0:
batch_size = inputs.padding.shape[0]
Expand All @@ -466,7 +466,7 @@ def body_fn(sub, state0, inputs):
if self.is_initializing():
# inputs has shape [b, t, dim] or [b, t, 1]
# sliced_inputs has shape [b, dim] or [b, 1].
sliced_inputs = jax.tree_map(lambda x: x[:, 1], inputs)
sliced_inputs = jax.tree.map(lambda x: x[:, 1], inputs)
# `body_fn` is sufficient to trigger PARAMS initialization.
_ = body_fn(self.cell, state0, sliced_inputs)

Expand Down
4 changes: 2 additions & 2 deletions praxis/layers/frnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_frnn_lstm_cell(self, jax_cell_class, output_nonlinearity):

rnn_theta = {'params': theta['params']['cell']}
ys = []
cell_state = jax.tree_map(lambda x: x, state0)
cell_state = jax.tree.map(lambda x: x, state0)
for t in range(act_in.shape[1]):
with base_layer.JaxContext.new_context():
inputs_t = NestedMap(act=act_in[:, t], padding=padding[:, t])
Expand Down Expand Up @@ -414,7 +414,7 @@ def test_frnn_reset_cell_state(

rnn_theta = {'params': theta['params']['cell']}
ys = []
cell_state = jax.tree_map(lambda x: x, state0)
cell_state = jax.tree.map(lambda x: x, state0)
for t in range(act_in.shape[1]):
with base_layer.JaxContext.new_context():
inputs_t = NestedMap(
Expand Down
14 changes: 9 additions & 5 deletions praxis/layers/multi_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,12 +1319,16 @@ def _vmap_on_broadcast_prefixes(self, fn: attentions.FnOnDecodeStateChunk,

# Wraps fn with slicing on args_to_slice and broadcast_args_to_slice.
def _sliced_fn(layer, args, args_to_slice, broadcast_args_to_slice, states):
sliced = jax.tree_map(
lambda x, d: self._slice_decode_chunk(x, chunk_id, d), args_to_slice,
args_time_dims)
broadcast_sliced = jax.tree_map(
sliced = jax.tree.map(
lambda x, d: self._slice_decode_chunk(x, chunk_id, d),
broadcast_args_to_slice, broadcast_args_time_dims)
args_to_slice,
args_time_dims,
)
broadcast_sliced = jax.tree.map(
lambda x, d: self._slice_decode_chunk(x, chunk_id, d),
broadcast_args_to_slice,
broadcast_args_time_dims,
)
return fn(layer, args, sliced, broadcast_sliced, states)

broadcast_dim_sizes = self.get_decode_state(
Expand Down
Loading

0 comments on commit 3f4cbb4

Please sign in to comment.