Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add memory limiting functionality for save, in addition to restore. #956

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
- Rolled back change in previous release to improve TensorStore I/O efficiency.
This change caused some unexpected failures on certain storage systems.
- Add memory-based rate limiting support during save.

## [0.5.16] - 2024-06-11

Expand Down
8 changes: 4 additions & 4 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,9 @@
_CHECKPOINT_FILE = 'checkpoint'


def get_byte_limiter(concurrent_gb: int):
def get_byte_limiter(concurrent_bytes: int):
async def _create_byte_limiter():
# Wrap creation in async function to avoid issues on python<=3.9.
concurrent_bytes = concurrent_gb * 10**9
# Construction must take place here so that it is within the same async
# method, to prevent errors resulting from different event loops, and
# cannot be created below this level because there must be a single object
Expand Down Expand Up @@ -366,7 +365,7 @@ def __init__(
if aggregate_filename is None:
aggregate_filename = _CHECKPOINT_FILE
self._aggregate_filename = aggregate_filename
self._concurrent_gb = concurrent_gb
self._concurrent_bytes = concurrent_gb * 10**9
self._use_ocdbt = use_ocdbt
self._use_zarr3 = use_zarr3
self._primary_host = primary_host
Expand Down Expand Up @@ -531,6 +530,7 @@ def _maybe_set_default_save_args(value, args_):
item if save_args is None else save_args,
is_leaf=tree_utils.is_empty_or_leaf,
)
# byte_limiter = get_byte_limiter(self._concurrent_gb)
param_infos, all_params_aggregated = self._get_param_infos(
item, directory, save_args, ocdbt_target_data_file_size
)
Expand Down Expand Up @@ -764,7 +764,7 @@ class TrainState:
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
byte_limiter = get_byte_limiter(self._concurrent_gb)
byte_limiter = get_byte_limiter(self._concurrent_bytes)
structure, use_zarr3_metadata = self._get_internal_metadata(directory)
# `checkpoint_restore_args` has a structure relative to the checkpoint,
# while `restore_args` remains structured relative to the output.
Expand Down
4 changes: 2 additions & 2 deletions checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def __init__(
if aggregate_filename is None:
aggregate_filename = _CHECKPOINT_FILE
self._aggregate_filename = aggregate_filename
self._concurrent_gb = concurrent_gb
self._concurrent_bytes = concurrent_gb * 10**9
self._use_ocdbt = use_ocdbt
self._use_zarr3 = use_zarr3
self._primary_host = primary_host
Expand Down Expand Up @@ -646,7 +646,7 @@ class TrainState:
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
byte_limiter = get_byte_limiter(self._concurrent_gb)
byte_limiter = get_byte_limiter(self._concurrent_bytes)
structure, use_zarr3_metadata = self._handler_impl._get_internal_metadata( # pylint: disable=protected-access
directory
)
Expand Down
62 changes: 61 additions & 1 deletion checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def create_sharded_array(arr, mesh, mesh_axes):
return jax.make_array_from_callback(
arr.shape,
jax.sharding.NamedSharding(mesh, mesh_axes),
lambda idx: arr[idx],
lambda idx: np.asarray(arr[idx], dtype=arr.dtype),
)


Expand Down Expand Up @@ -355,6 +355,18 @@ def set_tensorstore_driver_for_test():
serialization._DEFAULT_DRIVER = 'file' # pylint: disable=protected-access


class PyTreeCheckpointHandler(
pytree_checkpoint_handler.PyTreeCheckpointHandler
):

def save(self, directory, *args, **kwargs):
super().save(directory, *args, **kwargs)
sync_global_processes('PyTreeCheckpointHandler:save')
if multihost.process_index() == 0:
self.finalize(directory)
sync_global_processes('PyTreeCheckpointHandler:finalize')


class ErrorCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
"""Wrapper for PyTreeCheckpointHandler that has an error during save."""

Expand Down Expand Up @@ -472,3 +484,51 @@ def test_foo(self):
if name.startswith('test'):
setattr(cls, name, _get_wrapper_function(func))
return cls


def concurrent_gb_test_setup(limit_bytes: int):
"""Setup for tests exercising concurrent_gb setting."""
jax.config.update('jax_enable_x64', True)
handler = PyTreeCheckpointHandler(concurrent_gb=1)
handler._handler_impl._concurrent_bytes = ( # pylint: disable=protected-access
limit_bytes # override so we can use a small number of bytes.
)

mesh = jax.sharding.Mesh(
jax.devices(),
('x',),
)
pspec = jax.sharding.PartitionSpec(
None,
)

def _create_sharded_array(arr):
return create_sharded_array(arr, mesh, pspec)

# 3 arrays, each has a single chunk, with 8 bytes
tree = jax.tree.map(
_create_sharded_array,
{
'a': np.arange(1, dtype=np.int64),
'b': np.arange(1, dtype=np.int64),
'c': np.arange(1, dtype=np.int64),
'd': np.arange(1, dtype=np.int64),
},
)
restore_args = jax.tree.map(
lambda _: type_handlers.ArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec)
),
tree,
)
return handler, tree, restore_args


def assert_every_n_is_x_apart(testclass, values, n, x):
# For an array of values which is divided into sub-arrays of size n,
# asserts that the first element of every group is at least x greater
# than the last element of the previous group.
values = sorted(values)
assert len(values) % n == 0
for i in range(n, len(values), n):
testclass.assertGreaterEqual(values[i], values[i - 1] + x)
31 changes: 11 additions & 20 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,26 +1303,17 @@ async def serialize(
logging.debug('args = %s', arg)
logging.debug('replica_id = %s', replica_id)

if jax.__version_info__ > (0, 4, 25):
synchronous_ops += [
serialization.async_serialize(
value,
tspec,
commit_future=futures,
context=ts_context,
primary_host=self._primary_host,
replica_id=replica_id,
)
]
else:
synchronous_ops += [
serialization.async_serialize(
value,
tspec,
commit_future=futures,
context=ts_context,
)
]
synchronous_ops += [
serialization.async_serialize(
value,
tspec,
commit_future=futures,
context=ts_context,
primary_host=self._primary_host,
replica_id=replica_id,
byte_limiter=info.byte_limiter,
)
]

if value.sharding is not None:
if info.parent_dir is None:
Expand Down
Loading