Skip to content

Commit

Permalink
Removed mesh_utils._bounds_from_last_device which was only used in …
Browse files Browse the repository at this point in the history
…tests

PiperOrigin-RevId: 691342846
  • Loading branch information
superbobry authored and Google-ML-Automation committed Oct 30, 2024
1 parent bdf2ca1 commit f1c3109
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 35 deletions.
10 changes: 0 additions & 10 deletions jax/_src/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,16 +572,6 @@ def _generate_logical_mesh(
return logical_mesh


def _bounds_from_last_device(last_device) -> Sequence[int]:
"""Gets the bound from the given last device."""
# Must be passed the device at the highest-coordinate corner of the
# relevant mesh, which is a requirement we know is satisfied by the last
# device in jax.devices().
assert hasattr(last_device, 'coords'), 'Only TPU supported'
x, y, z = last_device.coords
return x + 1, y + 1, z + 1, last_device.core_on_chip + 1


def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray:
r"""Rearrange TPU devices in a slice into a physical mesh.
Expand Down
25 changes: 0 additions & 25 deletions tests/mesh_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,31 +205,6 @@ def mock_2x2x2_v5e_devices(one_device_per_chip=True):

class MeshUtilsTest(test_util.JaxTestCase):

@parameterized.named_parameters(
('1x1', mock_1x1_devices, (1, 1, 1, 2)),
('2x2', mock_2x2_devices, (2, 2, 1, 2)),
('4x4', mock_4x4_devices, (4, 4, 1, 2)),
('8x8', mock_8x8_devices, (8, 8, 1, 2)),
)
def test_bounds_from_last_device_2d(self, devices, expected_bounds):
self.assertEqual(
mesh_utils._bounds_from_last_device(devices()[-1]),
expected_bounds)

@parameterized.named_parameters(
('1x2x1_t', mock_1x2x1_devices, True, (1, 2, 1, 1)),
('1x2x1_f', mock_1x2x1_devices, False, (1, 2, 1, 2)),
('2x2x1_t', mock_2x2x1_devices, True, (2, 2, 1, 1)),
('2x2x1_f', mock_2x2x1_devices, False, (2, 2, 1, 2)),
('8x8x16_t', mock_8x8x16_devices, True, (8, 8, 16, 1)),
('8x8x16_f', mock_8x8x16_devices, False, (8, 8, 16, 2)),
)
def test_bounds_from_last_device_3d(self, devices, one_device_per_chip,
expected_bounds):
self.assertEqual(
mesh_utils._bounds_from_last_device(devices(one_device_per_chip)[-1]),
expected_bounds)

@parameterized.named_parameters(
('1x2x1_t', (1, 2, 1), True),
('4x4x4_t', (4, 4, 4), True),
Expand Down

0 comments on commit f1c3109

Please sign in to comment.