Skip to content

Commit

Permalink
Provide the utility functions orbax.export ApplyFn_Map to and from disk.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648500630
  • Loading branch information
maxwillzq authored and Orbax Authors committed Jul 10, 2024
1 parent 1e06498 commit 6168e90
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9"]
jax-version: ["newest"]
include:
- python-version: "3.9"
Expand Down Expand Up @@ -76,7 +76,7 @@ jobs:
strategy:
matrix:
python-version: ["3.9"]
jax-version: ["newest", "0.4.26"] # keep in sync with minimum version in export/pyproject.toml
jax-version: ["newest", "0.4.30"] # keep in sync with minimum version in export/pyproject.toml
steps:
- name: Cancel previous
uses: styfle/[email protected]
Expand Down
92 changes: 91 additions & 1 deletion export/orbax/export/jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@

import dataclasses
import os
from typing import Any, Callable, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union

from absl import logging
import jax
from jax import export as jax_export
from jax.experimental import jax2tf
import orbax.checkpoint as ocp
from orbax.export import dtensor_utils
from orbax.export import typing as orbax_export_typing
from orbax.export import utils as orbax_export_utils
import tensorflow as tf
from tensorflow.experimental import dtensor

Expand Down Expand Up @@ -333,6 +335,23 @@ def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
apply_fn_map = self._nontrackable_metadata.apply_fn_map
return _make_closures(params, apply_fn_map)

def to_jax_exported_map(
self, model_inputs: PyTree, output_dir: Union[str, None] = None
) -> Mapping[str, jax_export.Exported]:
"""Converts the orbax.export JaxModule to jax_export.Exported.
Args:
model_inputs: The model inputs.
output_dir: The output directory to save the jax_exported_map.
Returns:
A mapping from method key to jax_export.Exported.
"""
jax_exported_map = _jax_module_to_jax_exported_map(self, model_inputs)
if output_dir is not None:
orbax_export_utils.save_jax_exported_map(output_dir, jax_exported_map)
return jax_exported_map


def _get_param_names(params: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
Expand Down Expand Up @@ -429,3 +448,74 @@ def _to_tf_variable(x, name, trainable, pspec):
return jax.tree_util.tree_map(
_to_tf_variable, params, names, trainable, pspecs
)


def _jax_module_to_jax_exported_map(
j_module: JaxModule,
model_inputs: PyTree,
) -> Mapping[str, jax_export.Exported]:
"""Convert the orbax.export JaxModule to jax_export.Exported.
Args:
j_module: The orbax.export JaxModule.
model_inputs: The model inputs.
Returns:
A mapping from method key to jax_export.Exported.
"""
apply_fn_map = j_module.apply_fn_map
model_params = j_module.model_params
input_polymorphic_shape_map = j_module.input_polymorphic_shape_map
jax2tf_kwargs_map = j_module.jax2tf_kwargs_map

jax_exported_map = {}

def _symbolic_args_specs(model_inputs, method_key):
input_polymorphic_shape = input_polymorphic_shape_map[method_key]
polymorphic_constraints: Sequence[str] = ()
if 'polymorphic_constraints' in jax2tf_kwargs_map[method_key]:
polymorphic_constraints = jax2tf_kwargs_map[method_key][
'polymorphic_constraints'
]
if input_polymorphic_shape is None:
return model_inputs
else:
return jax_export.symbolic_args_specs(
model_inputs,
input_polymorphic_shape,
constraints=polymorphic_constraints,
)

symbolic_model_inputs_map = {
k: _symbolic_args_specs(model_inputs, k)
for k in input_polymorphic_shape_map.keys()
}

def _lowering_platforms(
jax2tf_kwargs: Any,
) -> Optional[Sequence[str]]:
if jax2tf_kwargs and 'native_serialization_platforms' in jax2tf_kwargs:
return tuple(jax2tf_kwargs['native_serialization_platforms'])
else:
return None

lowering_platforms_map = {
k: _lowering_platforms(v) for k, v in jax2tf_kwargs_map.items()
}

for method_key, apply_fn in apply_fn_map.items():
if not hasattr(apply_fn, 'trace'):
apply_fn = jax.jit(apply_fn)
if method_key not in input_polymorphic_shape_map:
raise ValueError(
f'Method key {method_key} not found in input_polymorphic_shape_map.'
)
if method_key not in lowering_platforms_map:
raise ValueError(
f'Method key {method_key} not found in lowering_platforms_map.'
)
jax_exported = jax_export.export(
apply_fn, platforms=lowering_platforms_map[method_key]
)(model_params, symbolic_model_inputs_map[method_key])
jax_exported_map[method_key] = jax_exported
return jax_exported_map
63 changes: 63 additions & 0 deletions export/orbax/export/jax_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
"""Tests for jax_module."""

import collections
import os

from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp
import numpy as np
from orbax import export as obx_export
from orbax.export import utils as orbax_export_utils
import tensorflow as tf

DEFAULT_METHOD_KEY = obx_export.JaxModule.DEFAULT_METHOD_KEY
JaxModule = obx_export.JaxModule


def _register_custom_dict_to_jax(dict_cls):
Expand Down Expand Up @@ -419,6 +422,66 @@ def test_variable_update_error(self):
with self.assertRaisesRegex(ValueError, 'Incompatible type conversion'):
jax_module.update_variables({'w': np.zeros((4, 8), dtype=np.int32)})

def test_save_load_as_jax_exported_map(self):

def linear(params, x):
return params['w'] @ x + params['b']

key_w, key_b, key_x = jax.random.split(jax.random.PRNGKey(1234), 3)
model_params = {
'w': jax.random.normal(key_w, shape=(8, 8)),
'b': jax.random.normal(key_b, shape=(8, 1)),
}
model_inputs = jax.random.normal(key_x, shape=(8, 1))
lowering_platforms = ['cpu', 'tpu']

j_module = JaxModule(
model_params,
linear,
jax2tf_kwargs={'native_serialization_platforms': lowering_platforms},
)
root_dir = self.create_tempdir().full_path
saved_dir = os.path.join(root_dir, 'jax_exported_map')
jax_exported_map = j_module.to_jax_exported_map(model_inputs, saved_dir)
restored_jax_exported_map = orbax_export_utils.load_jax_exported_map(
saved_dir
)
self.assertEqual(
set(restored_jax_exported_map.keys()),
set(jax_exported_map.keys()),
f'{restored_jax_exported_map.keys()} vs {jax_exported_map.keys()}',
)
self.assertEqual(
set(restored_jax_exported_map.keys()),
set(j_module.apply_fn_map.keys()),
f'{restored_jax_exported_map.keys()} vs {j_module.apply_fn_map.keys()}',
)
chex.assert_trees_all_close(
restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].call(
model_params, model_inputs
),
linear(model_params, model_inputs),
)
chex.assert_equal(
set(restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].platforms),
set(lowering_platforms),
)
args_kwargs = ((model_params, model_inputs), {})
in_tree = jax.tree.structure(args_kwargs)
in_avals = tuple(jax.tree.leaves(args_kwargs))
chex.assert_equal(
in_tree,
restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_tree,
)
chex.assert_trees_all_equal_shapes(
in_avals,
restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_avals,
)
chex.assert_trees_all_equal_dtypes(
in_avals,
restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_avals,
)


if __name__ == '__main__':
tf.test.main()
5 changes: 1 addition & 4 deletions export/orbax/export/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@

from typing import Any, Callable, Mapping, Sequence, TypeVar, Union
import jaxtyping
from orbax.export import utils as orbax_export_utils
import tensorflow as tf


T = TypeVar('T')
Nested = Union[T, tuple[Any, ...], Sequence[Any], Mapping[str, Any]]
WarmupExample = Union[list[Mapping[str, Any]], Mapping[str, Any]]
NestedTfTrackable = Nested[tf.saved_model.experimental.TrackableResource]
NestedTfTensorSpec = Nested[
Union[tf.TensorSpec, orbax_export_utils.TensorSpecWithDefault]
]


PyTree = jaxtyping.PyTree

Expand Down
73 changes: 72 additions & 1 deletion export/orbax/export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@
import dataclasses
import functools
import inspect
from typing import Any, Callable, Optional
import os
from typing import Any, Callable, Optional, Union
from absl import logging
import jax
from jax import export as jax_export
import jaxtyping
import tensorflow as tf

ConfigProto = Any
PyTree = jaxtyping.PyTree
SignatureDef = Any

_FILE_TYPE = 'jax_exported'


@dataclasses.dataclass
class TensorSpecWithDefault:
Expand Down Expand Up @@ -62,6 +67,11 @@ def __post_init__(self):
)


NestedTfTensorSpec = jaxtyping.PyTree[
Union[tf.TensorSpec, TensorSpecWithDefault]
]


def remove_signature_defaults(input_signature: PyTree) -> PyTree:
"""Removes TensorSpecWithDefault from an input_signature."""

Expand Down Expand Up @@ -329,3 +339,64 @@ def from_saved_model(
def signatures(self):
"""Returns a mapping for signature names to python callables."""
return self._signatures


def save_jax_exported_to_disk(
exp: jax_export.Exported,
bin_file_path: str,
vjp_order: int = 0,
) -> None:
if tf.io.gfile.exists(bin_file_path):
raise ValueError(f'File {bin_file_path} already exists.')
with tf.io.gfile.GFile(bin_file_path, 'wb') as f:
f.write(exp.serialize(vjp_order=vjp_order))


def load_jax_exported_from_disk(bin_file_path: str) -> jax_export.Exported:
if not tf.io.gfile.exists(bin_file_path):
raise ValueError(f'File {bin_file_path} does not exist.')
with tf.io.gfile.GFile(bin_file_path, 'rb') as f:
exp = jax_export.deserialize(bytearray(f.read()))
return exp


def save_jax_exported_map(
dir_path: str,
jax_exported_map: Mapping[str, jax_export.Exported],
):
"""Saves the orbax.export JaxExported Map to disk."""
if tf.io.gfile.exists(dir_path):
raise ValueError(f'Directory {dir_path} already exists.')

tf.io.gfile.makedirs(dir_path)
for method_key, jax_exported in jax_exported_map.items():
file_path = os.path.join(dir_path, f'{method_key}.{_FILE_TYPE}')
save_jax_exported_to_disk(jax_exported, os.path.join(dir_path, file_path))
logging.info('Saved JaxExported Map to %s successfully.', dir_path)


def load_jax_exported_map(dir_path: str) -> Mapping[str, jax_export.Exported]:
"""Loads the orbax.export ApplyFn JaxExported Map from disk.
Args:
dir_path: The directory path to load the ApplyFn Map.
Returns:
A map of method_key to JaxExported object.
"""
jax_exported_map = {}

if not tf.io.gfile.exists(dir_path):
raise ValueError(f'Directory {dir_path} does not exist.')

for method_key in tf.io.gfile.listdir(dir_path):
if not method_key.endswith(f'.{_FILE_TYPE}'):
continue
jax_exported = load_jax_exported_from_disk(
os.path.join(dir_path, method_key)
)
jax_exported_map[method_key[: -len(f'.{_FILE_TYPE}')]] = jax_exported
if not jax_exported_map:
raise ValueError(f'No .{_FILE_TYPE} files found in {dir_path}.')
logging.info('Loaded ApplyFn JaxExported Map from %s successfully.', dir_path)
return jax_exported_map
2 changes: 1 addition & 1 deletion export/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
'absl-py',
'etils',
'orbax-checkpoint',
'jax >= 0.4.26',
'jax >= 0.4.30',
'jaxlib',
'numpy',
'dataclasses-json',
Expand Down

0 comments on commit 6168e90

Please sign in to comment.