diff --git a/python/aitemplate/backend/cuda/upsample/upsampling2d.py b/python/aitemplate/backend/cuda/upsample/upsampling2d.py index ebb30dab9..d25912eab 100644 --- a/python/aitemplate/backend/cuda/upsample/upsampling2d.py +++ b/python/aitemplate/backend/cuda/upsample/upsampling2d.py @@ -24,9 +24,12 @@ Header_Files = """ +#include #include #include #include "cutlass/util/host_tensor.h" + +using bfloat16 = __nv_bfloat16; """ diff --git a/tests/unittest/ops/test_upsamping2d.py b/tests/unittest/ops/test_upsampling2d.py similarity index 69% rename from tests/unittest/ops/test_upsamping2d.py rename to tests/unittest/ops/test_upsampling2d.py index 2c4e88660..702fd6bd5 100644 --- a/tests/unittest/ops/test_upsamping2d.py +++ b/tests/unittest/ops/test_upsampling2d.py @@ -19,7 +19,12 @@ from aitemplate.compiler import compile_model from aitemplate.frontend import IntVar, nn, Tensor from aitemplate.testing import detect_target -from aitemplate.testing.test_utils import get_random_torch_tensor +from aitemplate.testing.test_utils import ( + filter_test_cases_by_params, + get_random_torch_tensor, + TestEnv, +) +from parameterized import parameterized _DEFAULT_BATCH_SIZE = [1, 3] @@ -60,38 +65,29 @@ def _test_single_op( y_transpose = torch.permute(y, (0, 3, 1, 2)) self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2)) - def test_bilinear_upsample_fp16(self): - self._test_single_op( - scale_factor=3.5, - mode="bilinear", - test_name="bilinear_upsampling2d_fp16", - dtype="float16", - ) - - def test_nearest_upsample_fp16(self): - self._test_single_op( - scale_factor=2.0, - mode="nearest", - test_name="nearest_upsampling2d_fp16", - dtype="float16", + @parameterized.expand( + **filter_test_cases_by_params( + { + TestEnv.CUDA_LESS_THAN_SM80: [("float16"), ("float32")], + TestEnv.CUDA_SM80: [("bfloat16")], + TestEnv.ROCM: [("float16")], + } ) - - @unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm") - def test_bilinear_upsample_fp32(self): - self._test_single_op( - scale_factor=3.5, - mode="bilinear", - test_name="bilinear_upsampling2d_fp32", - dtype="float32", - ) - - @unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm") - def test_nearest_upsample_fp32(self): + ) + def test_upsampling2d_constructor(self, ait_dtype): + # Currently upsampling2d bilinear does not support bfloat16. + if ait_dtype != "bfloat16": + self._test_single_op( + scale_factor=3.5, + mode="bilinear", + test_name=f"bilinear_upsampling2d_{ait_dtype}", + dtype=ait_dtype, + ) self._test_single_op( scale_factor=2.0, mode="nearest", - test_name="nearest_upsampling2d_fp32", - dtype="float32", + test_name=f"nearest_upsampling2d_{ait_dtype}", + dtype=ait_dtype, ) diff --git a/tests/unittest/ops/test_upsamping2d_add.py b/tests/unittest/ops/test_upsampling2d_add.py similarity index 100% rename from tests/unittest/ops/test_upsamping2d_add.py rename to tests/unittest/ops/test_upsampling2d_add.py