diff --git a/learned_optimization/jax_utils.py b/learned_optimization/jax_utils.py index 301d090..11a0ca2 100644 --- a/learned_optimization/jax_utils.py +++ b/learned_optimization/jax_utils.py @@ -46,9 +46,12 @@ def body_fn(_, operand): def in_jit() -> bool: """Returns true if tracing jit.""" - return "DynamicJaxprTrace" in str( - jax.core.thread_local_state.trace_state.trace_stack - ) + if jax.__version_info__ <= (0, 4, 33): + return "DynamicJaxprTrace" in str( + jax.core.thread_local_state.trace_state.trace_stack + ) + + return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE() Carry = TypeVar("Carry")