Skip to content

Commit

Permalink
using update from Kacper to initialize_comms and modifying to enable …
Browse files Browse the repository at this point in the history
…backwards compatibility with distributed training
  • Loading branch information
seanmcpherson committed Aug 9, 2023
1 parent 6004901 commit 708612f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
3 changes: 2 additions & 1 deletion examples/train_homogeneous_graph_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def main():
args.world_size,
master_ip_address,
args.backend)

sar.start_comm_thread()

# Load DGL partition data
partition_data = sar.load_dgl_partition_data(
args.partitioning_json_file, args.rank, device)
Expand Down
4 changes: 3 additions & 1 deletion examples/train_homogeneous_graph_basic_single-node.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
parser.add_argument('--ip-file', default='./ip_file', type=str,
help='File with ip-address. Worker 0 creates this file and all others read it ')

parser.add_argument('--shared-file', default='./shared_file', type=str,
help='Path to a file required by torch.dist for inter-process communication')

parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi'],
help='Communication backend to use '
Expand Down Expand Up @@ -243,7 +245,7 @@ def run(args, rank, lock, barrier):
master_ip_address = sar.nfs_ip_init(rank, args.ip_file)
sar.initialize_comms(rank,
args.world_size, master_ip_address,
args.backend)
args.backend, args.shared_file, barrier)

lock.acquire()
print("Node {} Lock Acquired".format(rank))
Expand Down
35 changes: 32 additions & 3 deletions sar/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torch import Tensor
from .config import Config
from .common_tuples import SocketInfo
from multiprocessing import Barrier

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -129,7 +130,8 @@ def nfs_ip_init(_rank: int, ip_file: str) -> str:


def initialize_comms(_rank: int, _world_size: int, master_ip_address: str,
backend: str, _comm_device: Optional[torch.device] = None,
backend: str, shared_file: str = None, barrier: Barrier = None,
_comm_device: Optional[torch.device] = None,
master_port_number: int = 12345):
"""
Initialize Pytorch's communication library
Expand All @@ -142,6 +144,10 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str,
:type master_ip_address: str
:param backend: Backend to use. Can be ccl, nccl, mpi or gloo
:type backend: str
:param shared_file: Path to a file required by torch.dist for inter-process communication
:type shared_file: str
:param barrier: Barrier for synchronizing processes
:type barrier: Barrier
:param _comm_device: The device on which the tensors should be on in order to transmit them\
through the backend. If not provided, the device is infered based on the backend type
:type _comm_device: torch.device
Expand Down Expand Up @@ -193,8 +199,31 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str,
os.environ['I_MPI_COMM_WORLD'] = str(_world_size)
os.environ['I_MPI_COMM_RANK'] = str(_rank)
try:
dist.init_process_group(
backend=backend, rank=_rank, world_size=_world_size)
if shared_file is None and barrier is None:
dist.init_process_group(
backend=backend, rank=_rank, world_size=_world_size)
elif shared_file is not None and barrier is not None:
if not os.path.isabs(shared_file):
current_dir = os.getcwd()
shared_file = os.path.join(current_dir, shared_file)

try:
os.remove(shared_file)
except FileNotFoundError as e:
...

barrier.wait()
prefix = "file://"
shared_file = prefix + shared_file
dist.init_process_group(
backend=backend, rank=_rank, world_size=_world_size,
init_method=shared_file)
else:
logger.error("SAR sar.initialize_comms shared_file and barrier should "
"either both be None or both have a value. Received"
"shared_file {}, barrier {}".format(shared_file, barrier))
raise

except:
logger.error("SAR was unable to initialize torch.distributed process group. "
"You can try to do it manually before calling sar.initialize_comms")
Expand Down

0 comments on commit 708612f

Please sign in to comment.