Skip to content

Commit

Permalink
[Pallas] Fix lowering tests for reduction ops
Browse files Browse the repository at this point in the history
Remove unnecessary skip statements. Also added tests for bf16 types.

PiperOrigin-RevId: 694130207
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Nov 7, 2024
1 parent de06584 commit 1a544b6
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,35 +1793,37 @@ def reduce(x_ref, y_ref):
for axis in [0, 1, (1,), (0, 1)]
for dtype in [
"float16",
"bfloat16",
"float32",
"float64",
"int32",
"int64",
"uint32",
"uint64",
]
if isinstance(axis, int) or "arg" not in op_name
])
def test_array_reduce(self, op, dtype, axis):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")
if not isinstance(axis, int):
self.skipTest("TODO: tuple axes are not yet supported")

if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

if jtu.test_device_matches(["tpu"]):
self.skipTest("Unimplemented primitive: broadcast_to")

if jtu.test_device_matches(["tpu"]) and dtype == "float16":
self.skipTest("float16 is not supported on TPU")

# Skip argmin/argmax on GPU in 64-bit mode because Pallas expects
# `index_type` to be i32
if (
jax.config.x64_enabled
and jtu.test_device_matches(["gpu"])
and op in {jnp.argmin, jnp.argmax}
and op in (jnp.argmin, jnp.argmax)
):
self.skipTest("Not supported on GPU in 64-bit mode")

# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

m, n = 32, 8

def make_x(key):
Expand Down

0 comments on commit 1a544b6

Please sign in to comment.