包含以下内容:
- flash_attn_1_fwd_f32_kernel
- flash_attn_2_fwd_f16_mma_m16n8k16_kernel (ldmatrix + MMA)
- PyTorch bindings
本仓库FlashAttention仅用于学习CUDA编程,考虑性能最优请使用FlashAttention官方版本:flash-attention
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 flash_attn.py
日志如下:
----------------------------------------------------------------------------------------------------
B: batch_size, H: n_head, N: seq_len, D: head_dim
----------------------------------------------------------------------------------------------------
B=8, H=8, N=256, D=64
out_FA1f32: ['0.01037013 ', '-0.09995531 ', '0.09193697 '], time:9.288564ms
out_f32_th(naive): ['0.01037012 ', '-0.09995528 ', '0.09193695 '], time:0.086453ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.01031494 ', '-0.09997559 ', '0.09197998 '], time:0.047593ms
out_f16_th(naive): ['0.01040649 ', '-0.10003662 ', '0.09197998 '], time:0.053408ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=8, N=256, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.15332031 ', '0.15917969 ', '0.07592773 '], time:0.091217ms
out_f16_th(naive): ['0.15368652 ', '0.15905762 ', '0.07580566 '], time:0.052757ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=8, N=512, D=64
out_FA1f32: ['0.01696955 ', '-0.05399467 ', '-0.03177956 '], time:37.062004ms
out_f32_th(naive): ['0.01696953 ', '-0.05399465 ', '-0.03177955 '], time:0.471001ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.01699829 ', '-0.0539856 ', '-0.0317688 '], time:0.168507ms
out_f16_th(naive): ['0.01699829 ', '-0.0539856 ', '-0.03173828 '], time:0.132778ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=8, N=512, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.06872559 ', '-0.07714844 ', '0.04348755 '], time:0.326455ms
out_f16_th(naive): ['0.06872559 ', '-0.07720947 ', '0.04345703 '], time:0.152197ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=8, N=1024, D=64
out_FA1f32: ['-0.04256601 ', '0.0555016 ', '0.05054659 '], time:148.082373ms
out_f32_th(naive): ['-0.04256602 ', '0.05550159 ', '0.05054657 '], time:2.673364ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.0425415 ', '0.05551147 ', '0.05053711 '], time:0.633800ms
out_f16_th(naive): ['-0.0425415 ', '0.05545044 ', '0.05053711 '], time:1.276960ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=8, N=1024, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.00053024 ', '0.04940796 ', '-0.01649475 '], time:1.235073ms
out_f16_th(naive): ['-0.00051165 ', '0.04946899 ', '-0.01644897 '], time:1.371036ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=16, N=256, D=64
out_FA1f32: ['0.06706338 ', '-0.01847678 ', '-0.02532079 '], time:9.592953ms
out_f32_th(naive): ['0.0670634 ', '-0.01847675 ', '-0.02532081 '], time:0.150659ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.06719971 ', '-0.01847839 ', '-0.02529907 '], time:0.060866ms
out_f16_th(naive): ['0.06713867 ', '-0.01846313 ', '-0.0252533 '], time:0.063777ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=16, N=256, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.05142212 ', '0.03041077 ', '-0.08868408 '], time:0.132723ms
out_f16_th(naive): ['-0.05151367 ', '0.03018188 ', '-0.08911133 '], time:0.079043ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=16, N=512, D=64
out_FA1f32: ['-0.03446965 ', '0.05762016 ', '0.07836776 '], time:38.253429ms
out_f32_th(naive): ['-0.03446964 ', '0.05762014 ', '0.07836778 '], time:1.357274ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.03445435 ', '0.05758667 ', '0.07836914 '], time:0.218937ms
out_f16_th(naive): ['-0.03445435 ', '0.05758667 ', '0.07830811 '], time:0.500908ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=16, N=512, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.00230026 ', '-0.05194092 ', '0.0164032 '], time:0.493281ms
out_f16_th(naive): ['-0.00205803 ', '-0.05209351 ', '0.01664734 '], time:0.568807ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=16, N=1024, D=64
out_FA1f32: ['0.02074369 ', '-0.01090947 ', '-0.01393144 '], time:152.446897ms
out_f32_th(naive): ['0.02074368 ', '-0.01090949 ', '-0.01393143 '], time:5.296123ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.02073669 ', '-0.01097107 ', '-0.01395416 '], time:0.834603ms
out_f16_th(naive): ['0.02073669 ', '-0.01092529 ', '-0.01390839 '], time:2.576745ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=8, H=16, N=1024, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.08306885 ', '0.03659058 ', '0.04852295 '], time:1.907628ms
out_f16_th(naive): ['0.08319092 ', '0.03668213 ', '0.04858398 '], time:2.696407ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=8, N=256, D=64
out_FA1f32: ['0.09634054 ', '-0.02606717 ', '0.13369624 '], time:9.618666ms
out_f32_th(naive): ['0.09634058 ', '-0.02606717 ', '0.13369617 '], time:0.147052ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.09649658 ', '-0.02606201 ', '0.13366699 '], time:0.060964ms
out_f16_th(naive): ['0.09631348 ', '-0.02613831 ', '0.13366699 '], time:0.063334ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=8, N=256, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.0680542 ', '0.18212891 ', '0.09741211 '], time:0.132513ms
out_f16_th(naive): ['-0.0680542 ', '0.18212891 ', '0.09747314 '], time:0.079212ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=8, N=512, D=64
out_FA1f32: ['0.06110233 ', '-0.03080001 ', '0.06487844 '], time:38.171313ms
out_f32_th(naive): ['0.06110234 ', '-0.0308 ', '0.06487839 '], time:1.358862ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.06112671 ', '-0.03077698 ', '0.06488037 '], time:0.218849ms
out_f16_th(naive): ['0.06109619 ', '-0.03079224 ', '0.06488037 '], time:0.497117ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=8, N=512, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.00991058 ', '-0.18884277 ', '-0.04980469 '], time:0.493472ms
out_f16_th(naive): ['-0.0098877 ', '-0.18884277 ', '-0.04980469 '], time:0.573759ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=8, N=1024, D=64
out_FA1f32: ['-0.01831236 ', '-0.07696866 ', '-0.04614653 '], time:152.500360ms
out_f32_th(naive): ['-0.01831233 ', '-0.07696865 ', '-0.04614652 '], time:5.295737ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.01831055 ', '-0.07696533 ', '-0.04614258 '], time:0.834262ms
out_f16_th(naive): ['-0.01826477 ', '-0.0769043 ', '-0.04614258 '], time:2.576706ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=8, N=1024, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.04501343 ', '0.07751465 ', '-0.01131439 '], time:1.907537ms
out_f16_th(naive): ['0.04501343 ', '0.07745361 ', '-0.01132965 '], time:2.697947ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=16, N=256, D=64
out_FA1f32: ['0.05493443 ', '0.03093347 ', '-0.05244123 '], time:12.086096ms
out_f32_th(naive): ['0.05493441 ', '0.03093351 ', '-0.05244119 '], time:0.518868ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.05496216 ', '0.03089905 ', '-0.05227661 '], time:0.083928ms
out_f16_th(naive): ['0.05487061 ', '0.03102112 ', '-0.05239868 '], time:0.133991ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=16, N=256, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.03808594 ', '-0.19189453 ', '0.00264549 '], time:0.192747ms
out_f16_th(naive): ['-0.03778076 ', '-0.19189453 ', '0.00281334 '], time:0.178058ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=16, N=512, D=64
out_FA1f32: ['0.02739076 ', '0.01203587 ', '0.09457675 '], time:48.142586ms
out_f32_th(naive): ['0.02739077 ', '0.01203588 ', '0.09457672 '], time:2.749476ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['0.02740479 ', '0.01203918 ', '0.09454346 '], time:0.291946ms
out_f16_th(naive): ['0.02740479 ', '0.01203156 ', '0.09460449 '], time:1.350477ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=16, N=512, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.06494141 ', '-0.06427002 ', '-0.04528809 '], time:0.690589ms
out_f16_th(naive): ['-0.06500244 ', '-0.06427002 ', '-0.04519653 '], time:1.470513ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=16, N=1024, D=64
out_FA1f32: ['-0.02254915 ', '0.00821745 ', '0.09361463 '], time:196.162612ms
out_f32_th(naive): ['-0.02254917 ', '0.00821746 ', '0.09361461 '], time:10.451190ms
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.02252197 ', '0.00821686 ', '0.09368896 '], time:1.106799ms
out_f16_th(naive): ['-0.02255249 ', '0.00818634 ', '0.09368896 '], time:5.125363ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
B=16, H=16, N=1024, D=128
----------------------------------------------------------------------------------------------------
out_FA2MMAf16: ['-0.07330322 ', '-0.06152344 ', '0.00090456 '], time:3.174434ms
out_f16_th(naive): ['-0.07336426 ', '-0.06149292 ', '0.00105381 '], time:5.335908ms
----------------------------------------------------------------------------------------------------