From c01ed4dcabc1c04f5d6f8e972543bae0e6971f73 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Mon, 31 Jul 2023 20:51:33 -0700 Subject: [PATCH] Make AIT use fp32 accumulation for reduce_3d kernels by default. (#862) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/862 Use fp32 accumulation by default, only use fp16 if use_fp16_acc == True. Reviewed By: chenyang78 Differential Revision: D47928197 fbshipit-source-id: 0498edc19fa617ec608e5fd263db104b4c529f2c --- .../backend/cuda/reduce/reduce_3d.py | 9 +++++- python/aitemplate/backend/cuda/reduce/var.py | 30 +++++++++++-------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/python/aitemplate/backend/cuda/reduce/reduce_3d.py b/python/aitemplate/backend/cuda/reduce/reduce_3d.py index c8728b9b1..a259d3974 100644 --- a/python/aitemplate/backend/cuda/reduce/reduce_3d.py +++ b/python/aitemplate/backend/cuda/reduce/reduce_3d.py @@ -28,6 +28,7 @@ from aitemplate.backend.common import tensor_accessor_codegen from aitemplate.backend.cuda.reduce import reduce_small_axis +from aitemplate.backend.target import Target DEFAULT_PROLOGUE_TEMPLATE = jinja2.Template( @@ -830,7 +831,13 @@ def gen_function( output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) if accumulation_type is None: # follow pytorch's semantics - acc_type = output_type + if ( + Target.current()._kwargs.get("use_fp16_acc", False) + and y._attrs["dtype"] == "float16" + ): + acc_type = output_type + else: + acc_type = "float" else: acc_type = accumulation_type diff --git a/python/aitemplate/backend/cuda/reduce/var.py b/python/aitemplate/backend/cuda/reduce/var.py index 754b07cf8..80b5dc336 100644 --- a/python/aitemplate/backend/cuda/reduce/var.py +++ b/python/aitemplate/backend/cuda/reduce/var.py @@ -24,6 +24,7 @@ from aitemplate.backend import registry from aitemplate.backend.backend_spec import CUDASpec from aitemplate.backend.cuda.reduce import reduce_3d +from aitemplate.backend.target import Target EXTRA_CODE_TEMPLATE = jinja2.Template( @@ -148,17 +149,17 @@ } // namespace arch template -struct NumericConverter, +struct NumericConverter, ElementT, FloatRoundStyle::round_to_nearest> { - using result_type = WelfordData; + using result_type = WelfordData<{{acc_type}}, BesselCorrection>; using source_type = ElementT; static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { - return WelfordData(-1, static_cast(s), ElementT(0)); + return WelfordData<{{acc_type}}, BesselCorrection>(-1, static_cast<{{acc_type}}>(s), {{acc_type}}(0)); } CUTLASS_HOST_DEVICE @@ -169,11 +170,11 @@ template struct NumericConverter, + WelfordData<{{acc_type}}, BesselCorrection>, FloatRoundStyle::round_to_nearest> { using result_type = ElementT; - using source_type = WelfordData; + using source_type = WelfordData<{{acc_type}}, BesselCorrection>; static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; CUTLASS_HOST_DEVICE @@ -183,14 +184,14 @@ if (s.count <= 1) { return ElementT(nanf("Not a Number")); } else { - return s.m2 / ElementT((int)(s.count - 1)); + return ElementT(s.m2) / ElementT((int)(s.count - 1)); } } else { // sample variance if (s.count <= 0) { return ElementT(nanf("Not a Number")); } else { - return s.m2 / ElementT((int)(s.count)); + return ElementT(s.m2) / ElementT((int)(s.count)); } } } @@ -294,17 +295,20 @@ def var_gen_function(func_attrs) -> str: """ bessel = "true" if func_attrs["unbiased"] else "false" backend_spec = CUDASpec() - elem_output_type = backend_spec.dtype_to_lib_type( - func_attrs["outputs"][0]._attrs["dtype"] - ) - acc_type = f"WelfordData<{elem_output_type}, {bessel}>" + output_type = func_attrs["outputs"][0]._attrs["dtype"] + elem_output_type = backend_spec.dtype_to_lib_type(output_type) + + acc_type = "float" + if Target.current()._kwargs.get("use_fp16_acc", False) and output_type == "float16": + acc_type = elem_output_type + welford_type = f"WelfordData<{acc_type}, {bessel}>" return reduce_3d.gen_function( func_attrs, "cutlass::welford_op", reduce_3d.DEFAULT_PROLOGUE_TEMPLATE, reduce_3d.DEFAULT_EPILOGUE_SCALAR_TEMPLATE, - EXTRA_CODE_TEMPLATE.render(), - accumulation_type=acc_type, + EXTRA_CODE_TEMPLATE.render(acc_type=acc_type), + accumulation_type=welford_type, )