diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic.py index 9b7ed9d..14551a3 100644 --- a/examples/train_homogeneous_graph_basic.py +++ b/examples/train_homogeneous_graph_basic.py @@ -113,7 +113,8 @@ def main(): args.world_size, master_ip_address, args.backend) - + sar.start_comm_thread() + # Load DGL partition data partition_data = sar.load_dgl_partition_data( args.partitioning_json_file, args.rank, device) diff --git a/examples/train_homogeneous_graph_basic_single-node.py b/examples/train_homogeneous_graph_basic_single-node.py index f8c57dc..dca186f 100644 --- a/examples/train_homogeneous_graph_basic_single-node.py +++ b/examples/train_homogeneous_graph_basic_single-node.py @@ -47,6 +47,8 @@ parser.add_argument('--ip-file', default='./ip_file', type=str, help='File with ip-address. Worker 0 creates this file and all others read it ') +parser.add_argument('--shared-file', default='./shared_file', type=str, + help='Path to a file required by torch.dist for inter-process communication') parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi'], help='Communication backend to use ' @@ -243,7 +245,7 @@ def run(args, rank, lock, barrier): master_ip_address = sar.nfs_ip_init(rank, args.ip_file) sar.initialize_comms(rank, args.world_size, master_ip_address, - args.backend) + args.backend, args.shared_file, barrier) lock.acquire() print("Node {} Lock Acquired".format(rank)) diff --git a/sar/comm.py b/sar/comm.py index 6e23fb5..c427e2e 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -35,6 +35,7 @@ from torch import Tensor from .config import Config from .common_tuples import SocketInfo +from multiprocessing import Barrier logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -129,7 +130,8 @@ def nfs_ip_init(_rank: int, ip_file: str) -> str: def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, - backend: str, _comm_device: Optional[torch.device] = None, + backend: str, shared_file: str = None, barrier: Barrier = None, + _comm_device: Optional[torch.device] = None, master_port_number: int = 12345): """ Initialize Pytorch's communication library @@ -142,6 +144,10 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, :type master_ip_address: str :param backend: Backend to use. Can be ccl, nccl, mpi or gloo :type backend: str + :param shared_file: Path to a file required by torch.dist for inter-process communication + :type shared_file: str + :param barrier: Barrier for synchronizing processes + :type barrier: Barrier :param _comm_device: The device on which the tensors should be on in order to transmit them\ through the backend. If not provided, the device is infered based on the backend type :type _comm_device: torch.device @@ -193,8 +199,31 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, os.environ['I_MPI_COMM_WORLD'] = str(_world_size) os.environ['I_MPI_COMM_RANK'] = str(_rank) try: - dist.init_process_group( - backend=backend, rank=_rank, world_size=_world_size) + if shared_file is None and barrier is None: + dist.init_process_group( + backend=backend, rank=_rank, world_size=_world_size) + elif shared_file is not None and barrier is not None: + if not os.path.isabs(shared_file): + current_dir = os.getcwd() + shared_file = os.path.join(current_dir, shared_file) + + try: + os.remove(shared_file) + except FileNotFoundError as e: + ... + + barrier.wait() + prefix = "file://" + shared_file = prefix + shared_file + dist.init_process_group( + backend=backend, rank=_rank, world_size=_world_size, + init_method=shared_file) + else: + logger.error("SAR sar.initialize_comms shared_file and barrier should " + "either both be None or both have a value. Received" + "shared_file {}, barrier {}".format(shared_file, barrier)) + raise + except: logger.error("SAR was unable to initialize torch.distributed process group. " "You can try to do it manually before calling sar.initialize_comms")