diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 265d36d62b50..6e7402c20a15 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -417,10 +417,11 @@ def _lower_mosaic_module_to_asm( needs_layout_passes = not device_type # We'll mutate the module, so clone it with module.context as ctx, module.operation.location as _: - module = ir.Module.parse( - module.operation.get_asm(binary=True, enable_debug_info=True) - ) if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: + module = ir.Module.parse( + module.operation.get_asm(binary=True, enable_debug_info=True) + ) + module_op = module.operation some_tpu = jax.devices(backend)[0] device_kind = some_tpu.device_kind if not device_kind.startswith("TPU v"): @@ -435,15 +436,17 @@ def _lower_mosaic_module_to_asm( ) needs_hlo_passes = False needs_layout_passes = False + else: + module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True try: pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})") - pipeline.run(module.operation) + pipeline.run(module_op) finally: ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects bytecode_buffer = io.BytesIO() - module.operation.write_bytecode(bytecode_buffer, desired_version=0) + module_op.write_bytecode(bytecode_buffer, desired_version=0) asm = bytecode_buffer.getvalue() return asm, ( has_communication,