Skip to content

Commit

Permalink
🔀 Merge pull request #27 from alvarobartt/add-chex-assertions
Browse files Browse the repository at this point in the history
🧪 Add assertions over `pytrees` with `chex`
  • Loading branch information
alvarobartt authored Jan 5, 2023
2 parents 4bce723 + c72f1d4 commit ecf6a3e
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 27 deletions.
Empty file added tests/__init__.py
Empty file.
12 changes: 10 additions & 2 deletions tests/test_core_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from safejax.core.save import serialize
from safejax.typing import ParamsDictLike

from .utils import assert_over_trees


@pytest.mark.parametrize(
"params, deserialize_kwargs, expected_output_type",
Expand Down Expand Up @@ -45,6 +47,8 @@ def test_deserialize(
assert id(decoded_params) != id(params)
assert decoded_params.keys() == params.keys()

assert_over_trees(params=params, decoded_params=decoded_params)


@pytest.mark.parametrize(
"params, deserialize_kwargs, expected_output_type",
Expand All @@ -70,7 +74,7 @@ def test_deserialize(
)
@pytest.mark.usefixtures("safetensors_file")
def test_deserialize_from_file(
params: FrozenDict,
params: ParamsDictLike,
deserialize_kwargs: Dict[str, Any],
expected_output_type: Union[dict, FrozenDict, VarCollection],
safetensors_file: Path,
Expand All @@ -82,6 +86,8 @@ def test_deserialize_from_file(
assert id(decoded_params) != id(params)
assert decoded_params.keys() == params.keys()

assert_over_trees(params=params, decoded_params=decoded_params)


@pytest.mark.parametrize(
"params, deserialize_kwargs, expected_output_type",
Expand All @@ -107,7 +113,7 @@ def test_deserialize_from_file(
)
@pytest.mark.usefixtures("safetensors_file", "fs")
def test_deserialize_from_file_in_fs(
params: FrozenDict,
params: ParamsDictLike,
deserialize_kwargs: Dict[str, Any],
expected_output_type: Union[dict, FrozenDict, VarCollection],
safetensors_file: Path,
Expand All @@ -121,3 +127,5 @@ def test_deserialize_from_file_in_fs(
assert len(decoded_params) > 0
assert id(decoded_params) != id(params)
assert decoded_params.keys() == params.keys()

assert_over_trees(params=params, decoded_params=decoded_params)
43 changes: 18 additions & 25 deletions tests/test_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from safejax.flax import deserialize, serialize
from safejax.typing import FlaxParams
from safejax.utils import flatten_dict

from .utils import assert_over_trees


@pytest.mark.parametrize(
Expand All @@ -25,12 +26,17 @@
)
def test_partial_deserialize(params: FlaxParams) -> None:
encoded_params = serialize(params=params)
assert isinstance(encoded_params, bytes)
assert len(encoded_params) > 0

decoded_params = deserialize(path_or_buf=encoded_params)
assert isinstance(decoded_params, FrozenDict)
assert len(decoded_params) > 0
assert id(decoded_params) != id(params)
assert decoded_params.keys() == params.keys()

assert_over_trees(params=params, decoded_params=decoded_params)


@pytest.mark.parametrize(
"params",
Expand All @@ -44,12 +50,17 @@ def test_partial_deserialize_from_file(
params: FlaxParams, safetensors_file: Path
) -> None:
safetensors_file = serialize(params=params, filename=safetensors_file)
assert isinstance(safetensors_file, Path)
assert safetensors_file.exists()

decoded_params = deserialize(path_or_buf=safetensors_file)
assert isinstance(decoded_params, FrozenDict)
assert len(decoded_params) > 0
assert id(decoded_params) != id(params)
assert decoded_params.keys() == params.keys()

assert_over_trees(params=params, decoded_params=decoded_params)


@pytest.mark.parametrize(
"params",
Expand Down Expand Up @@ -83,14 +94,8 @@ def test_safejax_and_msgpack(
assert id(msgpack_decoded_params) != id(params)
assert msgpack_decoded_params.keys() == params.keys()

params = flatten_dict(params)
safetensors_decoded_params = flatten_dict(safetensors_decoded_params)
msgpack_decoded_params = flatten_dict(msgpack_decoded_params)
assert safetensors_decoded_params.keys() == msgpack_decoded_params.keys()
assert all(
safetensors_decoded_params[k].shape == msgpack_decoded_params[k].shape
for k in params.keys()
)
assert_over_trees(params=params, decoded_params=safetensors_decoded_params)
assert_over_trees(params=params, decoded_params=msgpack_decoded_params)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -123,14 +128,8 @@ def test_safejax_and_msgpack_bytes(params: FlaxParams) -> None:
assert id(msgpack_bytes_decoded_params) != id(params)
assert msgpack_bytes_decoded_params.keys() == params.keys()

params = flatten_dict(params)
safetensors_decoded_params = flatten_dict(safetensors_decoded_params)
msgpack_bytes_decoded_params = flatten_dict(msgpack_bytes_decoded_params)
assert safetensors_decoded_params.keys() == msgpack_bytes_decoded_params.keys()
assert all(
safetensors_decoded_params[k].shape == msgpack_bytes_decoded_params[k].shape
for k in params.keys()
)
assert_over_trees(params=params, decoded_params=safetensors_decoded_params)
assert_over_trees(params=params, decoded_params=msgpack_bytes_decoded_params)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -161,11 +160,5 @@ def test_safejax_and_state_dict(params: FlaxParams) -> None:
assert id(state_dict_decoded_params) != id(params)
assert state_dict_decoded_params.keys() == params.keys()

params = flatten_dict(params)
safetensors_decoded_params = flatten_dict(safetensors_decoded_params)
state_dict_decoded_params = flatten_dict(state_dict_decoded_params)
assert safetensors_decoded_params.keys() == state_dict_decoded_params.keys()
assert all(
safetensors_decoded_params[k].shape == state_dict_decoded_params[k].shape
for k in params.keys()
)
assert_over_trees(params=params, decoded_params=safetensors_decoded_params)
assert_over_trees(params=params, decoded_params=state_dict_decoded_params)
6 changes: 6 additions & 0 deletions tests/test_haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from safejax.haiku import deserialize, serialize
from safejax.typing import HaikuParams

from .utils import assert_over_trees


@pytest.mark.parametrize(
"params",
Expand All @@ -23,6 +25,8 @@ def test_serialize_and_deserialize(params: HaikuParams) -> None:
assert id(decoded_params) != id(params)
assert decoded_params.keys() == params.keys()

assert_over_trees(params=params, decoded_params=decoded_params)


@pytest.mark.parametrize(
"params",
Expand All @@ -43,3 +47,5 @@ def test_serialize_and_deserialize_from_file(
assert len(decoded_params) > 0
assert id(decoded_params) != id(params)
assert decoded_params.keys() == params.keys()

assert_over_trees(params=params, decoded_params=decoded_params)
44 changes: 44 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import warnings

import chex
import jax
from flax.core.frozen_dict import FrozenDict, unfreeze
from objax.variable import VarCollection

from safejax.typing import ParamsDictLike


def assert_over_trees(params: ParamsDictLike, decoded_params: ParamsDictLike) -> None:
"""Assertions using `chex` to compare two trees of parameters.
Note:
This function does not support `objax.variable.VarCollection` objects yet,
so the assertions are just done over `jax`, `flax`, and `haiku` params.
Args:
params: a `ParamsDictLike` object with the original parameters.
decoded_params: a `ParamsDictLike` object with the decoded parameters using `safejax`.
Raises:
AssertionError: if the two trees are not equal on dtype, shape, structure, and values.
"""
if isinstance(params, VarCollection) or isinstance(decoded_params, VarCollection):
warnings.warn(
"This function does not support `objax.variable.VarCollection` objects yet."
)
else:
params = unfreeze(params) if isinstance(params, FrozenDict) else params
decoded_params = (
unfreeze(decoded_params)
if isinstance(decoded_params, FrozenDict)
else decoded_params
)
params_tree = jax.tree_util.tree_map(lambda x: x, params)
decoded_params_tree = jax.tree_util.tree_map(lambda x: x, decoded_params)

chex.assert_trees_all_close(
params_tree, decoded_params_tree
) # static and jittable static
chex.assert_trees_all_equal_dtypes(params_tree, decoded_params_tree)
chex.assert_trees_all_equal_shapes(params_tree, decoded_params_tree)
chex.assert_trees_all_equal_structs(params_tree, decoded_params_tree)

0 comments on commit ecf6a3e

Please sign in to comment.