Skip to content

Commit

Permalink
Add bf16 support to upsampling2d nearest (#750)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #750

Reviewed By: terrychenism, aakhundov

Differential Revision: D46504544

fbshipit-source-id: 662bc7d84db27969c972d50a7793c47ec3547ebb
  • Loading branch information
henryhu6 authored and facebook-github-bot committed Jun 15, 2023
1 parent 9ee885c commit e98d2dd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
3 changes: 3 additions & 0 deletions python/aitemplate/backend/cuda/upsample/upsampling2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@


Header_Files = """
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "cutlass/util/host_tensor.h"
using bfloat16 = __nv_bfloat16;
"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from aitemplate.compiler import compile_model
from aitemplate.frontend import IntVar, nn, Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import get_random_torch_tensor
from aitemplate.testing.test_utils import (
filter_test_cases_by_params,
get_random_torch_tensor,
TestEnv,
)
from parameterized import parameterized


_DEFAULT_BATCH_SIZE = [1, 3]
Expand Down Expand Up @@ -60,38 +65,29 @@ def _test_single_op(
y_transpose = torch.permute(y, (0, 3, 1, 2))
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))

def test_bilinear_upsample_fp16(self):
self._test_single_op(
scale_factor=3.5,
mode="bilinear",
test_name="bilinear_upsampling2d_fp16",
dtype="float16",
)

def test_nearest_upsample_fp16(self):
self._test_single_op(
scale_factor=2.0,
mode="nearest",
test_name="nearest_upsampling2d_fp16",
dtype="float16",
@parameterized.expand(
**filter_test_cases_by_params(
{
TestEnv.CUDA_LESS_THAN_SM80: [("float16"), ("float32")],
TestEnv.CUDA_SM80: [("bfloat16")],
TestEnv.ROCM: [("float16")],
}
)

@unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm")
def test_bilinear_upsample_fp32(self):
self._test_single_op(
scale_factor=3.5,
mode="bilinear",
test_name="bilinear_upsampling2d_fp32",
dtype="float32",
)

@unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm")
def test_nearest_upsample_fp32(self):
)
def test_upsampling2d_constructor(self, ait_dtype):
# Currently upsampling2d bilinear does not support bfloat16.
if ait_dtype != "bfloat16":
self._test_single_op(
scale_factor=3.5,
mode="bilinear",
test_name=f"bilinear_upsampling2d_{ait_dtype}",
dtype=ait_dtype,
)
self._test_single_op(
scale_factor=2.0,
mode="nearest",
test_name="nearest_upsampling2d_fp32",
dtype="float32",
test_name=f"nearest_upsampling2d_{ait_dtype}",
dtype=ait_dtype,
)


Expand Down
File renamed without changes.

0 comments on commit e98d2dd

Please sign in to comment.