From f5896d0f417b8a462cffc35a7c5d664a9eadfd50 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Tue, 20 Jun 2023 17:03:48 -0700 Subject: [PATCH] upstream norm (#774) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/774 Reviewed By: chenyang78 Differential Revision: D46847587 Pulled By: ipiszy fbshipit-source-id: af159f8b0f76112c803bc56589d4e6b41742ecc2 --- .../backend/rocm/normalization/groupnorm.py | 9 ++- .../backend/rocm/normalization/layernorm.py | 58 ++++++++++++------- .../backend/rocm/normalization/norm_common.py | 8 ++- .../backend/rocm/normalization/softmax.py | 9 +-- 4 files changed, 50 insertions(+), 34 deletions(-) diff --git a/python/aitemplate/backend/rocm/normalization/groupnorm.py b/python/aitemplate/backend/rocm/normalization/groupnorm.py index ab8dced5e..a059fac29 100644 --- a/python/aitemplate/backend/rocm/normalization/groupnorm.py +++ b/python/aitemplate/backend/rocm/normalization/groupnorm.py @@ -29,7 +29,7 @@ EXTRA_HEADERS = jinja2.Template( """ -#include "include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" """ ) @@ -127,16 +127,15 @@ static_cast(gamma), static_cast(beta), static_cast(output), + nullptr, + nullptr, YElementOp{} ); if(!device_instance.IsSupportedArgument(argument_ptr.get())) { - throw std::runtime_error( - "wrong! device_layernorm with the specified compilation parameters does " - "not support this Groupnorm problem"); + LOG(FATAL) << "wrong! " << device_instance.GetTypeString() << " with the specified compilation parameters does not support this Groupnorm problem."; }; - std::string instance_name = device_instance.GetTypeString(); auto invoker_ptr = device_instance.MakeInvokerPointer(); invoker_ptr->Run(argument_ptr.get(), StreamConfig{stream, false}); return; diff --git a/python/aitemplate/backend/rocm/normalization/layernorm.py b/python/aitemplate/backend/rocm/normalization/layernorm.py index 93d2216aa..af3efcf24 100644 --- a/python/aitemplate/backend/rocm/normalization/layernorm.py +++ b/python/aitemplate/backend/rocm/normalization/layernorm.py @@ -29,7 +29,7 @@ EXTRA_HEADERS = jinja2.Template( """ -#include "include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" """ ) @@ -65,10 +65,22 @@ EXEC_TEMPLATE = jinja2.Template( """ std::vector i_inStrides; - + std::vector i_outStrides; + {% if input_strides is defined %} + i_inStrides.push_back({{input_strides[-2]}}); + i_inStrides.push_back({{input_strides[-1]}}); + {% else %} i_inStrides.push_back(N); i_inStrides.push_back(1); + {% endif %} + {% if output_strides is defined %} + i_outStrides.push_back({{output_strides[-2]}}); + i_outStrides.push_back({{output_strides[-1]}}); + {% else %} + i_outStrides.push_back(N); + i_outStrides.push_back(1); + {% endif %} auto device_instance = {{instance}}{}; auto argument_ptr = device_instance.MakeArgumentPointer( @@ -76,23 +88,22 @@ i_inStrides, std::vector{0, 1}, std::vector{0, 1}, - i_inStrides, + i_outStrides, {1}, {{eps}}, - static_cast(input), + static_cast(input) + {{ input_offset if input_offset is defined else 0 }}, static_cast(gamma), static_cast(beta), - static_cast(output), + static_cast(output) + {{ output_offset if output_offset is defined else 0 }}, + nullptr, + nullptr, ck::tensor_operation::element_wise::PassThrough{} ); if(!device_instance.IsSupportedArgument(argument_ptr.get())) { - throw std::runtime_error( - "wrong! device_layernorm with the specified compilation parameters does " - "not support this Softmax problem"); + LOG(FATAL) << "wrong! " << device_instance.GetTypeString() << " with the specified compilation parameters does not support this Layernorm problem."; }; - std::string instance_name = device_instance.GetTypeString(); auto invoker_ptr = device_instance.MakeInvokerPointer(); invoker_ptr->Run(argument_ptr.get(), StreamConfig{stream, false}); return; @@ -238,6 +249,16 @@ def gen_function( """ rank = func_attrs["inputs"][0]._rank() eps = func_attrs.get("eps", "1e-5") + input_accessor = func_attrs["input_accessors"][0] + output_accessor = func_attrs["output_accessors"][0] + input_strides = [] + output_strides = [] + for i, _ in enumerate(input_accessor.original_shapes): + input_strides.append(input_accessor.stride(i)) + output_strides.append(output_accessor.stride(i)) + + input_offset = input_accessor.offset + output_offset = output_accessor.offset exec_path = func_attrs["exec_path"] op_instance = func_attrs["op_instance"] @@ -267,7 +288,14 @@ def gen_function( for key, _ in instances.items(): fname = "f" + sha1(key.encode()).hexdigest() program = exec_template.render( - instance=fname, dtype="void", reduce_dims=rank - 1, eps=eps + instance=fname, + dtype="void", + reduce_dims=rank - 1, + eps=eps, + input_strides=input_strides, + output_strides=output_strides, + input_offset=input_offset, + output_offset=output_offset, ) exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program) exec_paths += exec_inst @@ -349,14 +377,6 @@ def layernorm_gen_function_call(func_attrs, indent=" "): ), f"LayerNorm only supports input with rank >= 2, current rank: {len(shapes)}" input_dim_names = [shape._attrs["name"] for shape in shapes] - x = func_attrs["inputs"][0] - xshape = x._attrs["shape"] - - elem_cnt = 1 - for shape in xshape: - elem_cnt *= shape._attrs["values"][0] - instance_size = xshape[-1]._attrs["values"][0] - instance_num = elem_cnt // instance_size return FUNC_CALL_TEMPLATE.render( func_name=func_attrs["name"], @@ -364,8 +384,6 @@ def layernorm_gen_function_call(func_attrs, indent=" "): gamma=gamma_name, beta=beta_name, output=output_name, - M=instance_num, - N=instance_size, input_dim_names=input_dim_names, indent=indent, ) diff --git a/python/aitemplate/backend/rocm/normalization/norm_common.py b/python/aitemplate/backend/rocm/normalization/norm_common.py index 4f0da20e9..328fa47ec 100644 --- a/python/aitemplate/backend/rocm/normalization/norm_common.py +++ b/python/aitemplate/backend/rocm/normalization/norm_common.py @@ -184,8 +184,9 @@ {{func_call}} } timer.End(); - std::cout << "WS:" < #include #include +#include "logging.h" #include "include/ck/utility/print.hpp" #include "library/include/ck/library/utility/device_memory.hpp" #include "library/include/ck/library/utility/host_tensor.hpp" @@ -339,7 +341,6 @@ def gen_profiler( op_instance = func_attrs["op_instance"] file_pairs = [] for op_name, op in op_instance.items(): - config = emit_instance(op) config_name = extract_config_name(config) instances = INSTANCE_TEMPLATE.render( @@ -381,6 +382,7 @@ def gen_profiler( args_parse=args_parse, tensor_decl=tensor_decl, func_call=func_call, + op_name=op_name, ) prefix = os.path.join(workdir, "profiler", op_type) diff --git a/python/aitemplate/backend/rocm/normalization/softmax.py b/python/aitemplate/backend/rocm/normalization/softmax.py index 11a0aa85c..bc10c5e09 100644 --- a/python/aitemplate/backend/rocm/normalization/softmax.py +++ b/python/aitemplate/backend/rocm/normalization/softmax.py @@ -62,8 +62,8 @@ auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths, i_inStrides, reduceDims, - &alpha, - &beta, + alpha, + beta, static_cast(input), static_cast(output), ck::tensor_operation::element_wise::PassThrough{}, @@ -71,11 +71,8 @@ ); if(!device_instance.IsSupportedArgument(argument_ptr.get())) { - throw std::runtime_error( - "wrong! device_softmax with the specified compilation parameters does " - "not support this Softmax problem"); + LOG(FATAL) << "wrong! " << device_instance.GetTypeString() << " with the specified compilation parameters does not support this Softmax problem."; }; - std::string instance_name = device_instance.GetTypeString(); auto invoker_ptr = device_instance.MakeInvokerPointer(); invoker_ptr->Run(argument_ptr.get(), StreamConfig{stream, false}); return;