Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T170073014] Rewrite distributed examples for Tensor Parallel, Sequence Parallel, 2D (FSDP + TP) #1201

Merged
merged 20 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
21a5fcf
update requirements.txt
lessw2020 Nov 15, 2023
f962b60
add torchrun support, move to init_device_mesh
lessw2020 Nov 15, 2023
bc3c1dd
update twod fully working
lessw2020 Nov 16, 2023
11a3bb2
ensure proper dp group seeding for synth data
lessw2020 Nov 16, 2023
9cebdf0
swiglu model added
lessw2020 Nov 16, 2023
2447883
sequential running of custom, auto, seq parallel models
lessw2020 Nov 16, 2023
a388c20
streamline to 2D TP only for two_d_parallel example
lessw2020 Nov 17, 2023
842c3f0
sequence parallel working...needs init_device_mesh update
lessw2020 Nov 18, 2023
3aa1c53
seq parallel now using init_device_mesh
lessw2020 Nov 21, 2023
b54e2ec
tp and sp examples all working and updated
lessw2020 Nov 21, 2023
4889e3b
updates from code review
lessw2020 Nov 21, 2023
b215178
remove utils.py. Sample models created in example files
lessw2020 Nov 22, 2023
242c328
remove originals.py, leftover imports, various updates from code revi…
lessw2020 Nov 22, 2023
2f4a083
code linting via ruff
lessw2020 Nov 22, 2023
742966b
code formatting via ruff
lessw2020 Nov 22, 2023
7da71bc
move rank_log to utils.py, update example files
lessw2020 Nov 22, 2023
836f798
move logging imports and config to log_utils, update examples with ne…
lessw2020 Nov 22, 2023
2de0144
add gpu verification, update run_python_examples.sh
lessw2020 Nov 22, 2023
77fe3d8
update min gpu = 4 for fsdp+tp
lessw2020 Nov 22, 2023
5f4a5d3
move gpu check to top of examples, but before import init_device_mesh…
lessw2020 Nov 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions distributed/tensor_parallelism/original.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import argparse
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.distributed._tensor import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import (
PairwiseParallel,
parallelize_module,
)
from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp

from utils import cleanup, setup, ToyModel
try:
from torch.distributed.tensor.parallel import (
SequenceParallel
)
SP_AVAILABLE = True
except BaseException as e:
pass


"""
This is the script to test 2D Parallel which combines Tensor/Sequence
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model
in the SPMD style. We show an E2E working flow from forward, backward
and optimization.

We enabled Fully Sharded Data Parallel + Tensor Parallel in
separate parallel dimensions:
Data Parallel across hosts
Tensor Parallel within each host

We use a simple diagram to illustrate below:

======================================================================
------------ ------------ ------------ ------------
| Host 1 | | Host 2 | | | | Host N |
| 8 GPUs | | 8 GPUs | | | | 8 GPUs |
| | | | | ... | | |
| (TP) | | (TP) | | | | (TP) |
|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7|
| | | | | | | .., 8N-1]|
| | | | | | | |
------------ ------------ ------------ ------------
FSDP:
[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
======================================================================

More details can be seen in the slide:
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
"""


def demo_2d(rank, args):
"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
print(f"Running basic Megatron style TP example on rank {rank}.")
setup(rank, args.world_size)
assert (
args.world_size % args.tp_size == 0
), "World size needs to be divisible by TP size"

# create a sharding plan based on the given world_size.
device_mesh = DeviceMesh(
"cuda", torch.arange(0, args.world_size).view(-1, args.tp_size)
)

# create model and move it to GPU with id rank
model = ToyModel().cuda(rank)
# Create a optimizer for the parallelized module.
LR = 0.25
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# Parallelize the module based on the given Parallel Style.
parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel()
model = parallelize_module(model, device_mesh, parallel_style, tp_mesh_dim=1)

# We need to register hooks for TP + FSDP integration.
assert (
enable_2d_with_fsdp()
), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0"
dp_pg = device_mesh.get_dim_groups()[0]
model = FSDP(model, process_group=dp_pg)

# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
for i in range(args.iter_nums):
# For TP, input needs to be same across all TP ranks.
# while for SP, input can be different across all ranks.
# Setting the random seed is to mimic the behavior of dataloader.
dp_rank = (
rank
if args.run_seq_parallel
else dist.get_rank(dp_pg)
)
torch.manual_seed(i + dp_rank)
inp = torch.rand(20, 10).cuda(rank)
output = model(inp)
output.sum().backward()
optimizer.step()

cleanup()


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
parser = argparse.ArgumentParser()
# This is passed in via cmd
parser.add_argument("--world_size", type=int, default=n_gpus)
parser.add_argument("--iter_nums", type=int, default=10)
parser.add_argument("--run_seq_parallel", type=bool, default=False)
parser.add_argument("--tp_size", type=int, default=2)
args = parser.parse_args()
# The main entry point is called directly without using subprocess
if n_gpus < 4:
print("Requires at least 4 GPUs to run.")
elif not SP_AVAILABLE:
print(
"PyTorch doesn't have Sequence Parallelism available,"
" need nightly build."
)
else:
mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True)
6 changes: 3 additions & 3 deletions distributed/tensor_parallelism/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Python dependencies required for running the example

--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cu113
--extra-index-url https://download.pytorch.org/whl/nightly/cu116
torch >= 1.14.0.dev0; sys_platform == "linux"
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch >= 2.2.0.dev0; sys_platform == "linux"
13 changes: 13 additions & 0 deletions distributed/tensor_parallelism/run_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

# To run samples:
# bash run_example.sh file_to_run.py num_gpus
# where file_to_run = example to launch. Default = 'two_d_parallel_example.py'
# num_gpus = num local gpus to use (must be at least 2). Default =4

# samples to run include:
# sequence_parallel_example.py
# tensor_parallel_example.py
# two_d_parallel_example.py

echo "Launching ${1:-two_d_parallel_example.py} with ${2:-4} gpus"
torchrun --nnodes=1 --nproc_per_node=${2:-4} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-two_d_parallel_example.py}
132 changes: 78 additions & 54 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import argparse

import os
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard

from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)

from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel import parallelize_module
from utils import cleanup, setup, ToyModel

try:
from torch.distributed.tensor.parallel import (
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -33,51 +39,69 @@
"""


def demo_sp(rank, args):
"""
Main body of the demo of a basic version of sequence parallel by using
PyTorch native APIs.
"""
print(f"Running SP example on rank {rank}.")
setup(rank, args.world_size)

# create a sharding plan based on the given world_size.
device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size))

# create model and move it to GPU with id rank
model = ToyModel().cuda(rank)
# Create a optimizer for the parallelized module.
LR = 0.25
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# Parallelize the module based on the given Parallel Style.
model = parallelize_module(model, device_mesh, SequenceParallel())

# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
for _ in range(args.iter_nums):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10).cuda(rank)
output = model(inp)
output.sum().backward()
optimizer.step()

cleanup()


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
parser = argparse.ArgumentParser()
# This is passed in via cmd
parser.add_argument("--world_size", type=int, default=n_gpus)
parser.add_argument("--iter_nums", type=int, default=10)
args = parser.parse_args()
# The main entry point is called directly without using subprocess
if n_gpus < 2:
print("Requires at least 2 GPUs to run.")
elif not SP_AVAILABLE:
print(
"PyTorch doesn't have Sequence Parallelism available,"
" need nightly build."
)
else:
mp.spawn(demo_sp, args=(args,), nprocs=args.world_size, join=True)
class ToyModel(nn.Module):
""" MLP based model """
def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)

def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))


"""
Main body of the demo of a basic version of sequence parallel by using
PyTorch native APIs.
"""

_rank = int(os.environ["RANK"])

lessw2020 marked this conversation as resolved.
Show resolved Hide resolved

def rank_print(msg):
"""helper function to print only on global rank 0"""
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
if _rank==0:
print(f"{msg}")

print(f"Running basic Megatron style Sequence Parallel example on rank {_rank}.")

# create a device mesh based on the given world_size.
_device = f"cuda"
device_mesh = init_device_mesh(device_type = _device,mesh_shape = (int(os.environ["WORLD_SIZE"]),))
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved

rank_print(f"Device Mesh created: {device_mesh=}")


# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to(_device)

# Custom parallelization plan for the model
sp_model = parallelize_module(module = model,
device_mesh = device_mesh,
parallelize_plan = {
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)


# Create a optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr, foreach=True)


# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
num_iters = 10
rank_print(f"Sequence Parallel training starting...")

for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10,device=_device)
output = sp_model(inp)
output.sum().backward()
optimizer.step()
rank_print(f"Sequence Parallel iter {i} completed")

rank_print(f"Sequence Parallel training completed!")
Loading
Loading