Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T170073014] Rewrite distributed examples for Tensor Parallel, Sequence Parallel, 2D (FSDP + TP) #1201

Merged
merged 20 commits into from
Nov 22, 2023

Conversation

lessw2020
Copy link
Contributor

This PR updates the three distributed examples for Tensor Parallel, Sequence Parallel and 2D with the following main changes:
(note - internal reference - task [T170073014] Rewrite TensorParalell/SequenceParallel Examples using our new UX)

1 - move to torchrun launching (see run_.sh files) and relevant world topology introspection in the setup instead of mp.spawn.
2 - move device mesh creation to new api, init_device_mesh
3 - use custom parallelization plans (ColwiseParallel and RowwiseParallel) rather than the previous prebuilt PairwiseParallel() and SequenceParallel()
4 - For the 2D example - used a more relevant swiglu MLP model to showcase applying 2D to a more sophisticated/llama style situation.
5 - Adds more interactive UI for the user (start, per iter, and completion feedback).

Copy link

netlify bot commented Nov 21, 2023

Deploy Preview for pytorch-examples-preview canceled.

Name Link
🔨 Latest commit 5f4a5d3
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-examples-preview/deploys/655e7ed39bbef400093f38f9

@lessw2020
Copy link
Contributor Author

  • test failures are related to being unable to import init_device_mesh (?), no gpu's available (?), and lastly need to modify these tests to launch via the .sh files associated with each example (to run torchscript):
Traceback (most recent call last):
  File "tensor_parallelism/tensor_parallel_example.py", line 5, in <module>
    from torch.distributed._tensor.device_mesh import init_device_mesh
ImportError: cannot import name 'init_device_mesh' from 'torch.distributed._tensor.device_mesh' (/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/torch/distributed/_tensor/device_mesh.py)
tensor parallel example failed
Traceback (most recent call last):
  File "tensor_parallelism/sequence_parallel_example.py", line 5, in <module>
    from torch.distributed._tensor.device_mesh import init_device_mesh
ImportError: cannot import name 'init_device_mesh' from 'torch.distributed._tensor.device_mesh' (/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/torch/distributed/_tensor/device_mesh.py)
sequence parallel example failed
Traceback (most recent call last):
  File "tensor_parallelism/two_d_parallel_example.py", line 18, in <module>
    from torch.distributed._tensor.device_mesh import init_device_mesh
ImportError: cannot import name 'init_device_mesh' from 'torch.distributed._tensor.device_mesh' (/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/torch/distributed/_tensor/device_mesh.py)
2D parallel example failed
Requires at least 8 GPUs to run, but got 0.

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass, I think we can do things a lot simpler. please see inline comments

distributed/tensor_parallelism/original.py Outdated Show resolved Hide resolved
distributed/tensor_parallelism/run_sequence_parallel.sh Outdated Show resolved Hide resolved
distributed/tensor_parallelism/tensor_parallel_example.py Outdated Show resolved Hide resolved
distributed/tensor_parallelism/two_d_parallel_example.py Outdated Show resolved Hide resolved
distributed/tensor_parallelism/two_d_parallel_example.py Outdated Show resolved Hide resolved
# while for SP, input can be different across all ranks.
# We will use dp_rank for setting the random seed
# to mimic the behavior of the dataloader.
dp_rank = dist.get_rank(dp_pg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this needs to be consolidate to a device mesh API, cc @wz337

distributed/tensor_parallelism/two_d_parallel_example.py Outdated Show resolved Hide resolved
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much better! wondering what's the reason to keep original.py and also some inline comments about imports, etc.

distributed/tensor_parallelism/original.py Outdated Show resolved Hide resolved
distributed/tensor_parallelism/two_d_parallel_example.py Outdated Show resolved Hide resolved
distributed/tensor_parallelism/two_d_parallel_example.py Outdated Show resolved Hide resolved
distributed/tensor_parallelism/two_d_parallel_example.py Outdated Show resolved Hide resolved
distributed/tensor_parallelism/two_d_parallel_example.py Outdated Show resolved Hide resolved
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Only have some nits for logging, thanks for addressing the comments!

distributed/tensor_parallelism/fsdp_tp_example.py Outdated Show resolved Hide resolved
distributed/tensor_parallelism/tensor_parallel_example.py Outdated Show resolved Hide resolved
@msaroufim msaroufim merged commit c4dc481 into pytorch:main Nov 22, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants