-
Notifications
You must be signed in to change notification settings - Fork 479
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add schema compatibility test (#3130)
Summary: X-link: facebookresearch/FBGEMM#217 To ensure that changes to the ops are forward and backward compatible with the stable release, we add unit tests to test schema compatibility. **Usage**: ``` check_schema_compatibility_from_op_name( namespace: Callable, op_name: str ref_schema_str: str, ) check_schema_compatibility( op: Callable, ref_schema_str: str, ) ``` e.g., ``` check_schema_compatibility_from_op_name( torch.ops.fbgemm, "merge_pooled_embeddings", "fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor" ) check_schema_compatibility( fbgemm_gpu.sparse_ops.merge_pooled_embeddings, "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", ) ``` Differential Revision: D61766648
- Loading branch information
1 parent
c01bbb8
commit 8beab8d
Showing
5 changed files
with
479 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"_description": "This is a dict containing example schemas. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", | ||
"_version": 1, | ||
"data": { | ||
"mx4_to_fp32": | ||
"mx4_to_fp32(Tensor tensor, int group_size=32, bool use_triton=True, int ebits=2, int mbits=1) -> Tensor", | ||
"merge_pooled_embeddings": | ||
"merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", | ||
"dummy_func": | ||
"dummy_func(str var1, int var2) -> ()" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
{ | ||
"_description": "This is a dict containing schema of FBGEMM_GPU ops that are marked as stable. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", | ||
"_version": 1, | ||
"data": { | ||
"torch.ops.fbgemm.jagged_to_padded_dense": | ||
"fbgemm::jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor", | ||
"torch.ops.fbgemm.merge_pooled_embeddings": | ||
"fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor", | ||
"torch.ops.fbgemm.permute_pooled_embs_auto_grad": | ||
"fbgemm::permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor", | ||
"torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf": | ||
"fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor", | ||
"torch.ops.fbgemm.permute_2D_sparse_data": | ||
"fbgemm::permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", | ||
"torch.ops.fbgemm.permute_1D_sparse_data": | ||
"fbgemm::permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", | ||
"torch.ops.fbgemm.expand_into_jagged_permute": | ||
"fbgemm::expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor", | ||
"torch.ops.fbgemm.block_bucketize_sparse_features": | ||
"fbgemm::block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)", | ||
"torch.ops.fbgemm.asynchronous_complete_cumsum": | ||
"fbgemm::asynchronous_complete_cumsum(Tensor t_in) -> Tensor", | ||
"torch.ops.fbgemm.offsets_range": | ||
"fbgemm::offsets_range(Tensor offsets, SymInt range_size) -> Tensor", | ||
"torch.ops.fbgemm.segment_sum_csr": | ||
"fbgemm::segment_sum_csr(SymInt batch_size, Tensor csr_seg, Tensor values) -> Tensor", | ||
"torch.ops.fbgemm.keyed_jagged_index_select_dim1": | ||
"fbgemm::keyed_jagged_index_select_dim1(Tensor values, Tensor lengths, Tensor offsets, Tensor indices, SymInt batch_size, Tensor? weights=None, SymInt? selected_lengths_sum=None) -> Tensor[]" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
import json | ||
import os | ||
import unittest | ||
from typing import Callable | ||
|
||
import fbgemm_gpu | ||
import fbgemm_gpu.permute_pooled_embedding_modules | ||
import fbgemm_gpu.sparse_ops | ||
|
||
import torch | ||
from torch._C import FunctionSchema, parse_schema | ||
from torch._utils_internal import get_file_path_2 | ||
|
||
from .utils import infer_schema | ||
|
||
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. | ||
open_source: bool = getattr(fbgemm_gpu, "open_source", False) | ||
|
||
if open_source: | ||
from test_utils import TestSuite # pyre-fixme[21] | ||
|
||
else: | ||
# pyre-fixme[21] | ||
from fbgemm_gpu.test.test_utils import TestSuite | ||
|
||
|
||
def _check_schema_compatibility( | ||
schema: FunctionSchema, | ||
ref_schema_str: str, | ||
) -> None: | ||
""" | ||
Check if the schema is forward and backward compatible with the reference schema. | ||
This function will raise an Exception error if the schema is not compatible. | ||
Args: | ||
schema (FunctionSchema): The schema object to check. | ||
ref_schema_str (str): The reference schema in string format. | ||
Returns: | ||
None | ||
""" | ||
assert isinstance(schema, FunctionSchema) | ||
ref_schema = parse_schema(ref_schema_str) | ||
# pyre-fixme[16] | ||
fwd_compatible = schema.check_forward_compatible_with(ref_schema) | ||
# pyre-fixme[16] | ||
bwd_compatible = schema.is_backward_compatible_with(ref_schema) | ||
msg = "" | ||
if not fwd_compatible: | ||
msg += f"Schema of {schema} is not forward compatible with {ref_schema}\n" | ||
# pyre-fixme[16] | ||
if not bwd_compatible: | ||
msg += f"Schema of {schema} is not backward compatible with {ref_schema}" | ||
assert fwd_compatible and bwd_compatible, msg | ||
|
||
|
||
def check_schema_compatibility( | ||
op: Callable, | ||
ref_schema: str, | ||
) -> None: | ||
""" | ||
Check if the schema of the given operator is forward and backward compatible with the reference schema. | ||
This works with python functions whose schema do NOT have positional-only args, varargs, or varkwargs | ||
For ops registered via torch.ops.fbgemm and ops with *args and **kwargs, please use check_schema_compatibility_from_op_name. | ||
Args: | ||
op (Callable): The operator to check. | ||
ref_schema (str): The reference schema in string format. | ||
Returns: | ||
None | ||
""" | ||
op_schema = infer_schema(op, mutates_args={}) | ||
# pyre-fixme[16] | ||
op_name = op.__name__ | ||
# Create schema string | ||
schema_str = f"{op_name}{op_schema}" | ||
# Create FunctionalSchema | ||
functional_schema = parse_schema(schema_str) | ||
|
||
# Get stable schema to compare against | ||
return _check_schema_compatibility(functional_schema, ref_schema) | ||
|
||
|
||
def check_schema_compatibility_from_op_name( | ||
namespace: Callable, | ||
op_name: str, | ||
ref_schema_str: str, | ||
) -> None: | ||
""" | ||
Check if the schema of the given operator is forward and backward compatible with the reference schema. | ||
Use this function to check registered ops (via torch.ops.fbgemm). | ||
This function will raise an Exception error if the schema is not compatible. | ||
Args: | ||
namespace (Callable): The namespace of the operator e.g., torch.ops.fbgemm. | ||
op_name (str): The name of the operator. | ||
ref_schema_str (str): The reference schema in string format. | ||
Returns: | ||
None | ||
""" | ||
op = getattr(namespace, op_name) | ||
schema = op._schemas[""] | ||
|
||
return _check_schema_compatibility(schema, ref_schema_str) | ||
|
||
|
||
class StableRelease(TestSuite): # pyre-ignore[11] | ||
def test_stable_schema(self) -> None: | ||
""" | ||
Test the schema compatibility of the operators against stable schema. | ||
This is to ensure that any changes to the ops' schema do not break compatibility of the stable versions. | ||
This test will fail if the current op schema is not forward or backward compatible with the stable schema. | ||
""" | ||
|
||
# Load stable ops from file into dict | ||
stable_dict_file = open( | ||
get_file_path_2("", os.path.dirname(__file__), "stable_ops.json") | ||
) | ||
stable_op_dict = json.load(stable_dict_file)["data"] | ||
stable_dict_file.close() | ||
# Get all op names | ||
stable_op_names = set(stable_op_dict.keys()) | ||
|
||
# Check compatibility for all ops that are marked stable | ||
for full_op_name in stable_op_names: | ||
# Test the schema given the op name | ||
ref_schema_str = stable_op_dict[full_op_name] | ||
op_name = full_op_name.split(".")[3] | ||
|
||
check_schema_compatibility_from_op_name( | ||
torch.ops.fbgemm, op_name, ref_schema_str | ||
) | ||
|
||
def test_example_ops(self) -> None: | ||
""" | ||
Test examples for schema compatibility. | ||
""" | ||
|
||
# Load example ops to dict | ||
stable_dict_file = open( | ||
get_file_path_2("", os.path.dirname(__file__), "example.json") | ||
) | ||
op_dict = json.load(stable_dict_file)["data"] | ||
stable_dict_file.close() | ||
|
||
# Example op 1 | ||
# Expect to pass | ||
check_schema_compatibility( | ||
fbgemm_gpu.sparse_ops.merge_pooled_embeddings, | ||
op_dict["merge_pooled_embeddings"], | ||
) | ||
|
||
# Example op 2 | ||
# stable schema is: dummy_func(str var1, int var2) -> ()" | ||
def dummy_func(var1: str, var2: int, var3: torch.Tensor) -> None: | ||
pass | ||
|
||
# Expect to fail | ||
with self.assertRaises(AssertionError): # pyre-fixme[16] | ||
check_schema_compatibility( | ||
dummy_func, | ||
op_dict["dummy_func"], | ||
) | ||
|
||
# Example op 3 | ||
# stable schema is: dummy_func(str var1, int var2) -> ()" | ||
def dummy_func(var1: str, var2: int, var3: str = "default") -> None: | ||
pass | ||
|
||
# Expect to pass | ||
check_schema_compatibility( | ||
dummy_func, | ||
op_dict["dummy_func"], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.