Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 Triton matmul code silently requires contiguous tensors #2713

Open
rationalism opened this issue Jun 11, 2024 · 4 comments
Open

FP8 Triton matmul code silently requires contiguous tensors #2713

rationalism opened this issue Jun 11, 2024 · 4 comments

Comments

@rationalism
Copy link

Hello! Thank you very much for this FP8 rowwise matmul code, it's been extremely helpful. However, there is a subtle bug/hidden requirement when eg. calling this code here:

This works great, but only if the second matrix is contiguous in transposed format (eg. for M, N, K equal to (4,096, 2,048, 1,024), the second matrix must be contiguous in the shape (2,048, 1,024)). If it's not contiguous, the matmul will finish, but the results will be numerically nonsensical.

@q10
Copy link
Contributor

q10 commented Jun 11, 2024

CC @choutim

@sryap
Copy link
Contributor

sryap commented Jun 20, 2024

Hello @rationalism, thank you for your questions.

These triton-lang/triton#3952 and pytorch/pytorch#125437 should be related.

@rationalism
Copy link
Author

@q10 @sryap Tri Dao just released a paper on Flash Attention 3, which also has to deal with contiguous-layout FP8 matmul issues. Might be helpful?

https://tridao.me/publications/flash3/flash3.pdf

jwfromm added a commit to jwfromm/FBGEMM that referenced this issue Jul 31, 2024
Summary:
This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases.

The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: pytorch#2713

Adding explicit output strides to the kernel resolves the issue.

I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale.

Reviewed By: jianyuh

Differential Revision: D60535956
@jwfromm
Copy link
Contributor

jwfromm commented Jul 31, 2024

I think this issue should be resolved in #2919. The quantization kernel in triton was writing output using the same strides as the input but returning a contiguous tensor. This effectively transposed the output tensor. After the fix, it should always return a contiguous output in the proper layout.

jwfromm added a commit to jwfromm/FBGEMM that referenced this issue Jul 31, 2024
Summary:
This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases.

The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: pytorch#2713

Adding explicit output strides to the kernel resolves the issue.

I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale.

Reviewed By: jianyuh

Differential Revision: D60535956
jwfromm added a commit to jwfromm/FBGEMM that referenced this issue Jul 31, 2024
Summary:
Pull Request resolved: pytorch#2919

This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases.

The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: pytorch#2713

Adding explicit output strides to the kernel resolves the issue.

I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale.

Reviewed By: jianyuh

Differential Revision: D60535956
facebook-github-bot pushed a commit that referenced this issue Aug 1, 2024
Summary:
Pull Request resolved: #2919

This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases.

The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: #2713

Adding explicit output strides to the kernel resolves the issue.

I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale.

Reviewed By: jianyuh

Differential Revision: D60535956

fbshipit-source-id: 0c449e921e2703f2275e24028238f83fec1c0427
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants