Skip to content

Commit

Permalink
Support None in PmapSharding as a replacement for device_put_replicated.
Browse files Browse the repository at this point in the history
eg:
`jax.device_put(x, PmapSharding.default(x.shape, None, jax.local_devices()))`
PiperOrigin-RevId: 689956669
  • Loading branch information
pschuh authored and Google-ML-Automation committed Oct 26, 2024
1 parent 47bacfa commit 6b06557
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
14 changes: 8 additions & 6 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def is_equivalent_to(self: PmapSharding, other: PmapSharding, # type: ignore

# TODO(yashkatariya): Expose `sharded_dim_size` in the API if required.
@classmethod
def default(cls, shape: Shape, sharded_dim: int = 0,
def default(cls, shape: Shape, sharded_dim: int | None = 0,
devices: Sequence[xc.Device] | None = None) -> PmapSharding:
"""Creates a :class:`PmapSharding` which matches the default placement
used by :func:`jax.pmap`.
Expand All @@ -547,6 +547,13 @@ def default(cls, shape: Shape, sharded_dim: int = 0,
device order used by pmap is used, which is the order of
:func:`jax.local_devices`.
"""
if sharded_dim is None:
if devices is None:
raise ValueError("One of sharded_dim or devices must be set.")
nrep = len(devices)
return cls(np.array(devices),
sharding_specs.pmap_sharding_spec(nrep, nrep, shape, None))

# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
sharding_spec = sharding_specs.create_pmap_sharding_spec(
Expand All @@ -565,11 +572,6 @@ def default(cls, shape: Shape, sharded_dim: int = 0,
raise NotImplementedError(
'Multiple chunks in Chunked dimension not supported.')

if num_ways_sharded is None:
raise NotImplementedError(
'`None` to sharded_dim is not supported. Please file a jax '
'issue if you need this feature.')

if devices is None:
pmap_devices: np.ndarray = np.array(
xla_bridge.local_devices()[:num_ways_sharded])
Expand Down
8 changes: 8 additions & 0 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,14 @@ def test_default_pmap_sharding_with_devices(self):
ps = jax.sharding.PmapSharding.default((4, 2), devices=new_order)
self.assertEqual(ps._device_assignment, new_order)

def test_default_pmap_sharding_replicated(self):
x = np.zeros((len(jax.local_devices()), 8), dtype=np.float32)
x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(x)
ps = jax.sharding.PmapSharding.default(
shape=(8,), sharded_dim=None,
devices=jax.local_devices())
self.assertEqual(x.sharding, ps)

def test_mesh_repr(self):
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
mesh_repr = repr(mesh)
Expand Down

0 comments on commit 6b06557

Please sign in to comment.