Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add torchbench for Distributed Shampoo Optimizer v2 (pytorch#2616)
Summary: Pull Request resolved: pytorch#2616 - There is no optimizer that has been integrated into TorchBench. Distributed Shampoo is quite complicate, and has a direct dependency on Pytorch. This creates a need to add it to torchbench to guardrail it from Pytorch 2.0 changes. - This diff is to realize this feature, and particularly to enable Distributed Shampoo on Torchbench in Eager mode. I will create a follow up diff to add py2 compile feature. - For the current design of integration: -- Pick Ads DHEN CMF 5x model, since CMF is a major MC model -- choose optimizer stage alone benchmarking, rather than a full e2e benchmarking. This is because the computation of optimizer step itself is relatively ligher than fwd and bwd; and picking the e2e would make the optimizer step stage benchmarking results being shadowed by other stages(fwd, bwd) and make the benchmarking result not sensitive -- build on top of originall ads_dhen_5x pipeline, and skip the fwd and bwd stage, and also set up the Shampoo config inside the Model __init__ stage -- For Distributed Shampoo, there is a matrix root inverse computation, and in production, this is decided by precondition_frequency and its presence is trivial in the overall computation. And here for torchbench, we also skip it: by add the iteration count to bypass first root inverse compute. I.e.: Inside _prepare_before_optimizer func. -- Eventually the torchbench would do the following: 1. initialize the ads_dhen_cmf 5x model on a local gpu, preload the data, and do fwd and bwd; 2. change some state variable of Shampoo(iteration step for preconditioning etc), and get the optimizer ready; 3. benchmarking the optimizer with torchbench pipeline, and return the results back 05/16: - update the diff given the Shampoo v2 impl Reviewed By: xuzhao9 Differential Revision: D51192560 fbshipit-source-id: 247dceec1587a837aa9ca128252c47e9e0cf42b7
- Loading branch information