diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index e04e26f5..bbc44862 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -125,7 +125,7 @@ def set_device(self): device_id = local_rank torch.cuda.set_device(torch.cuda.device(device_id)) - def get_local_ranks(self, world_rank: int) -> Tuple[int, int, int]: + def get_local_ranks(self, world_rank: int) -> Tuple[int, int, int, int]: return tuple(i.item() for i in np.where(self.world_rank_matrix == world_rank)) def destroy(self): diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 7c0d2462..a2e13aa6 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -20,7 +20,7 @@ def _test_init_parallel_context(parallel_context: ParallelContext): world_rank = dist.get_rank(parallel_context.world_pg) ranks3d = parallel_context.get_local_ranks(world_rank) - assert isinstance(ranks3d, tuple) and len(ranks3d) + assert isinstance(ranks3d, tuple) and len(ranks3d) == 4 assert isinstance(parallel_context.world_rank_matrix, np.ndarray) assert isinstance(parallel_context.world_ranks_to_pg, dict)