Skip to content

Commit

Permalink
gpu: generic: add simple SYCL reduction implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
sgeor255 committed Dec 12, 2024
1 parent 012a338 commit fc994fd
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 0 deletions.
57 changes: 57 additions & 0 deletions src/gpu/generic/sycl/simple_reduction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "simple_reduction.hpp"

#include "gpu/generic/sycl/engine.hpp"
#include "gpu/generic/sycl/simple_reduction_kernels.hpp"

namespace dnnl {
namespace impl {
namespace gpu {
namespace generic {
namespace sycl {

status_t simple_reduction_t::pd_t::init_conf() {
conf_.alg = desc()->alg_kind;
conf_.src_md = xpu::sycl::md_t(src_md());
conf_.dst_md = xpu::sycl::md_t(dst_md());
conf_.p = desc()->p;
conf_.eps = desc()->eps;

auto src_wrap = memory_desc_wrapper(src_md());
auto dst_wrap = memory_desc_wrapper(dst_md());
dst_nelems_ = dst_wrap.nelems();

const auto ndims = dst_wrap.ndims();
for (int d = 0; d < xpu::sycl::md_t::max_dims; d++) {
conf_.reduce_dims[d] = dim_t {1};
if (d < ndims) {
if (src_wrap.dims()[d] != dst_wrap.dims()[d]) {
conf_.reduce_dims[d] = src_wrap.dims()[d];
conf_.reduce_size *= conf_.reduce_dims[d];
}
}
}

conf_.post_ops = sycl_post_ops_t(attr(), dst_wrap);

return status::success;
}

status_t simple_reduction_t::init(impl::engine_t *engine) {
const auto kid = ::sycl::get_kernel_id<reduction_kernel_fwd_t>();
CHECK(create_kernel(engine, kid, &kernel_));

return status::success;
}

status_t simple_reduction_t::execute(const exec_ctx_t &ctx) const {
return parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
reduction_kernel_fwd_t reduction_kernel(pd()->conf_, cgh, ctx);
cgh.parallel_for(::sycl::range<1>(pd()->dst_nelems_), reduction_kernel);
});
}

} // namespace sycl
} // namespace generic
} // namespace gpu
} // namespace impl
} // namespace dnnl
84 changes: 84 additions & 0 deletions src/gpu/generic/sycl/simple_reduction.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef GPU_GENERIC_SYCL_SIMPLE_REDUCTION_HPP
#define GPU_GENERIC_SYCL_SIMPLE_REDUCTION_HPP

#include "common/primitive_desc_iterator.hpp"
#include "common/reorder.hpp"
#include "common/reorder_pd.hpp"
#include "gpu/generic/sycl/sycl_gpu_primitive.hpp"
#include "gpu/generic/sycl/sycl_io_helper.hpp"
#include "gpu/generic/sycl/sycl_post_ops.hpp"
#include "gpu/generic/sycl/sycl_primitive_conf.hpp"
#include "gpu/generic/sycl/sycl_utils.hpp"
#include "gpu/gpu_reduction_pd.hpp"

namespace dnnl {
namespace impl {
namespace gpu {
namespace generic {
namespace sycl {

struct simple_reduction_t : public gpu::generic::sycl::primitive_t {
using gpu::generic::sycl::primitive_t::primitive_t;

struct pd_t : public gpu_reduction_pd_t {
using gpu_reduction_pd_t::gpu_reduction_pd_t;

DECLARE_COMMON_PD_T("dpcpp:ref:any", simple_reduction_t);

status_t init(impl::engine_t *engine) {
using sm = primitive_attr_t::skip_mask_t;

memory_desc_wrapper src_wrap(src_md());
memory_desc_wrapper dst_wrap(dst_md());

bool ok = set_default_params() == status::success
&& attr()->has_default_values(sm::post_ops)
&& sycl_post_ops_t::post_ops_ok(attr())
&& attr_.set_default_formats(dst_md()) == status::success
&& src_wrap.is_plain() && dst_wrap.is_plain()
&& src_wrap.ndims() == dst_wrap.ndims()
&& md_dims_in_range(src_md()) && md_dims_in_range(dst_md());
if (!ok) return status::unimplemented;

return init_conf();
}

sycl_simple_reduction_conf_t conf_;
dim_t dst_nelems_;

private:
status_t init_conf();
};

status_t init(impl::engine_t *engine) override;
status_t execute(const exec_ctx_t &ctx) const override;

private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
kernel_t kernel_;
std::shared_ptr<impl::primitive_t> reorder_p_;
};

} // namespace sycl
} // namespace generic
} // namespace gpu
} // namespace impl
} // namespace dnnl

#endif
132 changes: 132 additions & 0 deletions src/gpu/generic/sycl/simple_reduction_kernels.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@

#ifndef GPU_GENERIC_SYCL_SIMPLE_REDUCTION_KERNELS_HPP
#define GPU_GENERIC_SYCL_SIMPLE_REDUCTION_KERNELS_HPP

#include "common/c_types_map.hpp"
#include "common/dnnl_thread.hpp"
#include "common/primitive_exec_types.hpp"
#include "common/utils.hpp"
#include "gpu/generic/sycl/sycl_io_helper.hpp"
#include "gpu/generic/sycl/sycl_math_utils.hpp"
#include "gpu/generic/sycl/sycl_primitive_conf.hpp"
#include "xpu/sycl/memory_storage_base.hpp"
#include "xpu/sycl/types.hpp"

namespace dnnl {
namespace impl {
namespace gpu {
namespace generic {
namespace sycl {

struct Reducer {
dnnl_alg_kind_t alg_;
float p_, eps_;

Reducer(dnnl_alg_kind_t alg, float p, float eps)
: alg_(alg), p_(p), eps_(eps) {}

float identity() const {
if (alg_ == dnnl_reduction_min) {
return std::numeric_limits<float>::max();
} else if (alg_ == dnnl_reduction_max) {
return std::numeric_limits<float>::lowest();
} else if (alg_ == dnnl_reduction_mul) {
return 1.f;
}

return 0.f;
}

float reduce(float lhs, float rhs) const {
if (alg_ == dnnl_reduction_sum || alg_ == dnnl_reduction_mean) {
return lhs + rhs;
} else if (alg_ == dnnl_reduction_min) {
return ::sycl::min(lhs, rhs);
} else if (alg_ == dnnl_reduction_max) {
return ::sycl::max(lhs, rhs);
} else if (alg_ == dnnl_reduction_mul) {
return lhs * rhs;
} else if (alg_ == dnnl_reduction_norm_lp_max
|| alg_ == dnnl_reduction_norm_lp_sum
|| alg_ == dnnl_reduction_norm_lp_power_p_max
|| alg_ == dnnl_reduction_norm_lp_power_p_sum) {
return lhs + ::sycl::pow(::sycl::fabs(rhs), p_);
}

return ::sycl::nan(0U);
}

float finalize(float val, int size) const {
if (alg_ == dnnl_reduction_mean) {
return val / size;
} else if (alg_ == dnnl_reduction_norm_lp_max) {
return ::sycl::rootn(::sycl::max(val, eps_), p_);
} else if (alg_ == dnnl_reduction_norm_lp_sum) {
return ::sycl::rootn(val + eps_, p_);
} else if (alg_ == dnnl_reduction_norm_lp_power_p_max) {
return ::sycl::max(val, eps_);
} else if (alg_ == dnnl_reduction_norm_lp_power_p_sum) {
return val + eps_;
}

return val;
}
};

struct reduction_kernel_fwd_t {
sycl_simple_reduction_conf_t conf_;
xpu::sycl::in_memory_arg_t src_;
xpu::sycl::out_memory_arg_t dst_;
post_op_input_args po_args_;

reduction_kernel_fwd_t(const sycl_simple_reduction_conf_t &conf,
::sycl::handler &cgh, const exec_ctx_t &ctx)
: conf_(conf)
, src_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC))
, dst_(CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST))
, po_args_(cgh, ctx, conf_.post_ops) {}

void operator()(::sycl::item<1> item) const {
Reducer reducer(conf_.alg, conf_.p, conf_.eps);

memory_tensor_t<::sycl::access_mode::read> src(src_, conf_.src_md);
memory_tensor_t<::sycl::access_mode::write> dst(dst_, conf_.dst_md);
const int id = item.get_linear_id();

const auto &dst_md = conf_.dst_md;
dims_t pos;
int l_offset = id;
for (int i = 0; i < dst_md.ndims(); i++) {
const int d = dst_md.ndims() - 1 - i;
const dim_t cur_dim = dst_md.dims()[d];
pos[d] = l_offset % cur_dim;
l_offset = l_offset / cur_dim;
}

float acc = reducer.identity();
for (off_t d0 = 0; d0 < conf_.reduce_dims[0]; d0++)
for (off_t d1 = 0; d1 < conf_.reduce_dims[1]; d1++)
for (off_t d2 = 0; d2 < conf_.reduce_dims[2]; d2++)
for (off_t d3 = 0; d3 < conf_.reduce_dims[3]; d3++)
for (off_t d4 = 0; d4 < conf_.reduce_dims[4]; d4++)
for (off_t d5 = 0; d5 < conf_.reduce_dims[5];
d5++) {
dims_t src_off = {pos[0] + d0, pos[1] + d1,
pos[2] + d2, pos[3] + d3, pos[4] + d4,
pos[5] + d5};
const float val = src.load_md(src_off);
acc = reducer.reduce(acc, val);
}

float result = reducer.finalize(acc, conf_.reduce_size);
result = conf_.post_ops.apply(result, dst.load_md(pos), po_args_, pos);
dst.store_md(result, pos);
}
};

} // namespace sycl
} // namespace generic
} // namespace gpu
} // namespace impl
} // namespace dnnl
#endif
12 changes: 12 additions & 0 deletions src/gpu/generic/sycl/sycl_primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,17 @@ struct sycl_pooling_bwd_conf_t : public sycl_pooling_base_conf_t {
xpu::sycl::md_t diff_dst_md;
};

struct sycl_simple_reduction_conf_t {
dnnl_alg_kind_t alg = dnnl_alg_kind_undef;
xpu::sycl::md_t src_md;
xpu::sycl::md_t dst_md;
float p;
float eps;
sycl_post_ops_t post_ops;
dim_t reduce_dims[xpu::sycl::md_t::max_dims];
int reduce_size = 1;
};

CHECK_SYCL_KERNEL_ARG_TYPE(sycl_binary_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_prelu_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_shuffle_conf_t);
Expand All @@ -431,6 +442,7 @@ CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_bwd_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_fwd_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_data_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_convolution_bwd_weights_conf_t);
CHECK_SYCL_KERNEL_ARG_TYPE(sycl_simple_reduction_conf_t);

} // namespace sycl
} // namespace generic
Expand Down
5 changes: 5 additions & 0 deletions src/gpu/gpu_reduction_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
#include "gpu/amd/miopen_reduction.hpp"
#endif

#ifdef GENERIC_SYCL_KERNELS_ENABLED
#include "gpu/generic/sycl/simple_reduction.hpp"
#endif

namespace dnnl {
namespace impl {
namespace gpu {
Expand All @@ -51,6 +55,7 @@ constexpr impl_list_item_t impl_list[] = REG_REDUCTION_P({
GPU_INSTANCE_INTEL(intel::ocl::reusable_ref_reduction_t)
GPU_INSTANCE_NVIDIA(nvidia::cudnn_reduction_t)
GPU_INSTANCE_AMD(amd::miopen_reduction_t)
GPU_INSTANCE_GENERIC_SYCL(generic::sycl::simple_reduction_t)
nullptr,
});
// clang-format on
Expand Down

0 comments on commit fc994fd

Please sign in to comment.