Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

performance benchmark #5

Open
cliangyu opened this issue May 19, 2024 · 3 comments
Open

performance benchmark #5

cliangyu opened this issue May 19, 2024 · 3 comments

Comments

@cliangyu
Copy link

Hi, have you benchmarked fa2 with jax? How much speedup can you get?

@nshepperd
Copy link
Owner

I did a small comparison benchmark with fa2, naive mha and the pallas mha included with jax 0.4.28. On my desktop with a 3090 and float16:

B     T    H    C    TFlop/s (flash)    TFlop/s (naive)    TFlop/s (pallas)
---  ----  ---  ---  -----------------  -----------------  ------------------
 32  1024    4   32            53.8639            8.24277             51.5742
 32  1024    4   64            57.5379           14.9229              53.5826
 32  1024    4  128            58.8711           26.7008              54.7979
 32  1024    8   32            61.1098            8.6376              51.487
 32  1024    8   64            64.0922           15.6025              52.9087
 32  1024    8  128            65.6265           27.115               51.4726
 32  1024   16   32            63.9256            8.74247             52.2687
 32  1024   16   64            63.1579           15.7274              55.8753
 32  1024   16  128            65.6789           27.7762              57.1428
 32  1024   32   32            63.1234            8.63395             52.5392
 32  1024   32   64            66.5244           15.4328              56.243
 32  1024   32  128            65.5683           28.3536              59.0203

@imoneoi
Copy link

imoneoi commented Jul 25, 2024

Hi @nshepperd Could you please share the benchmarking code?

@Xynonners
Copy link

I did a small comparison benchmark with fa2, naive mha and the pallas mha included with jax 0.4.28. On my desktop with a 3090 and float16:

B     T    H    C    TFlop/s (flash)    TFlop/s (naive)    TFlop/s (pallas)
---  ----  ---  ---  -----------------  -----------------  ------------------
 32  1024    4   32            53.8639            8.24277             51.5742
 32  1024    4   64            57.5379           14.9229              53.5826
 32  1024    4  128            58.8711           26.7008              54.7979
 32  1024    8   32            61.1098            8.6376              51.487
 32  1024    8   64            64.0922           15.6025              52.9087
 32  1024    8  128            65.6265           27.115               51.4726
 32  1024   16   32            63.9256            8.74247             52.2687
 32  1024   16   64            63.1579           15.7274              55.8753
 32  1024   16  128            65.6789           27.7762              57.1428
 32  1024   32   32            63.1234            8.63395             52.5392
 32  1024   32   64            66.5244           15.4328              56.243
 32  1024   32  128            65.5683           28.3536              59.0203

any VRAM figures as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants