Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MV-1195] Implement Pdist Backward #62

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open

[MV-1195] Implement Pdist Backward #62

wants to merge 7 commits into from

Conversation

anhskrttt
Copy link
Collaborator

@anhskrttt anhskrttt commented Nov 14, 2024

Summary of Changes

  • Add PdistBackward operation and kernel.
  • Add driver test and gtest for PdistBackward.

Additional Notes

  • ROCm's pdist doesn't support for dtype=[fp16, bfp16] yet. This PR already includes support for those two types.
  • MIOpen performance is better for contiguous inputs.
  • For input_dtype=fp16, the result may experience underflow or overflow in the following cases: (1) when input.dims[0] is large, indicating a high number of points for distance calculation, or (2) when the value of p is large.

Benchmark Results

Average improvement over ROCm

type bwd
float 1.37
float16 -
bfloat16 -

Detail Benchmark

float32
op_name dtype size direction ROCm MIOpen MIOpen vs ROCm
PDist float32 [32 65536] bwd 2062221 828990 2.49
PDist float32 [512 512] bwd 3415568 2494930 1.37
PDist float32 [1024 512] bwd 19289647 10698600 1.80
PDist float32 [2048 512] bwd 82933944 46841100 1.77
PDist float32 [128 144] bwd 59216 86755 0.68
PDist float32 [16 255] bwd 27984 28000 1.00
PDist float32 [128 512] bwd 170863 150363 1.14
PDist float32 [128 128] bwd 57615 78168 0.74

@anhskrttt anhskrttt marked this pull request as ready for review November 17, 2024 10:45
{
pdist_backward_contiguous<INPUT_TYPE>(
input, output, grad, input_grad, p, n2, n2_squared_minus_1, N, NO, M);
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add new line at the end

Copy link
Collaborator

@hieule88 hieule88 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Partial Review] I found no kernel rightly implemented for non-contiguous case, please add and re-benchmark the performance. Additionally, why you dont implement Forward case, I found pdist kernel (forward case) in CL code and before all of the modification you will make, please change base branch to develop-moreh for the most update code. I will continue reviewing after you add non-contiguous and forward case

Comment on lines +48 to +55
for(size_t i = 0; i < input_numel; i++)
{
dinputHost[i] = 0;
}

for(int i = 0; i < N; ++i)
{
for(int j = i + 1; j < N; ++j)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pard_for for better perf

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use uint64_t i, j

double grad_k = static_cast<double>(doutput[k]);
double output_k = static_cast<double>(output[k]);

for(int m = 0; m < M; ++m)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may lead to CI/CD failed because M is size_t while m is int (comparation between different type)

Comment on lines +63 to +71
double input_first = static_cast<double>(input[i * M + m]);
double input_second = static_cast<double>(input[j * M + m]);
double diff = input_first - input_second;

Tcheck res =
static_cast<Tcheck>(miopen::solver::pdist::backward(diff, grad_k, output_k, p));

dinputHost[i * M + m] += res;
dinputHost[j * M + m] -= res;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tensor_view instead: tv.get_tensor_view_idx({i, m})

{
for(int j = i + 1; j < N; ++j)
{
long k = j + N * i - i * (i + 1) / 2 - i - 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can k be negative? if not, use uint64_t

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if can be, check for k before use it as index

Comment on lines +28 to +32
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <vector>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are those lib necessary ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow this order (AMD's convention):
`#include "InputFlags.hpp"
#include "driver.hpp"
#include "mloMarginRakningLossHost.hpp"
#include "random.hpp"
#include "tensor_driver.hpp"
#include "timer.hpp"

#include <../test/tensor_holder.hpp>
#include <../test/verify.hpp>

#include <miopen/env.hpp>
#include <miopen/handle.hpp>
#include <miopen/miopen.h>
#include <miopen/tensor.hpp>
#include `

Comment on lines +35 to +40
#include "miopen/buffer_info.hpp"
#include "miopen/errors.hpp"
#include "miopen/execution_context.hpp"
#include "miopen/invoke_params.hpp"
#include "miopen/tensor.hpp"
#include "miopen/tensor_view_utils.hpp"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

#include <miopen/pdist.hpp>
#include <miopen/pdist/solvers.hpp>
#include <miopen/pdist/invoke_params.hpp>
#include "miopen/pdist/problem_description.hpp"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines +71 to +74
if(!problem.IsAllContiguous())
{
return false;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found no kernel implemented for non-contiguous case, please add and re-benchmark the performance. btw, I see you define problem.IsAllContiguou() as not supported, this is not a clear meaning. If you did benchmark the right implementation of non-contiguous kernel, this IsAllContiguou() should not THROW, just return false and should be in IsOverRocm function

Comment on lines +99 to +101
auto input_dtype = miopen::GetDataType(problem.GetdInputDesc().GetType());
auto output_dtype = miopen::GetDataType(problem.GetdOutputDesc().GetType());
auto dinput_dtype = miopen::GetDataType(problem.GetdInputDesc().GetType());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have IsSameType so this is redundant

Comment on lines +172 to +173
{"INPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype},
{"OUTPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : output_dtype},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants