Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MultiMarginLoss forward #3166

Open
wants to merge 45 commits into
base: develop
Choose a base branch
from

Conversation

littlecutebird
Copy link
Collaborator

@littlecutebird littlecutebird commented Jul 29, 2024

  • Add MultiMarginLoss forward operation and kernel. Backward is not better compared to ROCm in general.
  • Given input tensor is (N,C), MIOpen is better if C is small enough. How small is depend on tensor data type and tensor is contiguous or not. Please check IsImprovementOverROCm function in src/solver/multimarginloss/forward_reduced_multimarginloss.cpp to see the exact number.
  • New API is guarded by MIOPEN_BETA_API macro.
  • Added driver test and gtest for MultiMarginLoss.
  • Compared to ROCm:

Unreduced:

type Forward
float32 7.26
float16 15.00
bfloat16 15.55
fp32
input_size num class cont ROCm MIOpen Improvement
1234567 18 uncont 2736031 1414000 1.93
1234567 18 cont 2483891 864593 2.87
845729 13 uncont 1819759 595572 3.06
845729 13 cont 1681666 231331 7.27
1974532 24 uncont 4458704 2948920 1.51
1974532 24 cont 4000088 2408740 1.66
763492 6 uncont 1588180 99931 15.89
763492 6 cont 1516405 50579 29.98
1293850 20 uncont 2887276 1645220 1.75
1293850 20 cont 2607842 1193670 2.18
987654 7 uncont 2050747 170387 12.04
987654 7 cont 1957373 81139 24.12
345678 10 uncont 745427 155471 4.79
345678 10 cont 693732 57797 12.00
1654783 21 uncont 3721822 2198620 1.69
1654783 21 cont 3342149 1672970 2.00
234567 9 uncont 510567 85388 5.98
234567 9 cont 473656 37512 12.63
1892345 12 uncont 3999049 1161310 3.44
1892345 12 cont 3732094 430144 8.68
574839 15 uncont 1266874 505133 2.51
574839 15 cont 1148460 232219 4.95
1495832 25 uncont 3432324 2366530 1.45
1495832 25 cont 3035611 1970630 1.54
934750 8 uncont 1943294 225179 8.63
934750 8 cont 1841760 111452 16.53
847293 22 uncont 1933358 1147350 1.69
847293 22 cont 1723602 902120 1.91
1639204 19 uncont 3641152 1995590 1.82
1639204 19 cont 3296039 1329190 2.48
215678 14 uncont 486488 157266 3.09
215678 14 cont 451624 81797 5.52
1274835 11 uncont 2687633 688444 3.90
1274835 11 cont 2518452 234317 10.75
1346789 5 uncont 2752544 115274 23.88
1346789 5 cont 2644498 62348 42.42
1765432 17 uncont 3891244 1926250 2.02
1765432 17 cont 3532131 1045040 3.38
263748 16 uncont 583974 263651 2.21
263748 16 cont 536647 145835 3.68
1498765 23 uncont 3401829 2145170 1.59
1498765 23 cont 3034908 1736040 1.75
1728394 18 uncont 3813535 1997810 1.91
1728394 18 cont 3472084 1224080 2.84
459283 20 uncont 1043054 580671 1.80
459283 20 cont 934768 391938 2.38
1583642 7 uncont 3274808 266466 12.29
1583642 7 cont 3113979 124248 25.06
983472 13 uncont 2114547 697656 3.03
983472 13 cont 1954917 266058 7.35
1862345 24 uncont 4195471 2796450 1.50
1862345 24 cont 3790247 2269900 1.67
712345 6 uncont 1478424 93231 15.86
712345 6 cont 1412536 47922 29.48
1456789 15 uncont 3158216 1332710 2.37
1456789 15 cont 2898056 565862 5.12
298376 10 uncont 654681 126881 5.16
298376 10 cont 600040 54179 11.08
1872345 21 uncont 4204781 2499140 1.68
1872345 21 cont 3782457 1895210 2.00
fp16
input_size num class cont ROCm MIOpen Improvement
1234567 18 uncont 2770542 867544 3.19
1234567 18 cont 2589537 280683 9.23
845729 13 uncont 1885278 235029 8.02
845729 13 cont 1782992 103576 17.21
1974532 24 uncont 4522927 2413650 1.87
1974532 24 cont 4147142 885873 4.68
763492 6 uncont 1650466 52854 31.23
763492 6 cont 1592787 30187 52.76
1293850 20 uncont 2921260 1197670 2.44
1293850 20 cont 2711519 347849 7.80
987654 7 uncont 2135050 82917 25.75
987654 7 cont 2058300 42348 48.60
345678 10 uncont 756323 59201 12.78
345678 10 cont 727251 31201 23.31
1654783 21 uncont 3748893 1676220 2.24
1654783 21 cont 3471346 528511 6.57
234567 9 uncont 521591 38969 13.38
234567 9 cont 495352 22382 22.13
1892345 12 uncont 4175606 434607 9.61
1892345 12 cont 3958634 172716 22.92
574839 15 uncont 1287945 235313 5.47
574839 15 cont 1214155 92553 13.12
1495832 25 uncont 3432836 1990650 1.72
1495832 25 cont 3151945 731859 4.31
934750 8 uncont 2018860 112145 18.00
934750 8 cont 1944830 67859 28.66
847293 22 uncont 1950558 908040 2.15
847293 22 cont 1783041 305892 5.83
1639204 19 uncont 3685872 1346360 2.74
1639204 19 cont 3436612 418375 8.21
215678 14 uncont 493448 82579 5.98
215678 14 cont 463240 37547 12.34
1274835 11 uncont 2807439 240308 11.68
1274835 11 cont 2668594 103327 25.83
1346789 5 uncont 2884926 66010 43.70
1346789 5 cont 2804704 40303 69.59
1765432 17 uncont 3940028 1051990 3.75
1765432 17 cont 3700896 347048 10.66
263748 16 uncont 595366 149711 3.98
263748 16 cont 559302 68535 8.16
1498765 23 uncont 3427829 1748010 1.96
1498765 23 cont 3149082 601632 5.23
1728394 18 uncont 3889197 1231500 3.16
1728394 18 cont 3618386 389307 9.29
459283 20 uncont 1056206 396311 2.67
459283 20 cont 971055 134030 7.25
1583642 7 uncont 3412214 124603 27.38
1583642 7 cont 3295832 61430 53.65
983472 13 uncont 2179909 271071 8.04
983472 13 cont 2065448 115965 17.81
1862345 24 uncont 4255988 2268800 1.88
1862345 24 cont 3912746 839845 4.66
712345 6 uncont 1548953 50393 30.74
712345 6 cont 1482665 28529 51.97
1456789 15 uncont 3237951 574324 5.64
1456789 15 cont 3057486 219740 13.91
298376 10 uncont 654425 55033 11.89
298376 10 cont 629128 28369 22.18
1872345 21 uncont 4235705 1905840 2.22
1872345 21 cont 3923990 594220 6.60
bfp16
input_size num class cont ROCm MIOpen Improvement
1234567 18 uncont 2867228 869197 3.30
1234567 18 cont 2684832 279723 9.60
845729 13 uncont 1948669 235402 8.28
845729 13 cont 1840799 103470 17.79
1974532 24 uncont 4648364 2424210 1.92
1974532 24 cont 4301426 881268 4.88
763492 6 uncont 1717569 52854 32.50
763492 6 cont 1655650 29849 55.47
1293850 20 uncont 3018138 1200160 2.51
1293850 20 cont 2813822 346160 8.13
987654 7 uncont 2210089 82455 26.80
987654 7 cont 2131146 42152 50.56
345678 10 uncont 781282 58988 13.24
345678 10 cont 752674 31343 24.01
1654783 21 uncont 3869212 1678700 2.30
1654783 21 cont 3595072 524672 6.85
234567 9 uncont 530918 39538 13.43
234567 9 cont 512471 21938 23.36
1892345 12 uncont 4305588 435158 9.89
1892345 12 cont 4101240 171738 23.88
574839 15 uncont 1332040 235188 5.66
574839 15 cont 1257514 92055 13.66
1495832 25 uncont 3554850 1989120 1.79
1495832 25 cont 3268662 729565 4.48
934750 8 uncont 2087515 112092 18.62
934750 8 cont 2013965 67148 29.99
847293 22 uncont 2001453 908360 2.20
847293 22 cont 1851424 303759 6.10
1639204 19 uncont 3813230 1343980 2.84
1639204 19 cont 3559714 416739 8.54
215678 14 uncont 502679 82793 6.07
215678 14 cont 475848 37387 12.73
1274835 11 uncont 2896046 239881 12.07
1274835 11 cont 2762464 102811 26.87
1346789 5 uncont 2983292 66952 44.56
1346789 5 cont 2900686 40178 72.20
1765432 17 uncont 4078969 1053790 3.87
1765432 17 cont 3834782 345804 11.09
263748 16 uncont 616565 149319 4.13
263748 16 cont 578374 69086 8.37
1498765 23 uncont 3532211 1741270 2.03
1498765 23 cont 3268280 598450 5.46
1728394 18 uncont 3999691 1226660 3.26
1728394 18 cont 3757055 386942 9.71
459283 20 uncont 1082045 396507 2.73
459283 20 cont 1004479 133508 7.52
1583642 7 uncont 3539592 124710 28.38
1583642 7 cont 3414588 61733 55.31
983472 13 uncont 2257385 270485 8.35
983472 13 cont 2137531 115930 18.44
1862345 24 uncont 4394361 2272480 1.93
1862345 24 cont 4056669 833342 4.87
712345 6 uncont 1604379 50091 32.03
712345 6 cont 1536123 28831 53.28
1456789 15 uncont 3349318 574432 5.83
1456789 15 cont 3166837 218567 14.49
298376 10 uncont 678089 55033 12.32
298376 10 cont 653993 28405 23.02
1872345 21 uncont 4372007 1903840 2.30
1872345 21 cont 4071267 591590 6.88

Reduced:

type Forward
float32 6.01
float16 11.03
bfloat16 11.27
fp32
input_size num class cont ROCm MIOpen Improvement
1234567 18 uncont 2756526 1443450 1.91
1234567 18 cont 2505107 889518 2.82
845729 13 uncont 1842239 620141 2.97
845729 13 cont 1703170 254637 6.69
1974532 24 uncont 4508367 2983960 1.51
1974532 24 cont 5639611 2444760 2.31
763492 6 uncont 1604643 124501 12.89
763492 6 cont 1527509 73922 20.66
1293850 20 uncont 2908540 1676670 1.73
1293850 20 cont 2632593 1222590 2.15
987654 7 uncont 2076539 195578 10.62
987654 7 cont 1978717 106527 18.57
345678 10 uncont 762306 174937 4.36
345678 10 cont 715635 78011 9.17
1654783 21 uncont 3749277 2231310 1.68
1654783 21 cont 3366740 1706730 1.97
234567 9 uncont 523559 104394 5.02
234567 9 cont 494791 58401 8.47
1892345 12 uncont 4024937 1192300 3.38
1892345 12 cont 3757646 464918 8.08
574839 15 uncont 1288873 526680 2.45
574839 15 cont 1169340 255456 4.58
1495832 25 uncont 3451571 2394160 1.44
1495832 25 cont 3064090 2004670 1.53
934750 8 uncont 1967774 249855 7.88
934750 8 cont 1867775 136021 13.73
847293 22 uncont 1953470 1171270 1.67
847293 22 cont 1748225 925410 1.89
1639204 19 uncont 4908330 2027870 2.42
1639204 19 cont 3318390 1365700 2.43
215678 14 uncont 512855 175969 2.91
215678 14 cont 462248 100838 4.58
1274835 11 uncont 2716928 717333 3.79
1274835 11 cont 2543524 262246 9.70
1346789 5 uncont 2782336 143542 19.38
1346789 5 cont 2674498 92376 28.95
1765432 17 uncont 3913532 1954870 2.00
1765432 17 cont 3556162 1074200 3.31
263748 16 uncont 607126 283882 2.14
263748 16 cont 554662 166386 3.33
1498765 23 uncont 3426581 2175950 1.57
1498765 23 cont 3060379 1770620 1.73
1728394 18 uncont 3847342 2027710 1.90
1728394 18 cont 3492244 1251890 2.79
459283 20 uncont 1061486 600619 1.77
459283 20 cont 959712 411708 2.33
1583642 7 uncont 3293495 295724 11.14
1583642 7 cont 3140634 154803 20.29
983472 13 uncont 2136548 723804 2.95
983472 13 cont 1971223 290375 6.79
1862345 24 uncont 4223018 2832500 1.49
1862345 24 cont 3799152 2296010 1.65
712345 6 uncont 1506617 116553 12.93
712345 6 cont 1434248 70870 20.24
1456789 15 uncont 3181691 1360640 2.34
1456789 15 cont 2920026 593930 4.92
298376 10 uncont 674761 145990 4.62
298376 10 cont 620520 73519 8.44
1872345 21 uncont 4232803 2532360 1.67
1872345 21 cont 3804127 1926750 1.97
fp16
input_size num class cont ROCm MIOpen Improvement
1234567 18 uncont 2803902 910034 3.08
1234567 18 cont 2616913 311138 8.41
845729 13 uncont 1910797 264114 7.23
845729 13 cont 1799488 129675 13.88
1974532 24 uncont 4523518 2466790 1.83
1974532 24 cont 4176581 918567 4.55
763492 6 uncont 1677234 77993 21.50
763492 6 cont 1616275 55788 28.97
1293850 20 uncont 2944891 1240300 2.37
1293850 20 cont 2737663 378997 7.22
987654 7 uncont 2154537 110190 19.55
987654 7 cont 2093962 70099 29.87
345678 10 uncont 1600659 79806 20.06
345678 10 cont 747826 55628 13.44
1654783 21 uncont 3765949 1726430 2.18
1654783 21 cont 3491234 561241 6.22
234567 9 uncont 544966 58952 9.24
234567 9 cont 519223 42098 12.33
1892345 12 uncont 4203350 479106 8.77
1892345 12 cont 3983866 208112 19.14
574839 15 uncont 1311177 262922 4.99
574839 15 cont 1233226 116447 10.59
1495832 25 uncont 3463507 2032090 1.70
1495832 25 cont 3179048 771308 4.12
934750 8 uncont 2040972 138421 14.74
934750 8 cont 1968478 93264 21.11
847293 22 uncont 1964830 937926 2.09
847293 22 cont 1807104 331013 5.46
1639204 19 uncont 3707055 1389720 2.67
1639204 19 cont 3461204 452313 7.65
215678 14 uncont 519975 102687 5.06
215678 14 cont 478888 57530 8.32
1274835 11 uncont 2829615 270300 10.47
1274835 11 cont 2693041 134599 20.01
1346789 5 uncont 2908573 98189 29.62
1346789 5 cont 2828319 71503 39.56
1765432 17 uncont 3969851 1102840 3.60
1765432 17 cont 3724528 380951 9.78
263748 16 uncont 618869 170440 3.63
263748 16 cont 581542 87326 6.66
1498765 23 uncont 3450501 1789890 1.93
1498765 23 cont 3176249 632602 5.02
1728394 18 uncont 3895085 1271530 3.06
1728394 18 cont 3645953 421788 8.64
459283 20 uncont 1066110 422481 2.52
459283 20 cont 997359 155853 6.40
1583642 7 uncont 3441654 156652 21.97
1583642 7 cont 3322346 94705 35.08
983472 13 uncont 2201607 301698 7.30
983472 13 cont 2091002 142877 14.63
1862345 24 uncont 4281854 2316950 1.85
1862345 24 cont 3945155 872784 4.52
712345 6 uncont 1575546 74798 21.06
712345 6 cont 1505194 53361 28.21
1456789 15 uncont 3261474 616275 5.29
1456789 15 cont 3079921 254829 12.09
298376 10 uncont 682713 74977 9.11
298376 10 cont 649161 50393 12.88
1872345 21 uncont 4257535 1948970 2.18
1872345 21 cont 3945419 632118 6.24
bfp16
input_size num class cont ROCm MIOpen Improvement
1234567 18 uncont 2896940 904132 3.20
1234567 18 cont 3536000 310764 11.38
845729 13 uncont 1974732 263599 7.49
845729 13 cont 1862318 129958 14.33
1974532 24 uncont 4684028 2470090 1.90
1974532 24 cont 4335346 907260 4.78
763492 6 uncont 1728321 78028 22.15
763492 6 cont 1667858 56410 29.57
1293850 20 uncont 3048665 1240510 2.46
1293850 20 cont 2836814 377877 7.51
987654 7 uncont 2233544 110883 20.14
987654 7 cont 2154634 69797 30.87
345678 10 uncont 804562 79859 10.07
345678 10 cont 775618 53174 14.59
1654783 21 uncont 3895259 1724530 2.26
1654783 21 cont 3619216 556655 6.50
234567 9 uncont 556422 58810 9.46
234567 9 cont 534103 41938 12.74
1892345 12 uncont 4329507 475603 9.10
1892345 12 cont 4129607 207632 19.89
574839 15 uncont 1353752 261571 5.18
574839 15 cont 1280041 116252 11.01
1495832 25 uncont 3573377 2033420 1.76
1495832 25 cont 3297878 769246 4.29
934750 8 uncont 2110139 138403 15.25
934750 8 cont 2039388 93104 21.90
847293 22 uncont 2035949 939988 2.17
847293 22 cont 1869919 330053 5.67
1639204 19 uncont 3835373 1392320 2.75
1639204 19 cont 3586082 452527 7.92
215678 14 uncont 525175 103238 5.09
215678 14 cont 497559 57708 8.62
1274835 11 uncont 2919757 269980 10.81
1274835 11 cont 2785568 133621 20.85
1346789 5 uncont 3017932 98384 30.68
1346789 5 cont 2925197 70970 41.22
1765432 17 uncont 4109449 1105090 3.72
1765432 17 cont 3863501 382801 10.09
263748 16 uncont 639525 167897 3.81
263748 16 cont 601670 87326 6.89
1498765 23 uncont 3562003 1791140 1.99
1498765 23 cont 3289848 630042 5.22
1728394 18 uncont 4028171 1272920 3.16
1728394 18 cont 3780239 419850 9.00
459283 20 uncont 1105581 423828 2.61
459283 20 cont 1026015 154803 6.63
1583642 7 uncont 3558443 156759 22.70
1583642 7 cont 3438494 94457 36.40
983472 13 uncont 2275354 301467 7.55
983472 13 cont 2164461 143215 15.11
1862345 24 uncont 4421699 2320580 1.91
1862345 24 cont 4086630 865622 4.72
712345 6 uncont 1618908 81304 19.91
712345 6 cont 1562475 54250 28.80
1456789 15 uncont 3370569 615778 5.47
1456789 15 cont 3189081 250599 12.73
298376 10 uncont 712394 75386 9.45
298376 10 cont 672761 48847 13.77
1872345 21 uncont 4410621 1950660 2.26
1872345 21 cont 4093784 629239 6.51

@littlecutebird
Copy link
Collaborator Author

Please check this comment #3146 (comment) @CAHEK7 . I have fixed all things like your recommendation.

@littlecutebird littlecutebird marked this pull request as ready for review July 31, 2024 06:42
@littlecutebird littlecutebird requested a review from a team as a code owner August 21, 2024 09:11
@littlecutebird
Copy link
Collaborator Author

@CAHEK7 I have updated the code, which included your recommendation:

  1. Optimize global memory access in kernel, which improve performance. This lead to changes in benchmark result in PR description and condition in isImprovementOverROCm (basically C can be a bit larger).
  2. Merge duplicated code in gtest, driver, src, kernel.
  3. avoid CVT_ACCUM2FLOAT after each kernel, avoid using divisor in problem_description

@littlecutebird littlecutebird changed the title Implement MultiMarginLoss Implement MultiMarginLoss forward Aug 27, 2024
@littlecutebird
Copy link
Collaborator Author

@CAHEK7 I also merged solvers to 1 file to avoid duplicated code, please check if it satisfied your recommendation. Thank you.

@atamazov
Copy link
Contributor

[off-topic] @littlecutebird Can you please provide email address at your github profile page? Thanks.

@littlecutebird
Copy link
Collaborator Author

[off-topic] @littlecutebird Can you please provide email address at your github profile page? Thanks.

I have updated it, please check

Copy link
Contributor

@iq136boy iq136boy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocked by CI failure. Need to restart after CI issue was fixed.

[2024-09-23T13:54:42.218Z] Exception occurred: org.kohsuke.github.HttpException: {"message":"API rate limit exceeded for user ID 49319081. If you reach out to GitHub Support for help, please include the request ID

@junliume
Copy link
Collaborator

@littlecutebird there are some CI issues and need code clean up (resolve conflict etc). Thanks!

[2024-09-27T04:26:25.937Z] /home/jenkins/workspace/S_MIOpen_nl-impl_MultiMarginLoss/src/solver/multimarginloss/forward_multimarginloss.cpp:95:29: error: use of undeclared identifier 'AlignUp'

[2024-09-27T04:26:25.937Z]    95 |         size_t xgridsize  = AlignUp(xgrid, xlocalsize);

[2024-09-27T04:26:25.937Z]       |                             ^

[2024-09-27T04:26:25.937Z] /home/jenkins/workspace/S_MIOpen_nl-impl_MultiMarginLoss/src/solver/multimarginloss/forward_multimarginloss.cpp:142:33: error: use of undeclared identifier 'AlignUp'

[2024-09-27T04:26:25.937Z]   142 |             size_t xgridsize  = AlignUp(_size, xlocalsize);

[2024-09-27T04:26:25.937Z]       |                                 ^

[2024-09-27T04:26:25.937Z] /home/jenkins/workspace/S_MIOpen_nl-impl_MultiMarginLoss/src/solver/multimarginloss/forward_multimarginloss.cpp:168:29: error: use of undeclared identifier 'AlignUp'

[2024-09-27T04:26:25.937Z]   168 |         size_t xgridsize  = AlignUp(_size, xlocalsize);

[2024-09-27T04:26:25.937Z]       |                             ^

[2024-09-27T04:26:25.937Z] /home/jenkins/workspace/S_MIOpen_nl-impl_MultiMarginLoss/src/solver/multimarginloss/forward_multimarginloss.cpp:260:35: error: use of undeclared identifier 'MultiBufferWorkspaceTraits'

[2024-09-27T04:26:25.937Z]   260 |                 auto wt         = MultiBufferWorkspaceTraits{size * data_size,

[2024-09-27T04:26:25.937Z]       |                                   ^

[2024-09-27T04:26:25.937Z] /home/jenkins/workspace/S_MIOpen_nl-impl_MultiMarginLoss/src/solver/multimarginloss/forward_multimarginloss.cpp:260:61: error: expected ';' at end of declaration

[2024-09-27T04:26:25.937Z]   260 |                 auto wt         = MultiBufferWorkspaceTraits{size * data_size,

[2024-09-27T04:26:25.937Z]       |                                                             ^

[2024-09-27T04:26:25.937Z]       |                                                             ;

[2024-09-27T04:26:25.937Z] /home/jenkins/workspace/S_MIOpen_nl-impl_MultiMarginLoss/src/solver/multimarginloss/forward_multimarginloss.cpp:310:12: error: use of undeclared identifier 'MultiBufferWorkspaceTraits'

[2024-09-27T04:26:25.937Z]   310 |     return MultiBufferWorkspaceTraits{

[2024-09-27T04:26:25.937Z]       |            ^

[2024-09-27T04:26:25.937Z] /home/jenkins/workspace/S_MIOpen_nl-impl_MultiMarginLoss/src/solver/multimarginloss/forward_multimarginloss.cpp:310:38: error: expected ';' after return statement

[2024-09-27T04:26:25.937Z]   310 |     return MultiBufferWorkspaceTraits{

[2024-09-27T04:26:25.937Z]       |                                      ^

[2024-09-27T04:26:25.937Z]       |                                      ;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants