Skip to content

Commit

Permalink
Sort devices by their implicit order instead of explicitly by id. IDs…
Browse files Browse the repository at this point in the history
… may be randomly generated, so it's better to rely on the implicit order, which is currently based on (process index, id).

PiperOrigin-RevId: 650294623
  • Loading branch information
Orbax Authors committed Jul 10, 2024
1 parent 1e06498 commit a9f8a3b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def fully_replicated_host_local_array_to_global_array(
# pmap-produced Array has a "scrambled" device order.
dbs = sorted(
[shard.data for shard in arr.addressable_shards],
key=lambda x: list(x.devices())[0].id,
key=lambda x: list(x.devices())[0],
)
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(mesh, partition_spec), dbs
Expand Down

0 comments on commit a9f8a3b

Please sign in to comment.