Skip to content

Commit

Permalink
[mosaic] Use .clone() to duplicate a module, rather than printing and…
Browse files Browse the repository at this point in the history
… parsing it.

PiperOrigin-RevId: 689708462
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Oct 25, 2024
1 parent 9088add commit bb5fbec
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,11 @@ def _lower_mosaic_module_to_asm(
needs_layout_passes = not device_type
# We'll mutate the module, so clone it
with module.context as ctx, module.operation.location as _:
module = ir.Module.parse(
module.operation.get_asm(binary=True, enable_debug_info=True)
)
if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value:
module = ir.Module.parse(
module.operation.get_asm(binary=True, enable_debug_info=True)
)
module_op = module.operation
some_tpu = jax.devices(backend)[0]
device_kind = some_tpu.device_kind
if not device_kind.startswith("TPU v"):
Expand All @@ -435,15 +436,17 @@ def _lower_mosaic_module_to_asm(
)
needs_hlo_passes = False
needs_layout_passes = False
else:
module_op = module.operation.clone()
prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects
ctx.allow_unregistered_dialects = True
try:
pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})")
pipeline.run(module.operation)
pipeline.run(module_op)
finally:
ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects
bytecode_buffer = io.BytesIO()
module.operation.write_bytecode(bytecode_buffer, desired_version=0)
module_op.write_bytecode(bytecode_buffer, desired_version=0)
asm = bytecode_buffer.getvalue()
return asm, (
has_communication,
Expand Down

0 comments on commit bb5fbec

Please sign in to comment.