-
Notifications
You must be signed in to change notification settings - Fork 9.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[doc][c10d] fixup fsdp tutorial #1297
Conversation
✅ Deploy Preview for pytorch-examples-preview canceled.
|
looks like running python example failed? |
2cfc3b8
to
0834097
Compare
Unrelated to my change - but I fixed it anyway. Needed to update to a newer Python version in CI. See the additional diff I made to |
752efec
to
0834097
Compare
@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg, verbose=True): | |||
if rank == 0: | |||
ck = checkpoint.keys() | |||
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") | |||
|
|||
dist_cp.load_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DCP usage is pretty outdated. Should we also update them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update this in a subsequent change - if that's ok with you?
This change is already too large as I am fixing up the python
tests that broke.
The breakage is unrelated to this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
e0fba21
to
7dcd080
Compare
@fduwjj - CI is green now.
|
7dcd080
to
162b0dc
Compare
This change will be rebased on #1299 to fix the failing Python Examples. |
Summary: Fix up the FSDP tutorial to get it functional again. 1. Add missing import for load_dataset. 2. Use `checkpoint` instead of `_shard.checkpoint` to get rid of a warning. 3. Add nlp to requirements.txt 4. Get rid of `load_metric` as this function does not exist in new `datasets` module. 5. Add `legacy=False` to get rid of tokenizer warnings. Test Plan: Ran the tutorial as follows and ensured that it ran successfully: ``` torchrun --nnodes=1 --nproc_per_node=2 T5_training.py W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] ***************************************** W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] ***************************************** dict_keys(['train', 'validation', 'test']) Size of train dataset: (157252, 3) Size of Validation dataset: (5599, 3) dict_keys(['train', 'validation', 'test']) Size of train dataset: (157252, 3) Size of Validation dataset: (5599, 3) bFloat16 enabled for mixed precision - using bfSixteen policy ```
162b0dc
to
cb00288
Compare
Summary:
Fix up the FSDP tutorial to get it functional again.
checkpoint
instead of_shard.checkpoint
to get rid of a warning.load_metric
as this function does not exist in newdatasets
module.legacy=False
to get rid of tokenizer warnings.Test Plan:
Ran the tutorial as follows and ensured that it ran successfully: