diff --git a/learned_optimization/circular_buffer.py b/learned_optimization/circular_buffer.py index 4710f5a..371f686 100644 --- a/learned_optimization/circular_buffer.py +++ b/learned_optimization/circular_buffer.py @@ -23,6 +23,7 @@ from typing import Generic, Tuple, TypeVar import jax +from jax import tree_util import jax.numpy as jnp CircularBufferState = collections.namedtuple("CircularBufferState", @@ -54,7 +55,7 @@ def build_one(x): tiled = jnp.tile(expanded, [self.size] + [1] * len(x.shape)) return jnp.asarray(tiled, dtype=x.dtype) - empty_buffer = jax.tree_map(build_one, self.abstract_value) + empty_buffer = tree_util.tree_map(build_one, self.abstract_value) return CircularBufferState( idx=jnp.asarray(0, jnp.int64), values=(empty_buffer, @@ -71,7 +72,8 @@ def do_update(src, to_set): else: return src.at[idx, :].set(to_set) - new_jax_array = jax.tree_map(do_update, state.values, (value, state.idx)) + new_jax_array = tree_util.tree_map(do_update, state.values, + (value, state.idx)) return CircularBufferState(idx=state.idx + 1, values=new_jax_array) def _reorder(self, vals, idx): @@ -100,8 +102,8 @@ def stack_reorder(self, state: CircularBufferState) -> Tuple[T, jnp.ndarray]: # candidate = jnp.clip((state.values[1] - state.idx + self.size), -1, self.size) mask = self._reorder(jnp.where(candidate == -1, 0, 1), state.idx) - return jax.tree_map(lambda x: self._reorder(x, state.idx), - state.values[0]), mask + return tree_util.tree_map(lambda x: self._reorder(x, state.idx), + state.values[0]), mask @functools.partial(jax.jit, static_argnums=(0,)) def gather_from_present( @@ -109,4 +111,4 @@ def gather_from_present( """Get the values from for each idx in the past.""" offset = (idxs % self.size) idx = (state.idx + offset) % self.size - return jax.tree_map(lambda x: x[idx], state.values[0]) + return tree_util.tree_map(lambda x: x[idx], state.values[0])