Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SigmoidFocalLoss operation #3143

Open
wants to merge 26 commits into
base: develop
Choose a base branch
from

Conversation

BuiChiTrung
Copy link
Collaborator

@BuiChiTrung BuiChiTrung commented Jul 24, 2024

This PR implement torchvision.ops.sigmoid_focal_loss operation. There is no constraint here, MIOpen is faster than ROCm in all cases.

  • Added SigmoidFocalLoss operation with forward and backward kernels.
  • Added driver test and gtest for 4 kernels.
  • Compared with ROCm.

Average improvement over ROCm

Reduced kernels:

type fwd bwd
float32 2.74 4.94
float16 2.41 4.40
bfloat16 2.58 4.70

Unreduced kernels:

type fwd bwd
float32 5.09 3.86
float16 4.83 3.46
bfloat16 5.10 3.92

Detail benchmark

Float32
dtype size is_contiguous reduction direction ROCm MIOpen improvement
float32 [20 30] TRUE none fwd 81551 13582 6.004343985
float32 [20 30] TRUE none bwd 77295 18933 4.08255427
float32 [20 30] TRUE sum fwd 91327 34649 2.635775924
float32 [20 30] TRUE sum bwd 77695 17742 4.379156803
float32 [20 30] TRUE mean fwd 90847 35912 2.529711517
float32 [20 30] TRUE mean bwd 80416 19164 4.196201211
float32 [20 30] FALSE none fwd 75296 13211 5.699492847
float32 [20 30] FALSE none bwd 75040 18883 3.973944818
float32 [20 30] FALSE sum fwd 88623 35545 2.493262062
float32 [20 30] FALSE sum bwd 75359 18208 4.138785149
float32 [20 30] FALSE mean fwd 87808 35313 2.486563022
float32 [20 30] FALSE mean bwd 135519 19630 6.903667855
float32 [5 10 10] TRUE none fwd 76575 12784 5.989909262
float32 [5 10 10] TRUE none bwd 77695 18225 4.263100137
float32 [5 10 10] TRUE sum fwd 83167 35562 2.338647995
float32 [5 10 10] TRUE sum bwd 77231 17141 4.505629777
float32 [5 10 10] TRUE mean fwd 88910 37323 2.38217721
float32 [5 10 10] TRUE mean bwd 83183 19292 4.311787269
float32 [5 10 10] FALSE none fwd 80991 13585 5.961796099
float32 [5 10 10] FALSE none bwd 90192 18723 4.817176734
float32 [5 10 10] FALSE sum fwd 95487 37251 2.563340581
float32 [5 10 10] FALSE sum bwd 81727 17319 4.718921416
float32 [5 10 10] FALSE mean fwd 85215 36825 2.314052953
float32 [5 10 10] FALSE mean bwd 144479 17229 8.385803007
float32 [2 5 10 10] TRUE none fwd 87183 13549 6.434644623
float32 [2 5 10 10] TRUE none bwd 80527 19488 4.132132594
float32 [2 5 10 10] TRUE sum fwd 87247 34940 2.497052089
float32 [2 5 10 10] TRUE sum bwd 75136 19790 3.796664982
float32 [2 5 10 10] TRUE mean fwd 84383 37340 2.259855383
float32 [2 5 10 10] TRUE mean bwd 75648 20217 3.741801454
float32 [2 5 10 10] FALSE none fwd 84047 14011 5.998643923
float32 [2 5 10 10] FALSE none bwd 79391 20003 3.968954657
float32 [2 5 10 10] FALSE sum fwd 94831 40594 2.33608415
float32 [2 5 10 10] FALSE sum bwd 79824 20412 3.9106408
float32 [2 5 10 10] FALSE mean fwd 90783 34122 2.660541586
float32 [2 5 10 10] FALSE mean bwd 150190 20430 7.351443955
float32 [25 300] TRUE none fwd 79807 9566 8.3427765
float32 [25 300] TRUE none bwd 79967 12215 6.546623005
float32 [25 300] TRUE sum fwd 86399 35882 2.407864668
float32 [25 300] TRUE sum bwd 80415 12535 6.415237335
float32 [25 300] TRUE mean fwd 88624 34495 2.56918394
float32 [25 300] TRUE mean bwd 81519 13246 6.154235241
float32 [25 300] FALSE none fwd 79567 9637 8.256407596
float32 [25 300] FALSE none bwd 79184 12286 6.445059417
float32 [25 300] FALSE sum fwd 85311 32699 2.608978868
float32 [25 300] FALSE sum bwd 79583 12962 6.139716093
float32 [25 300] FALSE mean fwd 92607 31614 2.929303473
float32 [25 300] FALSE mean bwd 136607 13229 10.32632852
float32 [25 100 100] TRUE none fwd 106431 16411 6.485345195
float32 [25 100 100] TRUE none bwd 122607 24288 5.048048419
float32 [25 100 100] TRUE sum fwd 121695 41287 2.947537966
float32 [25 100 100] TRUE sum bwd 119615 24235 4.935630287
float32 [25 100 100] TRUE mean fwd 122863 43492 2.824956314
float32 [25 100 100] TRUE mean bwd 123439 24484 5.041619017
float32 [25 100 100] FALSE none fwd 106975 23168 4.617360152
float32 [25 100 100] FALSE none bwd 120671 28591 4.220593893
float32 [25 100 100] FALSE sum fwd 123199 46870 2.628525709
float32 [25 100 100] FALSE sum bwd 118879 27951 4.253121534
float32 [25 100 100] FALSE mean fwd 126463 48292 2.618715315
float32 [25 100 100] FALSE mean bwd 231566 28218 8.206322206
float32 [100 20 20 20] TRUE none fwd 200398 41411 4.839245611
float32 [100 20 20 20] TRUE none bwd 241470 65771 3.671374922
float32 [100 20 20 20] TRUE sum fwd 214286 69522 3.082276114
float32 [100 20 20 20] TRUE sum bwd 238925 65931 3.623864343
float32 [100 20 20 20] TRUE mean fwd 214414 69736 3.074652977
float32 [100 20 20 20] TRUE mean bwd 247566 65700 3.768127854
float32 [100 20 20 20] FALSE none fwd 200734 51990 3.861011733
float32 [100 20 20 20] FALSE none bwd 240942 69647 3.459474206
float32 [100 20 20 20] FALSE sum fwd 216590 83871 2.582418238
float32 [100 20 20 20] FALSE sum bwd 241663 69238 3.490323233
float32 [100 20 20 20] FALSE mean fwd 217742 81684 2.665662798
float32 [100 20 20 20] FALSE mean bwd 549323 69131 7.946116793
float32 [100 10 10 10 10] TRUE none fwd 249582 50337 4.958221587
float32 [100 10 10 10 10] TRUE none bwd 291789 80564 3.621828608
float32 [100 10 10 10 10] TRUE sum fwd 265982 78181 3.402130953
float32 [100 10 10 10 10] TRUE sum bwd 286189 79888 3.582377829
float32 [100 10 10 10 10] TRUE mean fwd 265854 79052 3.363026868
float32 [100 10 10 10 10] TRUE mean bwd 293070 80101 3.658755821
float32 [100 10 10 10 10] FALSE none fwd 248734 62356 3.988934505
float32 [100 10 10 10 10] FALSE none bwd 290846 84724 3.43286436
float32 [100 10 10 10 10] FALSE sum fwd 263806 96121 2.744519928
float32 [100 10 10 10 10] FALSE sum bwd 283677 83924 3.380165388
float32 [100 10 10 10 10] FALSE mean fwd 267854 94539 2.833264579
float32 [100 10 10 10 10] FALSE mean bwd 636091 83977 7.574585899
float32 [2000 3000] TRUE none fwd 1285317 278904 4.608456673
float32 [2000 3000] TRUE none bwd 1456660 457011 3.187363105
float32 [2000 3000] TRUE sum fwd 1331780 349973 3.805379272
float32 [2000 3000] TRUE sum bwd 1433027 452334 3.168072707
float32 [2000 3000] TRUE mean fwd 1334900 347252 3.844182323
float32 [2000 3000] TRUE mean bwd 1487507 451605 3.293823142
float32 [2000 3000] FALSE none fwd 1304789 328934 3.966719767
float32 [2000 3000] FALSE none bwd 1452675 469327 3.095229978
float32 [2000 3000] FALSE sum fwd 1336884 401940 3.326078519
float32 [2000 3000] FALSE sum bwd 1421475 463442 3.067212294
float32 [2000 3000] FALSE mean fwd 1338228 399895 3.346448443
float32 [2000 3000] FALSE mean bwd 3341635 462854 7.219630812
float32 [25 1000 1000] TRUE none fwd 5138451 1145640 4.485223107
float32 [25 1000 1000] TRUE none bwd 5819901 1889040 3.080877589
float32 [25 1000 1000] TRUE sum fwd 5243106 1367510 3.834053133
float32 [25 1000 1000] TRUE sum bwd 5678942 4201130 1.351765358
float32 [25 1000 1000] TRUE mean fwd 5244754 1366410 3.838345738
float32 [25 1000 1000] TRUE mean bwd 5918332 1862170 3.178191035
float32 [25 1000 1000] FALSE none fwd 5152050 4529510 1.137440915
float32 [25 1000 1000] FALSE none bwd 5811853 4624380 1.256785342
float32 [25 1000 1000] FALSE sum fwd 5250178 4761350 1.102665841
float32 [25 1000 1000] FALSE sum bwd 5668574 4549310 1.246029398
float32 [25 1000 1000] FALSE mean fwd 5241474 4763050 1.100444883
float32 [25 1000 1000] FALSE mean bwd 35102137 4549250 7.716027257
float32 [10 100 100 100] TRUE none fwd 2106141 460996 4.568675216
float32 [10 100 100 100] TRUE none bwd 2372939 758370 3.128999037
float32 [10 100 100 100] TRUE sum fwd 2164653 561128 3.857681313
float32 [10 100 100 100] TRUE sum bwd 2299532 747898 3.074659914
float32 [10 100 100 100] TRUE mean fwd 2161885 562141 3.84580559
float32 [10 100 100 100] TRUE mean bwd 2419099 748982 3.229849315
float32 [10 100 100 100] FALSE none fwd 2117037 1324770 1.598041169
float32 [10 100 100 100] FALSE none bwd 2369579 1379190 1.718094679
float32 [10 100 100 100] FALSE sum fwd 2168365 1431280 1.514983092
float32 [10 100 100 100] FALSE sum bwd 2303644 1356110 1.698714706
float32 [10 100 100 100] FALSE mean fwd 2168701 1433990 1.51235434
float32 [10 100 100 100] FALSE mean bwd 14042708 1356220 10.35429945
float32 [10 100 100 100] FALSE mean bwd 14042708 1356220 10.35429945
Float16
dtype size is_contiguous reduction direction ROCm MIOpen improvement
float16 [20 30] TRUE none fwd 90351 12266 7.365970977
float16 [20 30] TRUE none bwd 85887 19147 4.48566355
float16 [20 30] TRUE sum fwd 96192 36338 2.647146238
float16 [20 30] TRUE sum bwd 83983 20498 4.097131427
float16 [20 30] TRUE mean fwd 95183 37405 2.544659805
float16 [20 30] TRUE mean bwd 83263 19929 4.177981836
float16 [20 30] FALSE none fwd 89791 12802 7.013825965
float16 [20 30] FALSE none bwd 85296 21017 4.058428891
float16 [20 30] FALSE sum fwd 98367 33802 2.910094077
float16 [20 30] FALSE sum bwd 80767 20217 3.995004204
float16 [20 30] FALSE mean fwd 96735 34247 2.824626975
float16 [20 30] FALSE mean bwd 135951 20110 6.760367976
float16 [5 10 10] TRUE none fwd 84943 14634 5.804496378
float16 [5 10 10] TRUE none bwd 85359 18492 4.615996106
float16 [5 10 10] TRUE sum fwd 93119 37785 2.464443562
float16 [5 10 10] TRUE sum bwd 80015 20021 3.996553619
float16 [5 10 10] TRUE mean fwd 91296 36327 2.513172021
float16 [5 10 10] TRUE mean bwd 81935 20164 4.063429875
float16 [5 10 10] FALSE none fwd 87808 13567 6.472175131
float16 [5 10 10] FALSE none bwd 83792 18581 4.509552769
float16 [5 10 10] FALSE sum fwd 95183 41946 2.269179421
float16 [5 10 10] FALSE sum bwd 80559 19097 4.218411269
float16 [5 10 10] FALSE mean fwd 94063 37020 2.5408698
float16 [5 10 10] FALSE mean bwd 143743 20181 7.122689659
float16 [2 5 10 10] TRUE none fwd 96864 14242 6.801291953
float16 [2 5 10 10] TRUE none bwd 88655 19594 4.524599367
float16 [2 5 10 10] TRUE sum fwd 106239 34104 3.115147783
float16 [2 5 10 10] TRUE sum bwd 89391 20893 4.278514335
float16 [2 5 10 10] TRUE mean fwd 105759 34370 3.077073029
float16 [2 5 10 10] TRUE mean bwd 91807 17923 5.122300954
float16 [2 5 10 10] FALSE none fwd 95023 14100 6.739219858
float16 [2 5 10 10] FALSE none bwd 91327 19790 4.614805457
float16 [2 5 10 10] FALSE sum fwd 102991 35046 2.938737659
float16 [2 5 10 10] FALSE sum bwd 90591 19772 4.581782318
float16 [2 5 10 10] FALSE mean fwd 103663 34672 2.989818874
float16 [2 5 10 10] FALSE mean bwd 146798 20430 7.185413607
float16 [25 300] TRUE none fwd 90751 9370 9.685272145
float16 [25 300] TRUE none bwd 91759 11735 7.819258628
float16 [25 300] TRUE sum fwd 100719 31934 3.153973821
float16 [25 300] TRUE sum bwd 90767 12784 7.100046934
float16 [25 300] TRUE mean fwd 97984 34121 2.871662612
float16 [25 300] TRUE mean bwd 93279 12766 7.306830644
float16 [25 300] FALSE none fwd 90191 9566 9.428287686
float16 [25 300] FALSE none bwd 91775 12429 7.383940784
float16 [25 300] FALSE sum fwd 100335 34299 2.925303945
float16 [25 300] FALSE sum bwd 89935 12891 6.976572803
float16 [25 300] FALSE mean fwd 99679 32006 3.114384803
float16 [25 300] FALSE mean bwd 136015 12891 10.55115972
float16 [25 100 100] TRUE none fwd 95295 16305 5.844526219
float16 [25 100 100] TRUE none bwd 105183 23933 4.394894079
float16 [25 100 100] TRUE sum fwd 109311 45839 2.38467244
float16 [25 100 100] TRUE sum bwd 103535 24217 4.275302473
float16 [25 100 100] TRUE mean fwd 112318 38975 2.881796023
float16 [25 100 100] TRUE mean bwd 107391 24306 4.41829178
float16 [25 100 100] FALSE none fwd 93999 20323 4.625252177
float16 [25 100 100] FALSE none bwd 104559 28147 3.714747575
float16 [25 100 100] FALSE sum fwd 111007 45554 2.436822233
float16 [25 100 100] FALSE sum bwd 103519 28911 3.580609457
float16 [25 100 100] FALSE mean fwd 111871 44523 2.51265638
float16 [25 100 100] FALSE mean bwd 213710 28787 7.423837149
float16 [100 20 20 20] TRUE none fwd 149903 40967 3.659115874
float16 [100 20 20 20] TRUE none bwd 163150 65735 2.481935042
float16 [100 20 20 20] TRUE sum fwd 167295 69113 2.420601045
float16 [100 20 20 20] TRUE sum bwd 164862 65664 2.510690789
float16 [100 20 20 20] TRUE mean fwd 160399 67264 2.384618815
float16 [100 20 20 20] TRUE mean bwd 167806 65735 2.552764889
float16 [100 20 20 20] FALSE none fwd 149711 50870 2.943011598
float16 [100 20 20 20] FALSE none bwd 163486 69931 2.337818707
float16 [100 20 20 20] FALSE sum fwd 163727 79106 2.069716583
float16 [100 20 20 20] FALSE sum bwd 167295 70055 2.388052245
float16 [100 20 20 20] FALSE mean fwd 164383 79586 2.06547634
float16 [100 20 20 20] FALSE mean bwd 471164 69931 6.737555591
float16 [100 10 10 10 10] TRUE none fwd 196527 50479 3.893242735
float16 [100 10 10 10 10] TRUE none bwd 211629 80492 2.629192963
float16 [100 10 10 10 10] TRUE sum fwd 205598 82413 2.494727774
float16 [100 10 10 10 10] TRUE sum bwd 209983 80155 2.619711808
float16 [100 10 10 10 10] TRUE mean fwd 209262 79159 2.643565482
float16 [100 10 10 10 10] TRUE mean bwd 215262 80048 2.689161503
float16 [100 10 10 10 10] FALSE none fwd 192430 62445 3.081591801
float16 [100 10 10 10 10] FALSE none bwd 206718 85239 2.425157498
float16 [100 10 10 10 10] FALSE sum fwd 207790 93134 2.231086392
float16 [100 10 10 10 10] FALSE sum bwd 207903 85150 2.441608925
float16 [100 10 10 10 10] FALSE mean fwd 209119 98681 2.119141476
float16 [100 10 10 10 10] FALSE mean bwd 558955 85133 6.565667837
float16 [2000 3000] TRUE none fwd 831897 279953 2.971559512
float16 [2000 3000] TRUE none bwd 863288 459143 1.880215968
float16 [2000 3000] TRUE sum fwd 860440 346806 2.481041274
float16 [2000 3000] TRUE sum bwd 863785 453293 1.905577629
float16 [2000 3000] TRUE mean fwd 863736 346254 2.494515587
float16 [2000 3000] TRUE mean bwd 890952 453328 1.965358416
float16 [2000 3000] FALSE none fwd 831721 322728 2.577157854
float16 [2000 3000] FALSE none bwd 859304 471246 1.823472242
float16 [2000 3000] FALSE sum fwd 862328 396551 2.174570232
float16 [2000 3000] FALSE sum bwd 866328 465929 1.859356254
float16 [2000 3000] FALSE mean fwd 864360 392852 2.200217894
float16 [2000 3000] FALSE mean bwd 2779080 466480 5.95755445
float16 [25 1000 1000] TRUE none fwd 3194948 1149220 2.780101286
float16 [25 1000 1000] TRUE none bwd 3291667 1894540 1.737449196
float16 [25 1000 1000] TRUE sum fwd 3247044 1366720 2.37579314
float16 [25 1000 1000] TRUE sum bwd 3301987 1870760 1.765051102
float16 [25 1000 1000] TRUE mean fwd 3251459 1371580 2.370593768
float16 [25 1000 1000] TRUE mean bwd 3396194 1871180 1.815001229
float16 [25 1000 1000] FALSE none fwd 3189828 3500840 0.91116075
float16 [25 1000 1000] FALSE none bwd 3271347 3584900 0.912535078
float16 [25 1000 1000] FALSE sum fwd 3241283 3733260 0.868217858
float16 [25 1000 1000] FALSE sum bwd 3314339 3548660 0.93396916
float16 [25 1000 1000] FALSE mean fwd 3246595 3704060 0.876496331
float16 [25 1000 1000] FALSE mean bwd 26391238 3550170 7.433795565
float16 [10 100 100 100] TRUE none fwd 1324852 462684 2.863405694
float16 [10 100 100 100] TRUE none bwd 1353908 760004 1.781448519
float16 [10 100 100 100] TRUE sum fwd 1360420 562407 2.418924373
float16 [10 100 100 100] TRUE sum bwd 1376628 751470 1.83191345
float16 [10 100 100 100] TRUE mean fwd 1359620 561357 2.422023775
float16 [10 100 100 100] TRUE mean bwd 1423780 750829 1.896277315
float16 [10 100 100 100] FALSE none fwd 1327300 1130510 1.174071879
float16 [10 100 100 100] FALSE none bwd 1375716 1240250 1.109224753
float16 [10 100 100 100] FALSE sum fwd 1363524 1235410 1.103701605
float16 [10 100 100 100] FALSE sum bwd 1381716 1209700 1.142197239
float16 [10 100 100 100] FALSE mean fwd 1370484 1234720 1.109955294
float16 [10 100 100 100] FALSE mean bwd 9185998 1210570 7.588159297
float32 [10 100 100 100] FALSE mean bwd 14042708 1356220 10.35429945
BFloat16
dtype size is_contiguous reduction direction ROCm MIOpen improvement
bfloat16 [20 30] TRUE none fwd 96207 14631 6.575558745
bfloat16 [20 30] TRUE none bwd 97151 19395 5.009074504
bfloat16 [20 30] TRUE sum fwd 102928 35538 2.896280038
bfloat16 [20 30] TRUE sum bwd 91967 19573 4.69866653
bfloat16 [20 30] TRUE mean fwd 99871 34827 2.867631435
bfloat16 [20 30] TRUE mean bwd 93791 20000 4.68955
bfloat16 [20 30] FALSE none fwd 96815 14065 6.883398507
bfloat16 [20 30] FALSE none bwd 97199 20555 4.728727803
bfloat16 [20 30] FALSE sum fwd 101455 35082 2.891938886
bfloat16 [20 30] FALSE sum bwd 89583 20413 4.388526919
bfloat16 [20 30] FALSE mean fwd 101567 34638 2.932242046
bfloat16 [20 30] FALSE mean bwd 146815 21533 6.8181396
bfloat16 [5 10 10] TRUE none fwd 92959 14118 6.584431223
bfloat16 [5 10 10] TRUE none bwd 96559 19950 4.840050125
bfloat16 [5 10 10] TRUE sum fwd 98318 37554 2.618043351
bfloat16 [5 10 10] TRUE sum bwd 90863 22155 4.101241255
bfloat16 [5 10 10] TRUE mean fwd 97791 36487 2.680160057
bfloat16 [5 10 10] TRUE mean bwd 93711 19346 4.843947069
bfloat16 [5 10 10] FALSE none fwd 97967 13514 7.249297025
bfloat16 [5 10 10] FALSE none bwd 98847 17478 5.655509784
bfloat16 [5 10 10] FALSE sum fwd 103247 38940 2.65143811
bfloat16 [5 10 10] FALSE sum bwd 90495 20412 4.433421517
bfloat16 [5 10 10] FALSE mean fwd 102095 38194 2.673063832
bfloat16 [5 10 10] FALSE mean bwd 150063 21764 6.895010108
bfloat16 [2 5 10 10] TRUE none fwd 107023 14829 7.217142086
bfloat16 [2 5 10 10] TRUE none bwd 108798 19754 5.507644021
bfloat16 [2 5 10 10] TRUE sum fwd 114687 37109 3.090544073
bfloat16 [2 5 10 10] TRUE sum bwd 103999 20430 5.090504161
bfloat16 [2 5 10 10] TRUE mean fwd 115327 35491 3.249471697
bfloat16 [2 5 10 10] TRUE mean bwd 107359 20661 5.196215091
bfloat16 [2 5 10 10] FALSE none fwd 106335 14171 7.503704749
bfloat16 [2 5 10 10] FALSE none bwd 111647 20359 5.483913748
bfloat16 [2 5 10 10] FALSE sum fwd 112079 35615 3.14696055
bfloat16 [2 5 10 10] FALSE sum bwd 103039 20306 5.074313011
bfloat16 [2 5 10 10] FALSE mean fwd 113407 35171 3.224446277
bfloat16 [2 5 10 10] FALSE mean bwd 159327 20519 7.764852088
bfloat16 [25 300] TRUE none fwd 97855 9530 10.26810073
bfloat16 [25 300] TRUE none bwd 105823 12197 8.676149873
bfloat16 [25 300] TRUE sum fwd 106463 33019 3.224295103
bfloat16 [25 300] TRUE sum bwd 101424 12819 7.912005617
bfloat16 [25 300] TRUE mean fwd 105791 33979 3.113422997
bfloat16 [25 300] TRUE mean bwd 107375 13229 8.11663769
bfloat16 [25 300] FALSE none fwd 97375 9832 9.903885273
bfloat16 [25 300] FALSE none bwd 107343 12179 8.813777814
bfloat16 [25 300] FALSE sum fwd 113568 31703 3.582247737
bfloat16 [25 300] FALSE sum bwd 106911 13353 8.00651539
bfloat16 [25 300] FALSE mean fwd 111935 44363 2.523161193
bfloat16 [25 300] FALSE mean bwd 146111 13478 10.84070337
bfloat16 [25 100 100] TRUE none fwd 99871 16500 6.052787879
bfloat16 [25 100 100] TRUE none bwd 114143 24324 4.692608124
bfloat16 [25 100 100] TRUE sum fwd 117183 42780 2.739200561
bfloat16 [25 100 100] TRUE sum bwd 111823 24235 4.614111822
bfloat16 [25 100 100] TRUE mean fwd 117839 40647 2.899082343
bfloat16 [25 100 100] TRUE mean bwd 119551 24324 4.914939977
bfloat16 [25 100 100] FALSE none fwd 100527 20323 4.946464597
bfloat16 [25 100 100] FALSE none bwd 114286 28004 4.081059849
bfloat16 [25 100 100] FALSE sum fwd 116926 44896 2.604374555
bfloat16 [25 100 100] FALSE sum bwd 112799 28929 3.899166926
bfloat16 [25 100 100] FALSE mean fwd 122095 45856 2.662574145
bfloat16 [25 100 100] FALSE mean bwd 218494 28769 7.594772151
bfloat16 [100 20 20 20] TRUE none fwd 158623 41731 3.801083128
bfloat16 [100 20 20 20] TRUE none bwd 174159 66179 2.631635413
bfloat16 [100 20 20 20] TRUE sum fwd 177903 70447 2.525345295
bfloat16 [100 20 20 20] TRUE sum bwd 176190 66073 2.666596038
bfloat16 [100 20 20 20] TRUE mean fwd 173326 69309 2.500771906
bfloat16 [100 20 20 20] TRUE mean bwd 181471 66126 2.744321447
bfloat16 [100 20 20 20] FALSE none fwd 159102 50799 3.131990787
bfloat16 [100 20 20 20] FALSE none bwd 175567 69558 2.524037494
bfloat16 [100 20 20 20] FALSE sum fwd 175678 78714 2.231852021
bfloat16 [100 20 20 20] FALSE sum bwd 177918 70269 2.53195577
bfloat16 [100 20 20 20] FALSE mean fwd 180158 79141 2.276418039
bfloat16 [100 20 20 20] FALSE mean bwd 481308 70198 6.856434656
bfloat16 [100 10 10 10 10] TRUE none fwd 207758 51154 4.061422372
bfloat16 [100 10 10 10 10] TRUE none bwd 226318 81079 2.791326977
bfloat16 [100 10 10 10 10] TRUE sum fwd 222222 78536 2.829555872
bfloat16 [100 10 10 10 10] TRUE sum bwd 225198 80421 2.800238744
bfloat16 [100 10 10 10 10] TRUE mean fwd 223310 79212 2.819143564
bfloat16 [100 10 10 10 10] TRUE mean bwd 235326 80474 2.924248826
bfloat16 [100 10 10 10 10] FALSE none fwd 209134 62000 3.373129032
bfloat16 [100 10 10 10 10] FALSE none bwd 226510 84795 2.671265994
bfloat16 [100 10 10 10 10] FALSE sum fwd 224717 91516 2.455494121
bfloat16 [100 10 10 10 10] FALSE sum bwd 228014 85399 2.669984426
bfloat16 [100 10 10 10 10] FALSE mean fwd 232046 91729 2.529690719
bfloat16 [100 10 10 10 10] FALSE mean bwd 577627 85328 6.769489499
bfloat16 [2000 3000] TRUE none fwd 903896 282298 3.201921374
bfloat16 [2000 3000] TRUE none bwd 973383 461968 2.107035552
bfloat16 [2000 3000] TRUE sum fwd 939544 351623 2.672020886
bfloat16 [2000 3000] TRUE sum bwd 957976 456100 2.100363955
bfloat16 [2000 3000] TRUE mean fwd 940568 350253 2.685395985
bfloat16 [2000 3000] TRUE mean bwd 996151 456384 2.1827036
bfloat16 [2000 3000] FALSE none fwd 920776 322158 2.858150349
bfloat16 [2000 3000] FALSE none bwd 961591 472595 2.034704134
bfloat16 [2000 3000] FALSE sum fwd 940984 392744 2.395922026
bfloat16 [2000 3000] FALSE sum bwd 957048 467652 2.046496112
bfloat16 [2000 3000] FALSE mean fwd 937320 393651 2.381093913
bfloat16 [2000 3000] FALSE mean bwd 2838903 468060 6.065254455
bfloat16 [25 1000 1000] TRUE none fwd 3488673 1160440 3.00633639
bfloat16 [25 1000 1000] TRUE none bwd 3706511 1906140 1.944511421
bfloat16 [25 1000 1000] TRUE sum fwd 3560448 1380640 2.578838799
bfloat16 [25 1000 1000] TRUE sum bwd 3698463 1883710 1.963392985
bfloat16 [25 1000 1000] TRUE mean fwd 3566721 1381190 2.582353623
bfloat16 [25 1000 1000] TRUE mean bwd 3855262 1883640 2.046708501
bfloat16 [25 1000 1000] FALSE none fwd 3495441 3502480 0.997990281
bfloat16 [25 1000 1000] FALSE none bwd 3705151 3596210 1.030293281
bfloat16 [25 1000 1000] FALSE sum fwd 3544400 3716430 0.953710954
bfloat16 [25 1000 1000] FALSE sum bwd 3719167 3544240 1.049355292
bfloat16 [25 1000 1000] FALSE mean fwd 3573984 3727210 0.958889894
bfloat16 [25 1000 1000] FALSE mean bwd 26678548 3542150 7.531738633
bfloat16 [10 100 100 100] TRUE none fwd 1467459 467714 3.137513523
bfloat16 [10 100 100 100] TRUE none bwd 1542386 764625 2.017179663
bfloat16 [10 100 100 100] TRUE sum fwd 1480723 567615 2.608674894
bfloat16 [10 100 100 100] TRUE sum bwd 1528754 756446 2.020969111
bfloat16 [10 100 100 100] TRUE mean fwd 1493827 568237 2.628880203
bfloat16 [10 100 100 100] TRUE mean bwd 1591858 756339 2.104688506
bfloat16 [10 100 100 100] FALSE none fwd 1450499 1137090 1.275623741
bfloat16 [10 100 100 100] FALSE none bwd 1534466 1223550 1.254109763
bfloat16 [10 100 100 100] FALSE sum fwd 1482163 1237170 1.198026949
bfloat16 [10 100 100 100] FALSE sum bwd 1526034 1209820 1.261372766
bfloat16 [10 100 100 100] FALSE mean fwd 1484323 1235710 1.201190409
bfloat16 [10 100 100 100] FALSE mean bwd 9208158 1210750 7.605333884
float32 [10 100 100 100] FALSE mean bwd 14042708 1356220 10.35429945

@BuiChiTrung BuiChiTrung self-assigned this Jul 24, 2024
@BuiChiTrung BuiChiTrung marked this pull request as ready for review July 25, 2024 08:49
@BuiChiTrung BuiChiTrung requested review from JehandadKhan and junliume and removed request for JehandadKhan and junliume July 25, 2024 08:51
@CAHEK7
Copy link
Contributor

CAHEK7 commented Jul 25, 2024

It seems a part of this PR included into #3146
May I ask you to collaborate with @littlecutebird and fix all the common parts of both PRs?

@BuiChiTrung
Copy link
Collaborator Author

BuiChiTrung commented Jul 31, 2024

I have updated my code following comments in #3146 . Please take another look at my PR.

@BuiChiTrung
Copy link
Collaborator Author

@junliume can you take a look at Windows build state, plz. I added MIOPEN_INTERNALS_EXPORT but it still fail.

@BuiChiTrung
Copy link
Collaborator Author

@iq136boy can you help me check the log in windows build stage, please?

@apwojcik
Copy link
Collaborator

apwojcik commented Aug 6, 2024

@iq136boy can you help me check the log in windows build stage, please?

[ 95%] Linking CXX executable ..\..\bin\test_sigmoid_focal_loss.exe
lld-link: error: undefined symbol: __declspec(dllimport) enum miopenStatus_t __cdecl miopen::SigmoidFocalLossForward(struct miopen::Handle &, void *, unsigned __int64, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void *, float, float, enum miopenLossReductionMode_t)
lld-link: error: undefined symbol: __declspec(dllimport) enum miopenStatus_t __cdecl miopen::SigmoidFocalLossBackward(struct miopen::Handle &, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void *, struct miopen::TensorDescriptor const &, void *, float, float, enum miopenLossReductionMode_t)
lld-link: error: undefined symbol: __declspec(dllimport) unsigned __int64 __cdecl miopen::GetSigmoidFocalLossForwardWorkspaceSize(struct miopen::Handle &, struct miopen::TensorDescriptor const &, struct miopen::TensorDescriptor const &, struct miopen::TensorDescriptor const &, enum miopenLossReductionMode_t)
>>> referenced by C:\home\jenkins\agent\workspace\UIF2_MIOpen_PR-3143\MIOpen\test\gtest\sigmoid_focal_loss.hpp:299
>>>               CMakeFiles/test_sigmoid_focal_loss.dir/sigmoid_focal_loss.cpp.obj:(protected: virtual void __cdecl SigmoidFocalLossFwdTest<float>::SetUp(void))
>>> referenced by C:\home\jenkins\agent\workspace\UIF2_MIOpen_PR-3143\MIOpen\test\gtest\sigmoid_focal_loss.hpp:299
>>>               CMakeFiles/test_sigmoid_focal_loss.dir/sigmoid_focal_loss.cpp.obj:(protected: virtual void __cdecl SigmoidFocalLossFwdTest<class half_float::half>::SetUp(void))
>>> referenced by C:\home\jenkins\agent\workspace\UIF2_MIOpen_PR-3143\MIOpen\test\gtest\sigmoid_focal_loss.hpp:299
>>>               CMakeFiles/test_sigmoid_focal_loss.dir/sigmoid_focal_loss.cpp.obj:(protected: virtual void __cdecl SigmoidFocalLossFwdTest<class bfloat16>::SetUp(void))

Copy link
Contributor

@CAHEK7 CAHEK7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask you to merge reduction and unreduction kernels and cpu code, as I suggested in #3166?

loss = alphaT * loss;
}

workspaceHost[id] = static_cast<Tcheck>(loss / divisor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need an external workspace for host functions - you are free to allocate memory right in this function (which is usually not possible or even does not make much sense on GPU).

But actually you don't need a workspace for CPU reduction, it's not a parallel algorithm without external synchronization.

Suggested change
workspaceHost[id] = static_cast<Tcheck>(loss / divisor);
outputHost[0] += static_cast<Tcheck>(loss / divisor);

Just don't forget to initialize outputHost[0] with 0.0f

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I declare workspaceHost in the driver as this allow us to compare the workspace in CPU with GPU before executing reduction.
Additionally, outputHost[0] += static_cast<Tcheck>(loss / divisor) results in verification fail as we use reduction algo in GPU. To avoid verification fail, I have to mimic the same reduction algo in CPU check function.

Copy link
Contributor

@CAHEK7 CAHEK7 Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shortly: CPU verification code must be as generic as possible and must not mimic to any particular GPU implementation.

@BuiChiTrung

I declare workspaceHost in the driver as this allow us to compare the workspace in CPU with GPU before executing reduction.

Basically workspace content is not specified and can't be compared. We can use multiple implementations for the same algorithm, and we can't do this check in common case and since this CPU implementation is used for that common case, it can't contain algorithm specific features.

Additionally, outputHost[0] += static_cast<Tcheck>(loss / divisor) results in verification fail as we use reduction algo in GPU. To avoid verification fail, I have to mimic the same reduction algo in CPU check function.

Mathematically it must be the same, but since there are FP rounding errors and other FP stuff, it may result into bigger error, but still, it must work. If the error is too big, then the algorithm is not applicable and can be used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the verify method in CPU to naive accumulation

driver/sigmoid_focal_loss_driver.hpp Outdated Show resolved Hide resolved
driver/sigmoid_focal_loss_driver.hpp Outdated Show resolved Hide resolved
driver/sigmoid_focal_loss_driver.hpp Outdated Show resolved Hide resolved
@long10024070
Copy link
Collaborator

@iq136boy For this PR only, can you and your colleagues give comments about documentation problem? It's great to learn by examples. Please provide us the parts that you guys belive that it's importance to have or to be mentioned in documents.

Copy link
Contributor

@iq136boy iq136boy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR is blocked by the following CI issue. The PR needs to restart CI after the issue is fixed.

[2024-09-23T14:07:43.339Z] Exception occurred: org.kohsuke.github.HttpException: {"message":"API rate limit exceeded for user ID 49319081. If you reach out to GitHub Support for help, please include the request ID

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants