From a308b4e97459b07c1b356642b2a8b4206c6d6de1 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 17 Sep 2024 00:15:14 -0700 Subject: [PATCH] Update DDP tutorial for the correct order of set_device (#1285) --- distributed/ddp-tutorial-series/multigpu.py | 4 ++-- distributed/ddp-tutorial-series/multigpu_torchrun.py | 4 ++-- distributed/ddp-tutorial-series/multinode.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/distributed/ddp-tutorial-series/multigpu.py b/distributed/ddp-tutorial-series/multigpu.py index 029731b5d2..7e11633305 100644 --- a/distributed/ddp-tutorial-series/multigpu.py +++ b/distributed/ddp-tutorial-series/multigpu.py @@ -18,8 +18,8 @@ def ddp_setup(rank, world_size): """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" - init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) + init_process_group(backend="nccl", rank=rank, world_size=world_size) class Trainer: def __init__( @@ -99,6 +99,6 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s parser.add_argument('save_every', type=int, help='How often to save a snapshot') parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') args = parser.parse_args() - + world_size = torch.cuda.device_count() mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size) diff --git a/distributed/ddp-tutorial-series/multigpu_torchrun.py b/distributed/ddp-tutorial-series/multigpu_torchrun.py index 66d8187346..32d6254d2d 100644 --- a/distributed/ddp-tutorial-series/multigpu_torchrun.py +++ b/distributed/ddp-tutorial-series/multigpu_torchrun.py @@ -11,8 +11,8 @@ def ddp_setup(): - init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + init_process_group(backend="nccl") class Trainer: def __init__( @@ -107,5 +107,5 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str parser.add_argument('save_every', type=int, help='How often to save a snapshot') parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') args = parser.parse_args() - + main(args.save_every, args.total_epochs, args.batch_size) diff --git a/distributed/ddp-tutorial-series/multinode.py b/distributed/ddp-tutorial-series/multinode.py index e80636bcc4..72670171b5 100644 --- a/distributed/ddp-tutorial-series/multinode.py +++ b/distributed/ddp-tutorial-series/multinode.py @@ -11,8 +11,8 @@ def ddp_setup(): - init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + init_process_group(backend="nccl") class Trainer: def __init__( @@ -108,5 +108,5 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str parser.add_argument('save_every', type=int, help='How often to save a snapshot') parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') args = parser.parse_args() - + main(args.save_every, args.total_epochs, args.batch_size)