Skip to content

Commit

Permalink
[Pallas:TPU] Use arith.divui for uint32 div.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691939453
  • Loading branch information
WindQAQ authored and Google-ML-Automation committed Oct 31, 2024
1 parent 48f24b6 commit 7af7a60
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 17 deletions.
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,7 @@ def _mul_lowering_rule(ctx: LoweringRuleContext, x, y):
def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
(aval_out,) = ctx.avals_out
if jnp.issubdtype(aval_out.dtype, jnp.integer):
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
return arith.divsi(x, y)
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
return arith.divui(x, y)
Expand Down
16 changes: 0 additions & 16 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,14 +1088,6 @@ def test_binary(self, f, dtype):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

# TODO(ayx): Fix these operations on TPU
if (
jtu.test_device_matches(["tpu"])
and f in (jnp.floor_divide, jnp.subtract)
and dtype == "uint32"
):
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1
)
Expand All @@ -1121,14 +1113,6 @@ def test_binary_scalar(self, f, dtype):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

# TODO(ayx): Fix these operations on TPU
if (
jtu.test_device_matches(["tpu"])
and f in (jnp.floor_divide, jnp.subtract)
and dtype == "uint32"
):
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call,
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
Expand Down

0 comments on commit 7af7a60

Please sign in to comment.