From 90ba6e48ce139013791e231f4104a9b6e430cf6b Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 26 Feb 2024 13:41:56 -0800 Subject: [PATCH] Add arg checks to `input_combine_with_length` Summary: - Add arg checks to `input_combine_with_length` Reviewed By: sryap Differential Revision: D54138774 fbshipit-source-id: 26dce6bc8939124594aad0770601915100b11f21 --- .../include/fbgemm_gpu/sparse_ops_utils.h | 8 ++ .../input_combine_ops/input_combine_cpu.cpp | 22 ++- .../input_combine_ops/input_combine_gpu.cpp | 30 +++- fbgemm_gpu/test/combine/__init__.py | 6 + fbgemm_gpu/test/combine/bad_inputs_test.py | 79 ++++++++++ fbgemm_gpu/test/combine/common.py | 126 ++++++++++++++++ fbgemm_gpu/test/combine/failures_dict.json | 85 +++++++++++ .../test/{ => combine}/input_combine_test.py | 135 +----------------- fbgemm_gpu/test/failures_dict.json | 79 ---------- 9 files changed, 350 insertions(+), 220 deletions(-) create mode 100644 fbgemm_gpu/test/combine/__init__.py create mode 100644 fbgemm_gpu/test/combine/bad_inputs_test.py create mode 100644 fbgemm_gpu/test/combine/common.py create mode 100644 fbgemm_gpu/test/combine/failures_dict.json rename fbgemm_gpu/test/{ => combine}/input_combine_test.py (64%) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h index b6449134f..b564d77be 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h @@ -216,6 +216,14 @@ inline bool torch_tensor_empty_or_on_cpu_check( " and ", \ (y).numel()) +#define TENSOR_NUMEL_IS_GT(ten, num) \ + TORCH_CHECK( \ + (ten).numel() > (num), \ + "Tensor '" #ten "' must have more than " #num \ + " element(s). " \ + "Found ", \ + (ten).numel()) + template std::string tensor_on_same_gpu_if_not_optional_check( const std::string& var_names_str, diff --git a/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp b/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp index d4a64ce61..eff9d48f2 100644 --- a/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp +++ b/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp @@ -6,10 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include "fbgemm_gpu/dispatch_macros.h" -#include "fbgemm_gpu/input_combine.h" -#include "fbgemm_gpu/sparse_ops_utils.h" - #include #include #include @@ -21,6 +17,10 @@ #include #include +#include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/input_combine.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + using Tensor = at::Tensor; namespace fbgemm_gpu { @@ -204,9 +204,23 @@ std::tuple tbe_input_combine_with_length_cpu( const std::vector& indices_list, const std::vector& lengths_list, const std::vector& per_sample_weights) { + // The list sizes should be non-zero and match TORCH_CHECK_GT(indices_list.size(), 0); TORCH_CHECK_EQ(lengths_list.size(), indices_list.size()); TORCH_CHECK_EQ(per_sample_weights.size(), indices_list.size()); + + // Either the corresponding weights are provided for all indices tensors, or + // none are provided for any of the indices tensors + { + const auto nonempty_weights = static_cast(std::count_if( + per_sample_weights.begin(), + per_sample_weights.end(), + [](const auto& ten) { return ten.numel() > 0; })); + TORCH_CHECK( + nonempty_weights == 0 || nonempty_weights == indices_list.size(), + "Either all weights tensors should be empty, or all should be non-empty"); + } + int64_t total_indices = 0; int64_t total_lengths = 0; bool need_weights = false; diff --git a/fbgemm_gpu/src/input_combine_ops/input_combine_gpu.cpp b/fbgemm_gpu/src/input_combine_ops/input_combine_gpu.cpp index 4aa90d01a..9b8225785 100644 --- a/fbgemm_gpu/src/input_combine_ops/input_combine_gpu.cpp +++ b/fbgemm_gpu/src/input_combine_ops/input_combine_gpu.cpp @@ -6,12 +6,11 @@ * LICENSE file in the root directory of this source tree. */ -#include "fbgemm_gpu/input_combine.h" -#include "fbgemm_gpu/sparse_ops_utils.h" - #include #include #include +#include "fbgemm_gpu/input_combine.h" +#include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; @@ -62,10 +61,24 @@ std::tuple tbe_input_combine_with_length_gpu( const std::vector& indices_list, const std::vector& lengths_list, const std::vector& per_sample_weights) { + // The list sizes should be non-zero and match + TORCH_CHECK_GT(indices_list.size(), 0); + TORCH_CHECK_EQ(lengths_list.size(), indices_list.size()); + TORCH_CHECK_EQ(per_sample_weights.size(), indices_list.size()); + + // Either the corresponding weights are provided for all indices tensors, or + // none are provided for any of the indices tensors + { + const auto nonempty_weights = static_cast(std::count_if( + per_sample_weights.begin(), + per_sample_weights.end(), + [](const auto& ten) { return ten.numel() > 0; })); + TORCH_CHECK( + nonempty_weights == 0 || nonempty_weights == indices_list.size(), + "Either all weights tensors should be empty, or all should be non-empty"); + } + const auto num_lists = indices_list.size(); - TORCH_CHECK_GT(num_lists, 0); - TORCH_CHECK_EQ(lengths_list.size(), num_lists); - TORCH_CHECK_EQ(per_sample_weights.size(), num_lists); const bool need_weights = std::any_of( per_sample_weights.begin(), per_sample_weights.end(), [](const auto& x) { return x.numel() > 0; @@ -143,14 +156,19 @@ std::tuple tbe_input_combine_with_length_gpu( } const auto& indices = indices_list[i]; const auto& lengths = lengths_list[i]; + + // Tensors are contiguous, on same device, and have same dtype TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(indices); TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(lengths); TENSORS_ON_SAME_DEVICE(indices, indices_0); TENSORS_ON_SAME_DEVICE(lengths, indices_0); TORCH_CHECK(indices.dtype() == c10::kInt || indices.dtype() == c10::kLong); TORCH_CHECK(lengths.dtype() == c10::kInt || lengths.dtype() == c10::kLong); + // Dimensions must be 1 TENSOR_NDIM_EQUALS(indices, 1); TENSOR_NDIM_EQUALS(lengths, 1); + // Indices must be non-empty + TENSOR_NUMEL_IS_GT(indices, 0); const auto indices_numel = indices.numel(); const auto lengths_numel = lengths.numel(); diff --git a/fbgemm_gpu/test/combine/__init__.py b/fbgemm_gpu/test/combine/__init__.py new file mode 100644 index 000000000..a9fdb3b99 --- /dev/null +++ b/fbgemm_gpu/test/combine/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fbgemm_gpu/test/combine/bad_inputs_test.py b/fbgemm_gpu/test/combine/bad_inputs_test.py new file mode 100644 index 000000000..88c7b3b5f --- /dev/null +++ b/fbgemm_gpu/test/combine/bad_inputs_test.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from hypothesis import given, settings + +from .common import open_source + +if open_source: + # pyre-ignore[21] + from test_utils import cpu_and_maybe_gpu, optests +else: + from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, optests + + +@optests.generate_opcheck_tests() +class BadInputsTest(unittest.TestCase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @given(device=cpu_and_maybe_gpu()) + @settings(deadline=None) + def test_tbe_input_combine_with_length_bad_args(self, device: torch.device) -> None: + arg0_list = [ + [88, 55], + [80, 29], + [2, 85], + [39, 51], + [84, 35], + [12, 6], + [94, 43], + [98, 59], + [19, 68], + [97, 89], + ] + arg0 = [torch.tensor(t, dtype=torch.int32, device=device) for t in arg0_list] + + arg1_list = [ + [1, 2], + [1, 2], + [1, 2], + [1, 2], + [1, 2], + [1, 2], + [1, 2], + [1, 2], + [1, 2], + [1, 2], + ] + arg1 = [torch.tensor(t, dtype=torch.int32, device=device) for t in arg1_list] + + arg2_list = [ + [], + [], + [], + [], + [3.0, 3.0], + [], + [], + [3.0, 3.0], + [3.0, 3.0], + [], + ] + arg2 = [torch.tensor(t, dtype=torch.float, device=device) for t in arg2_list] + + with self.assertRaises(RuntimeError): + torch.ops.fbgemm.tbe_input_combine_with_length( + arg0, + arg1, + arg2, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/combine/common.py b/fbgemm_gpu/test/combine/common.py new file mode 100644 index 000000000..996d1d689 --- /dev/null +++ b/fbgemm_gpu/test/combine/common.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors[56] + +from typing import List, Optional, Tuple + +import fbgemm_gpu +import torch +from fbgemm_gpu import sparse_ops # noqa: F401 + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if not open_source: + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") + + +class TBEInputPrepareReference(torch.nn.Module): + def __init__(self, include_last_offsets: List[bool]) -> None: + super().__init__() + self.include_last_offsets = include_last_offsets + + def forward( # noqa C901 + self, + indices_list: List[torch.Tensor], + offsets_list: List[torch.Tensor], + per_sample_weights_list: List[torch.Tensor], + batch_size: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + size = 0 + assert len(indices_list) > 0 + assert len(indices_list) == len(offsets_list) + assert len(indices_list) == len(per_sample_weights_list) + assert len(indices_list) == len(self.include_last_offsets) + for i in range(len(self.include_last_offsets)): + size += indices_list[i].size(0) + assert indices_list[i].dim() == 1 + assert offsets_list[i].dim() == 1 + if per_sample_weights_list[i].numel() > 0: + assert per_sample_weights_list[i].dim() == 1 + assert indices_list[i].numel() == per_sample_weights_list[i].numel() + combined_indices = torch.empty( + size, + dtype=torch.int32, + device=indices_list[0].device, + ) + torch.cat(indices_list, out=combined_indices) + offsets_starts = torch.zeros( + [len(offsets_list) + 1], + dtype=offsets_list[0].dtype, + device=offsets_list[0].device, + ) + offsets_accs = torch.zeros( + [len(offsets_list) + 1], + dtype=offsets_list[0].dtype, + device=offsets_list[0].device, + ) + + for i, include_last_offset in enumerate(self.include_last_offsets): + if include_last_offset: + offsets_starts[i + 1] = offsets_starts[i] + offsets_list[i].size(0) - 1 + else: + offsets_starts[i + 1] = offsets_starts[i] + offsets_list[i].size(0) + offsets_accs[i + 1] = offsets_accs[i] + indices_list[i].size(0) + + assert offsets_accs[-1] == combined_indices.size(0) + combined_offsets_size: List[int] = ( + [int(offsets_starts[-1].item()) + 1] + if batch_size is None + else [batch_size * len(offsets_list) + 1] + ) + combined_offsets = torch.zeros( + combined_offsets_size, + dtype=torch.int32, + device=offsets_list[0].device, + ) + if batch_size is None: + for i in range(len(self.include_last_offsets)): + combined_offsets[offsets_starts[i] : offsets_starts[i + 1]] = ( + offsets_list[i][: offsets_starts[i + 1] - offsets_starts[i]] + + offsets_accs[i] + ) + else: + for i in range(len(self.include_last_offsets)): + cur_start = batch_size * i + combined_offsets[ + cur_start : cur_start + offsets_starts[i + 1] - offsets_starts[i] + ] = ( + offsets_list[i][: offsets_starts[i + 1] - offsets_starts[i]] + + offsets_accs[i] + ) + cur_start = cur_start + offsets_starts[i + 1] - offsets_starts[i] + for j in range(batch_size - offsets_starts[i + 1] + offsets_starts[i]): + combined_offsets[cur_start + j] = ( + indices_list[i].numel() + offsets_accs[i] + ) + combined_offsets[-1] = offsets_accs[-1] + per_sample_weights: Optional[torch.Tensor] = None + for i in range(len(self.include_last_offsets)): + if per_sample_weights_list[i].size(0) > 0: + per_sample_weights = torch.ones( + combined_indices.size(0), + dtype=per_sample_weights_list[i].dtype, + device=per_sample_weights_list[i].device, + ) + break + if per_sample_weights is not None: + for i in range(len(self.include_last_offsets)): + if per_sample_weights_list[i].size(0) > 0: + # fmt: off + per_sample_weights[offsets_accs[i] : offsets_accs[i + 1]] = ( + per_sample_weights_list[i][:] + ) + # fmt: on + + # indices and offsets are required to be int32 for TBE + return combined_indices, combined_offsets, per_sample_weights diff --git a/fbgemm_gpu/test/combine/failures_dict.json b/fbgemm_gpu/test/combine/failures_dict.json new file mode 100644 index 000000000..3146d2f8c --- /dev/null +++ b/fbgemm_gpu/test/combine/failures_dict.json @@ -0,0 +1,85 @@ +{ + "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", + "_version": 1, + "data": { + "fbgemm::padding_fused_tbe_input_combine": { + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int32": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int64": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combined_mix": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int32": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int64": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combined_mix": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::padding_fused_tbe_input_combine_with_length": { + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int32_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int64_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combined_mix_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int32_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int64_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combined_mix_with_length": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::tbe_input_combine": {}, + "fbgemm::tbe_input_combine_with_length": { + "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_int32_with_length": { + "comment": "", + "status": "xsuccess" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_int64_with_length": { + "comment": "", + "status": "xsuccess" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_mix_with_length": { + "comment": "", + "status": "xsuccess" + }, + "InputCombineTest.test_faketensor__test_input_combine_int32_with_length": { + "comment": "", + "status": "xsuccess" + }, + "InputCombineTest.test_faketensor__test_input_combine_int64_with_length": { + "comment": "", + "status": "xsuccess" + }, + "InputCombineTest.test_faketensor__test_input_combine_mix_with_length": { + "comment": "", + "status": "xsuccess" + } + } + } +} diff --git a/fbgemm_gpu/test/input_combine_test.py b/fbgemm_gpu/test/combine/input_combine_test.py similarity index 64% rename from fbgemm_gpu/test/input_combine_test.py rename to fbgemm_gpu/test/combine/input_combine_test.py index 883472714..597e949b2 100644 --- a/fbgemm_gpu/test/input_combine_test.py +++ b/fbgemm_gpu/test/combine/input_combine_test.py @@ -5,150 +5,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import unittest -from typing import Callable, Dict, List, Optional, Tuple import torch -from fbgemm_gpu import sparse_ops # noqa: F401 from hypothesis import given, settings -try: - # pyre-ignore[21] - from fbgemm_gpu import open_source # noqa: F401 +from .common import open_source, TBEInputPrepareReference +if open_source: # pyre-ignore[21] from test_utils import cpu_and_maybe_gpu, optests -except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") +else: from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, optests DEFAULT_DEVICE = torch.device("cpu") -class TBEInputPrepareReference(torch.nn.Module): - def __init__(self, include_last_offsets: List[bool]) -> None: - super().__init__() - self.include_last_offsets = include_last_offsets - - def forward( # noqa C901 - self, - indices_list: List[torch.Tensor], - offsets_list: List[torch.Tensor], - per_sample_weights_list: List[torch.Tensor], - batch_size: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - size = 0 - assert len(indices_list) > 0 - assert len(indices_list) == len(offsets_list) - assert len(indices_list) == len(per_sample_weights_list) - assert len(indices_list) == len(self.include_last_offsets) - for i in range(len(self.include_last_offsets)): - size += indices_list[i].size(0) - assert indices_list[i].dim() == 1 - assert offsets_list[i].dim() == 1 - if per_sample_weights_list[i].numel() > 0: - assert per_sample_weights_list[i].dim() == 1 - assert indices_list[i].numel() == per_sample_weights_list[i].numel() - combined_indices = torch.empty( - size, - dtype=torch.int32, - device=indices_list[0].device, - ) - torch.cat(indices_list, out=combined_indices) - offsets_starts = torch.zeros( - [len(offsets_list) + 1], - dtype=offsets_list[0].dtype, - device=offsets_list[0].device, - ) - offsets_accs = torch.zeros( - [len(offsets_list) + 1], - dtype=offsets_list[0].dtype, - device=offsets_list[0].device, - ) - - for i, include_last_offset in enumerate(self.include_last_offsets): - if include_last_offset: - offsets_starts[i + 1] = offsets_starts[i] + offsets_list[i].size(0) - 1 - else: - offsets_starts[i + 1] = offsets_starts[i] + offsets_list[i].size(0) - offsets_accs[i + 1] = offsets_accs[i] + indices_list[i].size(0) - - assert offsets_accs[-1] == combined_indices.size(0) - combined_offsets_size: List[int] = ( - [int(offsets_starts[-1].item()) + 1] - if batch_size is None - else [batch_size * len(offsets_list) + 1] - ) - combined_offsets = torch.zeros( - combined_offsets_size, - dtype=torch.int32, - device=offsets_list[0].device, - ) - if batch_size is None: - for i in range(len(self.include_last_offsets)): - combined_offsets[offsets_starts[i] : offsets_starts[i + 1]] = ( - offsets_list[i][: offsets_starts[i + 1] - offsets_starts[i]] - + offsets_accs[i] - ) - else: - for i in range(len(self.include_last_offsets)): - cur_start = batch_size * i - combined_offsets[ - cur_start : cur_start + offsets_starts[i + 1] - offsets_starts[i] - ] = ( - offsets_list[i][: offsets_starts[i + 1] - offsets_starts[i]] - + offsets_accs[i] - ) - cur_start = cur_start + offsets_starts[i + 1] - offsets_starts[i] - for j in range(batch_size - offsets_starts[i + 1] + offsets_starts[i]): - combined_offsets[cur_start + j] = ( - indices_list[i].numel() + offsets_accs[i] - ) - combined_offsets[-1] = offsets_accs[-1] - per_sample_weights: Optional[torch.Tensor] = None - for i in range(len(self.include_last_offsets)): - if per_sample_weights_list[i].size(0) > 0: - per_sample_weights = torch.ones( - combined_indices.size(0), - dtype=per_sample_weights_list[i].dtype, - device=per_sample_weights_list[i].device, - ) - break - if per_sample_weights is not None: - for i in range(len(self.include_last_offsets)): - if per_sample_weights_list[i].size(0) > 0: - # fmt: off - per_sample_weights[offsets_accs[i] : offsets_accs[i + 1]] = ( - per_sample_weights_list[i][:] - ) - # fmt: on - - # indices and offsets are required to be int32 for TBE - return combined_indices, combined_offsets, per_sample_weights - - -# e.g. "test_faketensor__test_cumsum": [unittest.expectedFailure] -# Please avoid putting tests here, you should put operator-specific -# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. -additional_decorators: Dict[str, List[Callable]] = { - "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ - # This operator has been grandfathered in. We need to fix this test failure. - unittest.expectedFailure, - ], - "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add_jagged_output": [ - # This operator has been grandfathered in. We need to fix this test failure. - unittest.expectedFailure, - ], -} - - -@optests.generate_opcheck_tests(additional_decorators=additional_decorators) +@optests.generate_opcheck_tests() class InputCombineTest(unittest.TestCase): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index 39860aa73..52d46322a 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -289,58 +289,6 @@ "status": "xsuccess" } }, - "fbgemm::padding_fused_tbe_input_combine": { - "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int32": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int64": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combined_mix": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int32": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int64": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_faketensor__test_padding_fused_input_combined_mix": { - "comment": "", - "status": "xfail" - } - }, - "fbgemm::padding_fused_tbe_input_combine_with_length": { - "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int32_with_length": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int64_with_length": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combined_mix_with_length": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int32_with_length": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int64_with_length": { - "comment": "", - "status": "xfail" - }, - "InputCombineTest.test_faketensor__test_padding_fused_input_combined_mix_with_length": { - "comment": "", - "status": "xfail" - } - }, "fbgemm::permute102_baddbmm_permute102": { "SparseOpsTest.test_aot_dispatch_dynamic__test_permute102_baddbmm_permute102": { "comment": "", @@ -468,33 +416,6 @@ "comment": "", "status": "xfail" } - }, - "fbgemm::tbe_input_combine": {}, - "fbgemm::tbe_input_combine_with_length": { - "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_int32_with_length": { - "comment": "", - "status": "xsuccess" - }, - "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_int64_with_length": { - "comment": "", - "status": "xsuccess" - }, - "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_mix_with_length": { - "comment": "", - "status": "xsuccess" - }, - "InputCombineTest.test_faketensor__test_input_combine_int32_with_length": { - "comment": "", - "status": "xsuccess" - }, - "InputCombineTest.test_faketensor__test_input_combine_int64_with_length": { - "comment": "", - "status": "xsuccess" - }, - "InputCombineTest.test_faketensor__test_input_combine_mix_with_length": { - "comment": "", - "status": "xsuccess" - } } } }