Skip to content

Commit

Permalink
Introduce "no_tf32" param for targets, to disable kernels using numer…
Browse files Browse the repository at this point in the history
…ically less accurate tf32 (#874)

Summary:
Pull Request resolved: #874

Adressing github #872 ( #872 ): "Option for choosing fp32 gemm backend implementation"

As reported by Github user zhekunz2, small numerical discrepancies between pytorch's and AITemplate GEMM could be observed on GPUs >= SM80 ( A100 and above ) where GEMM Kernels with TF32 could be selected.

Most of the time these Kernels are a good choice due to their performance and relatively good accuracy, but sometimes perfect accuracy is required.

So this diff introduces a "no_tf32" option that can be passed to detect_target, which prevents the usage of certain Cutlass GEMM Kernels using TF32.

Example usage as in this new unit test, which is slightly modified code from the initial report:
```
def test_rrr_no_tf32(self):
        # Test accuracy with tf32 disabled
        # this test uses a smaller numerical tolerance level
        # than the others
        allow_tf32_bak = torch.backends.cuda.matmul.allow_tf32
        torch.backends.cuda.matmul.allow_tf32 = False
        try:
            test_dtype = torch.float32
            test_dtype_str = "float32"
            A = torch.rand((64, 64), dtype=test_dtype).cuda()
            B = torch.rand((64, 64), dtype=test_dtype).cuda()
            result_cuda = torch.matmul(A, B)

            target = detect_target(no_tf32=True)  # Disable tf32 for accuracy
            A_ait = Tensor(
                shape=[64, 64], dtype=test_dtype_str, name="input_0", is_input=True
            )
            B_ait = Tensor(
                shape=[64, 64], dtype=test_dtype_str, name="input_1", is_input=True
            )
            OP = ops.gemm_rrr()
            Y = OP(A_ait, B_ait)
            Y._attrs["name"] = "output_0"
            Y._attrs["is_output"] = True
            module = compile_model(Y, target, "./tmp", f"gemm_rrr_no_tf32")
            inputs = {
                "input_0": A.clone().detach().cuda(),
                "input_1": B.clone().detach().cuda(),
            }
            result_ait = torch.empty([64, 64], dtype=test_dtype, device="cuda")
            module.run_with_tensors(inputs, [result_ait])
            torch.testing.assert_close(result_cuda, result_ait)
        finally:
            torch.backends.cuda.matmul.allow_tf32 = allow_tf32_bak
```

Reviewed By: chenyang78

Differential Revision: D48034389

fbshipit-source-id: b2909743711c4fab66d5f613d0873047b082db03
  • Loading branch information
kadeng authored and facebook-github-bot committed Aug 10, 2023
1 parent edfa797 commit e0e00e2
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ def default_fproc(
):
return ret
acc_type = cutlass_lib.library.DataType.f32

if (
"no_tf32" in Target.current()._kwargs
and data_type == "float"
and Target.current()._kwargs["no_tf32"]
):
if (
op.tile_description.math_instruction.element_a
== cutlass_lib.library.DataType.tf32
):
return ret

# check target use fp16 acc
if "use_fp16_acc" in Target.current()._kwargs and data_type == "cutlass::half_t":
if Target.current()._kwargs["use_fp16_acc"]:
Expand Down
12 changes: 12 additions & 0 deletions python/aitemplate/backend/cuda/gemm_universal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,18 @@ def default_fproc(
):
return ret
acc_type = cutlass_lib.library.DataType.f32

if (
"no_tf32" in Target.current()._kwargs
and data_type == "float"
and Target.current()._kwargs["no_tf32"]
):
if (
op.tile_description.math_instruction.element_a
== cutlass_lib.library.DataType.tf32
):
return ret

# check target use fp16 acc
if "use_fp16_acc" in Target.current()._kwargs and data_type == "cutlass::half_t":
if Target.current()._kwargs["use_fp16_acc"]:
Expand Down
71 changes: 71 additions & 0 deletions tests/unittest/ops/test_gemm_no_tf32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

import torch

from aitemplate.compiler import compile_model, ops
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import filter_test_cases_by_test_env


@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.")
@unittest.skipIf(
detect_target().name() == "cuda" and int(detect_target()._arch) < 80,
"Not supported by CUDA < SM80.",
)
class GEMMNoTF32TestCase(unittest.TestCase):
def test_rrr_no_tf32(self):
# Test accuracy with tf32 disabled
# this test uses a smaller numerical tolerance level
# than the others
allow_tf32_bak = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
try:
test_dtype = torch.float32
test_dtype_str = "float32"
A = torch.rand((64, 64), dtype=test_dtype).cuda()
B = torch.rand((64, 64), dtype=test_dtype).cuda()
result_cuda = torch.matmul(A, B)

target = detect_target(no_tf32=True) # Disable tf32 for accuracy
A_ait = Tensor(
shape=[64, 64], dtype=test_dtype_str, name="input_0", is_input=True
)
B_ait = Tensor(
shape=[64, 64], dtype=test_dtype_str, name="input_1", is_input=True
)
OP = ops.gemm_rrr()
Y = OP(A_ait, B_ait)
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
module = compile_model(Y, target, "./tmp", "gemm_rrr_no_tf32")
inputs = {
"input_0": A.clone().detach().cuda(),
"input_1": B.clone().detach().cuda(),
}
result_ait = torch.empty([64, 64], dtype=test_dtype, device="cuda")
module.run_with_tensors(inputs, [result_ait])
torch.testing.assert_close(result_cuda, result_ait)
finally:
torch.backends.cuda.matmul.allow_tf32 = allow_tf32_bak


filter_test_cases_by_test_env(GEMMNoTF32TestCase)


if __name__ == "__main__":
unittest.main()

0 comments on commit e0e00e2

Please sign in to comment.