-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vote on new features in Discussions #694
Comments
Hi developers, Firstly, thanks for the great work that can demonstrate the power of PyTorch newly released features! I just have one confusion about the usage of FSDP2 To put it more clear, in most use cases of training LLM such like Lllama2, the precision of From the profiling results, we found this approach (warpping Apart from that, there are also some other use cases: dtype of MoE gating layers is required to be So, does mixed precision within a Thanks! |
@zigzagcai RMSNorm only has activations in fp32, the weights are still bf16. |
cc: @awgu |
it should be simple but Gradient Accumulation it is very useful for sfting big models. |
gradient accumulation is not that worth it with fully shared since you need to all gather the weight at each forward anyway. Tho yeah could makes sense to still have it |
@samsja you can avoid the all-gather/reduce-scatter per microbatch with FSDP2 with hopefully intuitive APIs:
This will use extra memory since unsharded parameters and gradients are held through forward and backward (roughly equivalent to ZeRO-1). |
hmm nice I did not know that Yeah so grad acc makes sense with zero 1 but not zero 2 |
@zigzagcai sorry for the delay -- I was out last week.
This is not well-supported (at least not simply). Part of this is an API design question trading off with performance. E.g., how would you specify which parameters in the parameter group are using fp32 vs. using bf16? (Let me know if you have ideas here.)
The reason is what FSDP2's default prefetching algorithm is to only allow effectively one in-flight all-gather at a time in backward. This will lead to poor overlapping like you saw when the all-gather sizes are flipping between small and large since we cannot overlap the transformer block all-gather with just the RMSNorm backward for example. FSDP2 exposes some manual APIs to configure the prefetching that can help here. I will need to find some time to put together an example. Let me see if I can do it tomorrow or later this week. Mainly, you can use
Do you mean that the MoE router weight must be in fp32? I do want to clarify the use case somewhat (though the feature request is still valid). For the cases I have seen, using bf16 RMSNorm weight and bf16 router weight are sufficient. The computation kernel can upcast intermediates as needed, but that does not mean the weight itself needs to be in fp32. |
@aniltrkkn Thanks for the great suggestion! I also agree that gradient accumulation is quite important. I myself implemented one version. Hope it would help a bit. Line 517: https://github.com/zyushun/Adam-mini/blob/main/examples/llama/train.py |
@tianyu-l Thanks for organizing the great discussion! I have one request but I am not sure if we have it now: is there a demo code that transform the saved checkpoint into the format by Huggingface Transformers? That would be quite useful for downstream evaluation or further SFT, RL. |
There is a conversion script that should be compatible with torchtitan here: https://github.com/PrimeIntellect-ai/prime/blob/main/scripts/export_dcp.py |
@zyushun I agree such a script is desirable but missing. |
@zigzagcai do you have a repro of your setup? On 8xH100s Llama3-8B, I see some exposed collectives (e.g. norm reduce-scatter), but it should not be detrimental to training throughput:
|
add tests/support for low bit optimizers and Flash attention-3 |
Hi torchtitanists,
Thank you for your interests in torchtitan!
We created #693 for the community to add feature requests and vote on them. We'll try to prioritize on the most requested features. Please share what you'd like to see next!
The text was updated successfully, but these errors were encountered: