Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Layernorm kernel #641

Merged
merged 2 commits into from
Sep 24, 2024
Merged

Add Layernorm kernel #641

merged 2 commits into from
Sep 24, 2024

Conversation

rahulbatra85
Copy link

No description provided.

Copy link
Collaborator

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rahulbatra85, I can't see anything wrong with your PR. I've just have some questions and minor code cleanup suggestions. Feel free to ignore them if you judge appropriate.

Please drop a short line about this new kernel to python/perf-kernels/README.md file.

@@ -128,8 +128,10 @@ jobs:
pytest -vvv ./python/perf-kernels/flash-attention.py
pytest -vvvv ./python/perf-kernels/softmax.py
pytest -vvv ./python/perf-kernels/rmsnorm.py
pytest -vvv ./python/perf-kernels/layernorm.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about running all tests with just one pytest invocation? According to https://docs.pytest.org/en/stable/how-to/usage.html, it's possible to do something like pytest -vvvv ./python/perf-kernels. By this way, we'll be editing .github/workflows/amd_perf_kernel_Integration_tests.yml less often and new tests are going to run by default. Do you see any drawback?

Maybe it's worth asking @micmelesse's opinion on this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's a Michael question

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait for Michael's opinion!

Copy link
Collaborator

@micmelesse micmelesse Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is fine. I think some of the tests are broken but maybe worth it to see the state of things

python/perf-kernels/layernorm.py Show resolved Hide resolved
python/perf-kernels/layernorm.py Outdated Show resolved Hide resolved
python/perf-kernels/layernorm.py Outdated Show resolved Hide resolved
y = x_hat * w + b
# Write output
tl.store(Y + cols, y, mask=mask)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea:
We have three for loops that do masked loads. Do you foresee any benefit of peeling the last iteration of each loop so all iterations except the last one do unmasked loads? I think Shucai and Xiaohu got some performance improvements doing this with GEMMs. I'm not sure if the idea could be beneficial for layer norm.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, yeah, I didn't think of that. Will try this out

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please let me know if this helped at all.

python/perf-kernels/layernorm.py Outdated Show resolved Hide resolved
python/perf-kernels/layernorm.py Show resolved Hide resolved
@brunomazzottiamd

This comment was marked as resolved.

@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps,
BLOCK_SIZE: tl.constexpr):
Copy link

@scxiao scxiao Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an input use_mask: tl.constexpr to the kernel? Then in the implementation, the read from global memory can be like:

    loop_num = tl.cdiv(n_cols, BLOCK_SIZE)
    if use_mask:
        loop_num -= 1
    #calculate mean
    mean = 0
    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for b in range(0, loop_num):
        col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        x_block = tl.load(x_ptr_start + col_offsets).to(tl.float32)
        _mean += x_block
    if use_mask:
        col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        x_block = tl.load(x_ptr_start + col_offsets, mask=col_offsets < n_cols, other=0.).to(tl.float32)
        _mean += x_block
    mean = tl.sum(_mean, axis=0) / n_cols

    #same for the variance calculation.
    _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for b in range(0, n_cols, BLOCK_SIZE):
        col_offsets = b + tl.arange(0, BLOCK_SIZE)
        x_block = tl.load(x_ptr_start + col_offsets, mask=col_offsets < n_cols, other=0.).to(tl.float32)
        x_block = tl.where(col_offsets < n_cols, x_block - mean, 0.)
        _var += x_block * x_block
    var = tl.sum(_var, axis=0) / n_cols
    rstd = tl.rsqrt(var + eps)

In this way, we do need mask in most of the iterations, which can make the load input be more efficient.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will try this out.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rahulbatra85 rahulbatra85 merged commit e13fc4c into main_perf Sep 24, 2024
4 checks passed
@rahulbatra85 rahulbatra85 deleted the main_perf-layernorm branch September 24, 2024 18:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants