Skip to content

Commit

Permalink
Add schema compatibility test (#3130)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#217

Pull Request resolved: #3130

To ensure that changes to the ops are forward and backward compatible with the stable release, we add unit tests to test schema compatibility.

**Usage**:
```
check_schema_compatibility_from_op_name(
  namespace: Callable,
  op_name: str
  ref_schema_str: str,
)
check_schema_compatibility(
  op: Callable,
  ref_schema_str: str,
)
```

e.g.,
```
check_schema_compatibility_from_op_name(
  torch.ops.fbgemm,
  "merge_pooled_embeddings",
  "fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor"
)
check_schema_compatibility(
  fbgemm_gpu.sparse_ops.merge_pooled_embeddings,
  "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor",
)
```

Reviewed By: q10

Differential Revision: D61766648

fbshipit-source-id: dac52b88834331a466e7165812def1a3fe4c0804
  • Loading branch information
spcyppt authored and facebook-github-bot committed Sep 19, 2024
1 parent 904a1c6 commit 0377308
Show file tree
Hide file tree
Showing 5 changed files with 479 additions and 0 deletions.
6 changes: 6 additions & 0 deletions fbgemm_gpu/test/release/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
12 changes: 12 additions & 0 deletions fbgemm_gpu/test/release/example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"_description": "This is a dict containing example schemas. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1",
"_version": 1,
"data": {
"mx4_to_fp32":
"mx4_to_fp32(Tensor tensor, int group_size=32, bool use_triton=True, int ebits=2, int mbits=1) -> Tensor",
"merge_pooled_embeddings":
"merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor",
"dummy_func":
"dummy_func(str var1, int var2) -> ()"
}
}
30 changes: 30 additions & 0 deletions fbgemm_gpu/test/release/stable_ops.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"_description": "This is a dict containing schema of FBGEMM_GPU ops that are marked as stable. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1",
"_version": 1,
"data": {
"torch.ops.fbgemm.jagged_to_padded_dense":
"fbgemm::jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor",
"torch.ops.fbgemm.merge_pooled_embeddings":
"fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor",
"torch.ops.fbgemm.permute_pooled_embs_auto_grad":
"fbgemm::permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor",
"torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf":
"fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor",
"torch.ops.fbgemm.permute_2D_sparse_data":
"fbgemm::permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)",
"torch.ops.fbgemm.permute_1D_sparse_data":
"fbgemm::permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)",
"torch.ops.fbgemm.expand_into_jagged_permute":
"fbgemm::expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor",
"torch.ops.fbgemm.block_bucketize_sparse_features":
"fbgemm::block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)",
"torch.ops.fbgemm.asynchronous_complete_cumsum":
"fbgemm::asynchronous_complete_cumsum(Tensor t_in) -> Tensor",
"torch.ops.fbgemm.offsets_range":
"fbgemm::offsets_range(Tensor offsets, SymInt range_size) -> Tensor",
"torch.ops.fbgemm.segment_sum_csr":
"fbgemm::segment_sum_csr(SymInt batch_size, Tensor csr_seg, Tensor values) -> Tensor",
"torch.ops.fbgemm.keyed_jagged_index_select_dim1":
"fbgemm::keyed_jagged_index_select_dim1(Tensor values, Tensor lengths, Tensor offsets, Tensor indices, SymInt batch_size, Tensor? weights=None, SymInt? selected_lengths_sum=None) -> Tensor[]"
}
}
186 changes: 186 additions & 0 deletions fbgemm_gpu/test/release/stable_release_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import json
import os
import unittest
from typing import Callable

import fbgemm_gpu
import fbgemm_gpu.permute_pooled_embedding_modules
import fbgemm_gpu.sparse_ops

import torch
from torch._C import FunctionSchema, parse_schema
from torch._utils_internal import get_file_path_2

from .utils import infer_schema

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
from test_utils import TestSuite # pyre-fixme[21]

else:
# pyre-fixme[21]
from fbgemm_gpu.test.test_utils import TestSuite


def _check_schema_compatibility(
schema: FunctionSchema,
ref_schema_str: str,
) -> None:
"""
Check if the schema is forward and backward compatible with the reference schema.
This function will raise an Exception error if the schema is not compatible.
Args:
schema (FunctionSchema): The schema object to check.
ref_schema_str (str): The reference schema in string format.
Returns:
None
"""
assert isinstance(schema, FunctionSchema)
ref_schema = parse_schema(ref_schema_str)
# pyre-fixme[16]
fwd_compatible = schema.check_forward_compatible_with(ref_schema)
# pyre-fixme[16]
bwd_compatible = schema.is_backward_compatible_with(ref_schema)
msg = ""
if not fwd_compatible:
msg += f"Schema of {schema} is not forward compatible with {ref_schema}\n"
# pyre-fixme[16]
if not bwd_compatible:
msg += f"Schema of {schema} is not backward compatible with {ref_schema}"
assert fwd_compatible and bwd_compatible, msg


def check_schema_compatibility(
op: Callable,
ref_schema: str,
) -> None:
"""
Check if the schema of the given operator is forward and backward compatible with the reference schema.
This works with python functions whose schema do NOT have positional-only args, varargs, or varkwargs
For ops registered via torch.ops.fbgemm and ops with *args and **kwargs, please use check_schema_compatibility_from_op_name.
Args:
op (Callable): The operator to check.
ref_schema (str): The reference schema in string format.
Returns:
None
"""
op_schema = infer_schema(op, mutates_args={})
# pyre-fixme[16]
op_name = op.__name__
# Create schema string
schema_str = f"{op_name}{op_schema}"
# Create FunctionalSchema
functional_schema = parse_schema(schema_str)

# Get stable schema to compare against
return _check_schema_compatibility(functional_schema, ref_schema)


def check_schema_compatibility_from_op_name(
namespace: Callable,
op_name: str,
ref_schema_str: str,
) -> None:
"""
Check if the schema of the given operator is forward and backward compatible with the reference schema.
Use this function to check registered ops (via torch.ops.fbgemm).
This function will raise an Exception error if the schema is not compatible.
Args:
namespace (Callable): The namespace of the operator e.g., torch.ops.fbgemm.
op_name (str): The name of the operator.
ref_schema_str (str): The reference schema in string format.
Returns:
None
"""
op = getattr(namespace, op_name)
schema = op._schemas[""]

return _check_schema_compatibility(schema, ref_schema_str)


class StableRelease(TestSuite): # pyre-ignore[11]
def test_stable_schema(self) -> None:
"""
Test the schema compatibility of the operators against stable schema.
This is to ensure that any changes to the ops' schema do not break compatibility of the stable versions.
This test will fail if the current op schema is not forward or backward compatible with the stable schema.
"""

# Load stable ops from file into dict
stable_dict_file = open(
get_file_path_2("", os.path.dirname(__file__), "stable_ops.json")
)
stable_op_dict = json.load(stable_dict_file)["data"]
stable_dict_file.close()
# Get all op names
stable_op_names = set(stable_op_dict.keys())

# Check compatibility for all ops that are marked stable
for full_op_name in stable_op_names:
# Test the schema given the op name
ref_schema_str = stable_op_dict[full_op_name]
op_name = full_op_name.split(".")[3]

check_schema_compatibility_from_op_name(
torch.ops.fbgemm, op_name, ref_schema_str
)

def test_example_ops(self) -> None:
"""
Test examples for schema compatibility.
"""

# Load example ops to dict
stable_dict_file = open(
get_file_path_2("", os.path.dirname(__file__), "example.json")
)
op_dict = json.load(stable_dict_file)["data"]
stable_dict_file.close()

# Example op 1
# Expect to pass
check_schema_compatibility(
fbgemm_gpu.sparse_ops.merge_pooled_embeddings,
op_dict["merge_pooled_embeddings"],
)

# Example op 2
# stable schema is: dummy_func(str var1, int var2) -> ()"
def dummy_func(var1: str, var2: int, var3: torch.Tensor) -> None:
pass

# Expect to fail
with self.assertRaises(AssertionError): # pyre-fixme[16]
check_schema_compatibility(
dummy_func,
op_dict["dummy_func"],
)

# Example op 3
# stable schema is: dummy_func(str var1, int var2) -> ()"
def dummy_func(var1: str, var2: int, var3: str = "default") -> None:
pass

# Expect to pass
check_schema_compatibility(
dummy_func,
op_dict["dummy_func"],
)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 0377308

Please sign in to comment.