Skip to content

Commit

Permalink
upstream norm (#774)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #774

Reviewed By: chenyang78

Differential Revision: D46847587

Pulled By: ipiszy

fbshipit-source-id: af159f8b0f76112c803bc56589d4e6b41742ecc2
  • Loading branch information
fsx950223 authored and facebook-github-bot committed Jun 21, 2023
1 parent 373000f commit f5896d0
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 34 deletions.
9 changes: 4 additions & 5 deletions python/aitemplate/backend/rocm/normalization/groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

EXTRA_HEADERS = jinja2.Template(
"""
#include "include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
"""
)

Expand Down Expand Up @@ -127,16 +127,15 @@
static_cast<ck::half_t *>(gamma),
static_cast<ck::half_t *>(beta),
static_cast<ck::half_t *>(output),
nullptr,
nullptr,
YElementOp{}
);
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error(
"wrong! device_layernorm with the specified compilation parameters does "
"not support this Groupnorm problem");
LOG(FATAL) << "wrong! " << device_instance.GetTypeString() << " with the specified compilation parameters does not support this Groupnorm problem.";
};
std::string instance_name = device_instance.GetTypeString();
auto invoker_ptr = device_instance.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{stream, false});
return;
Expand Down
58 changes: 38 additions & 20 deletions python/aitemplate/backend/rocm/normalization/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

EXTRA_HEADERS = jinja2.Template(
"""
#include "include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
"""
)

Expand Down Expand Up @@ -65,34 +65,45 @@
EXEC_TEMPLATE = jinja2.Template(
"""
std::vector<ck::index_t> i_inStrides;
std::vector<ck::index_t> i_outStrides;
{% if input_strides is defined %}
i_inStrides.push_back({{input_strides[-2]}});
i_inStrides.push_back({{input_strides[-1]}});
{% else %}
i_inStrides.push_back(N);
i_inStrides.push_back(1);
{% endif %}
{% if output_strides is defined %}
i_outStrides.push_back({{output_strides[-2]}});
i_outStrides.push_back({{output_strides[-1]}});
{% else %}
i_outStrides.push_back(N);
i_outStrides.push_back(1);
{% endif %}
auto device_instance = {{instance}}{};
auto argument_ptr = device_instance.MakeArgumentPointer(
{M, N},
i_inStrides,
std::vector<ck::index_t>{0, 1},
std::vector<ck::index_t>{0, 1},
i_inStrides,
i_outStrides,
{1},
{{eps}},
static_cast<ck::half_t *>(input),
static_cast<ck::half_t *>(input) + {{ input_offset if input_offset is defined else 0 }},
static_cast<ck::half_t *>(gamma),
static_cast<ck::half_t *>(beta),
static_cast<ck::half_t *>(output),
static_cast<ck::half_t *>(output) + {{ output_offset if output_offset is defined else 0 }},
nullptr,
nullptr,
ck::tensor_operation::element_wise::PassThrough{}
);
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error(
"wrong! device_layernorm with the specified compilation parameters does "
"not support this Softmax problem");
LOG(FATAL) << "wrong! " << device_instance.GetTypeString() << " with the specified compilation parameters does not support this Layernorm problem.";
};
std::string instance_name = device_instance.GetTypeString();
auto invoker_ptr = device_instance.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{stream, false});
return;
Expand Down Expand Up @@ -238,6 +249,16 @@ def gen_function(
"""
rank = func_attrs["inputs"][0]._rank()
eps = func_attrs.get("eps", "1e-5")
input_accessor = func_attrs["input_accessors"][0]
output_accessor = func_attrs["output_accessors"][0]
input_strides = []
output_strides = []
for i, _ in enumerate(input_accessor.original_shapes):
input_strides.append(input_accessor.stride(i))
output_strides.append(output_accessor.stride(i))

input_offset = input_accessor.offset
output_offset = output_accessor.offset

exec_path = func_attrs["exec_path"]
op_instance = func_attrs["op_instance"]
Expand Down Expand Up @@ -267,7 +288,14 @@ def gen_function(
for key, _ in instances.items():
fname = "f" + sha1(key.encode()).hexdigest()
program = exec_template.render(
instance=fname, dtype="void", reduce_dims=rank - 1, eps=eps
instance=fname,
dtype="void",
reduce_dims=rank - 1,
eps=eps,
input_strides=input_strides,
output_strides=output_strides,
input_offset=input_offset,
output_offset=output_offset,
)
exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program)
exec_paths += exec_inst
Expand Down Expand Up @@ -349,23 +377,13 @@ def layernorm_gen_function_call(func_attrs, indent=" "):
), f"LayerNorm only supports input with rank >= 2, current rank: {len(shapes)}"

input_dim_names = [shape._attrs["name"] for shape in shapes]
x = func_attrs["inputs"][0]
xshape = x._attrs["shape"]

elem_cnt = 1
for shape in xshape:
elem_cnt *= shape._attrs["values"][0]
instance_size = xshape[-1]._attrs["values"][0]
instance_num = elem_cnt // instance_size

return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
input=input_name,
gamma=gamma_name,
beta=beta_name,
output=output_name,
M=instance_num,
N=instance_size,
input_dim_names=input_dim_names,
indent=indent,
)
8 changes: 5 additions & 3 deletions python/aitemplate/backend/rocm/normalization/norm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@
{{func_call}}
}
timer.End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer.GetElapsedTime() << std::endl;
std::cout << "OP:" << "{{op_name}}" << ",";
std::cout << "TIME:" << timer.GetElapsedTime() << ",";
std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl;
}
"""
)
Expand All @@ -199,6 +200,7 @@
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "logging.h"
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
Expand Down Expand Up @@ -339,7 +341,6 @@ def gen_profiler(
op_instance = func_attrs["op_instance"]
file_pairs = []
for op_name, op in op_instance.items():

config = emit_instance(op)
config_name = extract_config_name(config)
instances = INSTANCE_TEMPLATE.render(
Expand Down Expand Up @@ -381,6 +382,7 @@ def gen_profiler(
args_parse=args_parse,
tensor_decl=tensor_decl,
func_call=func_call,
op_name=op_name,
)

prefix = os.path.join(workdir, "profiler", op_type)
Expand Down
9 changes: 3 additions & 6 deletions python/aitemplate/backend/rocm/normalization/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,17 @@
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
i_inStrides,
reduceDims,
&alpha,
&beta,
alpha,
beta,
static_cast<ck::half_t *>(input),
static_cast<ck::half_t *>(output),
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}
);
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error(
"wrong! device_softmax with the specified compilation parameters does "
"not support this Softmax problem");
LOG(FATAL) << "wrong! " << device_instance.GetTypeString() << " with the specified compilation parameters does not support this Softmax problem.";
};
std::string instance_name = device_instance.GetTypeString();
auto invoker_ptr = device_instance.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{stream, false});
return;
Expand Down

0 comments on commit f5896d0

Please sign in to comment.