From f1c3109bf503435cf0bae37e744510493aae621d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 30 Oct 2024 02:39:05 -0700 Subject: [PATCH] Removed `mesh_utils._bounds_from_last_device` which was only used in tests PiperOrigin-RevId: 691342846 --- jax/_src/mesh_utils.py | 10 ---------- tests/mesh_utils_test.py | 25 ------------------------- 2 files changed, 35 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index bb6152167658..c37bbba4d836 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -572,16 +572,6 @@ def _generate_logical_mesh( return logical_mesh -def _bounds_from_last_device(last_device) -> Sequence[int]: - """Gets the bound from the given last device.""" - # Must be passed the device at the highest-coordinate corner of the - # relevant mesh, which is a requirement we know is satisfied by the last - # device in jax.devices(). - assert hasattr(last_device, 'coords'), 'Only TPU supported' - x, y, z = last_device.coords - return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 - - def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: r"""Rearrange TPU devices in a slice into a physical mesh. diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 42522d7f4b1b..66f1fc9f6cfb 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -205,31 +205,6 @@ def mock_2x2x2_v5e_devices(one_device_per_chip=True): class MeshUtilsTest(test_util.JaxTestCase): - @parameterized.named_parameters( - ('1x1', mock_1x1_devices, (1, 1, 1, 2)), - ('2x2', mock_2x2_devices, (2, 2, 1, 2)), - ('4x4', mock_4x4_devices, (4, 4, 1, 2)), - ('8x8', mock_8x8_devices, (8, 8, 1, 2)), - ) - def test_bounds_from_last_device_2d(self, devices, expected_bounds): - self.assertEqual( - mesh_utils._bounds_from_last_device(devices()[-1]), - expected_bounds) - - @parameterized.named_parameters( - ('1x2x1_t', mock_1x2x1_devices, True, (1, 2, 1, 1)), - ('1x2x1_f', mock_1x2x1_devices, False, (1, 2, 1, 2)), - ('2x2x1_t', mock_2x2x1_devices, True, (2, 2, 1, 1)), - ('2x2x1_f', mock_2x2x1_devices, False, (2, 2, 1, 2)), - ('8x8x16_t', mock_8x8x16_devices, True, (8, 8, 16, 1)), - ('8x8x16_f', mock_8x8x16_devices, False, (8, 8, 16, 2)), - ) - def test_bounds_from_last_device_3d(self, devices, one_device_per_chip, - expected_bounds): - self.assertEqual( - mesh_utils._bounds_from_last_device(devices(one_device_per_chip)[-1]), - expected_bounds) - @parameterized.named_parameters( ('1x2x1_t', (1, 2, 1), True), ('4x4x4_t', (4, 4, 4), True),