-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Layernorm kernel #641
Add Layernorm kernel #641
Conversation
b772444
to
674d526
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rahulbatra85, I can't see anything wrong with your PR. I've just have some questions and minor code cleanup suggestions. Feel free to ignore them if you judge appropriate.
Please drop a short line about this new kernel to python/perf-kernels/README.md
file.
@@ -128,8 +128,10 @@ jobs: | |||
pytest -vvv ./python/perf-kernels/flash-attention.py | |||
pytest -vvvv ./python/perf-kernels/softmax.py | |||
pytest -vvv ./python/perf-kernels/rmsnorm.py | |||
pytest -vvv ./python/perf-kernels/layernorm.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about running all tests with just one pytest
invocation? According to https://docs.pytest.org/en/stable/how-to/usage.html, it's possible to do something like pytest -vvvv ./python/perf-kernels
. By this way, we'll be editing .github/workflows/amd_perf_kernel_Integration_tests.yml
less often and new tests are going to run by default. Do you see any drawback?
Maybe it's worth asking @micmelesse's opinion on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's a Michael question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait for Michael's opinion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is fine. I think some of the tests are broken but maybe worth it to see the state of things
y = x_hat * w + b | ||
# Write output | ||
tl.store(Y + cols, y, mask=mask) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an idea:
We have three for
loops that do masked loads. Do you foresee any benefit of peeling the last iteration of each loop so all iterations except the last one do unmasked loads? I think Shucai and Xiaohu got some performance improvements doing this with GEMMs. I'm not sure if the idea could be beneficial for layer norm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, yeah, I didn't think of that. Will try this out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please let me know if this helped at all.
d88abbc
to
13c01c4
Compare
This comment was marked as resolved.
This comment was marked as resolved.
13c01c4
to
042aa91
Compare
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
@triton.jit | ||
def layernorm_kernel(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps, | ||
BLOCK_SIZE: tl.constexpr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an input use_mask: tl.constexpr
to the kernel? Then in the implementation, the read from global memory can be like:
loop_num = tl.cdiv(n_cols, BLOCK_SIZE)
if use_mask:
loop_num -= 1
#calculate mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for b in range(0, loop_num):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x_block = tl.load(x_ptr_start + col_offsets).to(tl.float32)
_mean += x_block
if use_mask:
col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x_block = tl.load(x_ptr_start + col_offsets, mask=col_offsets < n_cols, other=0.).to(tl.float32)
_mean += x_block
mean = tl.sum(_mean, axis=0) / n_cols
#same for the variance calculation.
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for b in range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
x_block = tl.load(x_ptr_start + col_offsets, mask=col_offsets < n_cols, other=0.).to(tl.float32)
x_block = tl.where(col_offsets < n_cols, x_block - mean, 0.)
_var += x_block * x_block
var = tl.sum(_var, axis=0) / n_cols
rstd = tl.rsqrt(var + eps)
In this way, we do need mask in most of the iterations, which can make the load input be more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will try this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this and the perf improves https://github.com/ROCm/triton-internal/issues/126#issuecomment-2369175077
e389075
to
ccb3538
Compare
No description provided.