From 9384f4bb61d44b3617f02e7063c96852161f393b Mon Sep 17 00:00:00 2001 From: Youqing Xiaozhua <843213558@qq.com> Date: Thu, 19 Oct 2023 18:42:52 +1100 Subject: [PATCH] [Fix] --local-rank for PyTorch >= 2.0.0 (#2051) Co-authored-by: rangoliu --- tools/test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/test.py b/tools/test.py index 9eb605b5ae..1d2d278508 100644 --- a/tools/test.py +++ b/tools/test.py @@ -35,7 +35,10 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank)