Skip to content

Commit

Permalink
Implement PReLU backward (#3152)
Browse files Browse the repository at this point in the history
  • Loading branch information
long10024070 authored Aug 21, 2024
1 parent 39936d8 commit 6feaec9
Show file tree
Hide file tree
Showing 30 changed files with 2,532 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ The MIOpen API library is structured as follows:
* :doc:`Getitem <../doxygen/html/group__getitem>` (experimental)
* :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental)
* :doc:`RotaryPositionalEmbeddings <../doxygen/html/group__RotaryPositionalEmbeddings>` (experimental)
* :doc:`ReLU <../doxygen/html/group___re_l_u>` (experimental)
1 change: 1 addition & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ add_executable(MIOpenDriver
dm_layernorm.cpp
dm_lrn.cpp
dm_pool.cpp
dm_prelu.cpp
dm_reduce.cpp
dm_reduceextreme.cpp
dm_reducecalculation.cpp
Expand Down
40 changes: 40 additions & 0 deletions driver/dm_prelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "registry_driver_maker.hpp"
#include "prelu_driver.hpp"

static Driver* makeDriver(const std::string& base_arg)
{
if(base_arg == "prelu")
return new PReLUDriver<float, float>();
if(base_arg == "prelufp16")
return new PReLUDriver<float16, float>();
if(base_arg == "prelubfp16")
return new PReLUDriver<bfloat16, float>();
return nullptr;
}

REGISTER_DRIVER_MAKER(makeDriver);
6 changes: 4 additions & 2 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
"groupnorm[bfp16|fp16], cat[bfp16|fp16], addlayernorm[bfp16|fp16], "
"t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], "
"adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, "
"getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16]\n");
"getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], "
"prelu[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand Down Expand Up @@ -207,7 +208,8 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "transformersadamwfp16" && arg != "transformersampadamw" && arg != "getitem" &&
arg != "getitemfp16" && arg != "getitembfp16" && arg != "reducecalculation" &&
arg != "reducecalculationfp16" && arg != "reducecalculationbfp16" && arg != "rope" &&
arg != "ropefp16" && arg != "ropebfp16" && arg != "--version")
arg != "ropefp16" && arg != "ropebfp16" && arg != "prelu" && arg != "prelufp16" &&
arg != "prelubfp16" && arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Usage();
Expand Down
104 changes: 104 additions & 0 deletions driver/mloPReLUHost.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#pragma once

#include <../test/ford.hpp>

#include <miopen/tensor.hpp>
#include <miopen/tensor_view_utils.hpp>
#include <miopen/prelu/utils.hpp>

template <typename Tgpu, typename Tcheck>
int32_t mloPReLUBackwardRunHost(const miopenTensorDescriptor_t inputDesc,
const miopenTensorDescriptor_t weightDesc,
const miopenTensorDescriptor_t doutputDesc,
const miopenTensorDescriptor_t dinputDesc,
const Tgpu* input,
const Tgpu* weight,
const Tgpu* doutput,
Tcheck* dinput_host,
Tcheck* dweight_host)
{
auto input_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(inputDesc));
auto doutput_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(doutputDesc));
auto dinput_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(dinputDesc));

auto input_sz = miopen::deref(inputDesc).GetElementSize();
auto weight_sz = miopen::deref(weightDesc).GetElementSize();
auto weight_grad_collector = std::vector<float>(input_sz);

par_ford(input_sz)([&](int gid) {
auto tensor_layout = tensor_layout_t<5>(input_tv, gid);
float input_v = static_cast<float>(input[input_tv.get_tensor_view_idx(tensor_layout)]);
float grad_v = static_cast<float>(doutput[doutput_tv.get_tensor_view_idx(tensor_layout)]);

if(dinput_host)
{
float weight_v;
if(weight_sz == 1)
weight_v = static_cast<float>(weight[0]);
else
weight_v = static_cast<float>(weight[tensor_layout.layout[1]]);
float input_grad_v = input_v > 0 ? grad_v : weight_v * grad_v;
dinput_host[dinput_tv.get_tensor_view_idx(tensor_layout)] =
static_cast<Tcheck>(input_grad_v);
}
if(dweight_host)
{
weight_grad_collector[gid] = input_v > 0 ? 0 : input_v * grad_v;
}
});

if(dweight_host)
{
if(weight_sz == 1)
{
double sum = 0;
for(int i = 0; i < input_sz; ++i)
sum += static_cast<double>(weight_grad_collector[i]);
dweight_host[0] = static_cast<Tcheck>(sum);
}
else
{
size_t inner_size = std::accumulate(
&input_tv.size[2], &input_tv.size[4], 1ul, std::multiplies<size_t>());
size_t outer_size = inner_size * input_tv.size[1];
par_ford(weight_sz)([&](int i) {
double sum = 0;
ford(input_tv.size[0])([&](int j) {
ford(inner_size)([&](int k) {
sum += static_cast<double>(
weight_grad_collector[j * outer_size + i * inner_size + k]);
});
});
dweight_host[i] = static_cast<Tcheck>(sum);
});
}
}

return miopenStatusSuccess;
}
Loading

0 comments on commit 6feaec9

Please sign in to comment.