From c64b70646c62d3f2bdac8b5fb3f66c69cf99aa22 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Thu, 12 Sep 2024 16:13:24 -0700 Subject: [PATCH] Add schema compatibility test Summary: To ensure that changes to the ops are forward and backward compatible with the stable release, we add unit tests to test schema compatibility. **Usage**: ``` check_schema_compatibility_from_op_name( namespace: Callable, op_name: str ref_schema_str: str, ) check_schema_compatibility( op: Callable, ref_schema_str: str, ) ``` e.g., ``` check_schema_compatibility_from_op_name( torch.ops.fbgemm, "merge_pooled_embeddings", "fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor" ) check_schema_compatibility( fbgemm_gpu.sparse_ops.merge_pooled_embeddings, "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", ) ``` Differential Revision: D61766648 --- fbgemm_gpu/example.json | 8 + fbgemm_gpu/test/release/example.json | 12 + fbgemm_gpu/test/release/stable_ops.json | 30 ++ .../test/release/stable_release_test.py | 187 ++++++++++++ fbgemm_gpu/test/release/utils.py | 285 ++++++++++++++++++ 5 files changed, 522 insertions(+) create mode 100644 fbgemm_gpu/example.json create mode 100644 fbgemm_gpu/test/release/example.json create mode 100644 fbgemm_gpu/test/release/stable_ops.json create mode 100755 fbgemm_gpu/test/release/stable_release_test.py create mode 100644 fbgemm_gpu/test/release/utils.py diff --git a/fbgemm_gpu/example.json b/fbgemm_gpu/example.json new file mode 100644 index 0000000000..89039e8412 --- /dev/null +++ b/fbgemm_gpu/example.json @@ -0,0 +1,8 @@ +{ + "_description": "This is a dict containing example schemas. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "mx4_to_fp32": + "mx4_to_fp32(Tensor tensor, SymInt group_size=32, bool use_triton=True, SymInt ebits=2, SymInt mbits=1) -> Tensor" + } + } diff --git a/fbgemm_gpu/test/release/example.json b/fbgemm_gpu/test/release/example.json new file mode 100644 index 0000000000..50324b9a30 --- /dev/null +++ b/fbgemm_gpu/test/release/example.json @@ -0,0 +1,12 @@ +{ + "_description": "This is a dict containing example schemas. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "mx4_to_fp32": + "mx4_to_fp32(Tensor tensor, int group_size=32, bool use_triton=True, int ebits=2, int mbits=1) -> Tensor", + "merge_pooled_embeddings": + "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", + "dummy_func": + "dummy_func(str var1, int var2) -> ()" + } +} diff --git a/fbgemm_gpu/test/release/stable_ops.json b/fbgemm_gpu/test/release/stable_ops.json new file mode 100644 index 0000000000..d5fe76a247 --- /dev/null +++ b/fbgemm_gpu/test/release/stable_ops.json @@ -0,0 +1,30 @@ +{ + "_description": "This is a dict containing schema of FBGEMM_GPU ops that are marked as stable. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "torch.ops.fbgemm.jagged_to_padded_dense": + "fbgemm::jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor", + "torch.ops.fbgemm.merge_pooled_embeddings": + "fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor", + "torch.ops.fbgemm.permute_pooled_embs_auto_grad": + "fbgemm::permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor", + "torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf": + "fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor", + "torch.ops.fbgemm.permute_2D_sparse_data": + "fbgemm::permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + "torch.ops.fbgemm.permute_1D_sparse_data": + "fbgemm::permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + "torch.ops.fbgemm.expand_into_jagged_permute": + "fbgemm::expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor", + "torch.ops.fbgemm.block_bucketize_sparse_features": + "fbgemm::block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)", + "torch.ops.fbgemm.asynchronous_complete_cumsum": + "fbgemm::asynchronous_complete_cumsum(Tensor t_in) -> Tensor", + "torch.ops.fbgemm.offsets_range": + "fbgemm::offsets_range(Tensor offsets, SymInt range_size) -> Tensor", + "torch.ops.fbgemm.segment_sum_csr": + "fbgemm::segment_sum_csr(SymInt batch_size, Tensor csr_seg, Tensor values) -> Tensor", + "torch.ops.fbgemm.keyed_jagged_index_select_dim1": + "fbgemm::keyed_jagged_index_select_dim1(Tensor values, Tensor lengths, Tensor offsets, Tensor indices, SymInt batch_size, Tensor? weights=None, SymInt? selected_lengths_sum=None) -> Tensor[]" + } +} diff --git a/fbgemm_gpu/test/release/stable_release_test.py b/fbgemm_gpu/test/release/stable_release_test.py new file mode 100755 index 0000000000..466bba2eab --- /dev/null +++ b/fbgemm_gpu/test/release/stable_release_test.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import json +import os +import unittest +from typing import Callable, Tuple + +# pyre-fixme[21] +import fbgemm_gpu +import fbgemm_gpu.permute_pooled_embedding_modules +import fbgemm_gpu.sparse_ops +import torch +from torch._C import FunctionSchema, parse_schema +from torch._utils_internal import get_file_path_2 + +from . import utils + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if open_source: + # pyre-ignore[21] + from test_utils import TestSuite + +else: + # pyre-fixme[21] + from fbgemm_gpu.test.test_utils import TestSuite + + +def _check_schema_compatibility( + schema: FunctionSchema, + ref_schema_str: str, +) -> None: + """ + Check if the schema is forward and backward compatible with the reference schema. + This function will raise an Exception error if the schema is not compatible. + + Args: + schema (FunctionSchema): The schema object to check. + ref_schema_str (str): The reference schema in string format. + Returns: + None + """ + assert isinstance(schema, FunctionSchema) + ref_schema = parse_schema(ref_schema_str) + # pyre-fixme[16] + fwd_compatible = schema.check_forward_compatible_with(ref_schema) + # pyre-fixme[16] + bwd_compatible = schema.is_backward_compatible_with(ref_schema) + msg = "" + if not fwd_compatible: + msg += f"Schema of {schema} is not forward compatible with {ref_schema}\n" + # pyre-fixme[16] + if not bwd_compatible: + msg += f"Schema of {schema} is not backward compatible with {ref_schema}" + assert fwd_compatible and bwd_compatible, msg + + +def check_schema_compatibility( + op: Callable, + ref_schema: str, +) -> None: + """ + Check if the schema of the given operator is forward and backward compatible with the reference schema. + This works with python functions whose schema do NOT have positional-only args, varargs, or varkwargs + For ops registerd via torch.ops.fbgemm and ops with *args and **kwargs, please use check_schema_compatibility_from_op_name. + + Args: + op (Callable): The operator to check. + ref_schema (str): The reference schema in string format. + Returns: + None + """ + op_schema = utils.infer_schema(op, mutates_args={}) + # pyre-fixme[16] + op_name = op.__name__ + # Create schema string + schema_str = f"{op_name}{op_schema}" + # Create FunctionalSchema + functional_schema = parse_schema(schema_str) + + # Get stable schema to compare against + return _check_schema_compatibility(functional_schema, ref_schema) + + +def check_schema_compatibility_from_op_name( + namespace: Callable, + op_name: str, + ref_schema_str: str, +) -> None: + """ + Check if the schema of the given operator is forward and backward compatible with the reference schema. + Use this function to check registered ops (via torch.ops.fbgemm). + This function will raise an Exception error if the schema is not compatible. + + Args: + namespace (Callable): The namespace of the operator e.g., torch.ops.fbgemm. + op_name (str): The name of the operator. + ref_schema_str (str): The reference schema in string format. + Returns: + None + """ + op = getattr(namespace, op_name) + schema = op._schemas[""] + + return _check_schema_compatibility(schema, ref_schema_str) + + +class StableRelease(TestSuite): # pyre-ignore[11] + def test_stable_schema(self) -> None: + """ + Test the schema compatibility of the operators against stable schema. + This is to ensure that any changes to the ops' schema do not break compatibility of the stable versions. + This test will fail if the current op schema is not forward or backward compatible with the stable schema. + """ + + # Load stable ops from file into dict + stable_dict_file = open( + get_file_path_2("", os.path.dirname(__file__), "stable_ops.json") + ) + stable_op_dict = json.load(stable_dict_file)["data"] + stable_dict_file.close() + # Get all op names + stable_op_names = set(stable_op_dict.keys()) + + # Check compatibility for all ops that are marked stable + for full_op_name in stable_op_names: + # Test the schema given the op name + ref_schema_str = stable_op_dict[full_op_name] + op_name = full_op_name.split(".")[3] + + check_schema_compatibility_from_op_name( + torch.ops.fbgemm, op_name, ref_schema_str + ) + + def test_example_ops(self) -> None: + """ + Test examples for schema compatibility. + """ + + # Load example ops to dict + stable_dict_file = open( + get_file_path_2("", os.path.dirname(__file__), "example.json") + ) + op_dict = json.load(stable_dict_file)["data"] + stable_dict_file.close() + + # Example op 1 + # Expect to pass + check_schema_compatibility( + fbgemm_gpu.sparse_ops.merge_pooled_embeddings, + op_dict["merge_pooled_embeddings"], + ) + + # Example op 2 + # stable schema is: dummy_func(str var1, int var2) -> ()" + def dummy_func(var1: str, var2: int, var3: torch.Tensor) -> None: + print("") + + # Expect to fail + with self.assertRaises(AssertionError): # pyre-fixme[16] + check_schema_compatibility( + dummy_func, + op_dict["dummy_func"], + ) + + # Example op 3 + # stable schema is: dummy_func(str var1, int var2) -> ()" + def dummy_func(var1: str, var2: int, var3: str = "default") -> None: + print("") + + # Expect to pass + check_schema_compatibility( + dummy_func, + op_dict["dummy_func"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/release/utils.py b/fbgemm_gpu/test/release/utils.py new file mode 100644 index 0000000000..9dce300f4e --- /dev/null +++ b/fbgemm_gpu/test/release/utils.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import typing +from typing import List, Optional, Sequence, Union # noqa: F401 + +import torch +from torch import device, dtype, Tensor, types + + +# Temporary work around for `infer_schema` + +# `get_supported_param_types` and `SUPPORTED_RETURN_TYPES` are modified from torch/_library/infer_schema.py +# as `torch.library.infer_schema` infers any `int` to be `SymInt` in the schema and does not +# support `str` as return type, which may not reflect the actual signature of the function. +# Other modifications are to address linter warning. +# The rest of the code is copied from `torch/_library/infer_schema.py` +# TO DO: clean up and remove this when we implement our own + + +def infer_schema( + prototype_function: typing.Callable, + /, + *, + mutates_args, + op_name: Optional[str] = None, +) -> str: + r"""Parses the schema of a given function with type hints. The schema is inferred from the + function's type hints, and can be used to define a new operator. + + We make the following assumptions: + + * None of the outputs alias any of the inputs or each other. + * | String type annotations "device, dtype, Tensor, types" without library specification are + | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + | without library specification are assumed to be typing.*. + * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown", + | it assumes that all inputs to the operator are being mutates. + + Callers (e.g. the custom ops API) are responsible for checking these assumptions. + + Args: + prototype_function: The function from which to infer a schema for from its type annotations. + op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the + name is not included in the inferred schema. Note that the input schema to + ``torch.library.Library.define`` requires a operator name. + mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function. + + Returns: + The inferred schema. + + Example: + >>> def foo_impl(x: torch.Tensor) -> torch.Tensor: + >>> return x.sin() + >>> + >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) + foo(Tensor x) -> Tensor + >>> + >>> infer_schema(foo_impl, mutates_args={}) + (Tensor x) -> Tensor + """ + UNKNOWN_MUTATES = "unknown" + sig = inspect.signature(prototype_function) + + def error_fn(what): + raise ValueError( + f"infer_schema(func): {what} " f"Got func with signature {sig})" + ) + + def convert_type_string(annotation_type: str): + try: + return eval(annotation_type) + except Exception as e: + error_fn( + f"Unsupported type annotation {annotation_type}. It is not a type." + ) + + params = [] + seen_args = set() + saw_kwarg_only_arg = False + for idx, (name, param) in enumerate(sig.parameters.items()): + if not supported_param(param): + error_fn("We do not support positional-only args, varargs, or varkwargs.") + + if param.kind == inspect.Parameter.KEYWORD_ONLY: + # The first time we see a kwarg-only arg, add "*" to the schema. + if not saw_kwarg_only_arg: + params.append("*") + saw_kwarg_only_arg = True + + if param.annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.") + + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type = param.annotation + if type(annotation_type) == str: + annotation_type = convert_type_string(annotation_type) + + if annotation_type not in SUPPORTED_PARAM_TYPES.keys(): + if annotation_type.__origin__ is tuple: + list_type = tuple_to_list(annotation_type) + example_type_str = "\n\n" + # Only suggest the list type if this type is supported. + if list_type in SUPPORTED_PARAM_TYPES.keys(): + example_type_str = f"For example, {list_type}.\n\n" + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. " + f"{example_type_str}" + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + else: + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] + if type(mutates_args) == str: + if mutates_args != UNKNOWN_MUTATES: + raise ValueError( + "mutates_args must either be a sequence of the names of " + "the arguments that are mutated or the string 'unknown'. " + ) + if schema_type.startswith("Tensor"): + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + elif name in mutates_args: + if not schema_type.startswith("Tensor"): + error_fn( + f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" + ) + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + seen_args.add(name) + if param.default is inspect.Parameter.empty: + params.append(f"{schema_type} {name}") + else: + default_repr = None + if param.default is None or isinstance(param.default, (int, float, bool)): + default_repr = str(param.default) + elif isinstance(param.default, (str, torch.device)): + default_repr = f'"{param.default}"' + elif isinstance(param.default, torch.dtype): + dtype_repr = str(param.default) + torch_dot = "torch." + assert dtype_repr.startswith(torch_dot) + default_repr = dtype_repr[len(torch_dot) :] + else: + error_fn( + f"Parameter {name} has an unsupported default value type {type(param.default)}. " + f"Please file an issue on GitHub so we can prioritize this." + ) + params.append(f"{schema_type} {name}={default_repr}") + if mutates_args != UNKNOWN_MUTATES: + mutates_args_not_seen = set(mutates_args) - seen_args + if len(mutates_args_not_seen) > 0: + error_fn( + f"{mutates_args_not_seen} in mutates_args were not found in " + f"the custom op's signature. " + f"mutates_args should contain the names of all args that the " + f"custom op mutates, or just the string 'unknown' if you don't know." + ) + return_annotation = sig.return_annotation + if type(return_annotation) == str: + return_annotation = convert_type_string(return_annotation) + ret = parse_return(return_annotation, error_fn) + if op_name is not None: + return f"{op_name}({', '.join(params)}) -> {ret}" + return f"({', '.join(params)}) -> {ret}" + + +def derived_types( + base_type, cpp_type, list_base, optional_base_list, optional_list_base +): + result = [ + (base_type, cpp_type), + (typing.Optional[base_type], f"{cpp_type}?"), + ] + + def derived_seq_types(typ): + return [ + typing.Sequence[typ], # type: ignore[valid-type] + typing.List[typ], # type: ignore[valid-type] + ] + + if list_base: + for seq_typ in derived_seq_types(base_type): + result.append((seq_typ, f"{cpp_type}[]")) # type: ignore[valid-type] + if optional_base_list: + for seq_typ in derived_seq_types(typing.Optional[base_type]): + result.append((seq_typ, f"{cpp_type}?[]")) # type: ignore[valid-type] + if optional_list_base: + for seq_typ in derived_seq_types(base_type): # type: ignore[valid-type] + result.append((typing.Optional[seq_typ], f"{cpp_type}[]?")) # type: ignore[valid-type] + return result + + +# Modified support param types and return types from torch/_library/infer_schema.py +def get_supported_param_types(): + data = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (Tensor, "Tensor", True, True, False), + (int, "int", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (types.Number, "Scalar", True, False, False), + (dtype, "ScalarType", False, False, False), + (device, "Device", False, False, False), + ] + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + Tensor: "Tensor", + typing.List[Tensor]: "Tensor[]", + int: "int", + float: "float", + bool: "bool", + str: "str", + types.Number: "Scalar", +} + + +def parse_return(annotation, error_fn): + if annotation is None: + return "()" + + if annotation is inspect.Parameter.empty: + error_fn("No return type annotation was provided. Please add one.") + + origin = typing.get_origin(annotation) + if origin is not tuple: + if annotation not in SUPPORTED_RETURN_TYPES.keys(): + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + return SUPPORTED_RETURN_TYPES[annotation] + + args = typing.get_args(annotation) + for arg in args: + if arg not in SUPPORTED_RETURN_TYPES: + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + + return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")" + + +SUPPORTED_PARAM_TYPES = get_supported_param_types() + + +def supported_param(param: inspect.Parameter) -> bool: + return param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + +def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.List]: + """ + Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type. + """ + type_args = getattr(tuple_type, "__args__", None) + # Account for different python versions, e.g. python 3.8 would give () + # but python 3.12 would give None. + if tuple_type is typing.Tuple or type_args == () or type_args is None: + # Handle the case of an empty tuple type + return typing.List[typing.Any] + elif len(type_args) == 1: + # General case: create a List with the same type arguments + return typing.List[type_args[0]] # type: ignore[valid-type] + elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type] + return typing.List[type_args[0]] # type: ignore[valid-type] + else: + return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc]