You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First of all, thanks for the really nice codebase! Just wanted to add a discussion on the MoE implementation as well as adding load balancing losses.
I see that you deliberately put MoEs as an example instead of in src. Obviously, it's very simple to include the features in the main modeling code and I think it's worth it to do so. In particular, the example setup requires to start the runs from within the examples folder (or wherever the MoE code is); we have a setup in our team where you start training runs outside of nanotron and thus need imports via the pip installed nanotron.
There are also some other minor fixes, for example a correct cost estimation for pipeline splitting, actually using the arguments for activation function, or the correct SwiGLU for expert parallel (last time I checked, your example code used only 2 instead of the 3 weight matrices when experts per rank = 1).
More importantly, I added load balancing losses for expert balance. In my experience, they can make an important difference especially at larger scale, for instance for GPU utilization or training stability (see e.g. the "[...] a lot of weird errors turn out to be symptoms of expert load imbalance", link).
How I implemented the losses is maybe suboptimal -- it's very similar to how Megablocks did it originally, where the balancing losses are local to each pipeline rank. They are still added to be tracked for the backward pass in the pipeline engine (commit here). But, when having multiple ranks in pipeline parallel, the logging at the last rank does not see previous ranks' losses. This means e.g. wandb logs show lower loss values. I guess the only way to log correctly would be to pass the losses through the network (just like the inputs via TensorPointers etc.).
I would be very happy to hear your thoughts and input on this, in particular the load balancing implementation. If desired, I could also open a PR and we continue discussing there :)
The text was updated successfully, but these errors were encountered:
Thanks a lot for your comment! I'm currently working on a PR #192 for this :) I've also fixed the issue of correct logging there. Waiting for more input
Hey,
First of all, thanks for the really nice codebase! Just wanted to add a discussion on the MoE implementation as well as adding load balancing losses.
I see that you deliberately put MoEs as an example instead of in src. Obviously, it's very simple to include the features in the main modeling code and I think it's worth it to do so. In particular, the example setup requires to start the runs from within the examples folder (or wherever the MoE code is); we have a setup in our team where you start training runs outside of nanotron and thus need imports via the pip installed nanotron.
I've implemented the changes for that here: https://github.com/swiss-ai/nanotron/tree/moe
There are also some other minor fixes, for example a correct cost estimation for pipeline splitting, actually using the arguments for activation function, or the correct SwiGLU for expert parallel (last time I checked, your example code used only 2 instead of the 3 weight matrices when experts per rank = 1).
More importantly, I added load balancing losses for expert balance. In my experience, they can make an important difference especially at larger scale, for instance for GPU utilization or training stability (see e.g. the "[...] a lot of weird errors turn out to be symptoms of expert load imbalance", link).
How I implemented the losses is maybe suboptimal -- it's very similar to how Megablocks did it originally, where the balancing losses are local to each pipeline rank. They are still added to be tracked for the backward pass in the pipeline engine (commit here). But, when having multiple ranks in pipeline parallel, the logging at the last rank does not see previous ranks' losses. This means e.g. wandb logs show lower loss values. I guess the only way to log correctly would be to pass the losses through the network (just like the inputs via TensorPointers etc.).
I would be very happy to hear your thoughts and input on this, in particular the load balancing implementation. If desired, I could also open a PR and we continue discussing there :)
The text was updated successfully, but these errors were encountered: