diff --git a/fbgemm_gpu/test/release/__init__.py b/fbgemm_gpu/test/release/__init__.py new file mode 100644 index 000000000..a9fdb3b99 --- /dev/null +++ b/fbgemm_gpu/test/release/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fbgemm_gpu/test/release/example.json b/fbgemm_gpu/test/release/example.json new file mode 100644 index 000000000..50324b9a3 --- /dev/null +++ b/fbgemm_gpu/test/release/example.json @@ -0,0 +1,12 @@ +{ + "_description": "This is a dict containing example schemas. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "mx4_to_fp32": + "mx4_to_fp32(Tensor tensor, int group_size=32, bool use_triton=True, int ebits=2, int mbits=1) -> Tensor", + "merge_pooled_embeddings": + "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", + "dummy_func": + "dummy_func(str var1, int var2) -> ()" + } +} diff --git a/fbgemm_gpu/test/release/stable_ops.json b/fbgemm_gpu/test/release/stable_ops.json new file mode 100644 index 000000000..d5fe76a24 --- /dev/null +++ b/fbgemm_gpu/test/release/stable_ops.json @@ -0,0 +1,30 @@ +{ + "_description": "This is a dict containing schema of FBGEMM_GPU ops that are marked as stable. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "torch.ops.fbgemm.jagged_to_padded_dense": + "fbgemm::jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor", + "torch.ops.fbgemm.merge_pooled_embeddings": + "fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor", + "torch.ops.fbgemm.permute_pooled_embs_auto_grad": + "fbgemm::permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor", + "torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf": + "fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor", + "torch.ops.fbgemm.permute_2D_sparse_data": + "fbgemm::permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + "torch.ops.fbgemm.permute_1D_sparse_data": + "fbgemm::permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + "torch.ops.fbgemm.expand_into_jagged_permute": + "fbgemm::expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor", + "torch.ops.fbgemm.block_bucketize_sparse_features": + "fbgemm::block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)", + "torch.ops.fbgemm.asynchronous_complete_cumsum": + "fbgemm::asynchronous_complete_cumsum(Tensor t_in) -> Tensor", + "torch.ops.fbgemm.offsets_range": + "fbgemm::offsets_range(Tensor offsets, SymInt range_size) -> Tensor", + "torch.ops.fbgemm.segment_sum_csr": + "fbgemm::segment_sum_csr(SymInt batch_size, Tensor csr_seg, Tensor values) -> Tensor", + "torch.ops.fbgemm.keyed_jagged_index_select_dim1": + "fbgemm::keyed_jagged_index_select_dim1(Tensor values, Tensor lengths, Tensor offsets, Tensor indices, SymInt batch_size, Tensor? weights=None, SymInt? selected_lengths_sum=None) -> Tensor[]" + } +} diff --git a/fbgemm_gpu/test/release/stable_release_test.py b/fbgemm_gpu/test/release/stable_release_test.py new file mode 100755 index 000000000..7055fd566 --- /dev/null +++ b/fbgemm_gpu/test/release/stable_release_test.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import json +import os +import unittest +from typing import Callable + +import fbgemm_gpu +import fbgemm_gpu.permute_pooled_embedding_modules +import fbgemm_gpu.sparse_ops + +import torch +from torch._C import FunctionSchema, parse_schema +from torch._utils_internal import get_file_path_2 + +from .utils import infer_schema + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if open_source: + from test_utils import TestSuite # pyre-fixme[21] + +else: + # pyre-fixme[21] + from fbgemm_gpu.test.test_utils import TestSuite + + +def _check_schema_compatibility( + schema: FunctionSchema, + ref_schema_str: str, +) -> None: + """ + Check if the schema is forward and backward compatible with the reference schema. + This function will raise an Exception error if the schema is not compatible. + + Args: + schema (FunctionSchema): The schema object to check. + ref_schema_str (str): The reference schema in string format. + Returns: + None + """ + assert isinstance(schema, FunctionSchema) + ref_schema = parse_schema(ref_schema_str) + # pyre-fixme[16] + fwd_compatible = schema.check_forward_compatible_with(ref_schema) + # pyre-fixme[16] + bwd_compatible = schema.is_backward_compatible_with(ref_schema) + msg = "" + if not fwd_compatible: + msg += f"Schema of {schema} is not forward compatible with {ref_schema}\n" + # pyre-fixme[16] + if not bwd_compatible: + msg += f"Schema of {schema} is not backward compatible with {ref_schema}" + assert fwd_compatible and bwd_compatible, msg + + +def check_schema_compatibility( + op: Callable, + ref_schema: str, +) -> None: + """ + Check if the schema of the given operator is forward and backward compatible with the reference schema. + This works with python functions whose schema do NOT have positional-only args, varargs, or varkwargs + For ops registered via torch.ops.fbgemm and ops with *args and **kwargs, please use check_schema_compatibility_from_op_name. + + Args: + op (Callable): The operator to check. + ref_schema (str): The reference schema in string format. + Returns: + None + """ + op_schema = infer_schema(op, mutates_args={}) + # pyre-fixme[16] + op_name = op.__name__ + # Create schema string + schema_str = f"{op_name}{op_schema}" + # Create FunctionalSchema + functional_schema = parse_schema(schema_str) + + # Get stable schema to compare against + return _check_schema_compatibility(functional_schema, ref_schema) + + +def check_schema_compatibility_from_op_name( + namespace: Callable, + op_name: str, + ref_schema_str: str, +) -> None: + """ + Check if the schema of the given operator is forward and backward compatible with the reference schema. + Use this function to check registered ops (via torch.ops.fbgemm). + This function will raise an Exception error if the schema is not compatible. + + Args: + namespace (Callable): The namespace of the operator e.g., torch.ops.fbgemm. + op_name (str): The name of the operator. + ref_schema_str (str): The reference schema in string format. + Returns: + None + """ + op = getattr(namespace, op_name) + schema = op._schemas[""] + + return _check_schema_compatibility(schema, ref_schema_str) + + +class StableRelease(TestSuite): # pyre-ignore[11] + def test_stable_schema(self) -> None: + """ + Test the schema compatibility of the operators against stable schema. + This is to ensure that any changes to the ops' schema do not break compatibility of the stable versions. + This test will fail if the current op schema is not forward or backward compatible with the stable schema. + """ + + # Load stable ops from file into dict + stable_dict_file = open( + get_file_path_2("", os.path.dirname(__file__), "stable_ops.json") + ) + stable_op_dict = json.load(stable_dict_file)["data"] + stable_dict_file.close() + # Get all op names + stable_op_names = set(stable_op_dict.keys()) + + # Check compatibility for all ops that are marked stable + for full_op_name in stable_op_names: + # Test the schema given the op name + ref_schema_str = stable_op_dict[full_op_name] + op_name = full_op_name.split(".")[3] + + check_schema_compatibility_from_op_name( + torch.ops.fbgemm, op_name, ref_schema_str + ) + + def test_example_ops(self) -> None: + """ + Test examples for schema compatibility. + """ + + # Load example ops to dict + stable_dict_file = open( + get_file_path_2("", os.path.dirname(__file__), "example.json") + ) + op_dict = json.load(stable_dict_file)["data"] + stable_dict_file.close() + + # Example op 1 + # Expect to pass + check_schema_compatibility( + fbgemm_gpu.sparse_ops.merge_pooled_embeddings, + op_dict["merge_pooled_embeddings"], + ) + + # Example op 2 + # stable schema is: dummy_func(str var1, int var2) -> ()" + def dummy_func(var1: str, var2: int, var3: torch.Tensor) -> None: + pass + + # Expect to fail + with self.assertRaises(AssertionError): # pyre-fixme[16] + check_schema_compatibility( + dummy_func, + op_dict["dummy_func"], + ) + + # Example op 3 + # stable schema is: dummy_func(str var1, int var2) -> ()" + def dummy_func(var1: str, var2: int, var3: str = "default") -> None: + pass + + # Expect to pass + check_schema_compatibility( + dummy_func, + op_dict["dummy_func"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/release/utils.py b/fbgemm_gpu/test/release/utils.py new file mode 100644 index 000000000..005cf38b5 --- /dev/null +++ b/fbgemm_gpu/test/release/utils.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import typing +from typing import Iterable, List, Optional, Sequence, Union # noqa: F401 + +import torch +from torch import device, dtype, Tensor, types + +from torch._library.infer_schema import ( + derived_types, + parse_return, + supported_param, + SUPPORTED_PARAM_TYPES, + tuple_to_list, +) + +# Temporary work around for `infer_schema` + +# `get_supported_param_types` and `SUPPORTED_RETURN_TYPES` are modified from torch/_library/infer_schema.py +# as `torch.library.infer_schema` infers any `int` to be `SymInt` in the schema and does not +# support `str` as return type, which may not reflect the actual signature of the function. +# Other modifications are to address linter warning. +# The rest of the code is copied from `torch/_library/infer_schema.py` +# TO DO: clean up and remove this when we implement our own + + +def error_fn(what: str, sig: Optional[inspect.Signature] = None): + raise ValueError(f"infer_schema(func): {what} " f"Got func with signature {sig})") + + +def convert_type_string(annotation_type: str): + try: + return eval(annotation_type) + except Exception: + error_fn(f"Unsupported type annotation {annotation_type}. It is not a type.") + + +# Modified support param types and return types from torch/_library/infer_schema.py +def get_supported_param_types(): + data = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (Tensor, "Tensor", True, True, False), + (int, "int", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (types.Number, "Scalar", True, False, False), + (dtype, "ScalarType", False, False, False), + (device, "Device", False, False, False), + ] + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + Tensor: "Tensor", + typing.List[Tensor]: "Tensor[]", + int: "int", + float: "float", + bool: "bool", + str: "str", + types.Number: "Scalar", +} + + +def check_param_annotation(name: str, annotation: type, sig: inspect.Signature): + if annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.", sig) + + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type = annotation + if isinstance(annotation_type, str): + annotation_type = convert_type_string(annotation_type) + + if annotation_type not in SUPPORTED_PARAM_TYPES.keys(): + if annotation_type.__origin__ is tuple: + list_type = tuple_to_list(annotation_type) + example_type_str = "\n\n" + # Only suggest the list type if this type is supported. + if list_type in SUPPORTED_PARAM_TYPES.keys(): + example_type_str = f"For example, {list_type}.\n\n" + error_fn( + f"Parameter {name} has unsupported type {annotation}. " + f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. " + f"{example_type_str}" + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}.", + sig, + ) + else: + error_fn( + f"Parameter {name} has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}.", + sig, + ) + return annotation_type + + +def get_schema_type( + schema_type: str, + mutates_args: Union[str, Iterable[str]], + name: str, + sig: inspect.Signature, + idx: int, +): + if isinstance(mutates_args, str): + if mutates_args != "unknown": + raise ValueError( + "mutates_args must either be a sequence of the names of " + "the arguments that are mutated or the string 'unknown'. " + ) + if schema_type.startswith("Tensor"): + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + elif name in mutates_args: + if not schema_type.startswith("Tensor"): + error_fn( + f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" + ) + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + return schema_type + + +def check_mutates_args( + mutates_args: Union[str, Iterable[str]], sig: inspect.Signature, seen_args: set +): + if mutates_args != "unknown": + mutates_args_not_seen = set(mutates_args) - seen_args + if len(mutates_args_not_seen) > 0: + error_fn( + f"{mutates_args_not_seen} in mutates_args were not found in " + f"the custom op's signature. " + f"mutates_args should contain the names of all args that the " + f"custom op mutates, or just the string 'unknown' if you don't know.", + sig, + ) + + +def get_return_annonation( + return_annotation: type, +): + if isinstance(return_annotation, str): + return_annotation = convert_type_string(return_annotation) + return parse_return(return_annotation, error_fn) + + +def infer_schema( + prototype_function: typing.Callable, + /, + *, + mutates_args, + op_name: Optional[str] = None, +) -> str: + r""" + This is modified from torch._library.infer_schema.infer_schema. + + Parses the schema of a given function with type hints. The schema is inferred from the + function's type hints, and can be used to define a new operator. + + We make the following assumptions: + + * None of the outputs alias any of the inputs or each other. + * | String type annotations "device, dtype, Tensor, types" without library specification are + | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + | without library specification are assumed to be typing.*. + * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown", + | it assumes that all inputs to the operator are being mutates. + + Callers (e.g. the custom ops API) are responsible for checking these assumptions. + + Args: + prototype_function: The function from which to infer a schema for from its type annotations. + op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the + name is not included in the inferred schema. Note that the input schema to + ``torch.library.Library.define`` requires a operator name. + mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function. + + Returns: + The inferred schema. + + Example: + >>> def foo_impl(x: torch.Tensor) -> torch.Tensor: + >>> return x.sin() + >>> + >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) + foo(Tensor x) -> Tensor + >>> + >>> infer_schema(foo_impl, mutates_args={}) + (Tensor x) -> Tensor + """ + sig = inspect.signature(prototype_function) + + params = [] + seen_args = set() + saw_kwarg_only_arg = False + for idx, (name, param) in enumerate(sig.parameters.items()): + if not supported_param(param): + error_fn( + "We do not support positional-only args, varargs, or varkwargs.", sig + ) + + if param.kind == inspect.Parameter.KEYWORD_ONLY: + # The first time we see a kwarg-only arg, add "*" to the schema. + if not saw_kwarg_only_arg: + params.append("*") + saw_kwarg_only_arg = True + + annotation_type = check_param_annotation(name, param.annotation, sig) + + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] + schema_type = get_schema_type(schema_type, mutates_args, name, sig, idx) + + seen_args.add(name) + if param.default is inspect.Parameter.empty: + params.append(f"{schema_type} {name}") + else: + default_repr = None + if param.default is None or isinstance(param.default, (int, float, bool)): + default_repr = str(param.default) + elif isinstance(param.default, (str, torch.device)): + default_repr = f'"{param.default}"' + elif isinstance(param.default, torch.dtype): + dtype_repr = str(param.default) + torch_dot = "torch." + assert dtype_repr.startswith(torch_dot) + default_repr = dtype_repr[len(torch_dot) :] + else: + error_fn( + f"Parameter {name} has an unsupported default value type {type(param.default)}. " + f"Please file an issue on GitHub so we can prioritize this.", + sig, + ) + params.append(f"{schema_type} {name}={default_repr}") + check_mutates_args(mutates_args, sig, seen_args) + + ret = get_return_annonation(sig.return_annotation) + if op_name is not None: + return f"{op_name}({', '.join(params)}) -> {ret}" + return f"({', '.join(params)}) -> {ret}"