Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SoftMarginLoss #3226

Open
wants to merge 28 commits into
base: develop
Choose a base branch
from
Open

Conversation

littlecutebird
Copy link
Collaborator

@littlecutebird littlecutebird commented Aug 30, 2024

  • Added SoftMarginLoss operation for both forward and backward. Compared to ROCm, it is better for all cases.
  • New API is guarded by MIOPEN_BETA_API macro. Added 2 kernels: SoftMarginLossForward5d, SoftMarginLossBackward5d
  • Added driver test and gtest for SoftMarginLoss.
  • Compared to ROCm:

Unreduced:

type Forward Backward
float32 2.50 3.30
float16 2.46 3.12
bfloat16 2.51 3.30
fp32 forward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 312830 213670 1.46
[32 80 870] [69600 1 80] FALSE 126889 71484 1.78
[32 80 870] [69600 870 1] TRUE 89579 57439 1.56
[4 182403 91] [16598673 91 1] TRUE 2231121 1555260 1.43
[1534680] [1] TRUE 64798 41155 1.57
[16 1 512 512] [262144 262144 512 1] TRUE 156073 103217 1.51
[2 3 160 160] [6528000 2176000 13600 85] FALSE 33934 19786 1.72
[2 3 80 80] [1632000 544000 6800 85] FALSE 25663 9564 2.68
[32756 80] [85 1] FALSE 109723 66630 1.65
[64 3 80 80] [1632000 544000 6800 85] FALSE 153257 119786 1.28
[64 3 40 40] [408000 136000 3400 85] FALSE 49854 32924 1.51
[22311 80] [85 1] FALSE 79085 47395 1.67
[64 3 20 20] [102000 34000 1700 85] FALSE 28703 12035 2.38
[8 4] [4 1] TRUE 17823 10631 1.68
[56 4] [4 1] TRUE 15104 11306 1.34
[131 4] [4 1] TRUE 20512 9511 2.16
[10000] [1] TRUE 21167 7306 2.90
[200 50] [50 1] TRUE 23648 7288 3.24
[20 50 10] [500 10 1] TRUE 20128 7235 2.78
[4 25 4 25] [2500 100 25 1] TRUE 23567 7324 3.22
[12 3 4 5 6] [360 120 30 6 1] TRUE 21584 8142 2.65
[10000] [3] FALSE 18880 7235 2.61
[200 50] [1 200] FALSE 28879 7253 3.98
[200 50] [505 1] FALSE 29087 7324 3.97
[20 50 10] [1 20 1000] FALSE 28447 7608 3.74
[20 50 10] [7575 15 1] FALSE 27359 7271 3.76
[4 25 4 25] [1 16 4 400] FALSE 27007 7324 3.69
[4 25 4 25] [5859 217 31 1] FALSE 29456 7093 4.15
[12 3 4 5 6] [360 120 6 24 1] FALSE 24496 8017 3.06
[12 3 4 5 6] [5760 960 120 12 1] FALSE 31439 8391 3.75
fp32 backward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 553089 232674 2.38
[32 80 870] [69600 1 80] FALSE 299391 73208 4.09
[32 80 870] [69600 870 1] TRUE 161303 61492 2.62
[4 182403 91] [16598673 91 1] TRUE 4006075 1703260 2.35
[1534680] [1] TRUE 116011 43999 2.64
[16 1 512 512] [262144 262144 512 1] TRUE 273716 111466 2.46
[2 3 160 160] [6528000 2176000 13600 85] FALSE 46222 18808 2.46
[2 3 80 80] [1632000 544000 6800 85] FALSE 31038 9493 3.27
[32756 80] [85 1] FALSE 190648 71661 2.66
[64 3 80 80] [1632000 544000 6800 85] FALSE 196184 123004 1.59
[64 3 40 40] [408000 136000 3400 85] FALSE 64541 33279 1.94
[22311 80] [85 1] FALSE 139178 50595 2.75
[64 3 20 20] [102000 34000 1700 85] FALSE 36431 12337 2.95
[8 4] [4 1] TRUE 27519 12622 2.18
[56 4] [4 1] TRUE 23840 11360 2.10
[131 4] [4 1] TRUE 27472 10435 2.63
[10000] [1] TRUE 27423 7786 3.52
[200 50] [50 1] TRUE 28080 8302 3.38
[20 50 10] [500 10 1] TRUE 28000 7999 3.50
[4 25 4 25] [2500 100 25 1] TRUE 27343 8337 3.28
[12 3 4 5 6] [360 120 30 6 1] TRUE 27247 8888 3.07
[10000] [3] FALSE 28352 7999 3.54
[200 50] [1 200] FALSE 42255 8373 5.05
[200 50] [505 1] FALSE 35823 7751 4.62
[20 50 10] [1 20 1000] FALSE 44127 8213 5.37
[20 50 10] [7575 15 1] FALSE 32879 8160 4.03
[4 25 4 25] [1 16 4 400] FALSE 42688 8284 5.15
[4 25 4 25] [5859 217 31 1] FALSE 33519 8124 4.13
[12 3 4 5 6] [360 120 6 24 1] FALSE 46415 8871 5.23
[12 3 4 5 6] [5760 960 120 12 1] FALSE 35791 8640 4.14
fp16 forward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 245058 214701 1.14
[32 80 870] [69600 1 80] FALSE 111946 67626 1.66
[32 80 870] [69600 870 1] TRUE 71116 57030 1.25
[4 182403 91] [16598673 91 1] TRUE 1704623 1565680 1.09
[1534680] [1] TRUE 51181 40746 1.26
[16 1 512 512] [262144 262144 512 1] TRUE 122427 103182 1.19
[2 3 160 160] [6528000 2176000 13600 85] FALSE 29743 17724 1.68
[2 3 80 80] [1632000 544000 6800 85] FALSE 23279 9422 2.47
[32756 80] [85 1] FALSE 87740 66346 1.32
[64 3 80 80] [1632000 544000 6800 85] FALSE 120027 92443 1.30
[64 3 40 40] [408000 136000 3400 85] FALSE 41742 28586 1.46
[22311 80] [85 1] FALSE 63229 47021 1.34
[64 3 20 20] [102000 34000 1700 85] FALSE 22895 11395 2.01
[8 4] [4 1] TRUE 18463 9226 2.00
[56 4] [4 1] TRUE 16768 10080 1.66
[131 4] [4 1] TRUE 22543 9173 2.46
[10000] [1] TRUE 23711 7182 3.30
[200 50] [50 1] TRUE 24336 7182 3.39
[20 50 10] [500 10 1] TRUE 23520 7360 3.20
[4 25 4 25] [2500 100 25 1] TRUE 25071 7199 3.48
[12 3 4 5 6] [360 120 30 6 1] TRUE 24208 8071 3.00
[10000] [3] FALSE 20176 7235 2.79
[200 50] [1 200] FALSE 26015 7395 3.52
[200 50] [505 1] FALSE 26991 7235 3.73
[20 50 10] [1 20 1000] FALSE 26127 7235 3.61
[20 50 10] [7575 15 1] FALSE 26064 7057 3.69
[4 25 4 25] [1 16 4 400] FALSE 27296 7288 3.75
[4 25 4 25] [5859 217 31 1] FALSE 26624 7146 3.73
[12 3 4 5 6] [360 120 6 24 1] FALSE 26720 7840 3.41
[12 3 4 5 6] [5760 960 120 12 1] FALSE 32159 8320 3.87
fp16 backward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 337901 233368 1.45
[32 80 870] [69600 1 80] FALSE 249874 69777 3.58
[32 80 870] [69600 870 1] TRUE 103578 61635 1.68
[4 182403 91] [16598673 91 1] TRUE 2320300 1712220 1.36
[1534680] [1] TRUE 71181 43484 1.64
[16 1 512 512] [262144 262144 512 1] TRUE 175272 111946 1.57
[2 3 160 160] [6528000 2176000 13600 85] FALSE 42190 16746 2.52
[2 3 80 80] [1632000 544000 6800 85] FALSE 31822 9600 3.31
[32756 80] [85 1] FALSE 125770 71519 1.76
[64 3 80 80] [1632000 544000 6800 85] FALSE 139818 94755 1.48
[64 3 40 40] [408000 136000 3400 85] FALSE 55549 28142 1.97
[22311 80] [85 1] FALSE 85789 50524 1.70
[64 3 20 20] [102000 34000 1700 85] FALSE 35278 11413 3.09
[8 4] [4 1] TRUE 28079 12924 2.17
[56 4] [4 1] TRUE 23776 12195 1.95
[131 4] [4 1] TRUE 28975 10453 2.77
[10000] [1] TRUE 30447 8088 3.76
[200 50] [50 1] TRUE 29456 8231 3.58
[20 50 10] [500 10 1] TRUE 29728 8177 3.64
[4 25 4 25] [2500 100 25 1] TRUE 29664 8106 3.66
[12 3 4 5 6] [360 120 30 6 1] TRUE 30063 8871 3.39
[10000] [3] FALSE 29616 8160 3.63
[200 50] [1 200] FALSE 42367 8284 5.11
[200 50] [505 1] FALSE 35807 8071 4.44
[20 50 10] [1 20 1000] FALSE 43918 8213 5.35
[20 50 10] [7575 15 1] FALSE 34831 7964 4.37
[4 25 4 25] [1 16 4 400] FALSE 43135 8302 5.20
[4 25 4 25] [5859 217 31 1] FALSE 33759 8017 4.21
[12 3 4 5 6] [360 120 6 24 1] FALSE 45439 8515 5.34
[12 3 4 5 6] [5760 960 120 12 1] FALSE 36959 9155 4.04
bfp16 forward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 270257 219466 1.23
[32 80 870] [69600 1 80] FALSE 116921 67590 1.73
[32 80 870] [69600 870 1] TRUE 76108 59324 1.28
[4 182403 91] [16598673 91 1] TRUE 1896068 1595560 1.19
[1534680] [1] TRUE 54781 41937 1.31
[16 1 512 512] [262144 262144 512 1] TRUE 132810 105831 1.25
[2 3 160 160] [6528000 2176000 13600 85] FALSE 30495 17475 1.75
[2 3 80 80] [1632000 544000 6800 85] FALSE 23135 9333 2.48
[32756 80] [85 1] FALSE 94156 68177 1.38
[64 3 80 80] [1632000 544000 6800 85] FALSE 122075 91715 1.33
[64 3 40 40] [408000 136000 3400 85] FALSE 41950 28320 1.48
[22311 80] [85 1] FALSE 67629 48124 1.41
[64 3 20 20] [102000 34000 1700 85] FALSE 24447 11431 2.14
[8 4] [4 1] TRUE 20224 11200 1.81
[56 4] [4 1] TRUE 16368 10542 1.55
[131 4] [4 1] TRUE 24799 10115 2.45
[10000] [1] TRUE 25344 7555 3.35
[200 50] [50 1] TRUE 26704 7199 3.71
[20 50 10] [500 10 1] TRUE 23472 7288 3.22
[4 25 4 25] [2500 100 25 1] TRUE 24736 7768 3.18
[12 3 4 5 6] [360 120 30 6 1] TRUE 23519 8533 2.76
[10000] [3] FALSE 21263 7235 2.94
[200 50] [1 200] FALSE 28783 7448 3.86
[200 50] [505 1] FALSE 28639 7342 3.90
[20 50 10] [1 20 1000] FALSE 28143 7235 3.89
[20 50 10] [7575 15 1] FALSE 27856 7342 3.79
[4 25 4 25] [1 16 4 400] FALSE 27952 7431 3.76
[4 25 4 25] [5859 217 31 1] FALSE 27439 7484 3.67
[12 3 4 5 6] [360 120 6 24 1] FALSE 26911 8213 3.28
[12 3 4 5 6] [5760 960 120 12 1] FALSE 33871 8160 4.15
bfp16 backward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 387226 240141 1.61
[32 80 870] [69600 1 80] FALSE 256274 71395 3.59
[32 80 870] [69600 870 1] TRUE 112138 64017 1.75
[4 182403 91] [16598673 91 1] TRUE 2719733 1754320 1.55
[1534680] [1] TRUE 79645 45191 1.76
[16 1 512 512] [262144 262144 512 1] TRUE 197064 115057 1.71
[2 3 160 160] [6528000 2176000 13600 85] FALSE 46255 18737 2.47
[2 3 80 80] [1632000 544000 6800 85] FALSE 35343 10453 3.38
[32756 80] [85 1] FALSE 139306 73848 1.89
[64 3 80 80] [1632000 544000 6800 85] FALSE 148474 96479 1.54
[64 3 40 40] [408000 136000 3400 85] FALSE 59357 29706 2.00
[22311 80] [85 1] FALSE 98076 51875 1.89
[64 3 20 20] [102000 34000 1700 85] FALSE 39998 13386 2.99
[8 4] [4 1] TRUE 30128 11751 2.56
[56 4] [4 1] TRUE 25055 12373 2.02
[131 4] [4 1] TRUE 32559 10506 3.10
[10000] [1] TRUE 34687 8568 4.05
[200 50] [50 1] TRUE 34544 8408 4.11
[20 50 10] [500 10 1] TRUE 33648 8639 3.89
[4 25 4 25] [2500 100 25 1] TRUE 34335 8764 3.92
[12 3 4 5 6] [360 120 30 6 1] TRUE 33983 9173 3.70
[10000] [3] FALSE 31600 8533 3.70
[200 50] [1 200] FALSE 48271 8302 5.81
[200 50] [505 1] FALSE 40383 8533 4.73
[20 50 10] [1 20 1000] FALSE 47374 8604 5.51
[20 50 10] [7575 15 1] FALSE 38431 8444 4.55
[4 25 4 25] [1 16 4 400] FALSE 46383 8515 5.45
[4 25 4 25] [5859 217 31 1] FALSE 36944 8444 4.38
[12 3 4 5 6] [360 120 6 24 1] FALSE 49263 9475 5.20
[12 3 4 5 6] [5760 960 120 12 1] FALSE 40592 9475 4.28

Reduced:

type Forward Backward
float32 3.26 2.88
float16 3.04 2.66
bfloat16 3.10 2.84
fp32 forward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 369111 301989 1.22
[32 80 870] [69600 1 80] FALSE 161182 106168 1.52
[32 80 870] [69600 870 1] TRUE 122654 91341 1.34
[4 182403 91] [16598673 91 1] TRUE 2476240 2107970 1.17
[1534680] [1] TRUE 170494 69972 2.44
[16 1 512 512] [262144 262144 512 1] TRUE 200662 152301 1.32
[2 3 160 160] [6528000 2176000 13600 85] FALSE 105519 37333 2.83
[2 3 80 80] [1632000 544000 6800 85] FALSE 103921 24408 4.26
[32756 80] [85 1] FALSE 148782 105724 1.41
[64 3 80 80] [1632000 544000 6800 85] FALSE 178350 145973 1.22
[64 3 40 40] [408000 136000 3400 85] FALSE 89535 51928 1.72
[22311 80] [85 1] FALSE 109711 78826 1.39
[64 3 20 20] [102000 34000 1700 85] FALSE 71823 30346 2.37
[8 4] [4 1] TRUE 99951 20497 4.88
[56 4] [4 1] TRUE 95135 20444 4.65
[131 4] [4 1] TRUE 101231 26737 3.79
[10000] [1] TRUE 120670 24017 5.02
[200 50] [50 1] TRUE 107343 23093 4.65
[20 50 10] [500 10 1] TRUE 107951 23466 4.60
[4 25 4 25] [2500 100 25 1] TRUE 116206 22630 5.14
[12 3 4 5 6] [360 120 30 6 1] TRUE 99311 24639 4.03
[10000] [3] FALSE 114110 22915 4.98
[200 50] [1 200] FALSE 101343 22986 4.41
[200 50] [505 1] FALSE 107390 23128 4.64
[20 50 10] [1 20 1000] FALSE 107598 26453 4.07
[20 50 10] [7575 15 1] FALSE 95119 22435 4.24
[4 25 4 25] [1 16 4 400] FALSE 110078 25937 4.24
[4 25 4 25] [5859 217 31 1] FALSE 85247 24053 3.54
[12 3 4 5 6] [360 120 6 24 1] FALSE 85183 32764 2.60
[12 3 4 5 6] [5760 960 120 12 1] FALSE 99375 24888 3.99
fp32 backward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 528898 248549 2.13
[32 80 870] [69600 1 80] FALSE 252066 75057 3.36
[32 80 870] [69600 870 1] TRUE 153719 65421 2.35
[4 182403 91] [16598673 91 1] TRUE 3838133 1819290 2.11
[1534680] [1] TRUE 111195 46666 2.38
[16 1 512 512] [262144 262144 512 1] TRUE 262661 118595 2.21
[2 3 160 160] [6528000 2176000 13600 85] FALSE 45726 18862 2.42
[2 3 80 80] [1632000 544000 6800 85] FALSE 30014 9599 3.13
[32756 80] [85 1] FALSE 184584 76497 2.41
[64 3 80 80] [1632000 544000 6800 85] FALSE 192808 123839 1.56
[64 3 40 40] [408000 136000 3400 85] FALSE 63757 33439 1.91
[22311 80] [85 1] FALSE 132763 53919 2.46
[64 3 20 20] [102000 34000 1700 85] FALSE 34639 12373 2.80
[8 4] [4 1] TRUE 27759 12533 2.21
[56 4] [4 1] TRUE 24015 12871 1.87
[131 4] [4 1] TRUE 24752 10168 2.43
[10000] [1] TRUE 24687 8515 2.90
[200 50] [50 1] TRUE 26800 8586 3.12
[20 50 10] [500 10 1] TRUE 24624 8586 2.87
[4 25 4 25] [2500 100 25 1] TRUE 24831 8231 3.02
[12 3 4 5 6] [360 120 30 6 1] TRUE 26736 9315 2.87
[10000] [3] FALSE 28815 8622 3.34
[200 50] [1 200] FALSE 36079 8782 4.11
[200 50] [505 1] FALSE 32527 8924 3.64
[20 50 10] [1 20 1000] FALSE 37264 8320 4.48
[20 50 10] [7575 15 1] FALSE 30239 8284 3.65
[4 25 4 25] [1 16 4 400] FALSE 36687 8675 4.23
[4 25 4 25] [5859 217 31 1] FALSE 28655 8391 3.41
[12 3 4 5 6] [360 120 6 24 1] FALSE 37087 9653 3.84
[12 3 4 5 6] [5760 960 120 12 1] FALSE 32399 10079 3.21
fp16 forward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 291221 302291 0.96
[32 80 870] [69600 1 80] FALSE 141342 103821 1.36
[32 80 870] [69600 870 1] TRUE 133998 91519 1.46
[4 182403 91] [16598673 91 1] TRUE 1853128 2113120 0.88
[1534680] [1] TRUE 119247 69279 1.72
[16 1 512 512] [262144 262144 512 1] TRUE 158849 152177 1.04
[2 3 160 160] [6528000 2176000 13600 85] FALSE 79967 35786 2.23
[2 3 80 80] [1632000 544000 6800 85] FALSE 93090 26382 3.53
[32756 80] [85 1] FALSE 119166 103875 1.15
[64 3 80 80] [1632000 544000 6800 85] FALSE 146478 126950 1.15
[64 3 40 40] [408000 136000 3400 85] FALSE 105903 47875 2.21
[22311 80] [85 1] FALSE 90287 77706 1.16
[64 3 20 20] [102000 34000 1700 85] FALSE 64847 29884 2.17
[8 4] [4 1] TRUE 108239 20462 5.29
[56 4] [4 1] TRUE 96959 20888 4.64
[131 4] [4 1] TRUE 101678 26186 3.88
[10000] [1] TRUE 108911 25920 4.20
[200 50] [50 1] TRUE 122207 26417 4.63
[20 50 10] [500 10 1] TRUE 98911 23004 4.30
[4 25 4 25] [2500 100 25 1] TRUE 111774 30986 3.61
[12 3 4 5 6] [360 120 30 6 1] TRUE 95423 25564 3.73
[10000] [3] FALSE 111503 23128 4.82
[200 50] [1 200] FALSE 110559 23342 4.74
[200 50] [505 1] FALSE 98271 28551 3.44
[20 50 10] [1 20 1000] FALSE 104975 22933 4.58
[20 50 10] [7575 15 1] FALSE 100159 23893 4.19
[4 25 4 25] [1 16 4 400] FALSE 106479 26168 4.07
[4 25 4 25] [5859 217 31 1] FALSE 84959 28106 3.02
[12 3 4 5 6] [360 120 6 24 1] FALSE 99134 24568 4.04
[12 3 4 5 6] [5760 960 120 12 1] FALSE 92783 32141 2.89
fp16 backward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 341916 250932 1.36
[32 80 870] [69600 1 80] FALSE 201044 72781 2.76
[32 80 870] [69600 870 1] TRUE 102250 65830 1.55
[4 182403 91] [16598673 91 1] TRUE 2344090 1837820 1.28
[1534680] [1] TRUE 72333 46844 1.54
[16 1 512 512] [262144 262144 512 1] TRUE 174856 120053 1.46
[2 3 160 160] [6528000 2176000 13600 85] FALSE 41646 18595 2.24
[2 3 80 80] [1632000 544000 6800 85] FALSE 30831 10666 2.89
[32756 80] [85 1] FALSE 125979 76852 1.64
[64 3 80 80] [1632000 544000 6800 85] FALSE 140906 96675 1.46
[64 3 40 40] [408000 136000 3400 85] FALSE 54558 29848 1.83
[22311 80] [85 1] FALSE 87149 53990 1.61
[64 3 20 20] [102000 34000 1700 85] FALSE 34863 13546 2.57
[8 4] [4 1] TRUE 25327 12391 2.04
[56 4] [4 1] TRUE 24303 13208 1.84
[131 4] [4 1] TRUE 27407 10560 2.60
[10000] [1] TRUE 28512 8444 3.38
[200 50] [50 1] TRUE 27120 8675 3.13
[20 50 10] [500 10 1] TRUE 27184 8533 3.19
[4 25 4 25] [2500 100 25 1] TRUE 27823 9511 2.93
[12 3 4 5 6] [360 120 30 6 1] TRUE 28655 9173 3.12
[10000] [3] FALSE 28463 8640 3.29
[200 50] [1 200] FALSE 36799 9262 3.97
[200 50] [505 1] FALSE 31359 8408 3.73
[20 50 10] [1 20 1000] FALSE 36351 8817 4.12
[20 50 10] [7575 15 1] FALSE 31007 8657 3.58
[4 25 4 25] [1 16 4 400] FALSE 36143 9102 3.97
[4 25 4 25] [5859 217 31 1] FALSE 29152 8942 3.26
[12 3 4 5 6] [360 120 6 24 1] FALSE 37519 9404 3.99
[12 3 4 5 6] [5760 960 120 12 1] FALSE 33167 9724 3.41
bfp16 forward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 315541 302344 1.04
[32 80 870] [69600 1 80] FALSE 147134 104071 1.41
[32 80 870] [69600 870 1] TRUE 111966 93119 1.20
[4 182403 91] [16598673 91 1] TRUE 2060933 2116840 0.97
[1534680] [1] TRUE 104399 69652 1.50
[16 1 512 512] [262144 262144 512 1] TRUE 170487 152390 1.12
[2 3 160 160] [6528000 2176000 13600 85] FALSE 76767 35786 2.15
[2 3 80 80] [1632000 544000 6800 85] FALSE 97528 24871 3.92
[32756 80] [85 1] FALSE 128750 103768 1.24
[64 3 80 80] [1632000 544000 6800 85] FALSE 148734 126470 1.18
[64 3 40 40] [408000 136000 3400 85] FALSE 84079 52924 1.59
[22311 80] [85 1] FALSE 97359 77706 1.25
[64 3 20 20] [102000 34000 1700 85] FALSE 70831 34382 2.06
[8 4] [4 1] TRUE 102815 19822 5.19
[56 4] [4 1] TRUE 104142 19786 5.26
[131 4] [4 1] TRUE 121294 26364 4.60
[10000] [1] TRUE 101311 23679 4.28
[200 50] [50 1] TRUE 123118 33208 3.71
[20 50 10] [500 10 1] TRUE 99215 24159 4.11
[4 25 4 25] [2500 100 25 1] TRUE 110846 25173 4.40
[12 3 4 5 6] [360 120 30 6 1] TRUE 111215 35839 3.10
[10000] [3] FALSE 106463 25937 4.10
[200 50] [1 200] FALSE 111246 24017 4.63
[200 50] [505 1] FALSE 101679 23342 4.36
[20 50 10] [1 20 1000] FALSE 98543 23182 4.25
[20 50 10] [7575 15 1] FALSE 112414 23200 4.85
[4 25 4 25] [1 16 4 400] FALSE 81775 25920 3.15
[4 25 4 25] [5859 217 31 1] FALSE 83263 23857 3.49
[12 3 4 5 6] [360 120 6 24 1] FALSE 120350 24213 4.97
[12 3 4 5 6] [5760 960 120 12 1] FALSE 97231 25528 3.81
bfp16 backward
input_size stride_size cont ROCm MIOpen Improvement
[256 4 8732] [34928 8732 1] TRUE 386314 255394 1.51
[32 80 870] [69600 1 80] FALSE 209828 73315 2.86
[32 80 870] [69600 870 1] TRUE 113049 67448 1.68
[4 182403 91] [16598673 91 1] TRUE 2712646 1870170 1.45
[1534680] [1] TRUE 80653 47750 1.69
[16 1 512 512] [262144 262144 512 1] TRUE 196600 122062 1.61
[2 3 160 160] [6528000 2176000 13600 85] FALSE 44174 18808 2.35
[2 3 80 80] [1632000 544000 6800 85] FALSE 31758 10791 2.94
[32756 80] [85 1] FALSE 137530 78310 1.76
[64 3 80 80] [1632000 544000 6800 85] FALSE 150010 96212 1.56
[64 3 40 40] [408000 136000 3400 85] FALSE 58142 29777 1.95
[22311 80] [85 1] FALSE 99308 54790 1.81
[64 3 20 20] [102000 34000 1700 85] FALSE 37006 13671 2.71
[8 4] [4 1] TRUE 28847 12355 2.33
[56 4] [4 1] TRUE 25599 12853 1.99
[131 4] [4 1] TRUE 31311 11715 2.67
[10000] [1] TRUE 29471 8604 3.43
[200 50] [50 1] TRUE 29776 8408 3.54
[20 50 10] [500 10 1] TRUE 29488 8586 3.43
[4 25 4 25] [2500 100 25 1] TRUE 29567 8711 3.39
[12 3 4 5 6] [360 120 30 6 1] TRUE 31199 9706 3.21
[10000] [3] FALSE 29391 8640 3.40
[200 50] [1 200] FALSE 39599 9226 4.29
[200 50] [505 1] FALSE 33711 8640 3.90
[20 50 10] [1 20 1000] FALSE 38655 9013 4.29
[20 50 10] [7575 15 1] FALSE 33711 8515 3.96
[4 25 4 25] [1 16 4 400] FALSE 38272 9208 4.16
[4 25 4 25] [5859 217 31 1] FALSE 30335 8888 3.41
[12 3 4 5 6] [360 120 6 24 1] FALSE 41151 9777 4.21
[12 3 4 5 6] [5760 960 120 12 1] FALSE 36864 9671 3.81

@littlecutebird
Copy link
Collaborator Author

@CAHEK7 this PR has a very similar code structure to this MultiMarginLoss PR you have reviewed before.

{
// If reduction = None, O is DTYPE*
case 0: static_cast<DTYPE*>(O)[O_tv.get_tensor_view_idx(idx)] = CVT_ACCUM2FLOAT(loss); break;
// If reduction != Sum, O is FLOAT_ACCUM* and then all elements will be sum up in the next
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cmt should be If reduction = Sum, and the same with case 2.

Copy link
Contributor

@iq136boy iq136boy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is blocked by the following errors caused by the CI issue we recently have. After the CI issue is resolved, this PR need to restart to run through the CI.

[2024-09-24T07:18:43.885Z] Exception occurred: org.kohsuke.github.HttpException: {"message":"API rate limit exceeded for user ID 49319081. If you reach out to GitHub Support for help, please include the request ID

@junliume
Copy link
Collaborator

@littlecutebird the PR has some build issues, could you follow up with the fix? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants