Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make AIT use fp32 accumulation for reduce_3d kernels by default. #862

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/aitemplate/backend/cuda/reduce/reduce_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from aitemplate.backend.common import tensor_accessor_codegen

from aitemplate.backend.cuda.reduce import reduce_small_axis
from aitemplate.backend.target import Target


DEFAULT_PROLOGUE_TEMPLATE = jinja2.Template(
Expand Down Expand Up @@ -830,7 +831,13 @@ def gen_function(
output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"])
if accumulation_type is None:
# follow pytorch's semantics
acc_type = output_type
if (
Target.current()._kwargs.get("use_fp16_acc", False)
and y._attrs["dtype"] == "float16"
):
acc_type = output_type
else:
acc_type = "float"
else:
acc_type = accumulation_type

Expand Down
30 changes: 17 additions & 13 deletions python/aitemplate/backend/cuda/reduce/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from aitemplate.backend import registry
from aitemplate.backend.backend_spec import CUDASpec
from aitemplate.backend.cuda.reduce import reduce_3d
from aitemplate.backend.target import Target


EXTRA_CODE_TEMPLATE = jinja2.Template(
Expand Down Expand Up @@ -148,17 +149,17 @@
} // namespace arch
template <typename ElementT, bool BesselCorrection>
struct NumericConverter<WelfordData<ElementT, BesselCorrection>,
struct NumericConverter<WelfordData<{{acc_type}}, BesselCorrection>,
ElementT,
FloatRoundStyle::round_to_nearest> {
using result_type = WelfordData<ElementT, BesselCorrection>;
using result_type = WelfordData<{{acc_type}}, BesselCorrection>;
using source_type = ElementT;
static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & s) {
return WelfordData<ElementT, BesselCorrection>(-1, static_cast<ElementT>(s), ElementT(0));
return WelfordData<{{acc_type}}, BesselCorrection>(-1, static_cast<{{acc_type}}>(s), {{acc_type}}(0));
}
CUTLASS_HOST_DEVICE
Expand All @@ -169,11 +170,11 @@
template <typename ElementT, bool BesselCorrection>
struct NumericConverter<ElementT,
WelfordData<ElementT, BesselCorrection>,
WelfordData<{{acc_type}}, BesselCorrection>,
FloatRoundStyle::round_to_nearest> {
using result_type = ElementT;
using source_type = WelfordData<ElementT, BesselCorrection>;
using source_type = WelfordData<{{acc_type}}, BesselCorrection>;
static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest;
CUTLASS_HOST_DEVICE
Expand All @@ -183,14 +184,14 @@
if (s.count <= 1) {
return ElementT(nanf("Not a Number"));
} else {
return s.m2 / ElementT((int)(s.count - 1));
return ElementT(s.m2) / ElementT((int)(s.count - 1));
}
} else {
// sample variance
if (s.count <= 0) {
return ElementT(nanf("Not a Number"));
} else {
return s.m2 / ElementT((int)(s.count));
return ElementT(s.m2) / ElementT((int)(s.count));
}
}
}
Expand Down Expand Up @@ -294,17 +295,20 @@ def var_gen_function(func_attrs) -> str:
"""
bessel = "true" if func_attrs["unbiased"] else "false"
backend_spec = CUDASpec()
elem_output_type = backend_spec.dtype_to_lib_type(
func_attrs["outputs"][0]._attrs["dtype"]
)
acc_type = f"WelfordData<{elem_output_type}, {bessel}>"
output_type = func_attrs["outputs"][0]._attrs["dtype"]
elem_output_type = backend_spec.dtype_to_lib_type(output_type)

acc_type = "float"
if Target.current()._kwargs.get("use_fp16_acc", False) and output_type == "float16":
acc_type = elem_output_type
welford_type = f"WelfordData<{acc_type}, {bessel}>"
return reduce_3d.gen_function(
func_attrs,
"cutlass::welford_op",
reduce_3d.DEFAULT_PROLOGUE_TEMPLATE,
reduce_3d.DEFAULT_EPILOGUE_SCALAR_TEMPLATE,
EXTRA_CODE_TEMPLATE.render(),
accumulation_type=acc_type,
EXTRA_CODE_TEMPLATE.render(acc_type=acc_type),
accumulation_type=welford_type,
)


Expand Down
Loading