Skip to content

Commit

Permalink
Dedent your yields!
Browse files Browse the repository at this point in the history
Fixes a surprising interaction between the generator system in linear_util.py
and the try/finally python context managers we use for managing tracing context.
The `finally` block wasn't always being called until garbage collection, so the
context stack pushes/pops weren't always correctly nested. Dedenting the yield
fixes this particular bug but long-term we should get rid of linear_util
altogether.

PiperOrigin-RevId: 695898528
  • Loading branch information
dougalm authored and Google-ML-Automation committed Nov 12, 2024
1 parent c32db46 commit d47e254
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def batch_subtrace(tag, axis_data, in_dims, *in_vals):
outs = yield in_tracers, {}
out_vals, out_dims = unzip2(map(trace.to_batch_info, outs))
segment_lens, out_dims = indirectify_ragged_axes(out_dims)
yield (*segment_lens, *out_vals), out_dims
yield (*segment_lens, *out_vals), out_dims

def indirectify_ragged_axes(dims):
if not any(type(d) is RaggedAxis for d in dims):
Expand Down Expand Up @@ -803,7 +803,7 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
out_axes, in_vals, out_vals)
yield out_vals, new_out_axes
yield out_vals, new_out_axes

@lu.transformation_with_aux
def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes,
Expand Down

0 comments on commit d47e254

Please sign in to comment.