Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added functorch to functional_autograd_benchmark
Description: - Following pytorch/functorch#497 adding an option to run benchmarks with functorch and compare to original functional autograd results. Running the benchmark we get below table: <details> <summary> Table </summary> ``` | model | task | mean | var | | -- | -- | -- | -- | | resnet18 | vjp | 0.03826599195599556 | 4.3332115637895186e-06 | | resnet18 | functorch vjp | 0.037201929837465286 | 6.139693198292662e-09 | | resnet18 | vhp | 0.2202976644039154 | 2.8687209052691287e-08 | | resnet18 | functorch vhp | 0.22117868065834045 | 4.108771278765744e-08 | | resnet18 | jvp | 0.18679651618003845 | 1.832455254202614e-08 | | resnet18 | functorch jvp | 0.05305683612823486 | 1.6690266946284282e-08 | | fcn_resnet | vjp | 0.6071907877922058 | 7.436695454998699e-07 | | fcn_resnet | functorch vjp | 0.6115708947181702 | 1.121692207561864e-06 | | fcn_resnet | vhp | 3.419469118118286 | 0.020633839070796967 | | fcn_resnet | jvp | 2.5421929359436035 | 3.1765587209520163e-06 | | fcn_resnet | functorch jvp | 0.7628333568572998 | 1.4555752159139956e-07 | | detr | vjp | 0.19494840502738953 | 1.9122715457342565e-05 | | detr | vhp | 1.1664292812347412 | 0.000948643428273499 | | detr | jvp | 0.9990308880805969 | 1.0214127541985363e-05 | | ppl_simple_reg | vjp | 0.0007535457843914628 | 6.024204690646684e-09 | | ppl_simple_reg | functorch vjp | 0.0016954183811321855 | 1.160151974488599e-08 | | ppl_simple_reg | vhp | 0.0011888503795489669 | 5.93119386937957e-10 | | ppl_simple_reg | functorch vhp | 0.0026826143730431795 | 1.6787025103326414e-08 | | ppl_simple_reg | jvp | 0.001067900680936873 | 7.409912128331086e-10 | | ppl_simple_reg | functorch jvp | 0.002065300941467285 | 9.710328185974504e-08 | | ppl_simple_reg | hvp | 0.001212477684020996 | 1.974137298077494e-09 | | ppl_simple_reg | functorch hvp | 0.00482442369684577 | 2.327668653379078e-07 | | ppl_simple_reg | jacobian | 0.0009108781814575195 | 3.489469158068914e-09 | | ppl_simple_reg | functorch jacobian | 0.0019866942893713713 | 1.938326299466553e-08 | | ppl_simple_reg | hessian | 0.005053090862929821 | 3.370298600202659e-07 | | ppl_simple_reg | functorch hessian | 0.006374978926032782 | 7.556796077778927e-08 | | ppl_simple_reg | hessian_fwdrev | 0.0036706924438476562 | 1.996075527088692e-09 | | ppl_simple_reg | functorch hessian_fwdrev | 0.0058908225037157536 | 7.548283775804521e-08 | | ppl_simple_reg | hessian_revrev | 0.0015769004821777344 | 1.5754418214442012e-08 | | ppl_simple_reg | functorch hessian_revrev | 0.0041002752259373665 | 6.713568723171193e-08 | | ppl_simple_reg | jacfwd | 0.0018048763740807772 | 2.7375660849315864e-08 | | ppl_simple_reg | functorch jacfwd | 0.002047991845756769 | 2.432247070416338e-09 | | ppl_simple_reg | jacrev | 0.0009733677143231034 | 1.0078769818733235e-08 | | ppl_simple_reg | functorch jacrev | 0.0021971464157104492 | 1.2729884701911942e-08 | | ppl_robust_reg | vjp | 0.005820560269057751 | 8.582588151284654e-08 | | ppl_robust_reg | functorch vjp | 0.00796132069081068 | 9.663100541956737e-09 | | ppl_robust_reg | vhp | 0.009825301356613636 | 2.0081762386325863e-07 | | ppl_robust_reg | functorch vhp | 0.014890861697494984 | 4.558066279969353e-07 | | ppl_robust_reg | jvp | 0.008297419175505638 | 2.9454400873873965e-07 | | ppl_robust_reg | functorch jvp | 0.008052706718444824 | 7.120377176761394e-08 | | ppl_robust_reg | hvp | 0.015414690598845482 | 7.42123745567369e-07 | | ppl_robust_reg | functorch hvp | 0.02699306048452854 | 1.4650488537881756e-06 | | ppl_robust_reg | jacobian | 0.006207776255905628 | 1.7068457225377642e-07 | | ppl_robust_reg | functorch jacobian | 0.009173822589218616 | 1.2214455580306094e-07 | | ppl_robust_reg | hessian | 0.04670915752649307 | 1.4299343092716299e-05 | | ppl_robust_reg | functorch hessian | 0.02337808534502983 | 3.0397418413485866e-06 | | ppl_robust_reg | hessian_fwdrev | 0.024229884147644043 | 2.0425247839739313e-06 | | ppl_robust_reg | functorch hessian_fwdrev | 0.022021746262907982 | 3.512146236062108e-07 | | ppl_robust_reg | hessian_revrev | 0.012355780228972435 | 7.090877147675201e-07 | | ppl_robust_reg | functorch hessian_revrev | 0.013960313983261585 | 6.326549737423193e-07 | | ppl_robust_reg | jacfwd | 0.008112502284348011 | 2.88503088086145e-08 | | ppl_robust_reg | functorch jacfwd | 0.008947920985519886 | 4.2070990247111695e-08 | | ppl_robust_reg | jacrev | 0.00635871896520257 | 1.3403841592207755e-07 | | ppl_robust_reg | functorch jacrev | 0.009123563766479492 | 2.677554675756255e-07 | | wav2letter | vjp | 0.02078995667397976 | 2.1110793113621185e-06 | | wav2letter | functorch vjp | 0.019202351570129395 | 9.210506135559626e-09 | | wav2letter | vhp | 0.05997290462255478 | 8.558587616391833e-09 | | wav2letter | functorch vhp | 0.06035261228680611 | 1.6448565842708263e-09 | | wav2letter | jvp | 0.04507789760828018 | 1.5771547401399744e-09 | | wav2letter | functorch jvp | 0.013057494536042213 | 3.804750292601966e-09 | | deepspeech | vjp | 0.3648746609687805 | 1.5359055396402255e-05 | | transformer | vjp | 0.05496881157159805 | 1.242562319703211e-08 | | transformer | functorch vjp | 0.057835936546325684 | 2.6113376350167528e-08 | | transformer | vhp | 0.18313491344451904 | 7.226336151688884e-08 | | transformer | jvp | 0.13924935460090637 | 1.6989159234981344e-07 | | multiheadattn | vjp | 0.0014708995586261153 | 3.710916729460223e-08 | | multiheadattn | functorch vjp | 0.002404856728389859 | 2.1910574687922235e-08 | | multiheadattn | vhp | 0.003382015274837613 | 5.3098595742540056e-08 | | multiheadattn | functorch vhp | 0.005340623669326305 | 5.897558708056749e-08 | | multiheadattn | jvp | 0.0027526854537427425 | 3.508620949332908e-08 | | multiheadattn | functorch jvp | 0.0022981404326856136 | 1.327894807445773e-07 | ``` </details> <details> <summary> Stdout </summary> ``` Found functorch: 0.2.0a0+386a541 Results for model resnet18 on task vjp: 0.03826599195599556s (var: 4.3332115637895186e-06) Results for model resnet18 on task vjp using Functorch: 0.037201929837465286s (var: 6.139693198292662e-09) Results for model resnet18 on task vhp: 0.2202976644039154s (var: 2.8687209052691287e-08) Results for model resnet18 on task vhp using Functorch: 0.22117868065834045s (var: 4.108771278765744e-08) Results for model resnet18 on task jvp: 0.18679651618003845s (var: 1.832455254202614e-08) Results for model resnet18 on task jvp using Functorch: 0.05305683612823486s (var: 1.6690266946284282e-08) Results for model fcn_resnet on task vjp: 0.6071907877922058s (var: 7.436695454998699e-07) Results for model fcn_resnet on task vjp using Functorch: 0.6115708947181702s (var: 1.121692207561864e-06) Results for model fcn_resnet on task vhp: 3.419469118118286s (var: 0.020633839070796967) Failed model using Functorch: fcn_resnet, task: vhp, Error message: CUDA out of memory. Tried to allocate 114.00 MiB (GPU 0; 47.46 GiB total capacity; 45.62 GiB already allocated; 5.31 MiB free; 46.02 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF Results for model fcn_resnet on task jvp: 2.5421929359436035s (var: 3.1765587209520163e-06) Results for model fcn_resnet on task jvp using Functorch: 0.7628333568572998s (var: 1.4555752159139956e-07) Results for model detr on task vjp: 0.19494840502738953s (var: 1.9122715457342565e-05) Failed model using Functorch: detr, task: vjp, Error message: Cannot access data pointer of Tensor that doesn't have storage Results for model detr on task vhp: 1.1664292812347412s (var: 0.000948643428273499) Failed model using Functorch: detr, task: vhp, Error message: Cannot access data pointer of Tensor that doesn't have storage Results for model detr on task jvp: 0.9990308880805969s (var: 1.0214127541985363e-05) Failed model using Functorch: detr, task: jvp, Error message: Trying to use forward AD with _cdist_forward that does not support it because it has not been implemented yet. Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation. Results for model ppl_simple_reg on task vjp: 0.0007535457843914628s (var: 6.024204690646684e-09) Results for model ppl_simple_reg on task vjp using Functorch: 0.0016954183811321855s (var: 1.160151974488599e-08) Results for model ppl_simple_reg on task vhp: 0.0011888503795489669s (var: 5.93119386937957e-10) Results for model ppl_simple_reg on task vhp using Functorch: 0.0026826143730431795s (var: 1.6787025103326414e-08) Results for model ppl_simple_reg on task jvp: 0.001067900680936873s (var: 7.409912128331086e-10) Results for model ppl_simple_reg on task jvp using Functorch: 0.002065300941467285s (var: 9.710328185974504e-08) Results for model ppl_simple_reg on task hvp: 0.001212477684020996s (var: 1.974137298077494e-09) Results for model ppl_simple_reg on task hvp using Functorch: 0.00482442369684577s (var: 2.327668653379078e-07) Results for model ppl_simple_reg on task jacobian: 0.0009108781814575195s (var: 3.489469158068914e-09) Results for model ppl_simple_reg on task jacobian using Functorch: 0.0019866942893713713s (var: 1.938326299466553e-08) Results for model ppl_simple_reg on task hessian: 0.005053090862929821s (var: 3.370298600202659e-07) Results for model ppl_simple_reg on task hessian using Functorch: 0.006374978926032782s (var: 7.556796077778927e-08) Results for model ppl_simple_reg on task hessian_fwdrev: 0.0036706924438476562s (var: 1.996075527088692e-09) Results for model ppl_simple_reg on task hessian_fwdrev using Functorch: 0.0058908225037157536s (var: 7.548283775804521e-08) Results for model ppl_simple_reg on task hessian_revrev: 0.0015769004821777344s (var: 1.5754418214442012e-08) Results for model ppl_simple_reg on task hessian_revrev using Functorch: 0.0041002752259373665s (var: 6.713568723171193e-08) Results for model ppl_simple_reg on task jacfwd: 0.0018048763740807772s (var: 2.7375660849315864e-08) Results for model ppl_simple_reg on task jacfwd using Functorch: 0.002047991845756769s (var: 2.432247070416338e-09) Results for model ppl_simple_reg on task jacrev: 0.0009733677143231034s (var: 1.0078769818733235e-08) Results for model ppl_simple_reg on task jacrev using Functorch: 0.0021971464157104492s (var: 1.2729884701911942e-08) Results for model ppl_robust_reg on task vjp: 0.005820560269057751s (var: 8.582588151284654e-08) Results for model ppl_robust_reg on task vjp using Functorch: 0.00796132069081068s (var: 9.663100541956737e-09) Results for model ppl_robust_reg on task vhp: 0.009825301356613636s (var: 2.0081762386325863e-07) Results for model ppl_robust_reg on task vhp using Functorch: 0.014890861697494984s (var: 4.558066279969353e-07) Results for model ppl_robust_reg on task jvp: 0.008297419175505638s (var: 2.9454400873873965e-07) Results for model ppl_robust_reg on task jvp using Functorch: 0.008052706718444824s (var: 7.120377176761394e-08) Results for model ppl_robust_reg on task hvp: 0.015414690598845482s (var: 7.42123745567369e-07) Results for model ppl_robust_reg on task hvp using Functorch: 0.02699306048452854s (var: 1.4650488537881756e-06) Results for model ppl_robust_reg on task jacobian: 0.006207776255905628s (var: 1.7068457225377642e-07) Results for model ppl_robust_reg on task jacobian using Functorch: 0.009173822589218616s (var: 1.2214455580306094e-07) Results for model ppl_robust_reg on task hessian: 0.04670915752649307s (var: 1.4299343092716299e-05) Results for model ppl_robust_reg on task hessian using Functorch: 0.02337808534502983s (var: 3.0397418413485866e-06) Results for model ppl_robust_reg on task hessian_fwdrev: 0.024229884147644043s (var: 2.0425247839739313e-06) Results for model ppl_robust_reg on task hessian_fwdrev using Functorch: 0.022021746262907982s (var: 3.512146236062108e-07) Results for model ppl_robust_reg on task hessian_revrev: 0.012355780228972435s (var: 7.090877147675201e-07) Results for model ppl_robust_reg on task hessian_revrev using Functorch: 0.013960313983261585s (var: 6.326549737423193e-07) Results for model ppl_robust_reg on task jacfwd: 0.008112502284348011s (var: 2.88503088086145e-08) Results for model ppl_robust_reg on task jacfwd using Functorch: 0.008947920985519886s (var: 4.2070990247111695e-08) Results for model ppl_robust_reg on task jacrev: 0.00635871896520257s (var: 1.3403841592207755e-07) Results for model ppl_robust_reg on task jacrev using Functorch: 0.009123563766479492s (var: 2.677554675756255e-07) Results for model wav2letter on task vjp: 0.02078995667397976s (var: 2.1110793113621185e-06) Results for model wav2letter on task vjp using Functorch: 0.019202351570129395s (var: 9.210506135559626e-09) Results for model wav2letter on task vhp: 0.05997290462255478s (var: 8.558587616391833e-09) Results for model wav2letter on task vhp using Functorch: 0.06035261228680611s (var: 1.6448565842708263e-09) Results for model wav2letter on task jvp: 0.04507789760828018s (var: 1.5771547401399744e-09) Results for model wav2letter on task jvp using Functorch: 0.013057494536042213s (var: 3.804750292601966e-09) Results for model deepspeech on task vjp: 0.3648746609687805s (var: 1.5359055396402255e-05) Failed model using Functorch: deepspeech, task: vjp, Error message: Cannot access storage of TensorWrapper Results for model transformer on task vjp: 0.05496881157159805s (var: 1.242562319703211e-08) Results for model transformer on task vjp using Functorch: 0.057835936546325684s (var: 2.6113376350167528e-08) Results for model transformer on task vhp: 0.18313491344451904s (var: 7.226336151688884e-08) Failed model using Functorch: transformer, task: vhp, Error message: bad optional access Results for model transformer on task jvp: 0.13924935460090637s (var: 1.6989159234981344e-07) Failed model using Functorch: transformer, task: jvp, Error message: Trying to use forward AD with embedding that does not support it because it has not been implemented yet. Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation. Results for model multiheadattn on task vjp: 0.0014708995586261153s (var: 3.710916729460223e-08) Results for model multiheadattn on task vjp using Functorch: 0.002404856728389859s (var: 2.1910574687922235e-08) Results for model multiheadattn on task vhp: 0.003382015274837613s (var: 5.3098595742540056e-08) Results for model multiheadattn on task vhp using Functorch: 0.005340623669326305s (var: 5.897558708056749e-08) Results for model multiheadattn on task jvp: 0.0027526854537427425s (var: 3.508620949332908e-08) Results for model multiheadattn on task jvp using Functorch: 0.0022981404326856136s (var: 1.327894807445773e-07) ``` </details> All functorch errors are reported in its repository. cc @zou3519 Pull Request resolved: #75689 Approved by: https://github.com/zou3519
- Loading branch information