Skip to content

Latest commit

 

History

History

flash-attn

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

FlashAttention

0x00 说明

包含以下内容:

  • flash_attn_1_fwd_f32_kernel
  • flash_attn_2_fwd_f16_mma_m16n8k16_kernel (ldmatrix + MMA)
  • PyTorch bindings

本仓库FlashAttention仅用于学习CUDA编程,考虑性能最优请使用FlashAttention官方版本:flash-attention

运行测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 flash_attn.py

日志如下:

----------------------------------------------------------------------------------------------------
                         B: batch_size, H: n_head, N: seq_len, D: head_dim
----------------------------------------------------------------------------------------------------
                                        B=8, H=8, N=256, D=64
          out_FA1f32: ['0.01037013  ', '-0.09995531 ', '0.09193697  '], time:9.288564ms
   out_f32_th(naive): ['0.01037012  ', '-0.09995528 ', '0.09193695  '], time:0.086453ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.01031494  ', '-0.09997559 ', '0.09197998  '], time:0.047593ms
   out_f16_th(naive): ['0.01040649  ', '-0.10003662 ', '0.09197998  '], time:0.053408ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=8, N=256, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.15332031  ', '0.15917969  ', '0.07592773  '], time:0.091217ms
   out_f16_th(naive): ['0.15368652  ', '0.15905762  ', '0.07580566  '], time:0.052757ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=8, N=512, D=64
          out_FA1f32: ['0.01696955  ', '-0.05399467 ', '-0.03177956 '], time:37.062004ms
   out_f32_th(naive): ['0.01696953  ', '-0.05399465 ', '-0.03177955 '], time:0.471001ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.01699829  ', '-0.0539856  ', '-0.0317688  '], time:0.168507ms
   out_f16_th(naive): ['0.01699829  ', '-0.0539856  ', '-0.03173828 '], time:0.132778ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=8, N=512, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.06872559  ', '-0.07714844 ', '0.04348755  '], time:0.326455ms
   out_f16_th(naive): ['0.06872559  ', '-0.07720947 ', '0.04345703  '], time:0.152197ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=8, N=1024, D=64
          out_FA1f32: ['-0.04256601 ', '0.0555016   ', '0.05054659  '], time:148.082373ms
   out_f32_th(naive): ['-0.04256602 ', '0.05550159  ', '0.05054657  '], time:2.673364ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.0425415  ', '0.05551147  ', '0.05053711  '], time:0.633800ms
   out_f16_th(naive): ['-0.0425415  ', '0.05545044  ', '0.05053711  '], time:1.276960ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=8, N=1024, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.00053024 ', '0.04940796  ', '-0.01649475 '], time:1.235073ms
   out_f16_th(naive): ['-0.00051165 ', '0.04946899  ', '-0.01644897 '], time:1.371036ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=16, N=256, D=64
          out_FA1f32: ['0.06706338  ', '-0.01847678 ', '-0.02532079 '], time:9.592953ms
   out_f32_th(naive): ['0.0670634   ', '-0.01847675 ', '-0.02532081 '], time:0.150659ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.06719971  ', '-0.01847839 ', '-0.02529907 '], time:0.060866ms
   out_f16_th(naive): ['0.06713867  ', '-0.01846313 ', '-0.0252533  '], time:0.063777ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=16, N=256, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.05142212 ', '0.03041077  ', '-0.08868408 '], time:0.132723ms
   out_f16_th(naive): ['-0.05151367 ', '0.03018188  ', '-0.08911133 '], time:0.079043ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=16, N=512, D=64
          out_FA1f32: ['-0.03446965 ', '0.05762016  ', '0.07836776  '], time:38.253429ms
   out_f32_th(naive): ['-0.03446964 ', '0.05762014  ', '0.07836778  '], time:1.357274ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.03445435 ', '0.05758667  ', '0.07836914  '], time:0.218937ms
   out_f16_th(naive): ['-0.03445435 ', '0.05758667  ', '0.07830811  '], time:0.500908ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=16, N=512, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.00230026 ', '-0.05194092 ', '0.0164032   '], time:0.493281ms
   out_f16_th(naive): ['-0.00205803 ', '-0.05209351 ', '0.01664734  '], time:0.568807ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=16, N=1024, D=64
          out_FA1f32: ['0.02074369  ', '-0.01090947 ', '-0.01393144 '], time:152.446897ms
   out_f32_th(naive): ['0.02074368  ', '-0.01090949 ', '-0.01393143 '], time:5.296123ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.02073669  ', '-0.01097107 ', '-0.01395416 '], time:0.834603ms
   out_f16_th(naive): ['0.02073669  ', '-0.01092529 ', '-0.01390839 '], time:2.576745ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=8, H=16, N=1024, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.08306885  ', '0.03659058  ', '0.04852295  '], time:1.907628ms
   out_f16_th(naive): ['0.08319092  ', '0.03668213  ', '0.04858398  '], time:2.696407ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=8, N=256, D=64
          out_FA1f32: ['0.09634054  ', '-0.02606717 ', '0.13369624  '], time:9.618666ms
   out_f32_th(naive): ['0.09634058  ', '-0.02606717 ', '0.13369617  '], time:0.147052ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.09649658  ', '-0.02606201 ', '0.13366699  '], time:0.060964ms
   out_f16_th(naive): ['0.09631348  ', '-0.02613831 ', '0.13366699  '], time:0.063334ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=8, N=256, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.0680542  ', '0.18212891  ', '0.09741211  '], time:0.132513ms
   out_f16_th(naive): ['-0.0680542  ', '0.18212891  ', '0.09747314  '], time:0.079212ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=8, N=512, D=64
          out_FA1f32: ['0.06110233  ', '-0.03080001 ', '0.06487844  '], time:38.171313ms
   out_f32_th(naive): ['0.06110234  ', '-0.0308     ', '0.06487839  '], time:1.358862ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.06112671  ', '-0.03077698 ', '0.06488037  '], time:0.218849ms
   out_f16_th(naive): ['0.06109619  ', '-0.03079224 ', '0.06488037  '], time:0.497117ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=8, N=512, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.00991058 ', '-0.18884277 ', '-0.04980469 '], time:0.493472ms
   out_f16_th(naive): ['-0.0098877  ', '-0.18884277 ', '-0.04980469 '], time:0.573759ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=8, N=1024, D=64
          out_FA1f32: ['-0.01831236 ', '-0.07696866 ', '-0.04614653 '], time:152.500360ms
   out_f32_th(naive): ['-0.01831233 ', '-0.07696865 ', '-0.04614652 '], time:5.295737ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.01831055 ', '-0.07696533 ', '-0.04614258 '], time:0.834262ms
   out_f16_th(naive): ['-0.01826477 ', '-0.0769043  ', '-0.04614258 '], time:2.576706ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=8, N=1024, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.04501343  ', '0.07751465  ', '-0.01131439 '], time:1.907537ms
   out_f16_th(naive): ['0.04501343  ', '0.07745361  ', '-0.01132965 '], time:2.697947ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=16, N=256, D=64
          out_FA1f32: ['0.05493443  ', '0.03093347  ', '-0.05244123 '], time:12.086096ms
   out_f32_th(naive): ['0.05493441  ', '0.03093351  ', '-0.05244119 '], time:0.518868ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.05496216  ', '0.03089905  ', '-0.05227661 '], time:0.083928ms
   out_f16_th(naive): ['0.05487061  ', '0.03102112  ', '-0.05239868 '], time:0.133991ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=16, N=256, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.03808594 ', '-0.19189453 ', '0.00264549  '], time:0.192747ms
   out_f16_th(naive): ['-0.03778076 ', '-0.19189453 ', '0.00281334  '], time:0.178058ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=16, N=512, D=64
          out_FA1f32: ['0.02739076  ', '0.01203587  ', '0.09457675  '], time:48.142586ms
   out_f32_th(naive): ['0.02739077  ', '0.01203588  ', '0.09457672  '], time:2.749476ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['0.02740479  ', '0.01203918  ', '0.09454346  '], time:0.291946ms
   out_f16_th(naive): ['0.02740479  ', '0.01203156  ', '0.09460449  '], time:1.350477ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=16, N=512, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.06494141 ', '-0.06427002 ', '-0.04528809 '], time:0.690589ms
   out_f16_th(naive): ['-0.06500244 ', '-0.06427002 ', '-0.04519653 '], time:1.470513ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=16, N=1024, D=64
          out_FA1f32: ['-0.02254915 ', '0.00821745  ', '0.09361463  '], time:196.162612ms
   out_f32_th(naive): ['-0.02254917 ', '0.00821746  ', '0.09361461  '], time:10.451190ms
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.02252197 ', '0.00821686  ', '0.09368896  '], time:1.106799ms
   out_f16_th(naive): ['-0.02255249 ', '0.00818634  ', '0.09368896  '], time:5.125363ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                        B=16, H=16, N=1024, D=128
----------------------------------------------------------------------------------------------------
       out_FA2MMAf16: ['-0.07330322 ', '-0.06152344 ', '0.00090456  '], time:3.174434ms
   out_f16_th(naive): ['-0.07336426 ', '-0.06149292 ', '0.00105381  '], time:5.335908ms
----------------------------------------------------------------------------------------------------