From 84d525a3519bdfc11256c68844ed3aad0541ed38 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 28 Apr 2023 10:16:30 +0000 Subject: [PATCH 01/58] Update SAR to work with DGL >= 1.0 --- examples/partition_arxiv_products.py | 1 - examples/partition_mag.py | 1 - examples/train_homogeneous_graph_basic.py | 3 ++- sar/core/graphshard.py | 8 ++++---- sar/core/sampling.py | 2 +- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/partition_arxiv_products.py b/examples/partition_arxiv_products.py index a751a6f..c77a19b 100644 --- a/examples/partition_arxiv_products.py +++ b/examples/partition_arxiv_products.py @@ -85,7 +85,6 @@ def _idx_to_mask(idx_tensor): args.num_partitions, args.partition_out_path, num_hops=1, - reshuffle=True, balance_ntypes=train_mask, balance_edges=True) diff --git a/examples/partition_mag.py b/examples/partition_mag.py index ce76831..0e24512 100644 --- a/examples/partition_mag.py +++ b/examples/partition_mag.py @@ -72,7 +72,6 @@ def idx_to_mask(idx_tensor): args.num_partitions, args.partition_out_path, num_hops=1, - reshuffle=True, balance_edges=True) diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic.py index 8f39f25..14ee691 100644 --- a/examples/train_homogeneous_graph_basic.py +++ b/examples/train_homogeneous_graph_basic.py @@ -110,7 +110,8 @@ def main(): # Obtain the ip address of the master through the network file system master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) sar.initialize_comms(args.rank, - args.world_size, master_ip_address, + args.world_size, + master_ip_address, args.backend) # Load DGL partition data diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index f819cff..07437ee 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -33,7 +33,7 @@ from contextlib import contextmanager import torch import dgl # type:ignore -from dgl import DGLHeteroGraph +from dgl import DGLGraph from dgl.function.base import TargetCode # type:ignore import dgl.function as fn # type: ignore from torch import Tensor @@ -94,7 +94,7 @@ def __init__(self, num_dst_nodes=self.unique_tgt_nodes.size( 0) ) - self._graph_reverse: Optional[DGLHeteroGraph] = None + self._graph_reverse: Optional[DGLGraph] = None self._shard_info: Optional[ShardInfo] = None self.graph.edata.update(shard_edges_features.edge_features) @@ -108,7 +108,7 @@ def shard_info(self) -> Optional[ShardInfo]: return self._shard_info @property - def graph_reverse(self) -> DGLHeteroGraph: + def graph_reverse(self) -> DGLGraph: if self._graph_reverse is None: edges_src, edges_tgt = self.graph.all_edges() self._graph_reverse = dgl.create_block((edges_tgt, edges_src), @@ -169,7 +169,7 @@ def __len__(self): class GraphShardManager: """ Manages the local graph partition and exposes a subset of the interface - of dgl.heterograph.DGLHeteroGraph. Most importantly, it implements a + of dgl.heterograph.DGLGraph. Most importantly, it implements a distributed version of the ``update_all`` and ``apply_edges`` functions which are extensively used by GNN layers to exchange messages. By default, both ``update_all`` and ``apply_edges`` use sequential aggregation and diff --git a/sar/core/sampling.py b/sar/core/sampling.py index cbecd0d..5c0e127 100644 --- a/sar/core/sampling.py +++ b/sar/core/sampling.py @@ -25,7 +25,7 @@ import torch import dgl # type: ignore from dgl.heterograph import DGLBlock # type: ignore -from dgl.heterograph import DGLHeteroGraph as DGLGraph # type: ignore +from dgl.heterograph import DGLGraph # type: ignore from dgl.sampling import sample_neighbors # type: ignore import dgl.partition # type:ignore From d889af38547e0b962305b4c04e90b0083acea860 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 28 Apr 2023 10:28:19 +0000 Subject: [PATCH 02/58] Improve SAR initialization --- sar/comm.py | 58 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/sar/comm.py b/sar/comm.py index 7421353..92942ca 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -67,8 +67,9 @@ def get_socket_name() -> str: logger.info(f'getting socket name for ib adapter: {ib_adapters[0]}') sock_name = ib_adapters[0].nice_name else: + acceptable_prefixes = ['eth', 'ens', 'enp'] eth_adapters = [ - x for x in adaps if 'eth' in x.nice_name or 'enp' in x.nice_name] + x for x in adaps if any(prefix in x.nice_name for prefix in acceptable_prefixes)] logger.info( f'getting socket name for ethernet adapter: {eth_adapters[0]}') sock_name = eth_adapters[0].nice_name @@ -91,8 +92,9 @@ def dump_ip_address(ip_file: str) -> str: logger.info(f'found infinity band adapter: {ib_adapters[0]}') host_ip = ib_adapters[0].ips[0].ip else: + acceptable_prefixes = ['eth', 'ens', 'enp'] eth_adapters = [ - x for x in adaps if 'eth' in x.nice_name or 'enp' in x.nice_name] + x for x in adaps if any(prefix in x.nice_name for prefix in acceptable_prefixes)] logger.info(f'using ethernet adapter: {eth_adapters}') host_ip = eth_adapters[0].ips[0].ip with open(ip_file, 'w', encoding='utf-8') as f_handle: @@ -156,7 +158,7 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, """ assert backend in ['ccl', 'nccl', - 'mpi'], 'backend must be ccl,nccl, or mpi' + 'mpi'], 'backend must be ccl, nccl, or mpi' if _comm_device is None: if backend == 'nccl': _comm_device = torch.device('cuda') @@ -168,34 +170,44 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, if backend == 'ccl': # pylint: disable=unused-import - import torch_ccl # type: ignore + try: + import oneccl_bindings_for_pytorch # type: ignore + except: + try: + import torch_ccl # type: ignore + except: + raise "None of the oneccl_bindings_for_pytorch and torch_ccl package has been found" - os.environ['MASTER_ADDR'] = master_ip_address - os.environ['MASTER_PORT'] = str(master_port_number) + if not dist.is_initialized(): + os.environ['MASTER_ADDR'] = master_ip_address + os.environ['MASTER_PORT'] = str(master_port_number) - sock_name = get_socket_name() - os.environ['TP_SOCKET_IFNAME'] = sock_name - os.environ['GLOO_SOCKET_IFNAME'] = sock_name - os.environ['CCL_SOCKET_IFNAME'] = sock_name - os.environ['NCCL_SOCKET_IFNAME'] = sock_name + sock_name = get_socket_name() + os.environ['TP_SOCKET_IFNAME'] = sock_name + os.environ['GLOO_SOCKET_IFNAME'] = sock_name + os.environ['CCL_SOCKET_IFNAME'] = sock_name + os.environ['NCCL_SOCKET_IFNAME'] = sock_name - os.environ['FI_VERBS_IFACE'] = sock_name - os.environ['FI_mlx_IFACE'] = sock_name + os.environ['FI_VERBS_IFACE'] = sock_name + os.environ['FI_mlx_IFACE'] = sock_name - os.environ['MPI_COMM_WORLD'] = str(_world_size) - os.environ['MPI_COMM_RANK'] = str(_rank) + os.environ['MPI_COMM_WORLD'] = str(_world_size) + os.environ['MPI_COMM_RANK'] = str(_rank) - os.environ['OMPI_COMM_WORLD'] = str(_world_size) - os.environ['OMPI_COMM_RANK'] = str(_rank) + os.environ['OMPI_COMM_WORLD'] = str(_world_size) + os.environ['OMPI_COMM_RANK'] = str(_rank) - os.environ['IMPI_COMM_WORLD'] = str(_world_size) - os.environ['IMPI_COMM_RANK'] = str(_rank) + os.environ['IMPI_COMM_WORLD'] = str(_world_size) + os.environ['IMPI_COMM_RANK'] = str(_rank) - os.environ['I_MPI_COMM_WORLD'] = str(_world_size) - os.environ['I_MPI_COMM_RANK'] = str(_rank) + os.environ['I_MPI_COMM_WORLD'] = str(_world_size) + os.environ['I_MPI_COMM_RANK'] = str(_rank) - dist.init_process_group( - backend=backend, rank=_rank, world_size=_world_size) + dist.init_process_group( + backend=backend, rank=_rank, world_size=_world_size) + else: + assert dist.get_backend() in ['ccl', 'nccl', + 'mpi'], 'backend must be ccl, nccl, or mpi' _CommData.rank = _rank _CommData.world_size = _world_size From 0286a48ecb49bf09b5a86e61e38968ca1739e0e3 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 28 Apr 2023 13:12:43 +0000 Subject: [PATCH 03/58] Add converting DistGraph to SAR with example --- examples/train_distdgl_with_sar_inference.py | 500 +++++++++++++++++++ sar/__init__.py | 4 +- sar/construct_shard_manager.py | 6 + sar/data_loading.py | 94 +++- 4 files changed, 590 insertions(+), 14 deletions(-) create mode 100755 examples/train_distdgl_with_sar_inference.py diff --git a/examples/train_distdgl_with_sar_inference.py b/examples/train_distdgl_with_sar_inference.py new file mode 100755 index 0000000..8c5b42a --- /dev/null +++ b/examples/train_distdgl_with_sar_inference.py @@ -0,0 +1,500 @@ +import argparse +import socket +import time +from contextlib import contextmanager + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tqdm +import dgl +import dgl.nn.pytorch as dglnn +import os +import sar + +def load_subtensor(g, seeds, input_nodes, device, load_feat=True): + """ + Copys features and labels of a set of nodes onto GPU. + """ + batch_inputs = ( + g.ndata["features"][input_nodes].to(device) if load_feat else None + ) + batch_labels = g.ndata["labels"][seeds].to(device) + return batch_inputs, batch_labels + + +class DistSAGE(nn.Module): + def __init__( + self, in_feats, n_hidden, n_classes, n_layers, activation, dropout + ): + super().__init__() + self.n_layers = n_layers + self.n_hidden = n_hidden + self.n_classes = n_classes + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean")) + for i in range(1, n_layers - 1): + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean")) + self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean")) + self.dropout = nn.Dropout(dropout) + self.activation = activation + + def forward(self, blocks, x): + h = x + for i, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if i != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + return h + + def full_graph_inference(self, graph, features): + h = features + for i, layer in enumerate(self.layers): + h = layer(graph, h) + if i != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + return h + + def inference(self, g, x, batch_size, device): + """ + Inference with the GraphSAGE model on full neighbors (i.e. without + neighbor sampling). + + g : the entire graph. + x : the input of entire node set. + + Distributed layer-wise inference. + """ + # During inference with sampling, multi-layer blocks are very + # inefficient because lots of computations in the first few layers + # are repeated. Therefore, we compute the representation of all nodes + # layer by layer. The nodes on each layer are of course splitted in + # batches. + # TODO: can we standardize this? + nodes = dgl.distributed.node_split( + np.arange(g.num_nodes()), + g.get_partition_book(), + force_even=True, + ) + y = dgl.distributed.DistTensor( + (g.num_nodes(), self.n_hidden), + th.float32, + "h", + persistent=True, + ) + for i, layer in enumerate(self.layers): + if i == len(self.layers) - 1: + y = dgl.distributed.DistTensor( + (g.num_nodes(), self.n_classes), + th.float32, + "h_last", + persistent=True, + ) + print( + f"|V|={g.num_nodes()}, eval batch size: {batch_size}" + ) + + sampler = dgl.dataloading.NeighborSampler([-1]) + dataloader = dgl.dataloading.DistNodeDataLoader( + g, + nodes, + sampler, + batch_size=batch_size, + shuffle=False, + drop_last=False, + ) + + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + block = blocks[0].to(device) + h = x[input_nodes].to(device) + h_dst = h[: block.number_of_dst_nodes()] + h = layer(block, (h, h_dst)) + if i != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + + y[output_nodes] = h.cpu() + + x = y + g.barrier() + return y + + @contextmanager + def join(self): + """dummy join for standalone""" + yield + + +def compute_acc(pred, labels): + """ + Compute the accuracy of prediction given the labels. + """ + labels = labels.long() + return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) + + +def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): + """ + Evaluate the model on the validation set specified by ``val_nid``. + g : The entire graph. + inputs : The features of all the nodes. + labels : The labels of all the nodes. + val_nid : the node Ids for validation. + batch_size : Number of nodes to compute at the same time. + device : The GPU device to evaluate on. + """ + model.eval() + with th.no_grad(): + pred = model.inference(g, inputs, batch_size, device) + model.train() + return compute_acc(pred[val_nid], labels[val_nid]), compute_acc( + pred[test_nid], labels[test_nid] + ) + + +def sar_evaluate(fgm, g, model, device): + """ + Evaluate the model on the validation and test sets with SAR. + fgm : SAR's GraphShardManager to perform full graph inference + g : DistGraph - DistDGL graph + model : Trained model on which perform inference run + val_nid : the node Ids for validation. + batch_size : Number of nodes to compute at the same time. + device : The device to evaluate on + """ + # Obtain validation and test masks + # These are stored in local_partition of DistGraph as DistTensor + # Map to original NID in case of node reordering in DistDGL + local_part = g.local_partition + orig_n_ids = local_part.ndata[dgl.NID][local_part.ndata['inner_node'].bool().nonzero().view(-1)] + + features = g.ndata['features'][orig_n_ids] + labels = g.ndata['labels'][orig_n_ids] + test_mask = g.ndata['test_mask'][orig_n_ids].nonzero( + as_tuple=False).view(-1) + val_mask = g.ndata['val_mask'][orig_n_ids].nonzero( + as_tuple=False).view(-1) + + model.eval() + with th.no_grad(): + logits = model.full_graph_inference(fgm, features) + model.train() + # Calculate accuracy for validation and test + results = [] + for mask in [val_mask, test_mask]: + n_correct = (logits[mask].argmax(1) == + labels[mask]).float().sum() + results.extend([n_correct, mask.numel()]) + + acc_vec = th.FloatTensor(results) + # Sum the n_correct, and number of mask elements across all workers + sar.comm.all_reduce(acc_vec, op=th.distributed.ReduceOp.SUM, move_to_comm_device=True) + (val_acc, test_acc) = \ + (acc_vec[0] / acc_vec[1], + acc_vec[2] / acc_vec[3]) + return val_acc, test_acc + + +def run(args, device, data): + # Unpack data + train_nid, val_nid, test_nid, in_feats, n_classes, g = data + full_graph_manager = sar.convert_dist_graph(g).to(device) + shuffle = True + # prefetch_node_feats/prefetch_labels are not supported for DistGraph yet. + sampler = dgl.dataloading.NeighborSampler( + [int(fanout) for fanout in args.fan_out.split(",")] + ) + dataloader = dgl.dataloading.DistNodeDataLoader( + g, + train_nid, + sampler, + batch_size=args.batch_size, + shuffle=shuffle, + drop_last=False, + ) + # Define model and optimizer + model = DistSAGE( + in_feats, + args.num_hidden, + n_classes, + args.num_layers, + F.relu, + args.dropout, + ) + model = model.to(device) + if not args.standalone: + if args.num_gpus == -1: + model = th.nn.parallel.DistributedDataParallel(model) + else: + model = th.nn.parallel.DistributedDataParallel( + model, device_ids=[device], output_device=device + ) + loss_fcn = nn.CrossEntropyLoss() + loss_fcn = loss_fcn.to(device) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + # Training loop + iter_tput = [] + epoch = 0 + for epoch in range(args.num_epochs): + tic = time.time() + + sample_time = 0 + forward_time = 0 + backward_time = 0 + update_time = 0 + num_seeds = 0 + num_inputs = 0 + start = time.time() + # Loop over the dataloader to sample the computation dependency graph + # as a list of blocks. + step_time = [] + + with model.join(): + for step, (input_nodes, seeds, blocks) in enumerate(dataloader): + tic_step = time.time() + sample_time += tic_step - start + # fetch features/labels + batch_inputs, batch_labels = load_subtensor( + g, seeds, input_nodes, "cpu" + ) + batch_labels = batch_labels.long() + num_seeds += len(blocks[-1].dstdata[dgl.NID]) + num_inputs += len(blocks[0].srcdata[dgl.NID]) + # move to target device + blocks = [block.to(device) for block in blocks] + batch_inputs = batch_inputs.to(device) + batch_labels = batch_labels.to(device) + # Compute loss and prediction + start = time.time() + batch_pred = model(blocks, batch_inputs) + loss = loss_fcn(batch_pred, batch_labels) + forward_end = time.time() + optimizer.zero_grad() + loss.backward() + compute_end = time.time() + forward_time += forward_end - start + backward_time += compute_end - forward_end + + optimizer.step() + update_time += time.time() - compute_end + + step_t = time.time() - tic_step + step_time.append(step_t) + iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t) + if step % args.log_every == 0: + acc = compute_acc(batch_pred, batch_labels) + gpu_mem_alloc = ( + th.cuda.max_memory_allocated() / 1000000 + if th.cuda.is_available() + else 0 + ) + print( + "Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | " + "Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU " + "{:.1f} MB | time {:.3f} s".format( + g.rank(), + epoch, + step, + loss.item(), + acc.item(), + np.mean(iter_tput[3:]), + gpu_mem_alloc, + np.sum(step_time[-args.log_every:]), + ) + ) + start = time.time() + + toc = time.time() + print( + "Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, " + "forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, " + "#inputs: {}".format( + g.rank(), + toc - tic, + sample_time, + forward_time, + backward_time, + update_time, + num_seeds, + num_inputs, + ) + ) + epoch += 1 + + if epoch % args.eval_every == 0 and epoch != 0: + start = time.time() + # DistDNN evaluation + val_acc, test_acc = evaluate( + model if args.standalone else model.module, + g, + g.ndata["features"], + g.ndata["labels"], + val_nid, + test_nid, + args.batch_size_eval, + device, + ) + print( + "DistDGL: Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format + ( + g.rank(), val_acc, test_acc, time.time() - start + ) + ) + # SAR Evaluation + start = time.time() + val_acc, test_acc = sar_evaluate(full_graph_manager, g, model.module, device) + print( + "SAR: Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format + ( + g.rank(), val_acc, test_acc, time.time() - start + ) + ) + + +def main(args): + print(socket.gethostname(), "Initializing DGL dist") + dgl.distributed.initialize(args.ip_config, net_type=args.net_type) + if not args.standalone: + print(socket.gethostname(), "Initializing DGL process group") + master_ip_address = os.getenv("MASTER_ADDR") + if args.backend == 'ccl': + import oneccl_bindings_for_pytorch + th.distributed.init_process_group(backend=args.backend) + print(socket.gethostname(), "Initializing DistGraph") + g = dgl.distributed.DistGraph( + args.graph_name, + part_config=args.part_config + ) + sar.initialize_comms(g.rank(), + g.get_partition_book().num_partitions(), + master_ip_address, + args.backend + ) + + print(socket.gethostname(), "rank:", g.rank()) + + pb = g.get_partition_book() + if "trainer_id" in g.ndata: + train_nid = dgl.distributed.node_split( + g.ndata["train_mask"], + pb, + force_even=True, + node_trainer_ids=g.ndata["trainer_id"], + ) + val_nid = dgl.distributed.node_split( + g.ndata["val_mask"], + pb, + force_even=True, + node_trainer_ids=g.ndata["trainer_id"], + ) + test_nid = dgl.distributed.node_split( + g.ndata["test_mask"], + pb, + force_even=True, + node_trainer_ids=g.ndata["trainer_id"], + ) + else: + train_nid = dgl.distributed.node_split( + g.ndata["train_mask"], pb, force_even=True + ) + val_nid = dgl.distributed.node_split( + g.ndata["val_mask"], pb, force_even=True + ) + test_nid = dgl.distributed.node_split( + g.ndata["test_mask"], pb, force_even=True + ) + local_nid = pb.partid2nids(pb.partid).detach().numpy() + print( + "part {}, train: {} (local: {}), val: {} (local: {}), test: {} " + "(local: {})".format( + g.rank(), + len(train_nid), + len(np.intersect1d(train_nid.numpy(), local_nid)), + len(val_nid), + len(np.intersect1d(val_nid.numpy(), local_nid)), + len(test_nid), + len(np.intersect1d(test_nid.numpy(), local_nid)), + ) + ) + del local_nid + if args.num_gpus == -1: + device = th.device("cpu") + else: + dev_id = g.rank() % args.num_gpus + device = th.device("cuda:" + str(dev_id)) + n_classes = args.n_classes + if n_classes == 0: + labels = g.ndata["labels"][np.arange(g.num_nodes())] + n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) + del labels + print("#labels:", n_classes) + + # Pack data + in_feats = g.ndata["features"].shape[1] + data = train_nid, val_nid, test_nid, in_feats, n_classes, g + run(args, device, data) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GCN") + parser.add_argument("--graph_name", type=str, help="graph name") + parser.add_argument("--id", type=int, help="the partition id") + parser.add_argument( + "--ip_config", type=str, help="The file for IP configuration" + ) + parser.add_argument( + "--part_config", type=str, help="The path to the partition config file" + ) + parser.add_argument( + "--n_classes", type=int, default=0, help="the number of classes" + ) + parser.add_argument( + "--backend", + type=str, + default="ccl", + help="pytorch distributed backend", + ) + parser.add_argument( + "--num_gpus", + type=int, + default=-1, + help="the number of GPU device. Use -1 for CPU training", + ) + parser.add_argument("--num_epochs", type=int, default=20) + parser.add_argument("--num_hidden", type=int, default=16) + parser.add_argument("--num_layers", type=int, default=2) + parser.add_argument("--fan_out", type=str, default="10,25") + parser.add_argument("--batch_size", type=int, default=1000) + parser.add_argument("--batch_size_eval", type=int, default=100000) + parser.add_argument("--log_every", type=int, default=20) + parser.add_argument("--eval_every", type=int, default=5) + parser.add_argument("--lr", type=float, default=0.003) + parser.add_argument("--dropout", type=float, default=0.5) + parser.add_argument( + "--local_rank", type=int, help="get rank of the process" + ) + parser.add_argument( + "--standalone", action="store_true", help="run in the standalone mode" + ) + parser.add_argument( + "--pad-data", + default=False, + action="store_true", + help="Pad train nid to the same length across machine, to ensure num " + "of batches to be the same.", + ) + parser.add_argument( + "--net_type", + type=str, + default="socket", + help="backend net type, 'socket' or 'tensorpipe'", + ) + args = parser.parse_args() + + print(args) + main(args) diff --git a/sar/__init__.py b/sar/__init__.py index e519173..048f2e5 100644 --- a/sar/__init__.py +++ b/sar/__init__.py @@ -27,7 +27,7 @@ nfs_ip_init, sync_params, gather_grads from .core import GraphShardManager, message_has_parameters, DistributedBlock,\ DistNeighborSampler, DataLoader -from .construct_shard_manager import construct_mfgs, construct_full_graph +from .construct_shard_manager import construct_mfgs, construct_full_graph, convert_dist_graph from .data_loading import load_dgl_partition_data, suffix_key_lookup from .distributed_bn import DistributedBN1D from .config import Config @@ -38,7 +38,7 @@ __all__ = ['initialize_comms', 'rank', 'world_size', 'nfs_ip_init', 'comm_device', 'DistributedBN1D', - 'construct_mfgs', 'construct_full_graph', 'GraphShardManager', + 'construct_mfgs', 'construct_full_graph', 'convert_dist_graph', 'GraphShardManager', 'load_dgl_partition_data', 'suffix_key_lookup', 'Config', 'edge_softmax', 'message_has_parameters', 'DistributedBlock', 'DistNeighborSampler', 'DataLoader', 'logging_setup', 'logger', 'sync_params', 'gather_grads', 'patch_dgl'] diff --git a/sar/construct_shard_manager.py b/sar/construct_shard_manager.py index 03e6acd..eaf3721 100644 --- a/sar/construct_shard_manager.py +++ b/sar/construct_shard_manager.py @@ -22,10 +22,12 @@ from typing import List, Tuple, Dict import torch from torch import Tensor +import dgl # type: ignore import numpy as np # type: ignore from .comm import exchange_tensors, rank from .common_tuples import ShardEdgesAndFeatures, PartitionData from .core import GraphShardManager, GraphShard +from .data_loading import _mask_features_dict, _get_type_ordered_edges, load_dgl_partition_data_from_graph def map_to_contiguous_range(active_indices: Tensor, sampled_indices: Tensor) -> Tensor: @@ -203,3 +205,7 @@ def construct_full_graph(partition_data: PartitionData) -> GraphShardManager: partition_data.node_ranges[rank()][0]) return GraphShardManager(graph_shard_list, seed_nodes, seed_nodes) + +def convert_dist_graph(dist_graph: dgl.distributed.DistGraph) -> GraphShardManager: + partition_data = load_dgl_partition_data_from_graph(dist_graph, dist_graph.device) + return construct_full_graph(partition_data) \ No newline at end of file diff --git a/sar/data_loading.py b/sar/data_loading.py index 8007652..1c892a1 100644 --- a/sar/data_loading.py +++ b/sar/data_loading.py @@ -87,27 +87,37 @@ def _get_type_ordered_edges(edge_mask: Tensor, edge_types: Tensor, return torch.cat(reordered_edge_mask) -def load_dgl_partition_data(partition_json_file: str, - own_partition_idx: int, device: torch.device) -> PartitionData: +def create_partition_data(graph: dgl.DGLGraph, + own_partition_idx: int, + node_features: Dict[str, torch.Tensor], + edge_features: Dict[str, Tensor], + partition_book: dgl.distributed.GraphPartitionBook, + node_type_list: List[str], + edge_type_list: List[str], + device: torch.device) -> PartitionData: """ - Loads partition data created by DGL's ``partition_graph`` function + Creates SAR's PartitionData object basing on graph partition and features. - :param partition_json_file: Path to the .json file containing partitioning data - :type partition_json_file: str - :param own_partition_idx: The index of the partition to load. This is typically the\ + :param graph: The graph partition structure for specific ``own_partition_idx`` + :type graph: dgl.DGLGraph + :param own_partition_idx: The index of the partition to create. This is typically the\ worker/machine rank :type own_partition_idx: int + :param node_features: Dictionary containing node features for graph partition + :type node_features: Dict[str, Tensor] + :param edge_features: Dictionary containing edge features for graph partition + :type edge_features: Dict[(str, str, str), Tensor] + :param partition_book: The graph partition information + :type partition_book: dgl.distributed.GraphPartitionBook + :param node_type_list: List of node types + :type node_type_list: List[str] + :param edge_type_list: List of edge types + :type edge_type_list: List[str] :param device: Device on which to place the loaded partition data :type device: torch.device :returns: The loaded partition data - """ - (graph, node_features, - edge_features, partition_book, _, - node_type_list, edge_type_list) = load_partition(partition_json_file, own_partition_idx) - is_heterogeneous = (len(edge_type_list) > 1) - # Delete redundant edge features with keys {relation name}/reltype. graph.edata[dgl.ETYPE ] already contains # the edge type in a heterogeneous graph if is_heterogeneous: @@ -165,3 +175,63 @@ def load_dgl_partition_data(partition_json_file: str, node_type_list, edge_type_list ) + + +def load_dgl_partition_data(partition_json_file: str, + own_partition_idx: int, device: torch.device) -> PartitionData: + """ + Loads partition data created by DGL's ``partition_graph`` function + + :param partition_json_file: Path to the .json file containing partitioning data + :type partition_json_file: str + :param own_partition_idx: The index of the partition to load. This is typically the\ + worker/machine rank + :type own_partition_idx: int + :param device: Device on which to place the loaded partition data + :type device: torch.device + :returns: The loaded partition data + + """ + (graph, node_features, + edge_features, partition_book, _, + node_type_list, edge_type_list) = load_partition(partition_json_file, own_partition_idx) + + return create_partition_data(graph, own_partition_idx, + node_features, edge_features, + partition_book, node_type_list, + edge_type_list, device) + +def load_dgl_partition_data_from_graph(graph: dgl.distributed.DistGraph, + device: torch.device) -> PartitionData: + """ + Loads partition data from DistGraph object + + :param graph: The distributed graph + :type graph: dgl.distributed.DistGraph + :param device: Device on which to place the loaded partition data + :type device: torch.device + :returns: The loaded partition data + + """ + own_partition_idx = graph.rank() + local_g = graph.local_partition + + assert dgl.NID in local_g.ndata + assert dgl.EID in local_g.edata + + # get originalmapping for node and edge ids + orig_n_ids = local_g.ndata[dgl.NID][local_g.ndata['inner_node'].bool().nonzero().view(-1)] + orig_e_ids = local_g.edata[dgl.EID][local_g.edata['inner_edge'].bool().nonzero().view(-1)] + + # fetch local features from DistTensor + node_features = {key : torch.Tensor(graph.ndata[key][orig_n_ids]) for key in list(graph.ndata.keys())} + edge_features = {key : torch.Tensor(graph.edata[key][orig_e_ids]) for key in list(graph.edata.keys())} + + partition_book = graph.get_partition_book() + node_type_list = local_g.ntypes + edge_type_list = [local_g.to_canonical_etype(etype) for etype in graph.etypes] + + return create_partition_data(local_g, own_partition_idx, + node_features, edge_features, + partition_book, node_type_list, + edge_type_list, device) \ No newline at end of file From 5bdfbaac3d5a4fc9d2846ee5750045832ed6b08d Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 28 Apr 2023 09:25:45 +0000 Subject: [PATCH 04/58] Update documentation --- docs/source/data_loading.rst | 61 +++++++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/docs/source/data_loading.rst b/docs/source/data_loading.rst index d2b4daf..978be86 100644 --- a/docs/source/data_loading.rst +++ b/docs/source/data_loading.rst @@ -3,15 +3,15 @@ Data loading and graph construction ========================================================== -After partitioning the graph using DGL's `partition_graph `_ function, SAR can load the graph data using :func:`sar.load_dgl_partition_data`. This yields a :class:`sar.common_tuples.PartitionData` object. The ``PartitionData`` object can then be used to construct various types of graph-like objects that can be passed to GNN models. You can construct graph objects to use for distributed full-batch training or graph objects to use for distributed training as follows: +After partitioning the graph using DGL's `partition_graph `_ function, SAR can load the graph data using :func:`sar.load_dgl_partition_data`. This yields a :class:`sar.common_tuples.PartitionData` object. The ``PartitionData`` object can then be used to construct various types of graph-like objects that can be passed to GNN models. You can construct graph objects to use for distributed full-batch training or graph objects to use for distributed training as follows: .. contents:: :local: :depth: 3 - + Full-batch training --------------------------------------------------------------------------------------- - + Constructing the full graph for sequential aggregation and rematerialization ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Construct a single distributed graph object of type :class:`sar.core.GraphShardManager`:: @@ -20,7 +20,7 @@ Construct a single distributed graph object of type :class:`sar.core.GraphShardM .. -The ``GraphShardManager`` object encapsulates N DGL graph objects (where N is the number of workers). Each graph object represents the edges incoming from one partition (including the local partition). ``GraphShardManager`` implements the ``update_all`` and ``apply_edges`` methods in addition to several other methods from the standard ``dgl.heterograph.DGLHeterograph`` API. The ``update_all`` and ``apply_edges`` methods implement the sequential aggregation and rematerialization scheme to realize the distributed forward and backward passes. ``GraphShardManager`` can usually be passed to GNN layers instead of ``dgl.heterograph.DGLHeterograph``. See the :ref:`the distributed graph limitations section` for some exceptions. +The ``GraphShardManager`` object encapsulates N DGL graph objects (where N is the number of workers). Each graph object represents the edges incoming from one partition (including the local partition). ``GraphShardManager`` implements the ``update_all`` and ``apply_edges`` methods in addition to several other methods from the standard ``dgl.heterograph.DGLGraph`` API. The ``update_all`` and ``apply_edges`` methods implement the sequential aggregation and rematerialization scheme to realize the distributed forward and backward passes. ``GraphShardManager`` can usually be passed to GNN layers instead of ``dgl.heterograph.DGLGraph``. See the :ref:`the distributed graph limitations section` for some exceptions. Constructing Message Flow Graphs (MFGs) for sequential aggregation and rematerialization ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -86,7 +86,7 @@ For sampling-based training, use the dataloader provided by SAR: :func:`sar.Data :: - shard_manager = sar.construct_full_graph(partition_data) + shard_manager = sar.construct_full_graph(partition_data) neighbor_sampler = sar.DistNeighborSampler( [15, 10, 5], #Fanout for every layer @@ -103,11 +103,48 @@ For sampling-based training, use the dataloader provided by SAR: :func:`sar.Data for blocks in dataloader: output = gnn_model(blocks) ... - -.. +.. + + +Full-graph inference +--------------------------------------------------------------------------------------- +SAR might also be utilized just for model evaluation. It is preferable to evaluate the model on the entire graph while performing mini-batch distributed training with the DGL package. To accomplish this, SAR can turn a `DistGraph `_ object into a GraphShardManager object, allowing for distributed full-graph inference. The procedure is simple since no further steps are required because the model parameters are already synchronized during inference. You can use :func:`sar.convert_dist_graph` in the following way to perform full-graph inference: +:: + + class GNNModel(nn.Module): + def __init__(n_layers: int): + super().__init__() + self.convs = nn.ModuleList([ + dgl.nn.SAGEConv(100, 100) + for _ in range(n_layers) + ]) + + # forward function prepared for mini-batch training + def forward(blocks: List[DGLBlock], features: torch.Tensor): + h = features + for idx, (layer, block) in enumerate(zip(self.convs, blocks)): + h = self.convs[idx](blocks[idx], h) + return h + + # implement inference function for full-graph input + def full_graph_inference(graph: sar.GraphShardManager, featues: torch.Tensor): + h = features + for idx, layer in enumerate(self.convs): + h = layer(graph, h) + return h + + # model wrapped in pytorch DistributedDataParallel + gnn_model = th.nn.parallel.DistributedDataParallel(GNNModel(3)) - + # Convert DistGraph into GraphShardManager + gsm = sar.convert_dist_graph(g).to(device) + + # Access to model through DistributedDataParallel module field + model_out = gnn_model.module.full_graph_inference(train_blocks, local_node_features) +.. + + Relevant methods --------------------------------------------------------------------------------------- @@ -117,11 +154,11 @@ Relevant methods .. autosummary:: :toctree: Data loading and graph construction :template: distneighborsampler - - load_dgl_partition_data + + load_dgl_partition_data construct_full_graph - construct_mfgs + construct_mfgs + convert_dist_graph DataLoader DistNeighborSampler - From 18001e3899881a155f2e452fa7662e50a3e4d755 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 2 May 2023 12:31:14 +0000 Subject: [PATCH 05/58] Remove 'reshuffle' parameter from docs --- docs/source/quick_start.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index f411463..5d265b0 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -11,7 +11,7 @@ Follow the following steps to enable distributed training in your DGL code: Partition the graph ---------------------------------- -Partition the graph using DGL's `partition_graph `_ function. See `here `_ for an example. The number of partitions should be the same as the number of training machines/workers that will be used. SAR requires consecutive node indices in each partition, and requires that the partition information include the one-hop neighborhoods of all nodes in the partition. Setting ``num_hops = 1`` and ``reshuffle = True`` in the call to ``partition_graph`` takes care of these requirements. ``partition_graph`` yields a directory structure with the partition information and a .json file ``graph_name.json``. +Partition the graph using DGL's `partition_graph `_ function. See `here `_ for an example. The number of partitions should be the same as the number of training machines/workers that will be used. SAR requires consecutive node indices in each partition, and requires that the partition information include the one-hop neighborhoods of all nodes in the partition. Setting ``num_hops = 1`` in the call to ``partition_graph`` takes care of these requirements. ``partition_graph`` yields a directory structure with the partition information and a .json file ``graph_name.json``. An example of partitioning the ogbn-arxiv graph in two parts: :: @@ -44,7 +44,7 @@ An example of partitioning the ogbn-arxiv graph in two parts: :: graph.ndata[name] = val dgl.distributed.partition_graph( - graph, 'arxiv', 2, './test_partition_data/', num_hops=1, reshuffle=True) + graph, 'arxiv', 2, './test_partition_data/', num_hops=1) .. From 5b18cfdf34c76bceddd044dbde5e6276f27dc1f7 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 2 May 2023 14:07:51 +0000 Subject: [PATCH 06/58] Add .gitignore file --- .gitignore | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100755 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..1ffa5f3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,55 @@ +# Python cache +__pycache__/ +*.pyc + +# Jupyter notebook checkpoints +.ipynb_checkpoints/ + +# Compiled Python files +*.pyc +*.pyo +*.pyd +__pycache__/ + +# Build directories +build/ +dist/ +*.egg-info/ + + +# Package distribution +*.egg +*.egg-info + +# IDE and editor files +.vscode/ +.idea/ +*.iml +*.iws +*.ipr + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Datasets and partitions +dataset/ +datasets/ +partition_data/ \ No newline at end of file From 29931386758c967d564ab255c0271d43d669c560 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 2 May 2023 14:10:01 +0000 Subject: [PATCH 07/58] Allow SAR wheel building --- examples/sar | 1 - requirements.txt | 9 ++++----- sar/__init__.py | 1 + setup.py | 19 +++++++++++++++++++ 4 files changed, 24 insertions(+), 6 deletions(-) delete mode 120000 examples/sar create mode 100644 setup.py diff --git a/examples/sar b/examples/sar deleted file mode 120000 index 2e70ff7..0000000 --- a/examples/sar +++ /dev/null @@ -1 +0,0 @@ -../sar \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b107e5d..310d901 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ --find-links=https://data.dgl.ai/wheels/repo.html -dgl==0.8.0 -numpy==1.22.0 -ogb==1.3.1 -torch==1.13.1 -ifaddr==0.1.7 +dgl>=1.0.0 +numpy>=1.22.0 +torch>=1.10.0 +ifaddr>=0.1.7 \ No newline at end of file diff --git a/sar/__init__.py b/sar/__init__.py index e519173..65a41f8 100644 --- a/sar/__init__.py +++ b/sar/__init__.py @@ -22,6 +22,7 @@ ''' Top-level SAR package ''' +from . import core from .comm import initialize_comms, rank, world_size, comm_device,\ nfs_ip_init, sync_params, gather_grads diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..90f64d2 --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup, find_packages + +setup( + name='SAR', + version='0.1.0', + install_requires=[ + 'dgl>=1.0.0', + 'numpy>=1.22.0', + 'torch>=1.10.0', + 'ifaddr>=0.1.7' + ], + packages=find_packages(), + author='Hesham Mostafa', + author_email='hesham.mostafa@intel.com', + description='A Python library for distributed training of Graph Neural Networks (GNNs) on large graphs, ' + 'supporting both full-batch and sampling-based training, and utilizing a sequential aggregation' + 'and rematerialization technique for linear memory scaling.', + url='https://github.com/IntelLabs/SAR/', +) \ No newline at end of file From ccb7cd66dd839dfe2ea9b27deb8c17641ca5f193 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 5 May 2023 07:44:23 +0000 Subject: [PATCH 08/58] Add pytests - test_patch_dgl --- sar/__init__.py | 4 ++-- tests/pytest.ini | 5 +++++ tests/test_patch_dgl.py | 19 +++++++++++++++++ tests/utils.py | 45 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 tests/pytest.ini create mode 100644 tests/test_patch_dgl.py create mode 100644 tests/utils.py diff --git a/sar/__init__.py b/sar/__init__.py index e519173..23478c0 100644 --- a/sar/__init__.py +++ b/sar/__init__.py @@ -32,7 +32,7 @@ from .distributed_bn import DistributedBN1D from .config import Config from .edge_softmax import edge_softmax -from .patch_dgl import patch_dgl +from .patch_dgl import patch_dgl, patched_edge_softmax from .logging_setup import logging_setup, logger @@ -41,4 +41,4 @@ 'construct_mfgs', 'construct_full_graph', 'GraphShardManager', 'load_dgl_partition_data', 'suffix_key_lookup', 'Config', 'edge_softmax', 'message_has_parameters', 'DistributedBlock', 'DistNeighborSampler', 'DataLoader', - 'logging_setup', 'logger', 'sync_params', 'gather_grads', 'patch_dgl'] + 'logging_setup', 'logger', 'sync_params', 'gather_grads', 'patch_dgl', 'patched_edge_softmax'] diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 0000000..c357cc1 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +testpaths = tests/ +python_files = test_*.py +python_classes = Test* +python_functions = test_* \ No newline at end of file diff --git a/tests/test_patch_dgl.py b/tests/test_patch_dgl.py new file mode 100644 index 0000000..12082e8 --- /dev/null +++ b/tests/test_patch_dgl.py @@ -0,0 +1,19 @@ +from utils import * + +@sar_test +def test_patch_dgl(): + import dgl + original_gat_edge_softmax = dgl.nn.pytorch.conv.gatconv.edge_softmax + original_dotgat_edge_softmax = dgl.nn.pytorch.conv.dotgatconv.edge_softmax + original_agnn_edge_softmax = dgl.nn.pytorch.conv.agnnconv.edge_softmax + + import sar + sar.patch_dgl() + + assert original_gat_edge_softmax == dgl.nn.functional.edge_softmax + assert original_dotgat_edge_softmax == dgl.nn.functional.edge_softmax + assert original_agnn_edge_softmax == dgl.nn.functional.edge_softmax + + assert dgl.nn.pytorch.conv.gatconv.edge_softmax == sar.patched_edge_softmax + assert dgl.nn.pytorch.conv.dotgatconv.edge_softmax == sar.patched_edge_softmax + assert dgl.nn.pytorch.conv.agnnconv.edge_softmax == sar.patched_edge_softmax \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..4192a84 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,45 @@ + +import multiprocessing as mp +import traceback +import pytest + +def sar_test(func): + """ + A decorator function that wraps all SAR tests with the primary objective + of facilitating module imports in tests without affecting other tests. + + :param func: The function that serves as the entry point to the test. + :type func: function + :returns: A function that encapsulates the pytest function. + """ + def test_wrapper(*args, **kwargs): + """ + The wrapping process involves defining another nested function, which is then invoked by a newly spawned process. + function spawns a new process and uses the "join" method to wait for the results. + Upon completion of the process, error and result handling are performed. + """ + def process_wrapper(func, mp_dict, *args, **kwargs): + try: + result = func(*args, **kwargs) + mp_dict["result"] = result + except Exception as e: + mp_dict['traceback'] = str(traceback.format_exc()) + mp_dict["exception"] = e + + manager = mp.Manager() + mp_dict = manager.dict() + + mp_args = (func, mp_dict) + args + p = mp.Process(target=process_wrapper, args=mp_args, **kwargs) + p.start() + p.join() + + if 'exception' in mp_dict: + print(mp_dict['exception'].args) + msg = mp_dict.get('traceback', ()) + for e_arg in mp_dict['exception'].args: + msg += str(e_arg) + pytest.fail(str(msg), pytrace=False) + + return mp_dict["result"] + return test_wrapper From 525720d1adf39eb1db7063a8f3495082b5945fef Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 5 May 2023 14:16:53 +0000 Subject: [PATCH 09/58] Add test_sar_full_graph inference test --- sar/__init__.py | 4 +- tests/models.py | 16 ++++++ tests/test_patch_dgl.py | 3 +- tests/test_sar.py | 106 ++++++++++++++++++++++++++++++++++++++++ tests/utils.py | 12 +++-- 5 files changed, 133 insertions(+), 8 deletions(-) create mode 100644 tests/models.py create mode 100644 tests/test_sar.py diff --git a/sar/__init__.py b/sar/__init__.py index 23478c0..87b2d69 100644 --- a/sar/__init__.py +++ b/sar/__init__.py @@ -32,7 +32,7 @@ from .distributed_bn import DistributedBN1D from .config import Config from .edge_softmax import edge_softmax -from .patch_dgl import patch_dgl, patched_edge_softmax +from .patch_dgl import patch_dgl, patched_edge_softmax, RelGraphConv from .logging_setup import logging_setup, logger @@ -41,4 +41,4 @@ 'construct_mfgs', 'construct_full_graph', 'GraphShardManager', 'load_dgl_partition_data', 'suffix_key_lookup', 'Config', 'edge_softmax', 'message_has_parameters', 'DistributedBlock', 'DistNeighborSampler', 'DataLoader', - 'logging_setup', 'logger', 'sync_params', 'gather_grads', 'patch_dgl', 'patched_edge_softmax'] + 'logging_setup', 'logger', 'RelGraphConv', 'sync_params', 'gather_grads', 'patch_dgl', 'patched_edge_softmax'] diff --git a/tests/models.py b/tests/models.py new file mode 100644 index 0000000..e864a08 --- /dev/null +++ b/tests/models.py @@ -0,0 +1,16 @@ +import dgl +import torch.nn.functional as F +from torch import nn + +class GNNModel(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): + super().__init__() + + self.convs = nn.ModuleList([ + dgl.nn.GraphConv(in_dim, out_dim, weight=False, bias=False), + ]) + + def forward(self, graph, features): + for idx, conv in enumerate(self.convs): + features = conv(graph, features) + return features \ No newline at end of file diff --git a/tests/test_patch_dgl.py b/tests/test_patch_dgl.py index 12082e8..b96ffc6 100644 --- a/tests/test_patch_dgl.py +++ b/tests/test_patch_dgl.py @@ -16,4 +16,5 @@ def test_patch_dgl(): assert dgl.nn.pytorch.conv.gatconv.edge_softmax == sar.patched_edge_softmax assert dgl.nn.pytorch.conv.dotgatconv.edge_softmax == sar.patched_edge_softmax - assert dgl.nn.pytorch.conv.agnnconv.edge_softmax == sar.patched_edge_softmax \ No newline at end of file + assert dgl.nn.pytorch.conv.RelGraphConv == sar.RelGraphConv + assert dgl.nn.RelGraphConv == sar.RelGraphConv \ No newline at end of file diff --git a/tests/test_sar.py b/tests/test_sar.py new file mode 100644 index 0000000..255192f --- /dev/null +++ b/tests/test_sar.py @@ -0,0 +1,106 @@ +from utils import * +import os +import scipy +import tempfile +import logging +import numpy as np +# Do not import DGL and SAR - these modules should be independent in each process + + + +def sar_process(mp_dict, rank, world_size, tmp_dir): + """ + This function should be an entry point to the 'independent' process. + It has to simulate behaviour of SAR which will be spawned across different + machines independently from other instances. Each process have individual memory space + so it is suitable environment for testing SAR + """ + import dgl + import torch + import sar + from models import GNNModel + + try: + if rank == 0: + # partitioning takes place offline, however + # for testing random graph is needed - only master node should do this + # random graph partitions will be then placed in temporary directory + graph = dgl.rand_graph(1000, 2500) + graph = dgl.add_self_loop(graph) + graph.ndata.clear() + graph.ndata['features'] = torch.rand((graph.num_nodes(), 1)) + + dgl.distributed.partition_graph( + graph, + 'random_graph', + 2, + tmp_dir, + num_hops=1, + balance_edges=True) + + part_file = os.path.join(tmp_dir, 'random_graph.json') + ip_file = os.path.join(tmp_dir, 'ip_file') + + master_ip_address = sar.nfs_ip_init(rank, ip_file) + + sar.initialize_comms(rank, + world_size, + master_ip_address, + 'ccl') + + torch.distributed.barrier() # wait for rank 0 to finish graph creation + + partition_data = sar.load_dgl_partition_data( + part_file, rank, 'cpu') + + full_graph_manager = sar.construct_full_graph(partition_data).to('cpu') + features = sar.suffix_key_lookup(partition_data.node_features, 'features') + del partition_data + + model = GNNModel(features.size(1), 32, features.size(1)).to('cpu') + sar.sync_params(model) + + logits = model(full_graph_manager, features) + + # put calculated results in multiprocessing dictionary + mp_dict[f"result_{rank}"] = logits.detach() + + + if rank == 0: + # only rank 0 is runned within parent process + # return used model and generated graph to caller + return model, graph + + except Exception as e: + mp_dict['traceback'] = str(traceback.format_exc()) + mp_dict['exception'] = e + return None, None + + +def test_sar_full_graph(): + with tempfile.TemporaryDirectory() as tmpdir: + manager = mp.Manager() + mp_dict = manager.dict() + + p = mp.Process(target=sar_process, args=(mp_dict, 1, 2, tmpdir)) + p.daemon = True + p.start() + + model, graph = sar_process(mp_dict, 0, 2, tmpdir) + p.join() + + if 'exception' in mp_dict: + handle_mp_exception(mp_dict) + + out = model(graph, graph.ndata['features']).numpy() + + # compare mean of all values instead of each node feature individually + # TODO: reorder SAR calculated logits to original NID mapping + full_graph_mean = out.mean() + + r0_logits = mp_dict["result_0"].numpy() + r1_logits = mp_dict["result_1"].numpy() + sar_logits_mean = np.concatenate((r0_logits, r1_logits)).mean() + + rtol = sar_logits_mean / 1000 + assert full_graph_mean == pytest.approx(sar_logits_mean, rtol) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index 4192a84..562b970 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,12 @@ import traceback import pytest +def handle_mp_exception(mp_dict): + msg = mp_dict.get('traceback', "") + for e_arg in mp_dict['exception'].args: + msg += str(e_arg) + pytest.fail(str(msg), pytrace=False) + def sar_test(func): """ A decorator function that wraps all SAR tests with the primary objective @@ -35,11 +41,7 @@ def process_wrapper(func, mp_dict, *args, **kwargs): p.join() if 'exception' in mp_dict: - print(mp_dict['exception'].args) - msg = mp_dict.get('traceback', ()) - for e_arg in mp_dict['exception'].args: - msg += str(e_arg) - pytest.fail(str(msg), pytrace=False) + handle_mp_exception(mp_dict) return mp_dict["result"] return test_wrapper From f7dceb02715ea341f03da8ac786d5cfd00b54b82 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 8 May 2023 09:09:43 +0000 Subject: [PATCH 10/58] Cleanup; Test 2 and 4 partitions --- tests/test_patch_dgl.py | 2 ++ tests/test_sar.py | 40 ++++++++++++++++++++++++---------------- tests/utils.py | 5 ++++- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/tests/test_patch_dgl.py b/tests/test_patch_dgl.py index b96ffc6..f3c0ba5 100644 --- a/tests/test_patch_dgl.py +++ b/tests/test_patch_dgl.py @@ -1,4 +1,6 @@ from utils import * +# Do not import DGL and SAR - these modules should be +# independently loaded inside each process @sar_test def test_patch_dgl(): diff --git a/tests/test_sar.py b/tests/test_sar.py index 255192f..d3ede9f 100644 --- a/tests/test_sar.py +++ b/tests/test_sar.py @@ -1,13 +1,11 @@ from utils import * import os -import scipy import tempfile -import logging -import numpy as np -# Do not import DGL and SAR - these modules should be independent in each process +import numpy as np +# Do not import DGL and SAR - these modules should be +# independently loaded inside each process - def sar_process(mp_dict, rank, world_size, tmp_dir): """ This function should be an entry point to the 'independent' process. @@ -33,7 +31,7 @@ def sar_process(mp_dict, rank, world_size, tmp_dir): dgl.distributed.partition_graph( graph, 'random_graph', - 2, + world_size, tmp_dir, num_hops=1, balance_edges=True) @@ -65,7 +63,6 @@ def sar_process(mp_dict, rank, world_size, tmp_dir): # put calculated results in multiprocessing dictionary mp_dict[f"result_{rank}"] = logits.detach() - if rank == 0: # only rank 0 is runned within parent process # return used model and generated graph to caller @@ -77,17 +74,25 @@ def sar_process(mp_dict, rank, world_size, tmp_dir): return None, None -def test_sar_full_graph(): +@pytest.mark.parametrize('world_size', [2, 4]) +@sar_test +def test_sar_full_graph(world_size): + print(world_size) with tempfile.TemporaryDirectory() as tmpdir: manager = mp.Manager() mp_dict = manager.dict() - p = mp.Process(target=sar_process, args=(mp_dict, 1, 2, tmpdir)) - p.daemon = True - p.start() + processes = [] + for rank in range(1, world_size): + p = mp.Process(target=sar_process, args=(mp_dict, rank, world_size, tmpdir)) + p.daemon = True + p.start() + processes.append(p) - model, graph = sar_process(mp_dict, 0, 2, tmpdir) - p.join() + model, graph = sar_process(mp_dict, 0, world_size, tmpdir) + + for p in processes: + p.join() if 'exception' in mp_dict: handle_mp_exception(mp_dict) @@ -98,9 +103,12 @@ def test_sar_full_graph(): # TODO: reorder SAR calculated logits to original NID mapping full_graph_mean = out.mean() - r0_logits = mp_dict["result_0"].numpy() - r1_logits = mp_dict["result_1"].numpy() - sar_logits_mean = np.concatenate((r0_logits, r1_logits)).mean() + sar_logits = mp_dict["result_0"].numpy() + for rank in range(1, world_size): + rank_logits = mp_dict[f"result_{rank}"].numpy() + sar_logits = np.concatenate((sar_logits, rank_logits)) + + sar_logits_mean = sar_logits.mean() rtol = sar_logits_mean / 1000 assert full_graph_mean == pytest.approx(sar_logits_mean, rtol) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index 562b970..f711858 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,11 +2,13 @@ import multiprocessing as mp import traceback import pytest +import functools def handle_mp_exception(mp_dict): msg = mp_dict.get('traceback', "") for e_arg in mp_dict['exception'].args: msg += str(e_arg) + print(str(msg), flush=True) pytest.fail(str(msg), pytrace=False) def sar_test(func): @@ -18,6 +20,7 @@ def sar_test(func): :type func: function :returns: A function that encapsulates the pytest function. """ + @functools.wraps(func) def test_wrapper(*args, **kwargs): """ The wrapping process involves defining another nested function, which is then invoked by a newly spawned process. @@ -36,7 +39,7 @@ def process_wrapper(func, mp_dict, *args, **kwargs): mp_dict = manager.dict() mp_args = (func, mp_dict) + args - p = mp.Process(target=process_wrapper, args=mp_args, **kwargs) + p = mp.Process(target=process_wrapper, args=mp_args, kwargs=kwargs) p.start() p.join() From 498d0d51eae980bea34f8e64fc22906cd1bf3848 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 8 May 2023 09:20:23 +0000 Subject: [PATCH 11/58] Add github workflow --- .github/workflows/sar_test.yaml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/sar_test.yaml diff --git a/.github/workflows/sar_test.yaml b/.github/workflows/sar_test.yaml new file mode 100644 index 0000000..1b259df --- /dev/null +++ b/.github/workflows/sar_test.yaml @@ -0,0 +1,32 @@ +name: SAR tests + +on: + pull_request: + branches: [main, convert] + workflow_dispatch: + +jobs: + sar_tests: + runs-on: ubuntu-latest + steps: + - name: Pull SAR + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.8' + + - name: Install requirements + run: | + python -m pip install --upgrade pip + python -m pip install pytest + python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu + python setup.py install + + - name: Run pytest + run: | + set +e + python -m pytest tests/ -sv \ No newline at end of file From 9662b0e5a5fb9c08ada2ae8b43d933cc25ba7cb4 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 8 May 2023 13:20:38 +0000 Subject: [PATCH 12/58] Add convert DistGraph test --- tests/test_sar.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/test_sar.py b/tests/test_sar.py index d3ede9f..ea1d155 100644 --- a/tests/test_sar.py +++ b/tests/test_sar.py @@ -111,4 +111,42 @@ def test_sar_full_graph(world_size): sar_logits_mean = sar_logits.mean() rtol = sar_logits_mean / 1000 - assert full_graph_mean == pytest.approx(sar_logits_mean, rtol) \ No newline at end of file + assert full_graph_mean == pytest.approx(sar_logits_mean, rtol) + + + + +@sar_test +def test_convert_dist_graph(): + with tempfile.TemporaryDirectory() as tmpdir: + import dgl + import torch + import sar + graph_name = 'random_graph' + part_file = os.path.join(tmpdir, 'random_graph.json') + ip_file = os.path.join(tmpdir, 'ip_file') + g = dgl.rand_graph(1000, 2500) + g = dgl.add_self_loop(g) + g.ndata.clear() + g.ndata['features'] = torch.rand((g.num_nodes(), 1)) + dgl.distributed.partition_graph( + g, + 'random_graph', + 1, + tmpdir, + num_hops=1, + balance_edges=True) + + master_ip_address = sar.nfs_ip_init(0, ip_file) + sar.initialize_comms(0, 1, master_ip_address, 'ccl') + + dgl.distributed.initialize("kv_ip_config.txt") + dist_g = dgl.distributed.DistGraph( + graph_name, part_config=part_file) + + sar_g = sar.convert_dist_graph(dist_g) + print(sar_g.graph_shards[0].graph.ndata) + assert len(sar_g.graph_shards) == dist_g.get_partition_book().num_partitions() + assert dist_g.num_edges() == sar_g.num_edges() + # this check fails (1000 != 2000) + #assert dist_g.num_nodes() == sar_g.num_nodes() \ No newline at end of file From 5812bd77ffacd9d832c50358f0b6a30e7cc9c470 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Tue, 13 Jun 2023 08:20:30 +0200 Subject: [PATCH 13/58] remove convert branch from workflow --- .github/workflows/sar_test.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/sar_test.yaml b/.github/workflows/sar_test.yaml index 1b259df..87fda95 100644 --- a/.github/workflows/sar_test.yaml +++ b/.github/workflows/sar_test.yaml @@ -2,7 +2,7 @@ name: SAR tests on: pull_request: - branches: [main, convert] + branches: [main] workflow_dispatch: jobs: @@ -29,4 +29,4 @@ jobs: - name: Run pytest run: | set +e - python -m pytest tests/ -sv \ No newline at end of file + python -m pytest tests/ -sv From 044caf429f12a6f5158b2d713add730674e10922 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 30 May 2023 15:58:38 +0200 Subject: [PATCH 14/58] Add support for ndata --- docs/source/shards.rst | 2 +- sar/core/graphshard.py | 80 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/docs/source/shards.rst b/docs/source/shards.rst index ef90621..9fbcd53 100644 --- a/docs/source/shards.rst +++ b/docs/source/shards.rst @@ -14,7 +14,7 @@ In the distributed implementation of the sequential backward pass in ``update_a Limitations of the distributed graph objects ------------------------------------------------------------------------------------ -Keep in mind that the distributed graph class :class:`sar.core.GraphShardManager` does not implement all the functionality of DGL's native graph class. For example, it does not impelement the ``successors`` and ``predecessors`` methods. It supports primarily the methods of DGL's native graphs that are relevant to GNNs such as ``update_all``, ``apply_edges``, and ``local_scope``. It also supports setting graph node and edge features through the dictionaries ``srcdata``, ``dstdata``, and ``edata``. Note that :class:`sar.core.GraphShardManager` does not support the ``ndata`` member dictionary. +Keep in mind that the distributed graph class :class:`sar.core.GraphShardManager` does not implement all the functionality of DGL's native graph class. For example, it does not impelement the ``successors`` and ``predecessors`` methods. It supports primarily the methods of DGL's native graphs that are relevant to GNNs such as ``update_all``, ``apply_edges``, and ``local_scope``. It also supports setting graph node and edge features through the dictionaries ``srcdata``, ``dstdata``, and ``edata``. To remain compatible with DGLGraph :class:`sar.core.GraphShardManager` provides also access to the ``ndata`` member, which works as alias to ``srcdata``, however it is not accessible when working with MFGs. :class:`sar.core.GraphShardManager` also supports the ``in_degrees`` and ``out_degrees`` members and supports querying the number of nodes and edges in the graph. diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index f819cff..ac9bce2 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -123,6 +123,42 @@ def to(self, device: torch.device): if self._graph_reverse is not None: self._graph_reverse = self._graph_reverse.to(device) +class GSMDataStore(): + """ + A straightforward class designed to manage dictionaries for srdata, dstdata, and edata. + It enables the creation of chained data and allows for rewinding all data simultaneously. + """ + def __init__(self, src_is_tgt: bool, num_src_nodes: int, num_dst_nodes: int, num_edges: int): + self.src_is_tgt = src_is_tgt + + if src_is_tgt: + assert num_src_nodes == num_dst_nodes, "Number of source nodes must be equal to the number of target nodes" + + if src_is_tgt: + self.srcdata = self.dstdata = self.ndata = ChainedDataView(num_src_nodes) + else: + self.srcdata = ChainedDataView(num_src_nodes) + self.dstdata = ChainedDataView(num_dst_nodes) + + self.edata = ChainedDataView(num_edges) + + def chain_data(self): + if self.src_is_tgt: + self.srcdata = self.dstdata = ChainedDataView(self.srcdata.acceptable_size, self.srcdata) + else: + self.srcdata = ChainedDataView(self.srcdata.acceptable_size, self.srcdata) + self.dstdata = ChainedDataView(self.dstdata.acceptable_size, self.dstdata) + + self.edata = ChainedDataView(self.edata.acceptable_size, self.edata) + + def rewind_data(self): + if self.src_is_tgt: + self.srcdata = self.dstdata = self.srcdata.rewind() + else: + self.srcdata = self.srcdata.rewind() + self.dstdata = self.dstdata.rewind() + + self.edata.rewind() class ChainedDataView(MutableMapping): """A dictionary that chains to children dictionary on missed __getitem__ calls""" @@ -195,6 +231,10 @@ def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, loca super().__init__() self.graph_shards = graph_shards + # source nodes and target nodes are all the same + # srcdata, dstdata and ndata should be also the same + self.src_is_tgt = local_src_seeds is local_tgt_seeds + assert all(self.tgt_node_range == x.tgt_range for x in self.graph_shards[1:]) @@ -227,17 +267,16 @@ def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, loca self.in_degrees_cache: Dict[Optional[str], Tensor] = {} self.out_degrees_cache: Dict[Optional[str], Tensor] = {} - self.dstdata = ChainedDataView(self.num_dst_nodes()) - self.srcdata = ChainedDataView(self.num_src_nodes()) - self.edata = ChainedDataView(self.num_edges()) + self.datastore = GSMDataStore(self.src_is_tgt, self.num_src_nodes(), + self.num_dst_nodes(), self.num_edges()) self._sampling_graph = None - @ property + @property def tgt_node_range(self) -> Tuple[int, int]: return self.graph_shards[0].tgt_range - @ property + @property def local_src_node_range(self) -> Tuple[int, int]: return self.graph_shards[rank()].src_range @@ -292,19 +331,30 @@ def update_boundary_nodes_indices(self) -> List[Tensor]: ind.sub_(self.tgt_node_range[0]) return indices_required_from_me - @ contextmanager + @contextmanager def local_scope(self): - self.dstdata = ChainedDataView( - self.dstdata.acceptable_size, self.dstdata) - self.srcdata = ChainedDataView( - self.srcdata.acceptable_size, self.srcdata) - self.edata = ChainedDataView(self.edata.acceptable_size, self.edata) + self.datastore.chain_data() yield - self.dstdata = self.dstdata.rewind() - self.srcdata = self.srcdata.rewind() - self.edata = self.edata.rewind() + self.datastore.rewind_data() + + @property + def srcdata(self): + return self.datastore.srcdata + + @property + def dstdata(self): + return self.datastore.dstdata - @ property + @property + def edata(self): + return self.datastore.edata + + @property + def ndata(self): + assert self.src_is_tgt, "ndata shouldn't be used with MFGs" + return self.datastore.srcdata + + @property def is_block(self): return True From 2d63cf66ceb657161d604e078726390b1d987e0b Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 14 Jun 2023 13:37:34 +0200 Subject: [PATCH 15/58] apply review comments --- docs/source/data_loading.rst | 2 +- docs/source/quick_start.rst | 6 +++--- examples/train_distdgl_with_sar_inference.py | 14 +++++--------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/docs/source/data_loading.rst b/docs/source/data_loading.rst index 978be86..f6d3cfa 100644 --- a/docs/source/data_loading.rst +++ b/docs/source/data_loading.rst @@ -141,7 +141,7 @@ SAR might also be utilized just for model evaluation. It is preferable to evalua gsm = sar.convert_dist_graph(g).to(device) # Access to model through DistributedDataParallel module field - model_out = gnn_model.module.full_graph_inference(train_blocks, local_node_features) + model_out = gnn_model.module.full_graph_inference(gsm, local_node_features) .. diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index 5d265b0..2dfef57 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -11,7 +11,7 @@ Follow the following steps to enable distributed training in your DGL code: Partition the graph ---------------------------------- -Partition the graph using DGL's `partition_graph `_ function. See `here `_ for an example. The number of partitions should be the same as the number of training machines/workers that will be used. SAR requires consecutive node indices in each partition, and requires that the partition information include the one-hop neighborhoods of all nodes in the partition. Setting ``num_hops = 1`` in the call to ``partition_graph`` takes care of these requirements. ``partition_graph`` yields a directory structure with the partition information and a .json file ``graph_name.json``. +Partition the graph using DGL's `partition_graph `_ function. See `here `_ for an example. The number of partitions should be the same as the number of training machines/workers that will be used. SAR requires consecutive node indices in each partition, and requires that the partition information include the one-hop neighborhoods of all nodes in the partition. Setting ``num_hops = 1`` and ``reshuffle = True`` (in DGL < 1.0) in the call to ``partition_graph`` takes care of these requirements. ``partition_graph`` yields a directory structure with the partition information and a .json file ``graph_name.json``. An example of partitioning the ogbn-arxiv graph in two parts: :: @@ -44,8 +44,8 @@ An example of partitioning the ogbn-arxiv graph in two parts: :: graph.ndata[name] = val dgl.distributed.partition_graph( - graph, 'arxiv', 2, './test_partition_data/', num_hops=1) - + graph, 'arxiv', 2, './test_partition_data/', num_hops=1) # use reshuffle=True in DGL < 1.0 + .. Note that we add the labels, and the train/test/validation masks as node features so that they get split into multiple parts alongside the graph. diff --git a/examples/train_distdgl_with_sar_inference.py b/examples/train_distdgl_with_sar_inference.py index 8c5b42a..0678103 100755 --- a/examples/train_distdgl_with_sar_inference.py +++ b/examples/train_distdgl_with_sar_inference.py @@ -16,7 +16,7 @@ def load_subtensor(g, seeds, input_nodes, device, load_feat=True): """ - Copys features and labels of a set of nodes onto GPU. + Copies features and labels of a set of nodes onto GPU. """ batch_inputs = ( g.ndata["features"][input_nodes].to(device) if load_feat else None @@ -427,11 +427,10 @@ def main(args): else: dev_id = g.rank() % args.num_gpus device = th.device("cuda:" + str(dev_id)) - n_classes = args.n_classes - if n_classes == 0: - labels = g.ndata["labels"][np.arange(g.num_nodes())] - n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) - del labels + + labels = g.ndata["labels"][np.arange(g.num_nodes())] + n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) + del labels print("#labels:", n_classes) # Pack data @@ -450,9 +449,6 @@ def main(args): parser.add_argument( "--part_config", type=str, help="The path to the partition config file" ) - parser.add_argument( - "--n_classes", type=int, default=0, help="the number of classes" - ) parser.add_argument( "--backend", type=str, From 584ec9308be14a1711ba9007f527d7dd15a7bc9f Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 15 Jun 2023 08:30:35 +0200 Subject: [PATCH 16/58] Remove GSMDataStore class --- sar/core/graphshard.py | 76 +++++++++++++----------------------------- 1 file changed, 23 insertions(+), 53 deletions(-) diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index ac9bce2..cf0ef3b 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -123,42 +123,6 @@ def to(self, device: torch.device): if self._graph_reverse is not None: self._graph_reverse = self._graph_reverse.to(device) -class GSMDataStore(): - """ - A straightforward class designed to manage dictionaries for srdata, dstdata, and edata. - It enables the creation of chained data and allows for rewinding all data simultaneously. - """ - def __init__(self, src_is_tgt: bool, num_src_nodes: int, num_dst_nodes: int, num_edges: int): - self.src_is_tgt = src_is_tgt - - if src_is_tgt: - assert num_src_nodes == num_dst_nodes, "Number of source nodes must be equal to the number of target nodes" - - if src_is_tgt: - self.srcdata = self.dstdata = self.ndata = ChainedDataView(num_src_nodes) - else: - self.srcdata = ChainedDataView(num_src_nodes) - self.dstdata = ChainedDataView(num_dst_nodes) - - self.edata = ChainedDataView(num_edges) - - def chain_data(self): - if self.src_is_tgt: - self.srcdata = self.dstdata = ChainedDataView(self.srcdata.acceptable_size, self.srcdata) - else: - self.srcdata = ChainedDataView(self.srcdata.acceptable_size, self.srcdata) - self.dstdata = ChainedDataView(self.dstdata.acceptable_size, self.dstdata) - - self.edata = ChainedDataView(self.edata.acceptable_size, self.edata) - - def rewind_data(self): - if self.src_is_tgt: - self.srcdata = self.dstdata = self.srcdata.rewind() - else: - self.srcdata = self.srcdata.rewind() - self.dstdata = self.dstdata.rewind() - - self.edata.rewind() class ChainedDataView(MutableMapping): """A dictionary that chains to children dictionary on missed __getitem__ calls""" @@ -267,8 +231,14 @@ def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, loca self.in_degrees_cache: Dict[Optional[str], Tensor] = {} self.out_degrees_cache: Dict[Optional[str], Tensor] = {} - self.datastore = GSMDataStore(self.src_is_tgt, self.num_src_nodes(), - self.num_dst_nodes(), self.num_edges()) + self.srcdata = ChainedDataView(self.num_src_nodes()) + self.edata = ChainedDataView(self.num_edges()) + + if self.src_is_tgt: + assert self.num_src_nodes() == self.num_dst_nodes() + self.dstdata = self.srcdata + else: + self.dstdata = ChainedDataView(self.num_dst_nodes()) self._sampling_graph = None @@ -333,26 +303,26 @@ def update_boundary_nodes_indices(self) -> List[Tensor]: @contextmanager def local_scope(self): - self.datastore.chain_data() + self.srcdata = ChainedDataView( + self.srcdata.acceptable_size, self.srcdata) + self.edata = ChainedDataView(self.edata.acceptable_size, self.edata) + if self.src_is_tgt: + self.dstdata = self.srcdata + else: + self.dstdata = ChainedDataView( + self.dstdata.acceptable_size, self.dstdata) yield - self.datastore.rewind_data() - - @property - def srcdata(self): - return self.datastore.srcdata - - @property - def dstdata(self): - return self.datastore.dstdata - - @property - def edata(self): - return self.datastore.edata + self.srcdata = self.srcdata.rewind() + self.edata = self.edata.rewind() + if self.src_is_tgt: + self.dstdata = self.srcdata + else: + self.dstdata = self.dstdata.rewind() @property def ndata(self): assert self.src_is_tgt, "ndata shouldn't be used with MFGs" - return self.datastore.srcdata + return self.srcdata @property def is_block(self): From 10bd37898e2d02e626c46bd731b1e00dde49606c Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Fri, 16 Jun 2023 06:55:19 +0200 Subject: [PATCH 17/58] Added example script --- examples/correct_and_smooth.py | 506 +++++++++++++++++++++++++++++++++ 1 file changed, 506 insertions(+) create mode 100644 examples/correct_and_smooth.py diff --git a/examples/correct_and_smooth.py b/examples/correct_and_smooth.py new file mode 100644 index 0000000..cfcbb29 --- /dev/null +++ b/examples/correct_and_smooth.py @@ -0,0 +1,506 @@ +# Copyright (c) 2022 Intel Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from argparse import ArgumentParser +import time +import json +import copy +import torch +import torch.nn.functional as F +from torch import nn +import torch.distributed as dist +import dgl # type: ignore +from dgl import function as fn # type: ignore +from dgl.heterograph import DGLBlock # type: ignore +from ogb.nodeproppred import DglNodePropPredDataset + +import sar + + +parser = ArgumentParser(description="CorrectAndSmooth example") + +parser.add_argument('--partitioning-json-file', default='', type=str, + help='Path to the .json file containing partitioning information' +) +parser.add_argument('--ip-file', default='./ip_file', type=str, + help='File with ip-address. Worker 0 creates this file and all others read it' + ) +parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi'], + help='Communication backend to use' + ) +parser.add_argument('--cpu-run', action='store_true', + help='Run on CPUs if set, otherwise run on GPUs' + ) +parser.add_argument('--rank', default=0, type=int, + help='Rank of the current worker ' + ) +parser.add_argument('--world-size', default=2, type=int, + help='Number of workers ' + ) +parser.add_argument('--model', default="mlp", type=str, choices=['mlp', 'linear'], + help='Model type' + ) +parser.add_argument('--num-layers', default=3, type=int, + help='Number of layers in the model' + ) +parser.add_argument('--hidden-layer-dim', default=256, type=int, + help='Dimension of GNN hidden layer' + ) +parser.add_argument('--dropout', default=0.4, type=float, + help='Dropout rate for layers in the model' + ) +parser.add_argument('--lr', default=1e-2, type=float, + help='learning rate' + ) +parser.add_argument('--epochs', default=300, type=int, + help='Number of training epochs' + ) +parser.add_argument('--num-correction-layers', default=50, type=int, + help='The number of correct propagations' + ) +parser.add_argument('--correction-alpha', default=0.979, type=float, + help='The coefficient of correction' + ) +parser.add_argument('--correction-adj', default="DAD", type=str, + help='DAD: D^-0.5 * A * D^-0.5 | DA: D^-1 * A | AD: A * D^-1' + ) +parser.add_argument('--num-smoothing-layers', default=50, type=int, + help='The number of smooth propagations' + ) +parser.add_argument('--smoothing-alpha', default=0.756, type=float, + help='The coefficient of smoothing' + ) +parser.add_argument('--smoothing-adj', default="DAD", type=str, + help='DAD: D^-0.5 * A * D^-0.5 | DA: D^-1 * A | AD: A * D^-1' + ) +parser.add_argument('--autoscale', action="store_true", + help='Automatically determine the scaling factor for "sigma"' + ) +parser.add_argument('--scale', default=20.0, type=float, + help='The scaling factor for "sigma", in case autoscale is set to False' + ) + + +class MLPLinear(nn.Module): + def __init__(self, in_dim, out_dim): + super(MLPLinear, self).__init__() + self.linear = nn.Linear(in_dim, out_dim) + self.reset_parameters() + + def reset_parameters(self): + self.linear.reset_parameters() + + def forward(self, x): + return F.log_softmax(self.linear(x), dim=-1) + + +class MLP(nn.Module): + def __init__(self, in_dim, hid_dim, out_dim, num_layers, dropout=0.0): + super(MLP, self).__init__() + assert num_layers >= 2 + + self.linears = nn.ModuleList() + self.bns = nn.ModuleList() + self.linears.append(nn.Linear(in_dim, hid_dim)) + self.bns.append(nn.BatchNorm1d(hid_dim)) + + for _ in range(num_layers - 2): + self.linears.append(nn.Linear(hid_dim, hid_dim)) + self.bns.append(nn.BatchNorm1d(hid_dim)) + + self.linears.append(nn.Linear(hid_dim, out_dim)) + self.dropout = dropout + self.reset_parameters() + + def reset_parameters(self): + for layer in self.linears: + layer.reset_parameters() + for layer in self.bns: + layer.reset_parameters() + + def forward(self, x): + for linear, bn in zip(self.linears[:-1], self.bns): + x = linear(x) + x = F.relu(x, inplace=True) + x = bn(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.linears[-1](x) + return F.log_softmax(x, dim=-1) + + +class LabelPropagation(nn.Module): + r""" + + Description + ----------- + Introduced in `Learning from Labeled and Unlabeled Data with Label Propagation `_ + + .. math:: + \mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A} + \mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y}, + + where unlabeled data is inferred by labeled data via propagation. + + Parameters + ---------- + num_layers: int + The number of propagations. + alpha: float + The :math:`\alpha` coefficient. + adj: str + 'DAD': D^-0.5 * A * D^-0.5 + 'DA': D^-1 * A + 'AD': A * D^-1 + """ + + def __init__(self, num_layers, alpha, adj="DAD"): + super(LabelPropagation, self).__init__() + + self.num_layers = num_layers + self.alpha = alpha + self.adj = adj + + @torch.no_grad() + def forward( + self, g, labels, mask=None, post_step=lambda y: y.clamp_(0.0, 1.0) + ): + with g.local_scope(): + if labels.dtype == torch.long: + labels = F.one_hot(labels.view(-1)).to(torch.float32) + + y = labels + if mask is not None: + y = torch.zeros_like(labels) + y[mask] = labels[mask] + + last = (1 - self.alpha) * y + degs = g.in_degrees().float().clamp(min=1) + norm = ( + torch.pow(degs, -0.5 if self.adj == "DAD" else -1) + .to(labels.device) + .unsqueeze(1) + ) + + for _ in range(self.num_layers): + # Assume the graphs to be undirected + if self.adj in ["DAD", "AD"]: + y = norm * y + + g.srcdata["h"] = y + g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) + y = self.alpha * g.dstdata["h"] + + if self.adj in ["DAD", "DA"]: + y = y * norm + + y = post_step(last + y) + + return y + + +class CorrectAndSmooth(nn.Module): + r""" + + Description + ----------- + Introduced in `Combining Label Propagation and Simple Models Out-performs Graph Neural Networks `_ + + Parameters + ---------- + num_correction_layers: int + The number of correct propagations. + correction_alpha: float + The coefficient of correction. + correction_adj: str + 'DAD': D^-0.5 * A * D^-0.5 + 'DA': D^-1 * A + 'AD': A * D^-1 + num_smoothing_layers: int + The number of smooth propagations. + smoothing_alpha: float + The coefficient of smoothing. + smoothing_adj: str + 'DAD': D^-0.5 * A * D^-0.5 + 'DA': D^-1 * A + 'AD': A * D^-1 + autoscale: bool, optional + If set to True, will automatically determine the scaling factor :math:`\sigma`. Default is True. + scale: float, optional + The scaling factor :math:`\sigma`, in case :obj:`autoscale = False`. Default is 1. + """ + + def __init__( + self, + num_correction_layers, + correction_alpha, + correction_adj, + num_smoothing_layers, + smoothing_alpha, + smoothing_adj, + autoscale=True, + scale=1.0, + ): + super(CorrectAndSmooth, self).__init__() + + self.autoscale = autoscale + self.scale = scale + + self.prop1 = LabelPropagation( + num_correction_layers, correction_alpha, correction_adj + ) + self.prop2 = LabelPropagation( + num_smoothing_layers, smoothing_alpha, smoothing_adj + ) + + def correct(self, g, y_soft, y_true, mask): + with g.local_scope(): + assert abs(float(y_soft.sum()) / y_soft.size(0) - 1.0) < 1e-2 + numel = ( + int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) + ) + assert y_true.size(0) == numel + + if y_true.dtype == torch.long: + y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to( + y_soft.dtype + ) + + error = torch.zeros_like(y_soft) + error[mask] = y_true - y_soft[mask] + + if self.autoscale: + smoothed_error = self.prop1( + g, error, post_step=lambda x: x.clamp_(-1.0, 1.0) + ) + sigma = error[mask].abs().sum() / numel + scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True) + scale[scale.isinf() | (scale > 1000)] = 1.0 + + result = y_soft + scale * smoothed_error + result[result.isnan()] = y_soft[result.isnan()] + return result + else: + + def fix_input(x): + x[mask] = error[mask] + return x + + smoothed_error = self.prop1(g, error, post_step=fix_input) + + result = y_soft + self.scale * smoothed_error + result[result.isnan()] = y_soft[result.isnan()] + return result + + def smooth(self, g, y_soft, y_true, mask): + with g.local_scope(): + numel = ( + int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) + ) + assert y_true.size(0) == numel + + if y_true.dtype == torch.long: + y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to( + y_soft.dtype + ) + + y_soft[mask] = y_true + return self.prop2(g, y_soft) + + +def evaluate(logits, labels, masks): + results = [] + for indices_name in ['train_indices', 'val_indices', 'test_indices']: + n_correct = (logits[masks[indices_name]].argmax(1) == + labels[masks[indices_name]]).float().sum() + results.extend([n_correct, masks[indices_name].numel()]) + + acc_vec = torch.FloatTensor(results) + # Sum the n_correct, and number of mask elements across all workers + sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True) + (train_acc, val_acc, test_acc) = \ + (acc_vec[0] / acc_vec[1], + acc_vec[2] / acc_vec[3], + acc_vec[4] / acc_vec[5]) + return train_acc, val_acc, test_acc + + +def main(): + args = parser.parse_args() + print(args) + use_gpu = torch.cuda.is_available() and not args.cpu_run + device = torch.device('cuda' if use_gpu else 'cpu') + + # Obtain the ip address of the master through the network file system + master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) + sar.initialize_comms(args.rank, + args.world_size, master_ip_address, + args.backend) + + with open(args.partitioning_json_file, 'r') as f: + data = json.load(f) + dataset_name = data["graph_name"] + + # Load DGL partition data + partition_data = sar.load_dgl_partition_data( + args.partitioning_json_file, args.rank, device) + + # Obtain train,validation, and test masks + # These are stored as node features. Partitioning may prepend + # the node type to the mask names. So we use the convenience function + # suffix_key_lookup to look up the mask name while ignoring the + # arbitrary node type + masks = {} + for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'], + ['train_indices', 'val_indices', 'test_indices']): + boolean_mask = sar.suffix_key_lookup(partition_data.node_features, + mask_name) + masks[indices_name] = boolean_mask.nonzero( + as_tuple=False).view(-1).to(device) + + labels = sar.suffix_key_lookup(partition_data.node_features, + 'labels').long().to(device) + + # Obtain the number of classes by finding the max label across all workers + num_labels = labels.max() + 1 + sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True) + num_labels = num_labels.item() + + features = sar.suffix_key_lookup(partition_data.node_features, 'features').to(device) + full_graph_manager = sar.construct_full_graph(partition_data).to(device) + + + if dataset_name == "ogbn-arxiv": + # in order to perform dataset standarization, we have to calculate dataset's mean + # and standard deviation. It is not possible without communication between workers + local_means = features.mean(0) + workers_means = sar.comm.exchange_tensors([local_means] * args.world_size) + workers_means = torch.stack(workers_means, dim=0) + + local_feature_size = features.shape[0] + workers_feature_sizes = sar.comm.exchange_tensors([torch.tensor([local_feature_size])] * args.world_size) + workers_feature_sizes = torch.stack(workers_feature_sizes, dim=0) + + global_features_sum = torch.mul(workers_means, workers_feature_sizes).sum(dim=0) + global_feature_size = workers_feature_sizes.sum() + global_means = global_features_sum / global_feature_size + + local_std_numerator = torch.pow(features - global_means, 2).sum(dim=0) + workers_std_numerators = sar.comm.exchange_tensors([local_std_numerator] * args.world_size) + workers_std_numerators = torch.stack(workers_std_numerators, dim=0) + global_stds = torch.sqrt(workers_std_numerators.sum(dim=0) / global_feature_size) + + features = (features - global_means) / global_stds + + # We do not need the partition data anymore + del partition_data + + # load model + if args.model == "mlp": + model = MLP( + features.size(1), args.hidden_layer_dim, num_labels, args.num_layers, args.dropout + ).to(device) + elif args.model == "linear": + model = MLPLinear(features.size(1), num_labels).to(device) + else: + raise NotImplementedError(f"Model {args.model} is not supported.") + + print('model', model) + sar.sync_params(model) + + # Obtain the number of labeled nodes in the training + # This will be needed to properly obtain a cross entropy loss + # normalized by the number of training examples + n_train_points = torch.LongTensor([masks['train_indices'].numel()]) + sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True) + n_train_points = n_train_points.item() + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + best_acc = 0 + best_model = copy.deepcopy(model) + + # training + print("---------- Training ----------") + for epoch in range(args.epochs): + t_1 = time.time() + model.train() + + logits = model(features) + loss = F.nll_loss(logits[masks['train_indices']], + labels[masks['train_indices']], reduction='sum')/n_train_points + optimizer.zero_grad() + loss.backward() + # Do not forget to gather the parameter gradients from all workers + sar.gather_grads(model) + optimizer.step() + train_time = time.time() - t_1 + + model.eval() + with torch.no_grad(): + logits = model(features) + train_acc, val_acc, _ = evaluate(logits, labels, masks) + + result_message = ( + f"iteration [{epoch}/{args.epochs}] | " + ) + result_message += ', '.join([ + f"train loss={loss:.4f}, " + f"Accuracy: " + f"train={train_acc:.4f} " + f"valid={val_acc:.4f} " + f" | train time = {train_time} " + f" |" + ]) + print(result_message, flush=True) + + if val_acc > best_acc: + best_acc = val_acc + best_model = copy.deepcopy(model) + + # testing & saving model + print("---------- Testing ----------") + best_model.eval() + logits = best_model(features) + _, _, test_acc = evaluate(logits, labels, masks) + print(f"Test acc: {test_acc:.4f}") + + print("---------- Correct & Smoothing ----------") + y_soft = model(features).exp() + + cs = CorrectAndSmooth( + num_correction_layers=args.num_correction_layers, + correction_alpha=args.correction_alpha, + correction_adj=args.correction_adj, + num_smoothing_layers=args.num_smoothing_layers, + smoothing_alpha=args.smoothing_alpha, + smoothing_adj=args.smoothing_adj, + autoscale=args.autoscale, + scale=args.scale, + ) + + y_soft = cs.correct(full_graph_manager, y_soft, labels[masks['train_indices']], masks['train_indices']) + y_soft = cs.smooth(full_graph_manager, y_soft, labels[masks['train_indices']], masks['train_indices']) + _, _, test_acc = evaluate(y_soft, labels, masks) + print(f"Test acc: {test_acc:.4f}") + + +if __name__ == '__main__': + main() From 209ac3480301d1e97238b8ade39a48b9f7ae1eb4 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Fri, 16 Jun 2023 08:16:57 +0200 Subject: [PATCH 18/58] Remove unnecessary code --- sar/core/graphshard.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index f819cff..a6716a6 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -82,13 +82,7 @@ def __init__(self, self.unique_tgt_nodes, unique_tgt_nodes_inverse = \ torch.unique(shard_edges_features.edges[1], return_inverse=True) - edges_src_nodes = torch.arange(self.unique_src_nodes.size(0))[ - unique_src_nodes_inverse] - - edges_tgt_nodes = torch.arange(self.unique_tgt_nodes.size(0))[ - unique_tgt_nodes_inverse] - - self.graph = dgl.create_block((edges_src_nodes, edges_tgt_nodes), + self.graph = dgl.create_block((unique_src_nodes_inverse, unique_tgt_nodes_inverse), num_src_nodes=self.unique_src_nodes.size( 0), num_dst_nodes=self.unique_tgt_nodes.size( From 9bc9cc1c322400f220e59b5c6ee4b8a635ac9d14 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 20 Jun 2023 09:25:38 +0200 Subject: [PATCH 19/58] Mentioned C&S in README.md --- examples/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/README.md b/examples/README.md index f60bad1..1af72c8 100644 --- a/examples/README.md +++ b/examples/README.md @@ -29,3 +29,16 @@ python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/pa python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 ``` + +## Correct and Smooth +Example taken from [DGL implemenetation](https://github.com/dmlc/dgl/tree/master/examples/pytorch/correct_and_smooth) of C&S. Code is adjusted to perform distributed training with SAR. For instance, you can run the example with following commands: + +* **Plain MLP + C&S** +```shell +python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale +``` + +* **Plain Linear + C&S** +```shell +python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale +``` From 5d69d147f1b3f1a07c6e67402c95557f047a7bd1 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 20 Jun 2023 10:35:33 +0200 Subject: [PATCH 20/58] Refactored features normalization code --- examples/correct_and_smooth.py | 63 ++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/examples/correct_and_smooth.py b/examples/correct_and_smooth.py index cfcbb29..9525b29 100644 --- a/examples/correct_and_smooth.py +++ b/examples/correct_and_smooth.py @@ -342,6 +342,41 @@ def evaluate(logits, labels, masks): return train_acc, val_acc, test_acc +def data_normalization(features, world_size): + """ + Perform features normzalization by subtracting theur means and dividing them by their standard deviations. + Each position in features vector is normzalized independently. To calculate means and stds over whole + dataset, workers must communicate with each other. + + + :param features: dataset's features + :type features: Tensor + :param world_size: Number of workers. The same as the number of graph partitions + :type world_size: int + + :returns: Normalized Tensor of features + """ + local_means = features.mean(0) + workers_means = sar.comm.exchange_tensors([local_means] * world_size) + workers_means = torch.stack(workers_means, dim=0) + + local_feature_size = features.shape[0] + workers_feature_sizes = sar.comm.exchange_tensors([torch.tensor([local_feature_size])] * world_size) + workers_feature_sizes = torch.stack(workers_feature_sizes, dim=0) + + global_features_sum = torch.mul(workers_means, workers_feature_sizes).sum(dim=0) + global_feature_size = workers_feature_sizes.sum() + global_means = global_features_sum / global_feature_size + + local_std_numerator = torch.pow(features - global_means, 2).sum(dim=0) + workers_std_numerators = sar.comm.exchange_tensors([local_std_numerator] * world_size) + workers_std_numerators = torch.stack(workers_std_numerators, dim=0) + global_stds = torch.sqrt(workers_std_numerators.sum(dim=0) / global_feature_size) + + features = (features - global_means) / global_stds + return features + + def main(): args = parser.parse_args() print(args) @@ -354,10 +389,6 @@ def main(): args.world_size, master_ip_address, args.backend) - with open(args.partitioning_json_file, 'r') as f: - data = json.load(f) - dataset_name = data["graph_name"] - # Load DGL partition data partition_data = sar.load_dgl_partition_data( args.partitioning_json_file, args.rank, device) @@ -386,28 +417,8 @@ def main(): features = sar.suffix_key_lookup(partition_data.node_features, 'features').to(device) full_graph_manager = sar.construct_full_graph(partition_data).to(device) - - if dataset_name == "ogbn-arxiv": - # in order to perform dataset standarization, we have to calculate dataset's mean - # and standard deviation. It is not possible without communication between workers - local_means = features.mean(0) - workers_means = sar.comm.exchange_tensors([local_means] * args.world_size) - workers_means = torch.stack(workers_means, dim=0) - - local_feature_size = features.shape[0] - workers_feature_sizes = sar.comm.exchange_tensors([torch.tensor([local_feature_size])] * args.world_size) - workers_feature_sizes = torch.stack(workers_feature_sizes, dim=0) - - global_features_sum = torch.mul(workers_means, workers_feature_sizes).sum(dim=0) - global_feature_size = workers_feature_sizes.sum() - global_means = global_features_sum / global_feature_size - - local_std_numerator = torch.pow(features - global_means, 2).sum(dim=0) - workers_std_numerators = sar.comm.exchange_tensors([local_std_numerator] * args.world_size) - workers_std_numerators = torch.stack(workers_std_numerators, dim=0) - global_stds = torch.sqrt(workers_std_numerators.sum(dim=0) / global_feature_size) - - features = (features - global_means) / global_stds + if "ogbn-arxiv" in args.partitioning_json_file: + features = data_normalization(features, args.world_size) # We do not need the partition data anymore del partition_data From 02e1bf706f188e3be74a7e9ae3623f189b20913f Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 20 Jun 2023 10:37:07 +0200 Subject: [PATCH 21/58] Delete MIT license notice --- examples/correct_and_smooth.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/examples/correct_and_smooth.py b/examples/correct_and_smooth.py index 9525b29..fdafc1c 100644 --- a/examples/correct_and_smooth.py +++ b/examples/correct_and_smooth.py @@ -1,24 +1,3 @@ -# Copyright (c) 2022 Intel Corporation - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - - from argparse import ArgumentParser import time import json From cc8f3f2357db41cfffbf7f7cfac33979fc1b1806 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 20 Jun 2023 10:48:27 +0200 Subject: [PATCH 22/58] Added docstring for evaluate function --- examples/correct_and_smooth.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/correct_and_smooth.py b/examples/correct_and_smooth.py index fdafc1c..ef42eb6 100644 --- a/examples/correct_and_smooth.py +++ b/examples/correct_and_smooth.py @@ -305,6 +305,20 @@ def smooth(self, g, y_soft, y_true, mask): def evaluate(logits, labels, masks): + """ + Calculating accuracy metric over train, validation and test indices (in a distributed way). + + :param logits: Predictions of the model + :type logits: Tensor + :param labels: Ground truth labels + :type labels: Tensor + :param masks: Dictionary of Tensors, that contain indices for train, validation and test sets + :type masks: Dictionary + + :returns: Tuple of accuracy metrics: train, validation, test + """ + import pdb + pdb.set_trace() results = [] for indices_name in ['train_indices', 'val_indices', 'test_indices']: n_correct = (logits[masks[indices_name]].argmax(1) == @@ -328,7 +342,7 @@ def data_normalization(features, world_size): dataset, workers must communicate with each other. - :param features: dataset's features + :param features: Dataset's features :type features: Tensor :param world_size: Number of workers. The same as the number of graph partitions :type world_size: int From b885131d33037bf951bf9b1038dc328df025a270 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Wed, 21 Jun 2023 11:44:20 +0200 Subject: [PATCH 23/58] Env support for a socket and updated docs --- docs/source/comm.rst | 2 ++ sar/comm.py | 84 +++++++++++++++++++++----------------------- 2 files changed, 42 insertions(+), 44 deletions(-) diff --git a/docs/source/comm.rst b/docs/source/comm.rst index dc074b4..b628d64 100644 --- a/docs/source/comm.rst +++ b/docs/source/comm.rst @@ -24,6 +24,8 @@ In an environment with a networked file system, initializing ``torch.distributed :func:`sar.nfs_ip_init` communicates the master's ip address to the workers through the file system. In the absence of a networked file system, you should develop your own mechanism to communicate the master's ip address. +You can specify the name of the socket that will be used for communication with ``SOCKET_NAME`` environment variable (if not specified, the first available socket will be selected). + Relevant methods diff --git a/sar/comm.py b/sar/comm.py index 7421353..c3f8ba1 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -40,10 +40,30 @@ logger.setLevel(logging.DEBUG) +def get_socket() -> ifaddr._shared.Adapter: + """ + Gets the socket on the current host. If preffered socket is not specified using SOCKET_NAME + environment variable, the function returns the first available socket from `ifaddr.get_adapters()` + + :returns: Preffered or the first available socket + """ + adaps = ifaddr.get_adapters() + preferred_socket = os.environ.get("SOCKET_NAME") + if preferred_socket is not None: + adaps = list(filter(lambda x: x.nice_name == preferred_socket, adaps)) + if not adaps: + raise ValueError(f'Socket with given name: "{preferred_socket}" was not found.') + else: + adaps = list(filter(lambda x: x.nice_name != "lo", adaps)) + return adaps[0] + + def get_ip_address(ip_file: str) -> str: - ''' + """ Reads ip address from ip_file. Blocks until the file is created - ''' + + :returns: IP address + """ while True: while not os.path.isfile(ip_file): logger.info('waiting for ip file to be created') @@ -56,45 +76,18 @@ def get_ip_address(ip_file: str) -> str: return ip_addr -def get_socket_name() -> str: - ''' - Gets the socket name on the current host. Prefers Infiniband sockets - if multiple sockets exist - ''' - adaps = ifaddr.get_adapters() - ib_adapters = [x for x in adaps if 'eib' in x.nice_name] - if ib_adapters: - logger.info(f'getting socket name for ib adapter: {ib_adapters[0]}') - sock_name = ib_adapters[0].nice_name - else: - eth_adapters = [ - x for x in adaps if 'eth' in x.nice_name or 'enp' in x.nice_name] - logger.info( - f'getting socket name for ethernet adapter: {eth_adapters[0]}') - sock_name = eth_adapters[0].nice_name - return sock_name - - def dump_ip_address(ip_file: str) -> str: - """Dumps the ip address of the current host to a file - Prioritizes finding an infiniband adapter and dumping its address. + """ + Dumps the ip address of the current host to a file :param ip_file: File name where the ip address of the local host will be dumped :type ip_file: str + :returns: A string containing the ip address of the local host - """ - - adaps = ifaddr.get_adapters() - ib_adapters = [x for x in adaps if 'eib' in x.nice_name] - if ib_adapters: - logger.info(f'found infinity band adapter: {ib_adapters[0]}') - host_ip = ib_adapters[0].ips[0].ip - else: - eth_adapters = [ - x for x in adaps if 'eth' in x.nice_name or 'enp' in x.nice_name] - logger.info(f'using ethernet adapter: {eth_adapters}') - host_ip = eth_adapters[0].ips[0].ip + adap = get_socket() + print(f"SOCKET: {adap}") + host_ip = adap.ips[0].ip with open(ip_file, 'w', encoding='utf-8') as f_handle: f_handle.write(host_ip) logger.info(f'wrote ip {host_ip} to file {ip_file}') @@ -121,10 +114,12 @@ def nfs_ip_init(_rank: int, ip_file: str) -> str: :param _rank: Rank of the current machine :type _rank: int - :param ip_file: Path to the ip file that will be used to communicate the ip address between workers. The master will write its ip address to this file. Other workers will block until this file is created, and then read the ip address from it. + :param ip_file: Path to the ip file that will be used to communicate the ip address between workers.\ + The master will write its ip address to this file. Other workers will block until\ + this file is created, and then read the ip address from it. :type ip_file: str + :returns: A string with the ip address of the master machine/worker - """ if _rank == 0: master_ip = dump_ip_address(ip_file) @@ -173,14 +168,15 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, os.environ['MASTER_ADDR'] = master_ip_address os.environ['MASTER_PORT'] = str(master_port_number) - sock_name = get_socket_name() - os.environ['TP_SOCKET_IFNAME'] = sock_name - os.environ['GLOO_SOCKET_IFNAME'] = sock_name - os.environ['CCL_SOCKET_IFNAME'] = sock_name - os.environ['NCCL_SOCKET_IFNAME'] = sock_name + socket = get_socket() + print(f"SOCKET NAME: {socket.nice_name}") + os.environ['TP_SOCKET_IFNAME'] = socket.nice_name + os.environ['GLOO_SOCKET_IFNAME'] = socket.nice_name + os.environ['CCL_SOCKET_IFNAME'] = socket.nice_name + os.environ['NCCL_SOCKET_IFNAME'] = socket.nice_name - os.environ['FI_VERBS_IFACE'] = sock_name - os.environ['FI_mlx_IFACE'] = sock_name + os.environ['FI_VERBS_IFACE'] = socket.nice_name + os.environ['FI_mlx_IFACE'] = socket.nice_name os.environ['MPI_COMM_WORLD'] = str(_world_size) os.environ['MPI_COMM_RANK'] = str(_rank) From 2d7cecbc88e72b8ef26fc6bc85aa2c76f43eacac Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Wed, 21 Jun 2023 11:45:52 +0200 Subject: [PATCH 24/58] Removed debug --- sar/comm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sar/comm.py b/sar/comm.py index c3f8ba1..e1544ee 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -86,7 +86,6 @@ def dump_ip_address(ip_file: str) -> str: :returns: A string containing the ip address of the local host """ adap = get_socket() - print(f"SOCKET: {adap}") host_ip = adap.ips[0].ip with open(ip_file, 'w', encoding='utf-8') as f_handle: f_handle.write(host_ip) @@ -169,7 +168,6 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, os.environ['MASTER_PORT'] = str(master_port_number) socket = get_socket() - print(f"SOCKET NAME: {socket.nice_name}") os.environ['TP_SOCKET_IFNAME'] = socket.nice_name os.environ['GLOO_SOCKET_IFNAME'] = socket.nice_name os.environ['CCL_SOCKET_IFNAME'] = socket.nice_name From 3020bc23714e8aee08c453dfe4e41fd2a06f94a0 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 29 Jun 2023 12:17:52 +0200 Subject: [PATCH 25/58] Fix raising error --- sar/comm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sar/comm.py b/sar/comm.py index 92942ca..e9d38d5 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -176,7 +176,7 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, try: import torch_ccl # type: ignore except: - raise "None of the oneccl_bindings_for_pytorch and torch_ccl package has been found" + raise ImportError("None of the oneccl_bindings_for_pytorch and torch_ccl package has been found") if not dist.is_initialized(): os.environ['MASTER_ADDR'] = master_ip_address From 5663d257aa076e3dd0e5b47e38c1de834e8563a9 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 29 Jun 2023 13:21:40 +0200 Subject: [PATCH 26/58] Add maintainers --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 90f64d2..697ba43 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,8 @@ packages=find_packages(), author='Hesham Mostafa', author_email='hesham.mostafa@intel.com', + maintainer='Bartlomiej Gawrych, Kacper Pietkun', + maintainer_email='gawrych.bartlomiej@gmail.com, kacper.pietkun@intel.com', description='A Python library for distributed training of Graph Neural Networks (GNNs) on large graphs, ' 'supporting both full-batch and sampling-based training, and utilizing a sequential aggregation' 'and rematerialization technique for linear memory scaling.', From b4d2a68757c41f8dbe46dfed84934ae9e114548a Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 29 Jun 2023 16:19:07 +0200 Subject: [PATCH 27/58] Apply review comments --- tests/models.py | 2 +- tests/test_patch_dgl.py | 4 ++++ tests/test_sar.py | 29 ++++++++++++++++++----------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/tests/models.py b/tests/models.py index e864a08..462bf1f 100644 --- a/tests/models.py +++ b/tests/models.py @@ -7,7 +7,7 @@ def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): super().__init__() self.convs = nn.ModuleList([ - dgl.nn.GraphConv(in_dim, out_dim, weight=False, bias=False), + dgl.nn.GraphConv(in_dim, out_dim, weight=True, bias=False), ]) def forward(self, graph, features): diff --git a/tests/test_patch_dgl.py b/tests/test_patch_dgl.py index f3c0ba5..f44df3f 100644 --- a/tests/test_patch_dgl.py +++ b/tests/test_patch_dgl.py @@ -4,6 +4,10 @@ @sar_test def test_patch_dgl(): + """ + Import DGL library and SAR and check whether `patch_dgl` function + overrides edge_softmax function in specific GNN layers implementation. + """ import dgl original_gat_edge_softmax = dgl.nn.pytorch.conv.gatconv.edge_softmax original_dotgat_edge_softmax = dgl.nn.pytorch.conv.dotgatconv.edge_softmax diff --git a/tests/test_sar.py b/tests/test_sar.py index ea1d155..371582c 100644 --- a/tests/test_sar.py +++ b/tests/test_sar.py @@ -11,7 +11,7 @@ def sar_process(mp_dict, rank, world_size, tmp_dir): This function should be an entry point to the 'independent' process. It has to simulate behaviour of SAR which will be spawned across different machines independently from other instances. Each process have individual memory space - so it is suitable environment for testing SAR + so it is suitable environment for testing SAR. """ import dgl import torch @@ -46,8 +46,6 @@ def sar_process(mp_dict, rank, world_size, tmp_dir): master_ip_address, 'ccl') - torch.distributed.barrier() # wait for rank 0 to finish graph creation - partition_data = sar.load_dgl_partition_data( part_file, rank, 'cpu') @@ -77,6 +75,12 @@ def sar_process(mp_dict, rank, world_size, tmp_dir): @pytest.mark.parametrize('world_size', [2, 4]) @sar_test def test_sar_full_graph(world_size): + """ + Partition graph into `world_size` partitions and run `world_size` + processes which perform full graph inference using SAR algorithm. + Test is comparing mean of concatenated results from all processes + with mean of native DGL full graph inference result. + """ print(world_size) with tempfile.TemporaryDirectory() as tmpdir: manager = mp.Manager() @@ -97,7 +101,7 @@ def test_sar_full_graph(world_size): if 'exception' in mp_dict: handle_mp_exception(mp_dict) - out = model(graph, graph.ndata['features']).numpy() + out = model(graph, graph.ndata['features']).detach().numpy() # compare mean of all values instead of each node feature individually # TODO: reorder SAR calculated logits to original NID mapping @@ -113,11 +117,13 @@ def test_sar_full_graph(world_size): rtol = sar_logits_mean / 1000 assert full_graph_mean == pytest.approx(sar_logits_mean, rtol) - - - @sar_test def test_convert_dist_graph(): + """ + Create DGL's DistGraph object with random graph partitioned into + one part (only way to test DistGraph locally). Then perform converting + DistGraph into SAR GraphShardManager and check relevant properties. + """ with tempfile.TemporaryDirectory() as tmpdir: import dgl import torch @@ -137,16 +143,17 @@ def test_convert_dist_graph(): num_hops=1, balance_edges=True) - master_ip_address = sar.nfs_ip_init(0, ip_file) - sar.initialize_comms(0, 1, master_ip_address, 'ccl') + master_ip_address = sar.nfs_ip_init(_rank=0, ip_file=ip_file) + sar.initialize_comms(_rank=0, _world_size=1, + master_ip_address=master_ip_address, backend='ccl') dgl.distributed.initialize("kv_ip_config.txt") dist_g = dgl.distributed.DistGraph( graph_name, part_config=part_file) - + sar_g = sar.convert_dist_graph(dist_g) print(sar_g.graph_shards[0].graph.ndata) assert len(sar_g.graph_shards) == dist_g.get_partition_book().num_partitions() assert dist_g.num_edges() == sar_g.num_edges() # this check fails (1000 != 2000) - #assert dist_g.num_nodes() == sar_g.num_nodes() \ No newline at end of file + #assert dist_g.num_nodes() == sar_g.num_nodes() From acb54ad724adc2e18e4745aa0f0021925b909042 Mon Sep 17 00:00:00 2001 From: Kacper Pietkun Date: Mon, 3 Jul 2023 09:02:51 +0200 Subject: [PATCH 28/58] Update docs/source/comm.rst Co-authored-by: bgawrych --- docs/source/comm.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/comm.rst b/docs/source/comm.rst index b628d64..e77b1ad 100644 --- a/docs/source/comm.rst +++ b/docs/source/comm.rst @@ -24,7 +24,7 @@ In an environment with a networked file system, initializing ``torch.distributed :func:`sar.nfs_ip_init` communicates the master's ip address to the workers through the file system. In the absence of a networked file system, you should develop your own mechanism to communicate the master's ip address. -You can specify the name of the socket that will be used for communication with ``SOCKET_NAME`` environment variable (if not specified, the first available socket will be selected). +You can specify the name of the socket that will be used for communication with `SAR_SOCKET_NAME` environment variable (if not specified, the first available socket will be selected). From dc174e4ea5e07f10fe8f87d9b9fed1dbab35e126 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Mon, 3 Jul 2023 09:15:30 +0200 Subject: [PATCH 29/58] Change env variable name to SAR_SOCKET_NAME --- sar/comm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sar/comm.py b/sar/comm.py index e1544ee..d90efd2 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -42,13 +42,13 @@ def get_socket() -> ifaddr._shared.Adapter: """ - Gets the socket on the current host. If preffered socket is not specified using SOCKET_NAME + Gets the socket on the current host. If preffered socket is not specified using SAR_SOCKET_NAME environment variable, the function returns the first available socket from `ifaddr.get_adapters()` :returns: Preffered or the first available socket """ adaps = ifaddr.get_adapters() - preferred_socket = os.environ.get("SOCKET_NAME") + preferred_socket = os.environ.get("SAR_SOCKET_NAME") if preferred_socket is not None: adaps = list(filter(lambda x: x.nice_name == preferred_socket, adaps)) if not adaps: From 87590fbc7e90e1da943762e9e7dadab9b2ffbbaa Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Mon, 3 Jul 2023 09:42:40 +0200 Subject: [PATCH 30/58] Added named tuple for socket information --- sar/comm.py | 9 +++++---- sar/common_tuples.py | 5 +++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sar/comm.py b/sar/comm.py index d90efd2..662a3df 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -34,13 +34,14 @@ import torch.distributed as dist from torch import Tensor from .config import Config +from .common_tuples import SocketInfo logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) logger.setLevel(logging.DEBUG) -def get_socket() -> ifaddr._shared.Adapter: +def get_socket() -> SocketInfo: """ Gets the socket on the current host. If preffered socket is not specified using SAR_SOCKET_NAME environment variable, the function returns the first available socket from `ifaddr.get_adapters()` @@ -55,7 +56,7 @@ def get_socket() -> ifaddr._shared.Adapter: raise ValueError(f'Socket with given name: "{preferred_socket}" was not found.') else: adaps = list(filter(lambda x: x.nice_name != "lo", adaps)) - return adaps[0] + return SocketInfo(adaps[0].nice_name, adaps[0].ips[0].ip) def get_ip_address(ip_file: str) -> str: @@ -85,8 +86,8 @@ def dump_ip_address(ip_file: str) -> str: :returns: A string containing the ip address of the local host """ - adap = get_socket() - host_ip = adap.ips[0].ip + scoket = get_socket() + host_ip = scoket.ip_addr with open(ip_file, 'w', encoding='utf-8') as f_handle: f_handle.write(host_ip) logger.info(f'wrote ip {host_ip} to file {ip_file}') diff --git a/sar/common_tuples.py b/sar/common_tuples.py index b4acd47..2a71892 100644 --- a/sar/common_tuples.py +++ b/sar/common_tuples.py @@ -121,3 +121,8 @@ class ShardInfo(NamedTuple): src_node_range: Tuple[int, int] tgt_node_range: Tuple[int, int] edge_range: Tuple[int, int] + + +class SocketInfo(NamedTuple): + nice_name: str + ip_addr: str From 3df3c17c3f8cbe317cd264b3da2394053206939c Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 4 Jul 2023 11:53:20 +0200 Subject: [PATCH 31/58] Modify example to use MeanOp, VarOp and DistributedBN1D --- examples/correct_and_smooth.py | 36 ++++++++++------------------------ sar/distributed_bn.py | 7 +++++++ 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/examples/correct_and_smooth.py b/examples/correct_and_smooth.py index ef42eb6..3b33261 100644 --- a/examples/correct_and_smooth.py +++ b/examples/correct_and_smooth.py @@ -12,6 +12,7 @@ from ogb.nodeproppred import DglNodePropPredDataset import sar +from sar.distributed_bn import MeanOp, VarOp, DistributedBN1D parser = ArgumentParser(description="CorrectAndSmooth example") @@ -99,11 +100,11 @@ def __init__(self, in_dim, hid_dim, out_dim, num_layers, dropout=0.0): self.linears = nn.ModuleList() self.bns = nn.ModuleList() self.linears.append(nn.Linear(in_dim, hid_dim)) - self.bns.append(nn.BatchNorm1d(hid_dim)) + self.bns.append(DistributedBN1D(hid_dim)) for _ in range(num_layers - 2): self.linears.append(nn.Linear(hid_dim, hid_dim)) - self.bns.append(nn.BatchNorm1d(hid_dim)) + self.bns.append(DistributedBN1D(hid_dim)) self.linears.append(nn.Linear(hid_dim, out_dim)) self.dropout = dropout @@ -317,8 +318,6 @@ def evaluate(logits, labels, masks): :returns: Tuple of accuracy metrics: train, validation, test """ - import pdb - pdb.set_trace() results = [] for indices_name in ['train_indices', 'val_indices', 'test_indices']: n_correct = (logits[masks[indices_name]].argmax(1) == @@ -335,38 +334,23 @@ def evaluate(logits, labels, masks): return train_acc, val_acc, test_acc -def data_normalization(features, world_size): +def data_normalization(features, eps=1.0e-5): """ Perform features normzalization by subtracting theur means and dividing them by their standard deviations. Each position in features vector is normzalized independently. To calculate means and stds over whole dataset, workers must communicate with each other. - :param features: Dataset's features :type features: Tensor - :param world_size: Number of workers. The same as the number of graph partitions - :type world_size: int + :param eps: a value added to the variance for numerical stability + :type eps: float :returns: Normalized Tensor of features """ - local_means = features.mean(0) - workers_means = sar.comm.exchange_tensors([local_means] * world_size) - workers_means = torch.stack(workers_means, dim=0) - - local_feature_size = features.shape[0] - workers_feature_sizes = sar.comm.exchange_tensors([torch.tensor([local_feature_size])] * world_size) - workers_feature_sizes = torch.stack(workers_feature_sizes, dim=0) - - global_features_sum = torch.mul(workers_means, workers_feature_sizes).sum(dim=0) - global_feature_size = workers_feature_sizes.sum() - global_means = global_features_sum / global_feature_size - - local_std_numerator = torch.pow(features - global_means, 2).sum(dim=0) - workers_std_numerators = sar.comm.exchange_tensors([local_std_numerator] * world_size) - workers_std_numerators = torch.stack(workers_std_numerators, dim=0) - global_stds = torch.sqrt(workers_std_numerators.sum(dim=0) / global_feature_size) - - features = (features - global_means) / global_stds + mean = MeanOp.apply(features) + var = VarOp.apply(features) + std = torch.sqrt(var - mean**2 + eps) + features = (features - mean) / std return features diff --git a/sar/distributed_bn.py b/sar/distributed_bn.py index e298a27..74a6834 100644 --- a/sar/distributed_bn.py +++ b/sar/distributed_bn.py @@ -24,6 +24,7 @@ import torch.distributed as dist from torch import nn from torch.nn import Parameter +from torch.nn import init from .comm import all_reduce, comm_device, is_initialized @@ -51,6 +52,7 @@ def __init__(self, n_feats: int, eps: float = 1.0e-5, affine: bool = True, distr self.n_feats = n_feats self.weight: Optional[Parameter] self.bias: Optional[Parameter] + self.affine = affine if affine: self.weight = Parameter(torch.ones(n_feats)) self.bias = Parameter(torch.zeros(n_feats)) @@ -83,6 +85,11 @@ def forward(self, inp): else: result = normalized_x return result + + def reset_parameters(self): + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) From 6aebceaa1b162083fc76dc05e37d894a74e3bfb7 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 4 Jul 2023 12:26:08 +0200 Subject: [PATCH 32/58] Review changes --- sar/comm.py | 16 ++++++++-------- sar/common_tuples.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sar/comm.py b/sar/comm.py index 662a3df..4f93d1c 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -86,8 +86,8 @@ def dump_ip_address(ip_file: str) -> str: :returns: A string containing the ip address of the local host """ - scoket = get_socket() - host_ip = scoket.ip_addr + socket = get_socket() + host_ip = socket.ip_addr with open(ip_file, 'w', encoding='utf-8') as f_handle: f_handle.write(host_ip) logger.info(f'wrote ip {host_ip} to file {ip_file}') @@ -169,13 +169,13 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, os.environ['MASTER_PORT'] = str(master_port_number) socket = get_socket() - os.environ['TP_SOCKET_IFNAME'] = socket.nice_name - os.environ['GLOO_SOCKET_IFNAME'] = socket.nice_name - os.environ['CCL_SOCKET_IFNAME'] = socket.nice_name - os.environ['NCCL_SOCKET_IFNAME'] = socket.nice_name + os.environ['TP_SOCKET_IFNAME'] = socket.name + os.environ['GLOO_SOCKET_IFNAME'] = socket.name + os.environ['CCL_SOCKET_IFNAME'] = socket.name + os.environ['NCCL_SOCKET_IFNAME'] = socket.name - os.environ['FI_VERBS_IFACE'] = socket.nice_name - os.environ['FI_mlx_IFACE'] = socket.nice_name + os.environ['FI_VERBS_IFACE'] = socket.name + os.environ['FI_mlx_IFACE'] = socket.name os.environ['MPI_COMM_WORLD'] = str(_world_size) os.environ['MPI_COMM_RANK'] = str(_rank) diff --git a/sar/common_tuples.py b/sar/common_tuples.py index 2a71892..b19535b 100644 --- a/sar/common_tuples.py +++ b/sar/common_tuples.py @@ -124,5 +124,5 @@ class ShardInfo(NamedTuple): class SocketInfo(NamedTuple): - nice_name: str + name: str ip_addr: str From cae4095fdd5dff3d87fee41801afb18929af2ee2 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Tue, 4 Jul 2023 13:49:28 +0200 Subject: [PATCH 33/58] Add minibatch training with SAR inference command in examples --- examples/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/README.md b/examples/README.md index f60bad1..eb86326 100644 --- a/examples/README.md +++ b/examples/README.md @@ -29,3 +29,16 @@ python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/pa python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 ``` + +## Distributed Mini-Batch Training with Full-Graph inference +The script ``train_distdgl_with_sar_inference.py`` showcases how SAR can be effectively combined with native DGL distributed training. In this particular example, the training process utilizes a sampling approach, while the evaluation phase leverages the SAR library to perform computations on the entire graph. +```shell +python /home/ubuntu/workspace/dgl/tools/launch.py \ + --workspace /home/ubuntu/workspace/SAR/examples \ + --num_trainers 1 \ + --num_samplers 2 \ + --num_servers 1 \ + --part_config partition_data/ogbn-products.json \ + --ip_config ip_config.txt \ + "/home/ubuntu/miniconda3/bin/python train_distdgl_with_sar_inference.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 2 --batch_size 1000 part_config partition_data/ogbn-products.json" +``` From c7ebd3858359f7b3d7be7b69f1cec6ea02bd542e Mon Sep 17 00:00:00 2001 From: bgawrych Date: Tue, 4 Jul 2023 13:58:36 +0200 Subject: [PATCH 34/58] Update Python version in Github Actions --- .github/workflows/sar_test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/sar_test.yaml b/.github/workflows/sar_test.yaml index 87fda95..97e9332 100644 --- a/.github/workflows/sar_test.yaml +++ b/.github/workflows/sar_test.yaml @@ -17,7 +17,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.10' - name: Install requirements run: | From 9cc5e4840008484c142b967511f2d86c1e8c3f82 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Tue, 4 Jul 2023 15:12:31 +0200 Subject: [PATCH 35/58] Add missing dashes --- examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index eb86326..c06218b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -40,5 +40,5 @@ python /home/ubuntu/workspace/dgl/tools/launch.py \ --num_servers 1 \ --part_config partition_data/ogbn-products.json \ --ip_config ip_config.txt \ - "/home/ubuntu/miniconda3/bin/python train_distdgl_with_sar_inference.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 2 --batch_size 1000 part_config partition_data/ogbn-products.json" + "/home/ubuntu/miniconda3/bin/python train_distdgl_with_sar_inference.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 2 --batch_size 1000 ------ -----part_config partition_data/ogbn-products.json" ``` From 160a4e67ab4a62068ae6bf2ee0401ee0adc7064c Mon Sep 17 00:00:00 2001 From: bgawrych Date: Tue, 4 Jul 2023 15:22:01 +0200 Subject: [PATCH 36/58] Fix dashes --- examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index c06218b..1a1c8f0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -40,5 +40,5 @@ python /home/ubuntu/workspace/dgl/tools/launch.py \ --num_servers 1 \ --part_config partition_data/ogbn-products.json \ --ip_config ip_config.txt \ - "/home/ubuntu/miniconda3/bin/python train_distdgl_with_sar_inference.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 2 --batch_size 1000 ------ -----part_config partition_data/ogbn-products.json" + "/home/ubuntu/miniconda3/bin/python train_distdgl_with_sar_inference.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 2 --batch_size 1000 --part_config partition_data/ogbn-products.json" ``` From 5af0078b8a711d2d63a2d2754b5d42d9b882de34 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Wed, 12 Jul 2023 09:02:24 +0200 Subject: [PATCH 37/58] Testing comm.py --- tests/base_utils.py | 65 ++++++++++++++++ tests/models.py | 2 +- tests/{utils.py => multiprocessing_utils.py} | 39 +++++++++- tests/test_comm.py | 79 ++++++++++++++++++++ tests/test_patch_dgl.py | 3 +- tests/test_sar.py | 7 +- 6 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 tests/base_utils.py rename tests/{utils.py => multiprocessing_utils.py} (59%) create mode 100644 tests/test_comm.py diff --git a/tests/base_utils.py b/tests/base_utils.py new file mode 100644 index 0000000..570da93 --- /dev/null +++ b/tests/base_utils.py @@ -0,0 +1,65 @@ + +import os +import sar +import torch +import torch.distributed as dist +import dgl + + +def initialize_worker(rank, world_size, tmp_dir): + """ + Boilerplate code for setting up connection between workers + + :param rank: Rank of the current machine + :type rank: int + :param world_size: Number of workers. The same as the number of graph partitions + :type world_size: int + :param tmp_dir: Path to the directory where ip file will be created + :type tmp_dir: str + """ + torch.seed() + ip_file = os.path.join(tmp_dir, 'ip_file') + master_ip_address = sar.nfs_ip_init(rank, ip_file) + sar.initialize_comms(rank, world_size, master_ip_address, 'ccl') + + +def get_random_graph(): + """ + Generates small homogenous graph with features and labels + + :returns: dgl graph + """ + graph = dgl.rand_graph(1000, 2500) + graph = dgl.add_self_loop(graph) + graph.ndata.clear() + graph.ndata['features'] = torch.rand((graph.num_nodes(), 10)) + graph.ndata['labels'] = torch.randint(0, 10, (graph.num_nodes(),)) + return graph + + +def load_partition_data(rank, graph_name, tmp_dir): + """ + Boilerplate code for loading partition data + + :param rank: Rank of the current machine + :type rank: int + :param graph_name: Name of the partitioned graph + :type graph_name: str + :param tmp_dir: Path to the directory where partition data is located + :type tmp_dir: str + :returns: Tuple consisting of GraphShardManager object, partition features and labels + """ + partition_file = os.path.join(tmp_dir, f'{graph_name}.json') + partition_data = sar.load_dgl_partition_data(partition_file, rank, "cpu") + full_graph_manager = sar.construct_full_graph(partition_data).to('cpu') + features = sar.suffix_key_lookup(partition_data.node_features, 'features') + labels = sar.suffix_key_lookup(partition_data.node_features, 'labels') + return full_graph_manager, features, labels + + +def synchronize_processes(): + """ + Function that simulates dist.barrier (using all_reduce because there is an issue with dist.barrier() in ccl) + """ + dummy_tensor = torch.tensor(1) + dist.all_reduce(dummy_tensor, dist.ReduceOp.MAX) diff --git a/tests/models.py b/tests/models.py index 462bf1f..00475aa 100644 --- a/tests/models.py +++ b/tests/models.py @@ -3,7 +3,7 @@ from torch import nn class GNNModel(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): + def __init__(self, in_dim: int, out_dim: int): super().__init__() self.convs = nn.ModuleList([ diff --git a/tests/utils.py b/tests/multiprocessing_utils.py similarity index 59% rename from tests/utils.py rename to tests/multiprocessing_utils.py index f711858..5e5f5e9 100644 --- a/tests/utils.py +++ b/tests/multiprocessing_utils.py @@ -1,15 +1,52 @@ - import multiprocessing as mp import traceback import pytest import functools +import tempfile + def handle_mp_exception(mp_dict): + """ + Used to handle exceptions that occurred in child processes + + :param mp_dict: Dictionary that is shared between different processes + :type mp_dict: multiprocessing.managers.DictProxy + """ msg = mp_dict.get('traceback', "") for e_arg in mp_dict['exception'].args: msg += str(e_arg) print(str(msg), flush=True) pytest.fail(str(msg), pytrace=False) + + +def run_workers(func, world_size): + """ + Starts `world_size` number of processes, where each of them + behaves as a separate worker and invokes function specified + by the parameter. + + :param func: The function that will be invoked by each process + :type func: function + :returns: mp_dict which can be used by workers to return + results from `func` + """ + manager = mp.Manager() + mp_dict = manager.dict() + processes = [] + with tempfile.TemporaryDirectory() as tmp_dir: + for rank in range(1, world_size): + p = mp.Process(target=func, args=(mp_dict, rank, world_size, tmp_dir)) + p.daemon = True + p.start() + processes.append(p) + func(mp_dict, 0, world_size, tmp_dir) + + for p in processes: + p.join() + if 'exception' in mp_dict: + handle_mp_exception(mp_dict) + return mp_dict + def sar_test(func): """ diff --git a/tests/test_comm.py b/tests/test_comm.py new file mode 100644 index 0000000..b6dbd6d --- /dev/null +++ b/tests/test_comm.py @@ -0,0 +1,79 @@ +from copy import deepcopy +import traceback +from multiprocessing_utils import * +# Do not import DGL and SAR - these modules should be +# independently loaded inside each process + + +@pytest.mark.parametrize('world_size', [2, 4, 8]) +@sar_test +def test_sync_params(world_size): + """ + Checks whether model's parameters are the same across all + workers after calling sync_params function. Parameters of worker 0 + should be copied to all workers, so its parameters before and after + sync_params should be the same + """ + import torch + def sync_params(mp_dict, rank, world_size, tmp_dir): + import sar + from tests.base_utils import initialize_worker + from models import GNNModel + try: + initialize_worker(rank, world_size, tmp_dir) + model = GNNModel(16, 4) + if rank == 0: + mp_dict[f"result_{rank}"] = deepcopy(model.state_dict()) + sar.sync_params(model) + if rank != 0: + mp_dict[f"result_{rank}"] = model.state_dict() + except Exception as e: + mp_dict["traceback"] = str(traceback.format_exc()) + mp_dict["exception"] = e + + mp_dict = run_workers(sync_params, world_size) + for rank in range(1, world_size): + for key in mp_dict[f"result_0"].keys(): + assert torch.all(torch.eq(mp_dict[f"result_0"][key], mp_dict[f"result_{rank}"][key])) + + +@pytest.mark.parametrize('world_size', [2, 4, 8]) +@sar_test +def test_gather_grads(world_size): + """ + Checks whether parameter's gradients are the same across all + workers after calling gather_grads function + """ + import torch + def gather_grads(mp_dict, rank, world_size, tmp_dir): + import sar + import dgl + import torch.nn.functional as F + from models import GNNModel + from base_utils import initialize_worker, get_random_graph, synchronize_processes,\ + load_partition_data + try: + initialize_worker(rank, world_size, tmp_dir) + graph_name = 'dummy_graph' + if rank == 0: + g = get_random_graph() + dgl.distributed.partition_graph(g, graph_name, world_size, + tmp_dir, num_hops=1, + balance_edges=True) + synchronize_processes() + fgm, feat, labels = load_partition_data(rank, graph_name, tmp_dir) + model = GNNModel(feat.shape[1], labels.max()+1) + sar.sync_params(model) + sar_logits = model(fgm, feat) + sar_loss = F.cross_entropy(sar_logits, labels) + sar_loss.backward() + sar.gather_grads(model) + mp_dict[f"result_{rank}"] = [torch.tensor(x.grad) for x in model.parameters()] + except Exception as e: + mp_dict["traceback"] = str(traceback.format_exc()) + mp_dict["exception"] = e + + mp_dict = run_workers(gather_grads, world_size) + for rank in range(1, world_size): + for i in range(len(mp_dict["result_0"])): + assert torch.all(torch.eq(mp_dict["result_0"][i], mp_dict[f"result_{rank}"][i])) diff --git a/tests/test_patch_dgl.py b/tests/test_patch_dgl.py index f44df3f..e1f9f03 100644 --- a/tests/test_patch_dgl.py +++ b/tests/test_patch_dgl.py @@ -1,7 +1,8 @@ -from utils import * +from multiprocessing_utils import * # Do not import DGL and SAR - these modules should be # independently loaded inside each process + @sar_test def test_patch_dgl(): """ diff --git a/tests/test_sar.py b/tests/test_sar.py index 371582c..c2bcdcf 100644 --- a/tests/test_sar.py +++ b/tests/test_sar.py @@ -1,11 +1,12 @@ -from utils import * +from multiprocessing_utils import * import os import tempfile - +import traceback import numpy as np # Do not import DGL and SAR - these modules should be # independently loaded inside each process + def sar_process(mp_dict, rank, world_size, tmp_dir): """ This function should be an entry point to the 'independent' process. @@ -53,7 +54,7 @@ def sar_process(mp_dict, rank, world_size, tmp_dir): features = sar.suffix_key_lookup(partition_data.node_features, 'features') del partition_data - model = GNNModel(features.size(1), 32, features.size(1)).to('cpu') + model = GNNModel(features.size(1), features.size(1)).to('cpu') sar.sync_params(model) logits = model(full_graph_manager, features) From fc98747a2df8ba44033d4239485d73e75f3ba76d Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Thu, 13 Jul 2023 11:48:07 +0200 Subject: [PATCH 38/58] Handling args and kwargs in test functions --- tests/multiprocessing_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/multiprocessing_utils.py b/tests/multiprocessing_utils.py index 5e5f5e9..d4bdca1 100644 --- a/tests/multiprocessing_utils.py +++ b/tests/multiprocessing_utils.py @@ -19,13 +19,15 @@ def handle_mp_exception(mp_dict): pytest.fail(str(msg), pytrace=False) -def run_workers(func, world_size): +def run_workers(func, world_size, *args, **kwargs): """ Starts `world_size` number of processes, where each of them behaves as a separate worker and invokes function specified by the parameter. - :param func: The function that will be invoked by each process + :param func: The function that will be invoked by each process. It should take four + parameters: mp_dict - shared dictionary between different processes, rank - of the current machine, + world_size - number of workers, tmp_dir - path to the working directory (additionaly one can pass args and kwargs) :type func: function :returns: mp_dict which can be used by workers to return results from `func` @@ -35,12 +37,13 @@ def run_workers(func, world_size): processes = [] with tempfile.TemporaryDirectory() as tmp_dir: for rank in range(1, world_size): - p = mp.Process(target=func, args=(mp_dict, rank, world_size, tmp_dir)) + my_args = (mp_dict, rank, world_size, tmp_dir) + args + p = mp.Process(target=func, args=my_args, kwargs=kwargs) p.daemon = True p.start() processes.append(p) - func(mp_dict, 0, world_size, tmp_dir) - + func(mp_dict, 0, world_size, tmp_dir, *args, **kwargs) + for p in processes: p.join() if 'exception' in mp_dict: From eb76e1f8f9d8755a248c7563ea7dcd8bc86718c5 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Thu, 13 Jul 2023 11:48:27 +0200 Subject: [PATCH 39/58] Added information for users --- tests/base_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/base_utils.py b/tests/base_utils.py index 570da93..03d64c3 100644 --- a/tests/base_utils.py +++ b/tests/base_utils.py @@ -1,9 +1,11 @@ - import os import sar import torch import torch.distributed as dist import dgl +# IMPORTANT - This module should be imported independently +# only by the child processes - i.e. separate workers + def initialize_worker(rank, world_size, tmp_dir): From fa682ae26dc974ba3917d4794dbde9d794ea187c Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Thu, 13 Jul 2023 14:23:15 +0200 Subject: [PATCH 40/58] synchronization of sigma among workers --- examples/correct_and_smooth.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/correct_and_smooth.py b/examples/correct_and_smooth.py index 3b33261..743e3e4 100644 --- a/examples/correct_and_smooth.py +++ b/examples/correct_and_smooth.py @@ -257,7 +257,9 @@ def correct(self, g, y_soft, y_true, mask): int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) ) assert y_true.size(0) == numel - + numel = torch.tensor(numel) + sar.comm.all_reduce(numel, dist.ReduceOp.SUM, move_to_comm_device=True) + if y_true.dtype == torch.long: y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)).to( y_soft.dtype @@ -270,7 +272,9 @@ def correct(self, g, y_soft, y_true, mask): smoothed_error = self.prop1( g, error, post_step=lambda x: x.clamp_(-1.0, 1.0) ) - sigma = error[mask].abs().sum() / numel + error_sum = error[mask].abs().sum() + sar.comm.all_reduce(error_sum, dist.ReduceOp.SUM, move_to_comm_device=True) + sigma = error_sum / numel scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True) scale[scale.isinf() | (scale > 1000)] = 1.0 From 0f51aaa26cfe7193a05ad758aaff56bd120a3cad Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 14 Jul 2023 15:13:53 +0200 Subject: [PATCH 41/58] Fix pytest rtol error --- tests/test_sar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sar.py b/tests/test_sar.py index 371582c..f80f151 100644 --- a/tests/test_sar.py +++ b/tests/test_sar.py @@ -114,7 +114,7 @@ def test_sar_full_graph(world_size): sar_logits_mean = sar_logits.mean() - rtol = sar_logits_mean / 1000 + rtol = abs(sar_logits_mean) / 1000 assert full_graph_mean == pytest.approx(sar_logits_mean, rtol) @sar_test From ceafe30dcd6738c29ab6d395e278dce192b51581 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Thu, 6 Jul 2023 10:57:11 +0200 Subject: [PATCH 42/58] gloo all_to_all with isends gloo all_to_all with scatter gloo all_to_all with isends --- docs/source/comm.rst | 2 +- docs/source/quick_start.rst | 2 +- sar/comm.py | 35 ++++++++++++++++++++++++----------- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/docs/source/comm.rst b/docs/source/comm.rst index e77b1ad..7266102 100644 --- a/docs/source/comm.rst +++ b/docs/source/comm.rst @@ -4,7 +4,7 @@ SAR's communication routines ============================= -SAR uses only two types of collective communication calls: ``all_to_all`` and ``all_reduce``. This choice was made to improve scalability by avoiding any point-to-point communication. Currently, the only backends in `torch.distributed `_ that support ``all_to_all`` are ``nccl``, ``ccl``, or ``mpi``. Nvidia's ``nccl`` is already included in the PyTorch distribution and it is the natural choice when training on GPUs. +SAR uses only two types of collective communication calls: ``all_to_all`` and ``all_reduce``. This choice was made to improve scalability by avoiding any point-to-point communication. SAR supports four backends, which are ``ccl``, ``nccl``, ``mpi`` and ``gloo``. Nvidia's ``nccl`` is already included in the PyTorch distribution and it is the natural choice when training on GPUs. The ``ccl`` backend uses `Intel's OneCCL `_ library. You can install the PyTorch bindings for OneCCL `here `_ . ``ccl`` is the preferred backend when training on CPUs. diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index 2dfef57..4a6cc8a 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -66,7 +66,7 @@ Initialize the communication through a call to :func:`sar.initialize_comms` , sp .. -``backend_name`` can be ``nccl``, ``ccl``, or ``mpi``. +``backend_name`` can be ``ccl``, ``nccl``, ``mpi`` or ``gloo``. diff --git a/sar/comm.py b/sar/comm.py index 56b07df..49d928d 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -140,7 +140,7 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, :type _world_size: int :param master_ip_address: IP address of the master worker (worker with rank 0) :type master_ip_address: str - :param backend: Backend to use. Can be ccl, nccl, or mpi + :param backend: Backend to use. Can be ccl, nccl, mpi or gloo :type backend: str :param _comm_device: The device on which the tensors should be on in order to transmit them\ through the backend. If not provided, the device is infered based on the backend type @@ -150,8 +150,8 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, """ - assert backend in ['ccl', 'nccl', - 'mpi'], 'backend must be ccl, nccl, or mpi' + assert backend in ['ccl', 'nccl', 'mpi', 'gloo'],\ + 'backend must be ccl, nccl, mpi or gloo' if _comm_device is None: if backend == 'nccl': _comm_device = torch.device('cuda') @@ -199,8 +199,8 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, dist.init_process_group( backend=backend, rank=_rank, world_size=_world_size) else: - assert dist.get_backend() in ['ccl', 'nccl', - 'mpi'], 'backend must be ccl, nccl, or mpi' + assert dist.get_backend() in ['ccl', 'nccl', 'mpi', 'gloo'],\ + 'backend must be ccl, nccl, mpi or gloo' _CommData.rank = _rank _CommData.world_size = _world_size @@ -316,9 +316,7 @@ def all_reduce(red_tensor: torch.Tensor, op: dist.ReduceOp, def all_to_all_rounds(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor]): if Config.max_collective_size == 0: - #print('all to all', recv_tensors, send_tensors, flush=True) - dist.all_to_all(recv_tensors, send_tensors) - #print('all to all complete', recv_tensors, send_tensors, flush=True) + all_to_all_gloo_support(recv_tensors, send_tensors) else: max_n_elems = Config.max_collective_size total_elems = sum(r_tensor.numel() for r_tensor in recv_tensors) + \ @@ -332,7 +330,24 @@ def all_to_all_rounds(recv_tensors: List[torch.Tensor], send_tensors: List[torch s_tensor in send_tensors] recv_tensors_slices = [_get_tensor_slice(r_tensor, n_rounds, round_idx) for r_tensor in recv_tensors] - dist.all_to_all(recv_tensors_slices, send_tensors_slices) + all_to_all_gloo_support(recv_tensors_slices, send_tensors_slices) + + +def all_to_all_gloo_support(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor]): + if backend() == 'gloo': + send_requests = [] + for i in range(world_size()): + if i == rank(): + recv_tensors[i].copy_(send_tensors[i]) + else: + send_request = dist.isend(send_tensors[i], i) + send_requests.append(send_request) + for i in range(world_size()): + if i != rank(): + dist.recv(recv_tensors[i], i) + dist.barrier() + else: + dist.all_to_all(recv_tensors, send_tensors) def _get_tensor_slice(tens: Tensor, n_splits: int, split_idx: int) -> Tensor: @@ -429,8 +444,6 @@ def exchange_tensors(tensors: List[torch.Tensor], recv_sizes: Optional[List[int] comm_device()) for _ in range(len(tensors))] all_to_all(all_their_sizes, all_my_sizes) - #print('all my sizes', all_my_sizes) - #print('all their sizes', all_their_sizes) all_their_sizes_i = [cast(int, x.item()) for x in all_their_sizes] else: From 1186b49e213a923c70bdbd962585c7ce436f7a35 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Fri, 7 Jul 2023 13:22:10 +0200 Subject: [PATCH 43/58] Update examples - gloo backend --- examples/train_heterogeneous_graph.py | 2 +- examples/train_homogeneous_graph_advanced.py | 2 +- examples/train_homogeneous_graph_basic.py | 2 +- examples/train_homogeneous_sampling_basic.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/train_heterogeneous_graph.py b/examples/train_heterogeneous_graph.py index 4f56e25..54ab082 100644 --- a/examples/train_heterogeneous_graph.py +++ b/examples/train_heterogeneous_graph.py @@ -47,7 +47,7 @@ help='File with ip-address. Worker 0 creates this file and all others read it ') -parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi'], +parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'], help='Communication backend to use ' ) diff --git a/examples/train_homogeneous_graph_advanced.py b/examples/train_homogeneous_graph_advanced.py index 64551fe..5fb7210 100644 --- a/examples/train_homogeneous_graph_advanced.py +++ b/examples/train_homogeneous_graph_advanced.py @@ -52,7 +52,7 @@ help='SAR log level ') -parser.add_argument('--backend', default='ccl', type=str, choices=['ccl', 'nccl', 'mpi'], +parser.add_argument('--backend', default='ccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'], help='Communication backend to use ') parser.add_argument( diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic.py index 14ee691..9b7ed9d 100644 --- a/examples/train_homogeneous_graph_basic.py +++ b/examples/train_homogeneous_graph_basic.py @@ -47,7 +47,7 @@ help='File with ip-address. Worker 0 creates this file and all others read it ') -parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi'], +parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'], help='Communication backend to use ' ) diff --git a/examples/train_homogeneous_sampling_basic.py b/examples/train_homogeneous_sampling_basic.py index 500fdcc..d9de07d 100644 --- a/examples/train_homogeneous_sampling_basic.py +++ b/examples/train_homogeneous_sampling_basic.py @@ -50,7 +50,7 @@ help='File with ip-address. Worker 0 creates this file and all others read it ') -parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi'], +parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'], help='Communication backend to use ' ) From 7b982d58471d51798416b5b1c2d1d3d263dbff4a Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 4 Jul 2023 13:52:38 +0200 Subject: [PATCH 44/58] Add predict-then-propagate example --- examples/train_dist_appnp_with_sar.py | 227 ++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 examples/train_dist_appnp_with_sar.py diff --git a/examples/train_dist_appnp_with_sar.py b/examples/train_dist_appnp_with_sar.py new file mode 100644 index 0000000..b5efcb9 --- /dev/null +++ b/examples/train_dist_appnp_with_sar.py @@ -0,0 +1,227 @@ +from argparse import ArgumentParser + +import dgl # type: ignore +from dgl.nn.pytorch.conv import APPNPConv + +import sar + +import time +import torch +import torch.nn.functional as F +from torch import nn +import torch.distributed as dist + +parser = ArgumentParser(description="APPNP example") + +parser.add_argument("--partitioning-json-file", type=str, default="", + help="Path to the .json file containing partitioning information") + +parser.add_argument("--ip-file", type=str, default="./ip_file", + help="File with ip-address. Worker 0 creates this file and all others read it") + +parser.add_argument("--backend", type=str, default="nccl", + choices=["ccl", "nccl", "mpi"], + help="Communication backend to use") + +parser.add_argument("--cpu-run", action="store_true", + help="Run on CPUs if set, otherwise run on GPUs") + +parser.add_argument("--train-iters", type=int, default=100, + help="number of training iterations") + +parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + +parser.add_argument("--rank", type=int, default=0, + help="Rank of the current worker") + +parser.add_argument("--world-size", type=int, default=2, + help="Number of workers") + +parser.add_argument("--hidden-layer-dim", type=int, default=[64], nargs="+", + help="Dimension of GNN hidden layer") + +parser.add_argument("--k", type=int, default=10, + help="Number of propagation steps") + +parser.add_argument("--alpha", type=float, default=0.1, + help="Teleport Probability") + +parser.add_argument("--in-drop", type=float, default=0.5, + help="input feature dropout") + +parser.add_argument("--edge-drop", type=float, default=0.5, + help="edge propagation dropout") + +class APPNP(nn.Module): + def __init__( + self, + g, + in_feats, + hiddens, + n_classes, + activation, + feat_drop, + edge_drop, + alpha, + k, + ): + super(APPNP, self).__init__() + self.g = g + self.layers = nn.ModuleList() + # input layer + self.layers.append(nn.Linear(in_feats, hiddens[0])) + # hidden layers + for i in range(1, len(hiddens)): + self.layers.append(nn.Linear(hiddens[i - 1], hiddens[i])) + # output layer + self.layers.append(nn.Linear(hiddens[-1], n_classes)) + self.activation = activation + if feat_drop: + self.feat_drop = nn.Dropout(feat_drop) + else: + self.feat_drop = lambda x: x + self.propagate = APPNPConv(k, alpha, edge_drop) + self.reset_parameters() + + def reset_parameters(self): + for layer in self.layers: + layer.reset_parameters() + + def forward(self, features): + # prediction step + h = features + h = self.feat_drop(h) + h = self.activation(self.layers[0](h)) + for layer in self.layers[1:-1]: + h = self.activation(layer(h)) + h = self.layers[-1](self.feat_drop(h)) + # propagation step + h = self.propagate(self.g, h) + return h + +def evaluate(model, features, labels, masks): + model.eval() + train_mask, val_mask, test_mask = masks['train_indices'], masks['val_indices'], masks['test_indices'] + with torch.no_grad(): + logits = model(features) + results = [] + for mask in [train_mask, val_mask, test_mask]: + n_correct = (logits[mask].argmax(1) == + labels[mask]).float().sum() + results.extend([n_correct, mask.numel()]) + + acc_vec = torch.FloatTensor(results) + # Sum the n_correct, and number of mask elements across all workers + sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True) + (train_acc, val_acc, test_acc) = \ + (acc_vec[0] / acc_vec[1], + acc_vec[2] / acc_vec[3], + acc_vec[4] / acc_vec[5]) + + return train_acc, val_acc, test_acc + +def main(): + args = parser.parse_args() + print('args', args) + + use_gpu = torch.cuda.is_available() and not args.cpu_run + device = torch.device('cuda' if use_gpu else 'cpu') + + # Obtain the ip address of the master through the network file system + master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) + sar.initialize_comms(args.rank, + args.world_size, + master_ip_address, + args.backend) + + # Load DGL partition data + partition_data = sar.load_dgl_partition_data( + args.partitioning_json_file, args.rank, device) + + # Obtain train,validation, and test masks + # These are stored as node features. Partitioning may prepend + # the node type to the mask names. So we use the convenience function + # suffix_key_lookup to look up the mask name while ignoring the + # arbitrary node type + masks = {} + for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'], + ['train_indices', 'val_indices', 'test_indices']): + boolean_mask = sar.suffix_key_lookup(partition_data.node_features, + mask_name) + masks[indices_name] = boolean_mask.nonzero( + as_tuple=False).view(-1).to(device) + + labels = sar.suffix_key_lookup(partition_data.node_features, + 'label').long().to(device) + + # Obtain the number of classes by finding the max label across all workers + num_labels = labels.max() + 1 + sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True) + num_labels = num_labels.item() + + features = sar.suffix_key_lookup(partition_data.node_features, 'feat').to(device) + full_graph_manager = sar.construct_full_graph(partition_data).to(device) + + #We do not need the partition data anymore + del partition_data + + gnn_model = APPNP( + full_graph_manager, + features.size(1), + args.hidden_layer_dim, + num_labels, + F.relu, + args.in_drop, + args.edge_drop, + args.alpha, + args.k) + + gnn_model.reset_parameters() + print('model', gnn_model) + + # Synchronize the model parmeters across all workers + sar.sync_params(gnn_model) + + # Obtain the number of labeled nodes in the training + # This will be needed to properly obtain a cross entropy loss + # normalized by the number of training examples + n_train_points = torch.LongTensor([masks['train_indices'].numel()]) + sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True) + n_train_points = n_train_points.item() + + optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr, weight_decay=5e-4) + for train_iter_idx in range(args.train_iters): + # Train + gnn_model.train() + t_1 = time.time() + logits = gnn_model(features) + loss = F.cross_entropy(logits[masks['train_indices']], + labels[masks['train_indices']], reduction='sum') / n_train_points + + optimizer.zero_grad() + loss.backward() + # Do not forget to gather the parameter gradients from all workers + sar.gather_grads(gnn_model) + optimizer.step() + train_time = time.time() - t_1 + + if (train_iter_idx + 1) % 10 == 0: + train_acc, val_acc, test_acc = evaluate(gnn_model, features, labels, masks) + + result_message = ( + f"iteration [{train_iter_idx + 1}/{args.train_iters}] | " + ) + result_message += ', '.join([ + f"train loss={loss:.4f}, " + f"Accuracy: " + f"train={train_acc:.4f} " + f"valid={val_acc:.4f} " + f"test={test_acc:.4f} " + f" | train time = {train_time} " + f" |" + ]) + print(result_message, flush=True) + +if __name__ == '__main__': + main() From a994a12f43936003e5daefe488335c8eb661a225 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 31 May 2023 12:17:24 +0200 Subject: [PATCH 45/58] Unify partitioning script / Add cora, citeseer, pubmed --- examples/README.md | 7 +- examples/partition_arxiv_products.py | 93 -------------------- examples/partition_graph.py | 126 +++++++++++++++++++++++++++ examples/partition_mag.py | 79 ----------------- 4 files changed, 131 insertions(+), 174 deletions(-) delete mode 100644 examples/partition_arxiv_products.py create mode 100644 examples/partition_graph.py delete mode 100644 examples/partition_mag.py diff --git a/examples/README.md b/examples/README.md index 1a1c8f0..573a8d6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,8 +1,11 @@ ## Graph partitioning -``partition_arxiv_products.py`` partitions the ogbn-arxiv and ogbn-products graphs from the [Open Graph Benchmarks](https://ogb.stanford.edu/) using DGL's metis-based partitioning. The general technique there can be used to partition arbitrary homogeneous graphs. Note that all node-related information must be included in the graph's ``ndata`` dictionary so that they are correctly partitioned with the graph. Similarly, edge-related information must be included in the graph's ``edata`` dictionary +The ``partition_graph.py`` script can be used to partition both homogeneous and heterogeneous graphs. It utilizes DGL's metis-based partitioning algorithm to divide the graphs into smaller partitions. Note that all node-related information must be included in the graph's ``ndata`` dictionary so that they are correctly partitioned with the graph. +Similarly, edge-related information must be included in the graph's ``edata`` dictionary -``partition_mag.py`` partitions the [ogbn-mag](https://ogb.stanford.edu/docs/nodeprop/#ogbn-mag) heterogeneous graph. Again, all node-related information are included in the graph's ``ndata`` for the relevant node types +### Supported datasets: +- ogbn-products, ogbn-arxiv, ogb-mag from [Open Graph Benchmarks](https://ogb.stanford.edu/) +- cora, citeseer, pubmed ## Full-batch Training diff --git a/examples/partition_arxiv_products.py b/examples/partition_arxiv_products.py deleted file mode 100644 index c77a19b..0000000 --- a/examples/partition_arxiv_products.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2022 Intel Corporation - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from argparse import ArgumentParser -import dgl # type:ignore -import torch -from ogb.nodeproppred import DglNodePropPredDataset # type:ignore - - -parser = ArgumentParser(description="Graph partitioning for ogbn-arxiv and ogbn-products") - -parser.add_argument( - "--dataset-root", - type=str, - default="./datasets/", - help="The OGB datasets folder " -) - -parser.add_argument( - "--dataset-name", - type=str, - default="ogbn-arxiv", - choices=['ogbn-arxiv', 'ogbn-products'], - help="Dataset name. ogbn-arxiv or ogbn-products " -) - -parser.add_argument( - "--partition-out-path", - type=str, - default="./partition_data/", - help="Path to the output directory for the partition data " -) - - -parser.add_argument( - '--num-partitions', - default=2, - type=int, - help='Number of graph partitions to generate') - - -def main(): - args = parser.parse_args() - dataset = DglNodePropPredDataset(name=args.dataset_name, - root=args.dataset_root) - graph = dataset[0][0] - graph = dgl.to_bidirected(graph, copy_ndata=True) - graph = dgl.add_self_loop(graph) - - labels = dataset[0][1].view(-1) - split_idx = dataset.get_idx_split() - - def _idx_to_mask(idx_tensor): - mask = torch.BoolTensor(graph.number_of_nodes()).fill_(False) - mask[idx_tensor] = True - return mask - - train_mask, val_mask, test_mask = map( - _idx_to_mask, [split_idx['train'], split_idx['valid'], split_idx['test']]) - features = graph.ndata['feat'] - graph.ndata.clear() - for name, val in zip(['train_mask', 'val_mask', 'test_mask', 'labels', 'features'], - [train_mask, val_mask, test_mask, labels, features]): - graph.ndata[name] = val - - dgl.distributed.partition_graph( - graph, args.dataset_name, - args.num_partitions, - args.partition_out_path, - num_hops=1, - balance_ntypes=train_mask, - balance_edges=True) - - -if __name__ == '__main__': - main() diff --git a/examples/partition_graph.py b/examples/partition_graph.py new file mode 100644 index 0000000..94198c0 --- /dev/null +++ b/examples/partition_graph.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022 Intel Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from argparse import ArgumentParser +import dgl # type:ignore +import torch +from ogb.nodeproppred import DglNodePropPredDataset # type:ignore +from dgl.data import ( + CiteseerGraphDataset, + CoraGraphDataset, + PubmedGraphDataset, +) + +SUPPORTED_DATASETS = { + "cora": CoraGraphDataset, + "citeseer": CiteseerGraphDataset, + "pubmed": PubmedGraphDataset, + "ogbn-products": DglNodePropPredDataset, + "ogbn-arxiv": DglNodePropPredDataset, + "ogbn-mag": DglNodePropPredDataset, +} + +parser = ArgumentParser(description="Graph partitioning for ogbn-arxiv and ogbn-products") + +parser.add_argument("--dataset-root", type=str, default="./datasets/", + help="The OGB datasets folder") + +parser.add_argument("--dataset-name", type=str, default="ogbn-arxiv", + choices=["ogbn-arxiv", "ogbn-products", "ogbn-mag", + "cora", "citeseer", "pubmed"], + help="Dataset name. ogbn-arxiv or ogbn-products") + +parser.add_argument("--partition-out-path", type=str, default="./partition_data/", + help="Path to the output directory for the partition data") + +parser.add_argument("--num-partitions", type=int, default=2, + help="Number of graph partitions to generate") + +def get_dataset(args): + dataset_name = args.dataset_name + if dataset_name in ["cora", "citeseer", "pubmed"]: + return SUPPORTED_DATASETS[dataset_name](args.dataset_root) + else: + return SUPPORTED_DATASETS[dataset_name](dataset_name, args.dataset_root) + +def prepare_features(args, dataset, graph): + if args.dataset_name in ["cora", "citeseer", "pubmed"]: + assert all([x in graph.ndata.keys() for x in ["train_mask", "val_mask", "test_mask"]]) + return + + split_idx = dataset.get_idx_split() + ntype = "paper" if args.dataset_name == "ogbn-mag" else None + + def idx_to_mask(idx_tensor): + mask = torch.BoolTensor(graph.number_of_nodes(ntype)).fill_(False) + if ntype: + mask[idx_tensor[ntype]] = True + else: + mask[idx_tensor] = True + return mask + + train_mask, val_mask, test_mask = map( + idx_to_mask, [split_idx["train"], split_idx["valid"], split_idx["test"]]) + + if "feat" in graph.ndata.keys(): + features = graph.ndata["feat"] + else: + features = graph.ndata["features"] + + graph.ndata.clear() + + labels = dataset[0][1] + if ntype: + features = features[ntype] + labels = labels[ntype] + labels = labels.view(-1) + + for name, val in zip(["train_mask", "val_mask", "test_mask", "labels", "features"], + [train_mask, val_mask, test_mask, labels, features]): + graph.ndata[name] = {ntype: val} if ntype else val + +def main(): + args = parser.parse_args() + dataset = get_dataset(args) + dataset_name = args.dataset_name + if dataset_name.startswith("ogbn"): + graph = dataset[0][0] + else: + graph = dataset[0] + + if dataset_name != "ogbn-mag": + graph = dgl.remove_self_loop(graph) + graph = dgl.to_bidirected(graph, copy_ndata=True) + graph = dgl.add_self_loop(graph) + + prepare_features(args, dataset, graph) + balance_ntypes = graph.ndata["train_mask"] \ + if dataset_name in ["ogbn-products", "ogbn-arxiv"] else None + dgl.distributed.partition_graph( + graph, args.dataset_name, + args.num_partitions, + args.partition_out_path, + num_hops=1, + balance_ntypes=balance_ntypes, + balance_edges=True) + + +if __name__ == "__main__": + main() diff --git a/examples/partition_mag.py b/examples/partition_mag.py deleted file mode 100644 index 0e24512..0000000 --- a/examples/partition_mag.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2022 Intel Corporation - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from argparse import ArgumentParser -import dgl # type:ignore -import torch -from ogb.nodeproppred import DglNodePropPredDataset # type:ignore - - -parser = ArgumentParser(description="Graph partitioning for ogbn-mag") - -parser.add_argument( - "--dataset-root", - type=str, - default="./datasets/", - help="The OGB datasets folder " -) - -parser.add_argument( - "--partition-out-path", - type=str, - default="./partition_data/", - help="Path to the output directory for the partition data " -) - - -parser.add_argument( - '--num-partitions', - default=2, - type=int, - help='Number of graph partitions to generate') - - -def main(): - args = parser.parse_args() - dataset = DglNodePropPredDataset(name='ogbn-mag', - root=args.dataset_root) - graph = dataset[0][0] - labels = dataset[0][1]['paper'].view(-1) - split_idx = dataset.get_idx_split() - - def idx_to_mask(idx_tensor): - mask = torch.BoolTensor(graph.number_of_nodes('paper')).fill_(False) - mask[idx_tensor] = True - return mask - train_mask, val_mask, test_mask = map( - idx_to_mask, [split_idx['train']['paper'], split_idx['valid']['paper'], split_idx['test']['paper']]) - features = graph.ndata['feat']['paper'] - for name, val in zip(['train_mask', 'val_mask', 'test_mask', 'labels', 'features'], - [train_mask, val_mask, test_mask, labels, features]): - graph.ndata[name] = {'paper': val} - - dgl.distributed.partition_graph( - graph, 'ogbn-mag', - args.num_partitions, - args.partition_out_path, - num_hops=1, - balance_edges=True) - - -if __name__ == '__main__': - main() From 6b2ff7f2c1f31e8e2d54afb7cddb498842751097 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 10 Jul 2023 09:16:20 +0200 Subject: [PATCH 46/58] Trigger CI --- examples/train_dist_appnp_with_sar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/train_dist_appnp_with_sar.py b/examples/train_dist_appnp_with_sar.py index b5efcb9..5787385 100644 --- a/examples/train_dist_appnp_with_sar.py +++ b/examples/train_dist_appnp_with_sar.py @@ -163,7 +163,7 @@ def main(): features = sar.suffix_key_lookup(partition_data.node_features, 'feat').to(device) full_graph_manager = sar.construct_full_graph(partition_data).to(device) - #We do not need the partition data anymore + # We do not need the partition data anymore del partition_data gnn_model = APPNP( From 25ba5fa2f28829248f1789e92d496d50511aa148 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 18 Jul 2023 10:13:36 +0200 Subject: [PATCH 47/58] Update Readme with more details --- examples/README.md | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/examples/README.md b/examples/README.md index c4a0172..e0733bb 100644 --- a/examples/README.md +++ b/examples/README.md @@ -44,13 +44,28 @@ python /home/ubuntu/workspace/dgl/tools/launch.py \ ``` ## Correct and Smooth -Example taken from [DGL implemenetation](https://github.com/dmlc/dgl/tree/master/examples/pytorch/correct_and_smooth) of C&S. Code is adjusted to perform distributed training with SAR. For instance, you can run the example with following commands: +Example taken from [DGL implemenetation](https://github.com/dmlc/dgl/tree/master/examples/pytorch/correct_and_smooth) of C&S. Code is adjusted to perform distributed training with SAR. Introduced modifications change the way data normalization is performed - workers need to communicate with each other to calculate mean and standard deviation for the entire dataset (not just their partition). Moreover, workers need to be synchronized with each other to calculate sigma value required during "correct" phase. + +For instance, you can run the example with following commands (2 machines scenario): * **Plain MLP + C&S** -```shell -python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale -``` + * Rank 0 machine: + ```shell + python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale + ``` + + * Rank 1 machine: + ```shell + python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale + ``` * **Plain Linear + C&S** -```shell -python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale + * Rank 0 machine: + ```shell + python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale + ``` + + * Rank 1 machine: + ```shell + python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale + ``` From 31a1af3e1ce99c2c46085521b0c37e944441a1a6 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 18 Jul 2023 10:39:08 +0200 Subject: [PATCH 48/58] Update initialization docs --- docs/source/comm.rst | 8 ++++---- docs/source/quick_start.rst | 4 ++-- sar/comm.py | 13 +++++++------ 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/docs/source/comm.rst b/docs/source/comm.rst index e77b1ad..022ad71 100644 --- a/docs/source/comm.rst +++ b/docs/source/comm.rst @@ -16,12 +16,12 @@ In an environment with a networked file system, initializing ``torch.distributed comm_device = torch.device('cuda') else: comm_device = torch.device('cpu') - - master_ip_address = sar.nfs_ip_init(rank,path_to_ip_file) - sar.initialize_comms(rank,world_size, master_ip_address,backend_name,comm_device) -.. + master_ip_address = sar.nfs_ip_init(rank, path_to_ip_file) + sar.initialize_comms(rank, world_size, master_ip_address, backend_name, comm_device) +.. +:func:`sar.initialize_comms` tries to initialize the torch.distributed process group, but only if it has not been initialized. User can initialize process group on his own before calling :func:`sar.initialize_comms`. :func:`sar.nfs_ip_init` communicates the master's ip address to the workers through the file system. In the absence of a networked file system, you should develop your own mechanism to communicate the master's ip address. You can specify the name of the socket that will be used for communication with `SAR_SOCKET_NAME` environment variable (if not specified, the first available socket will be selected). diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index 2dfef57..11c6015 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -61,8 +61,8 @@ Initialize the communication through a call to :func:`sar.initialize_comms` , sp comm_device = torch.device('cuda') else: comm_device = torch.device('cpu') - master_ip_address = sar.nfs_ip_init(rank,path_to_ip_file) - sar.initialize_comms(rank,world_size, master_ip_address,backend_name,comm_device) + master_ip_address = sar.nfs_ip_init(rank, path_to_ip_file) + sar.initialize_comms(rank, world_size, master_ip_address, backend_name, comm_device) .. diff --git a/sar/comm.py b/sar/comm.py index 56b07df..f69c978 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -158,9 +158,6 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, else: _comm_device = torch.device('cpu') -# if is_initialized(): - # return - if backend == 'ccl': # pylint: disable=unused-import try: @@ -195,9 +192,13 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, os.environ['I_MPI_COMM_WORLD'] = str(_world_size) os.environ['I_MPI_COMM_RANK'] = str(_rank) - - dist.init_process_group( - backend=backend, rank=_rank, world_size=_world_size) + try: + dist.init_process_group( + backend=backend, rank=_rank, world_size=_world_size) + except: + logger.error("SAR was unable to initialize torch.distributed process group. " + "You can try to do it manually before calling sar.initialize_comms") + raise else: assert dist.get_backend() in ['ccl', 'nccl', 'mpi'], 'backend must be ccl, nccl, or mpi' From e54b91681d812dfbba8d00471bb1c80823095b07 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 18 Jul 2023 13:27:31 +0200 Subject: [PATCH 49/58] Added test for all_to_all function --- tests/base_utils.py | 4 ++-- tests/test_comm.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/tests/base_utils.py b/tests/base_utils.py index 03d64c3..834270d 100644 --- a/tests/base_utils.py +++ b/tests/base_utils.py @@ -8,7 +8,7 @@ -def initialize_worker(rank, world_size, tmp_dir): +def initialize_worker(rank, world_size, tmp_dir, backend="ccl"): """ Boilerplate code for setting up connection between workers @@ -22,7 +22,7 @@ def initialize_worker(rank, world_size, tmp_dir): torch.seed() ip_file = os.path.join(tmp_dir, 'ip_file') master_ip_address = sar.nfs_ip_init(rank, ip_file) - sar.initialize_comms(rank, world_size, master_ip_address, 'ccl') + sar.initialize_comms(rank, world_size, master_ip_address, backend) def get_random_graph(): diff --git a/tests/test_comm.py b/tests/test_comm.py index b6dbd6d..bf7c24e 100644 --- a/tests/test_comm.py +++ b/tests/test_comm.py @@ -77,3 +77,32 @@ def gather_grads(mp_dict, rank, world_size, tmp_dir): for rank in range(1, world_size): for i in range(len(mp_dict["result_0"])): assert torch.all(torch.eq(mp_dict["result_0"][i], mp_dict[f"result_{rank}"][i])) + + +@pytest.mark.parametrize("backend", ["ccl", "gloo"]) +@pytest.mark.parametrize("world_size", [2, 4, 8]) +@sar_test +def test_all_to_all(world_size, backend): + """ + Checks whether all_to_all operation works as expected. Test is + designed is such a way, that after calling all_to_all, each worker + should receive a list of tensors with values equal to their rank + """ + import torch + def all_to_all(mp_dict, rank, world_size, tmp_dir, **kwargs): + import sar + from base_utils import initialize_worker + try: + initialize_worker(rank, world_size, tmp_dir, backend=kwargs["backend"]) + send_tensors_list = [torch.tensor([x] * world_size) for x in range(world_size)] + recv_tensors_list = [torch.tensor([0] * world_size) for _ in range(world_size)] + sar.comm.all_to_all(recv_tensors_list, send_tensors_list) + mp_dict[f"result_{rank}"] = recv_tensors_list + except Exception as e: + mp_dict["traceback"] = str(traceback.format_exc()) + mp_dict["exception"] = e + + mp_dict = run_workers(all_to_all, world_size, backend=backend) + for rank in range(world_size): + for tensor in mp_dict[f"result_{rank}"]: + assert torch.all(torch.eq(tensor, torch.tensor([rank]*world_size))) From eb392a78b447c5b375d13843c03f6e41fdd459e5 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 18 Jul 2023 14:37:14 +0200 Subject: [PATCH 50/58] Added backend param for all comm.py tests --- tests/test_comm.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/test_comm.py b/tests/test_comm.py index bf7c24e..9dbfc99 100644 --- a/tests/test_comm.py +++ b/tests/test_comm.py @@ -5,9 +5,10 @@ # independently loaded inside each process +@pytest.mark.parametrize("backend", ["ccl", "gloo"]) @pytest.mark.parametrize('world_size', [2, 4, 8]) @sar_test -def test_sync_params(world_size): +def test_sync_params(world_size, backend): """ Checks whether model's parameters are the same across all workers after calling sync_params function. Parameters of worker 0 @@ -15,12 +16,12 @@ def test_sync_params(world_size): sync_params should be the same """ import torch - def sync_params(mp_dict, rank, world_size, tmp_dir): + def sync_params(mp_dict, rank, world_size, tmp_dir, **kwargs): import sar from tests.base_utils import initialize_worker from models import GNNModel try: - initialize_worker(rank, world_size, tmp_dir) + initialize_worker(rank, world_size, tmp_dir, backend=kwargs["backend"]) model = GNNModel(16, 4) if rank == 0: mp_dict[f"result_{rank}"] = deepcopy(model.state_dict()) @@ -31,21 +32,22 @@ def sync_params(mp_dict, rank, world_size, tmp_dir): mp_dict["traceback"] = str(traceback.format_exc()) mp_dict["exception"] = e - mp_dict = run_workers(sync_params, world_size) + mp_dict = run_workers(sync_params, world_size, backend=backend) for rank in range(1, world_size): for key in mp_dict[f"result_0"].keys(): assert torch.all(torch.eq(mp_dict[f"result_0"][key], mp_dict[f"result_{rank}"][key])) +@pytest.mark.parametrize("backend", ["ccl", "gloo"]) @pytest.mark.parametrize('world_size', [2, 4, 8]) @sar_test -def test_gather_grads(world_size): +def test_gather_grads(world_size, backend): """ Checks whether parameter's gradients are the same across all workers after calling gather_grads function """ import torch - def gather_grads(mp_dict, rank, world_size, tmp_dir): + def gather_grads(mp_dict, rank, world_size, tmp_dir, **kwargs): import sar import dgl import torch.nn.functional as F @@ -53,7 +55,7 @@ def gather_grads(mp_dict, rank, world_size, tmp_dir): from base_utils import initialize_worker, get_random_graph, synchronize_processes,\ load_partition_data try: - initialize_worker(rank, world_size, tmp_dir) + initialize_worker(rank, world_size, tmp_dir, backend=kwargs["backend"]) graph_name = 'dummy_graph' if rank == 0: g = get_random_graph() @@ -73,7 +75,7 @@ def gather_grads(mp_dict, rank, world_size, tmp_dir): mp_dict["traceback"] = str(traceback.format_exc()) mp_dict["exception"] = e - mp_dict = run_workers(gather_grads, world_size) + mp_dict = run_workers(gather_grads, world_size, backend=backend) for rank in range(1, world_size): for i in range(len(mp_dict["result_0"])): assert torch.all(torch.eq(mp_dict["result_0"][i], mp_dict[f"result_{rank}"][i])) From 80f9a753635aef95ebca525b75f956d617f39ca4 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 18 Jul 2023 14:49:41 +0200 Subject: [PATCH 51/58] Added docstrings for all_to_all functions --- sar/comm.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/sar/comm.py b/sar/comm.py index 49d928d..d3c1bf7 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -294,7 +294,8 @@ def all_to_all(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor def all_reduce(red_tensor: torch.Tensor, op: dist.ReduceOp, move_to_comm_device: bool = False): # pylint: disable=invalid-name - """ wrapper around dist.all_reduce + """ + Wrapper around dist.all_reduce :param red_tensor: reduction tensor :type red_tensor: torch.Tensor @@ -302,8 +303,6 @@ def all_reduce(red_tensor: torch.Tensor, op: dist.ReduceOp, :type op: dist.ReduceOp :param move_to_comm_device: Move to comm device or not :type move_to_comm_device: bool - - """ if move_to_comm_device: @@ -315,6 +314,16 @@ def all_reduce(red_tensor: torch.Tensor, op: dist.ReduceOp, def all_to_all_rounds(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor]): + """ + All_to_all wrapper which breaks down the collective call into multiple + torch.distributed.all_to_all calls so that the size of the data in each + call is below Config.max_collective_size + + :param recv_tensors: List of tensors to receive from other workers + :type recv_tensors: List[torch.Tensor] + :param send_tensors: List of tensor to send to other workers + :type send_tensors: List[torch.Tensor] + """ if Config.max_collective_size == 0: all_to_all_gloo_support(recv_tensors, send_tensors) else: @@ -334,6 +343,16 @@ def all_to_all_rounds(recv_tensors: List[torch.Tensor], send_tensors: List[torch def all_to_all_gloo_support(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor]): + """ + Since gloo backend doesn't support all_to_all function, SAR implements it + with multiple asynchronous sends (torch.dist.isend). For every other backend + torch.dist.all_to_all is used. + + :param recv_tensors: List of tensors to receive from other workers + :type recv_tensors: List[torch.Tensor] + :param send_tensors: List of tensor to send to other workers + :type send_tensors: List[torch.Tensor] + """ if backend() == 'gloo': send_requests = [] for i in range(world_size()): From 720018663640bcdcdc73a0db60342f1502c2eab3 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 18 Jul 2023 14:59:32 +0200 Subject: [PATCH 52/58] Update docs with additional information about gloo backend --- docs/source/comm.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/comm.rst b/docs/source/comm.rst index 7266102..f6b72f2 100644 --- a/docs/source/comm.rst +++ b/docs/source/comm.rst @@ -4,7 +4,7 @@ SAR's communication routines ============================= -SAR uses only two types of collective communication calls: ``all_to_all`` and ``all_reduce``. This choice was made to improve scalability by avoiding any point-to-point communication. SAR supports four backends, which are ``ccl``, ``nccl``, ``mpi`` and ``gloo``. Nvidia's ``nccl`` is already included in the PyTorch distribution and it is the natural choice when training on GPUs. +SAR uses only two types of collective communication calls: ``all_to_all`` and ``all_reduce``. This choice was made to improve scalability by avoiding any point-to-point communication. SAR supports four backends, which are ``ccl``, ``nccl``, ``mpi`` and ``gloo``. (Note: Using ``gloo`` backend may not be as optimal as using other backends, because it doesn't support ``all_to_all`` routine - SAR must use its own implementation, which uses multiple asynchronous sends (torch.dist.isend) between workers). Nvidia's ``nccl`` is already included in the PyTorch distribution and it is the natural choice when training on GPUs. The ``ccl`` backend uses `Intel's OneCCL `_ library. You can install the PyTorch bindings for OneCCL `here `_ . ``ccl`` is the preferred backend when training on CPUs. From 06b55b43a1222bc303e91cd8e8091d74d92d34c3 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 17 Jul 2023 15:14:52 +0200 Subject: [PATCH 53/58] Sign example --- examples/SIGN/README.md | 18 ++ examples/SIGN/train_sign_with_sar.py | 305 +++++++++++++++++++++++++++ examples/partition_graph.py | 15 +- 3 files changed, 332 insertions(+), 6 deletions(-) create mode 100644 examples/SIGN/README.md create mode 100644 examples/SIGN/train_sign_with_sar.py diff --git a/examples/SIGN/README.md b/examples/SIGN/README.md new file mode 100644 index 0000000..2306a27 --- /dev/null +++ b/examples/SIGN/README.md @@ -0,0 +1,18 @@ +## SIGN: Scalable Inception Graph Neural Networks + +Original script: https://github.com/dmlc/dgl/tree/master/examples/pytorch/sign + +Provided `train_sign_with_sar.py` script is an example how to intergrate SAR to preprocess graph data for training. + +### Results +Obtained results for two partitions: +- ogbn-products: 0.7832 +- reddit: 0.9639 + +### Run command: + +``` +python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 0 --world-size 2 + +python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 0 --world-size 2 +``` \ No newline at end of file diff --git a/examples/SIGN/train_sign_with_sar.py b/examples/SIGN/train_sign_with_sar.py new file mode 100644 index 0000000..c8870dd --- /dev/null +++ b/examples/SIGN/train_sign_with_sar.py @@ -0,0 +1,305 @@ +import argparse +import os +import time + +import dgl +import dgl.function as fn + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import sar + +def load_dataset(filename, rank, device): + partition_data = sar.load_dgl_partition_data(filename, rank, device) + # Obtain train,validation, and test masks + # These are stored as node features. Partitioning may prepend + # the node type to the mask names. So we use the convenience function + # suffix_key_lookup to look up the mask name while ignoring the + # arbitrary node type + masks = {} + for mask_name, indices_name in zip(["train_mask", "val_mask", "test_mask"], + ["train_indices", "val_indices", "test_indices"]): + boolean_mask = sar.suffix_key_lookup(partition_data.node_features, + mask_name) + masks[indices_name] = boolean_mask.nonzero( + as_tuple=False).view(-1).to(device) + print(partition_data.node_features.keys()) + + label_name, feature_name = ('feat', 'label') if 'reddit' in filename \ + else ('features', 'labels') + labels = sar.suffix_key_lookup(partition_data.node_features, + label_name).long().to(device) + + # Obtain the number of classes by finding the max label across all workers + n_classes = labels.max() + 1 + sar.comm.all_reduce(n_classes, torch.distributed.ReduceOp.MAX, move_to_comm_device=True) + n_classes = n_classes.item() + + features = sar.suffix_key_lookup(partition_data.node_features, feature_name).to(device) + full_graph_manager = sar.construct_full_graph(partition_data).to(device) + + full_graph_manager.ndata["feat"] = features + full_graph_manager.ndata["label"] = labels + return full_graph_manager, n_classes, \ + masks["train_indices"], masks["val_indices"], masks["test_indices"], + +class FeedForwardNet(nn.Module): + def __init__(self, in_feats, hidden, out_feats, n_layers, dropout): + super(FeedForwardNet, self).__init__() + self.layers = nn.ModuleList() + self.n_layers = n_layers + if n_layers == 1: + self.layers.append(nn.Linear(in_feats, out_feats)) + else: + self.layers.append(nn.Linear(in_feats, hidden)) + for _ in range(n_layers - 2): + self.layers.append(nn.Linear(hidden, hidden)) + self.layers.append(nn.Linear(hidden, out_feats)) + if self.n_layers > 1: + self.prelu = nn.PReLU() + self.dropout = nn.Dropout(dropout) + self.reset_parameters() + + def reset_parameters(self): + gain = nn.init.calculate_gain("relu") + for layer in self.layers: + nn.init.xavier_uniform_(layer.weight, gain=gain) + nn.init.zeros_(layer.bias) + + def forward(self, x): + for layer_id, layer in enumerate(self.layers): + x = layer(x) + if layer_id < self.n_layers - 1: + x = self.dropout(self.prelu(x)) + return x + + +class Model(nn.Module): + def __init__(self, in_feats, hidden, out_feats, R, n_layers, dropout): + super(Model, self).__init__() + self.dropout = nn.Dropout(dropout) + self.prelu = nn.PReLU() + self.inception_ffs = nn.ModuleList() + for hop in range(R + 1): + self.inception_ffs.append( + FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout) + ) + # self.linear = nn.Linear(hidden * (R + 1), out_feats) + self.project = FeedForwardNet( + (R + 1) * hidden, hidden, out_feats, n_layers, dropout + ) + + def forward(self, feats): + hidden = [] + for feat, ff in zip(feats, self.inception_ffs): + hidden.append(ff(feat)) + out = self.project(self.dropout(self.prelu(torch.cat(hidden, dim=-1)))) + return out + + +def calc_weight(g): + """ + Compute row_normalized(D^(-1/2)AD^(-1/2)) + """ + with g.local_scope(): + # compute D^(-0.5)*D(-1/2), assuming A is Identity + g.ndata["in_deg"] = g.in_degrees().float().pow(-0.5) + g.ndata["out_deg"] = g.out_degrees().float().pow(-0.5) + g.apply_edges(fn.u_mul_v("out_deg", "in_deg", "weight")) + # row-normalize weight + g.update_all(fn.copy_e("weight", "msg"), fn.sum("msg", "norm")) + g.apply_edges(fn.e_div_v("weight", "norm", "weight")) + return g.edata["weight"] + + +def preprocess(g, features, args): + """ + Pre-compute the average of n-th hop neighbors + """ + with torch.no_grad(): + g.edata["weight"] = calc_weight(g) + g.ndata["feat_0"] = features + for hop in range(1, args.R + 1): + g.update_all( + fn.u_mul_e(f"feat_{hop-1}", "weight", "msg"), + fn.sum("msg", f"feat_{hop}"), + ) + res = [] + for hop in range(args.R + 1): + res.append(g.ndata.pop(f"feat_{hop}")) + return res + + +def prepare_data(device, args): + data = load_dataset(args.partitioning_json_file, args.rank, device) + g, n_classes, train_nid, val_nid, test_nid = data + g = g.to(device) + in_feats = g.ndata["feat"].shape[1] + feats = preprocess(g, g.ndata["feat"], args) + labels = g.ndata["label"] + # move to device + train_nid = train_nid.to(device) + val_nid = val_nid.to(device) + test_nid = test_nid.to(device) + train_feats = [x[train_nid] for x in feats] + train_labels = labels[train_nid] + return ( + feats, + labels, + train_feats, + train_labels, + in_feats, + n_classes, + train_nid, + val_nid, + test_nid, + ) + +def evaluate(args, model, feats, labels, train, val, test): + with torch.no_grad(): + batch_size = args.eval_batch_size + if batch_size <= 0: + pred = model(feats) + else: + pred = [] + num_nodes = labels.shape[0] + n_batch = (num_nodes + batch_size - 1) // batch_size + for i in range(n_batch): + batch_start = i * batch_size + batch_end = min((i + 1) * batch_size, num_nodes) + batch_feats = [feat[batch_start:batch_end] for feat in feats] + pred.append(model(batch_feats)) + pred = torch.cat(pred) + + pred = torch.argmax(pred, dim=1) + correct = (pred == labels).float() + + # Sum the n_correct, and number of mask elements across all workers + results = [] + for mask in [train, val, test]: + n_correct = correct[mask].sum() + results.extend([n_correct, mask.numel()]) + + acc_vec = torch.FloatTensor(results) + # Sum the n_correct, and number of mask elements across all workers + sar.comm.all_reduce(acc_vec, op=torch.distributed.ReduceOp.SUM, move_to_comm_device=True) + (train_acc, val_acc, test_acc) = \ + (acc_vec[0] / acc_vec[1], + acc_vec[2] / acc_vec[3], + acc_vec[4] / acc_vec[5],) + + return train_acc, val_acc, test_acc + + +def main(args): + if args.gpu < 0: + device = "cpu" + else: + device = "cuda:{}".format(args.gpu) + + master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) + sar.initialize_comms(args.rank, + args.world_size, master_ip_address, + args.backend) + + data = prepare_data(device, args) + ( + feats, + labels, + train_feats, + train_labels, + in_size, + num_classes, + train_nid, + val_nid, + test_nid, + ) = data + + model = Model( + in_size, + args.num_hidden, + num_classes, + args.R, + args.ff_layer, + args.dropout, + ) + model = model.to(device) + if args.gpu == -1: + model = torch.nn.parallel.DistributedDataParallel(model) + else: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], output_device=device + ) + loss_fcn = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + sar.sync_params(model) + + best_epoch = 0 + best_val = 0 + best_test = 0 + + for epoch in range(1, args.num_epochs + 1): + with model.join(): + start = time.time() + model.train() + loss = loss_fcn(model(train_feats), train_labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if epoch % args.eval_every == 0: + model.eval() + acc = evaluate( + args, model, feats, labels, train_nid, val_nid, test_nid + ) + end = time.time() + log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start) + log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}".format( + *acc + ) + print(log) + if acc[1] > best_val: + best_val = acc[1] + best_epoch = epoch + best_test = acc[2] + + print( + "Best Epoch {}, Val {:.4f}, Test {:.4f}".format( + best_epoch, best_val, best_test + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="SIGN") + parser.add_argument("--partitioning-json-file", default="", type=str, + help="Path to the .json file containing partitioning information") + parser.add_argument("--ip-file", type=str, default="./ip_file", + help="File with ip-address. " + "Worker 0 creates this file and all others read it") + parser.add_argument("--backend", type=str, default="ccl", + choices=["ccl", "nccl", "mpi"], + help="Communication backend to use") + parser.add_argument("--rank", type=int, default=0, + help="Rank of the current worker") + parser.add_argument("--world-size", default=2, type=int, + help="Number of workers ") + parser.add_argument("--num-epochs", type=int, default=1000) + parser.add_argument("--num-hidden", type=int, default=256) + parser.add_argument("--R", type=int, default=3, help="number of hops") + parser.add_argument("--lr", type=float, default=0.003) + parser.add_argument("--dropout", type=float, default=0.5) + parser.add_argument("--gpu", type=int, default=-1) + parser.add_argument("--weight-decay", type=float, default=0) + parser.add_argument("--eval-every", type=int, default=10) + parser.add_argument("--eval-batch-size", type=int, default=250000, + help="evaluation batch size, -1 for full batch") + parser.add_argument("--ff-layer", type=int, default=2, help="number of feed-forward layers") + args = parser.parse_args() + + print(args) + main(args) diff --git a/examples/partition_graph.py b/examples/partition_graph.py index 94198c0..1697c29 100644 --- a/examples/partition_graph.py +++ b/examples/partition_graph.py @@ -32,20 +32,21 @@ "cora": CoraGraphDataset, "citeseer": CiteseerGraphDataset, "pubmed": PubmedGraphDataset, + 'reddit': RedditDataset, "ogbn-products": DglNodePropPredDataset, "ogbn-arxiv": DglNodePropPredDataset, "ogbn-mag": DglNodePropPredDataset, } -parser = ArgumentParser(description="Graph partitioning for ogbn-arxiv and ogbn-products") +parser = ArgumentParser(description="Graph partitioning for common graph datasets") parser.add_argument("--dataset-root", type=str, default="./datasets/", help="The OGB datasets folder") parser.add_argument("--dataset-name", type=str, default="ogbn-arxiv", - choices=["ogbn-arxiv", "ogbn-products", "ogbn-mag", - "cora", "citeseer", "pubmed"], - help="Dataset name. ogbn-arxiv or ogbn-products") + choices=['ogbn-arxiv', 'ogbn-products', 'ogbn-mag', + 'cora', 'citeseer', 'pubmed', 'reddit'], + help="Dataset name") parser.add_argument("--partition-out-path", type=str, default="./partition_data/", help="Path to the output directory for the partition data") @@ -57,12 +58,14 @@ def get_dataset(args): dataset_name = args.dataset_name if dataset_name in ["cora", "citeseer", "pubmed"]: return SUPPORTED_DATASETS[dataset_name](args.dataset_root) + elif dataset_name == 'reddit': + return SUPPORTED_DATASETS[dataset_name](self_loop=True, raw_dir=args.dataset_root) else: return SUPPORTED_DATASETS[dataset_name](dataset_name, args.dataset_root) def prepare_features(args, dataset, graph): - if args.dataset_name in ["cora", "citeseer", "pubmed"]: - assert all([x in graph.ndata.keys() for x in ["train_mask", "val_mask", "test_mask"]]) + if args.dataset_name in ['cora', 'citeseer', 'pubmed', 'reddit']: + assert all([x in graph.ndata.keys() for x in ['train_mask', 'val_mask', 'test_mask']]) return split_idx = dataset.get_idx_split() From 50bdbc48ef75528ca502d52cf10db5f57dc028d9 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 18 Jul 2023 15:00:50 +0200 Subject: [PATCH 54/58] stylistic changes --- tests/test_comm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_comm.py b/tests/test_comm.py index 9dbfc99..a39b58b 100644 --- a/tests/test_comm.py +++ b/tests/test_comm.py @@ -107,4 +107,4 @@ def all_to_all(mp_dict, rank, world_size, tmp_dir, **kwargs): mp_dict = run_workers(all_to_all, world_size, backend=backend) for rank in range(world_size): for tensor in mp_dict[f"result_{rank}"]: - assert torch.all(torch.eq(tensor, torch.tensor([rank]*world_size))) + assert torch.all(torch.eq(tensor, torch.tensor([rank] * world_size))) From 23f0d62de959eed615535567f8365e84f0bfb5b8 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Thu, 20 Jul 2023 11:52:07 +0200 Subject: [PATCH 55/58] Enable gloo in new examples --- examples/correct_and_smooth.py | 2 +- examples/train_dist_appnp_with_sar.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/correct_and_smooth.py b/examples/correct_and_smooth.py index 743e3e4..e7dba0b 100644 --- a/examples/correct_and_smooth.py +++ b/examples/correct_and_smooth.py @@ -23,7 +23,7 @@ parser.add_argument('--ip-file', default='./ip_file', type=str, help='File with ip-address. Worker 0 creates this file and all others read it' ) -parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi'], +parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi', 'gloo'], help='Communication backend to use' ) parser.add_argument('--cpu-run', action='store_true', diff --git a/examples/train_dist_appnp_with_sar.py b/examples/train_dist_appnp_with_sar.py index 5787385..0f075ab 100644 --- a/examples/train_dist_appnp_with_sar.py +++ b/examples/train_dist_appnp_with_sar.py @@ -20,7 +20,7 @@ help="File with ip-address. Worker 0 creates this file and all others read it") parser.add_argument("--backend", type=str, default="nccl", - choices=["ccl", "nccl", "mpi"], + choices=["ccl", "nccl", "mpi", "gloo"], help="Communication backend to use") parser.add_argument("--cpu-run", action="store_true", From c450efeacb42c82b7278a4579059d3ab0fb72fe0 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 25 Jul 2023 12:29:24 +0200 Subject: [PATCH 56/58] exchange_single_tensor with isend/recv path --- sar/comm.py | 9 ++++----- tests/test_comm.py | 28 +++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/sar/comm.py b/sar/comm.py index 9b1e68f..581ee37 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -396,18 +396,17 @@ def exchange_single_tensor(recv_idx: int, send_idx: int, :type recv_tensor: Tensor :param send_tensor: Tensor to send to the remote worker :type send_tensor: Tensor - - """ - ''' - ''' logger.debug( f'{rank()} : exchange_single_tensor on device {send_tensor.device} : {recv_idx}, {send_idx},{recv_tensor.size()},{send_tensor.size()}') dtype = send_tensor.dtype if send_idx == recv_idx == rank(): recv_tensor.copy_(send_tensor) + elif backend() == 'gloo': + send_request = dist.isend(send_tensor.to(comm_device()), send_idx) + dist.recv(recv_tensor.to(comm_device()), recv_idx) + dist.barrier() else: - send_tensors_list = [torch.Tensor([1.0]).to(dtype).to(comm_device()) for _ in range(world_size())] diff --git a/tests/test_comm.py b/tests/test_comm.py index a39b58b..e465f66 100644 --- a/tests/test_comm.py +++ b/tests/test_comm.py @@ -97,7 +97,7 @@ def all_to_all(mp_dict, rank, world_size, tmp_dir, **kwargs): try: initialize_worker(rank, world_size, tmp_dir, backend=kwargs["backend"]) send_tensors_list = [torch.tensor([x] * world_size) for x in range(world_size)] - recv_tensors_list = [torch.tensor([0] * world_size) for _ in range(world_size)] + recv_tensors_list = [torch.tensor([-1] * world_size) for _ in range(world_size)] sar.comm.all_to_all(recv_tensors_list, send_tensors_list) mp_dict[f"result_{rank}"] = recv_tensors_list except Exception as e: @@ -108,3 +108,29 @@ def all_to_all(mp_dict, rank, world_size, tmp_dir, **kwargs): for rank in range(world_size): for tensor in mp_dict[f"result_{rank}"]: assert torch.all(torch.eq(tensor, torch.tensor([rank] * world_size))) + + +@pytest.mark.parametrize("backend", ["ccl", "gloo"]) +@pytest.mark.parametrize("world_size", [2, 4, 8]) +@sar_test +def test_exchange_single_tensor(world_size, backend): + def exchange_single_tensor(mp_dict, rank, world_size, tmp_dir, **kwargs): + import torch + import sar + from base_utils import initialize_worker + try: + initialize_worker(rank, world_size, tmp_dir, backend=kwargs["backend"]) + send_idx = rank + recv_idx = rank + for _ in range(world_size): + send_tensor = torch.tensor([send_idx] * world_size) + recv_tensor = torch.tensor([-1] * world_size) + sar.comm.exchange_single_tensor(recv_idx, send_idx, recv_tensor, send_tensor) + assert torch.all(torch.eq(recv_tensor, torch.tensor([rank] * world_size))) + send_idx = (send_idx + 1) % world_size + recv_idx = (recv_idx - 1) % world_size + except Exception as e: + mp_dict["traceback"] = str(traceback.format_exc()) + mp_dict["exception"] = e + + mp_dict = run_workers(exchange_single_tensor, world_size, backend=backend) From 19cc356347242dae49cb3793fc062894e277713b Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 25 Jul 2023 16:22:51 +0200 Subject: [PATCH 57/58] readme typo fix --- examples/SIGN/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/SIGN/README.md b/examples/SIGN/README.md index 2306a27..96be547 100644 --- a/examples/SIGN/README.md +++ b/examples/SIGN/README.md @@ -14,5 +14,5 @@ Obtained results for two partitions: ``` python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 0 --world-size 2 -python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 0 --world-size 2 +python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 1 --world-size 2 ``` \ No newline at end of file From a04953bf3f0e803239ec70e8ca948d45994ee174 Mon Sep 17 00:00:00 2001 From: "Pietkun, Kacper" Date: Tue, 25 Jul 2023 16:23:38 +0200 Subject: [PATCH 58/58] Using gather grads instead of DistributedDataParallel --- examples/SIGN/train_sign_with_sar.py | 57 ++++++++++++---------------- examples/partition_graph.py | 1 + 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/examples/SIGN/train_sign_with_sar.py b/examples/SIGN/train_sign_with_sar.py index c8870dd..a65383d 100644 --- a/examples/SIGN/train_sign_with_sar.py +++ b/examples/SIGN/train_sign_with_sar.py @@ -27,7 +27,7 @@ def load_dataset(filename, rank, device): as_tuple=False).view(-1).to(device) print(partition_data.node_features.keys()) - label_name, feature_name = ('feat', 'label') if 'reddit' in filename \ + feature_name, label_name = ('feat', 'label') if 'reddit' in filename \ else ('features', 'labels') labels = sar.suffix_key_lookup(partition_data.node_features, label_name).long().to(device) @@ -224,14 +224,7 @@ def main(args): args.R, args.ff_layer, args.dropout, - ) - model = model.to(device) - if args.gpu == -1: - model = torch.nn.parallel.DistributedDataParallel(model) - else: - model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[device], output_device=device - ) + ).to(device) loss_fcn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay @@ -243,29 +236,29 @@ def main(args): best_test = 0 for epoch in range(1, args.num_epochs + 1): - with model.join(): - start = time.time() - model.train() - loss = loss_fcn(model(train_feats), train_labels) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if epoch % args.eval_every == 0: - model.eval() - acc = evaluate( - args, model, feats, labels, train_nid, val_nid, test_nid - ) - end = time.time() - log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start) - log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}".format( - *acc - ) - print(log) - if acc[1] > best_val: - best_val = acc[1] - best_epoch = epoch - best_test = acc[2] + start = time.time() + model.train() + loss = loss_fcn(model(train_feats), train_labels) + optimizer.zero_grad() + loss.backward() + sar.gather_grads(model) + optimizer.step() + + if epoch % args.eval_every == 0: + model.eval() + acc = evaluate( + args, model, feats, labels, train_nid, val_nid, test_nid + ) + end = time.time() + log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start) + log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}".format( + *acc + ) + print(log) + if acc[1] > best_val: + best_val = acc[1] + best_epoch = epoch + best_test = acc[2] print( "Best Epoch {}, Val {:.4f}, Test {:.4f}".format( diff --git a/examples/partition_graph.py b/examples/partition_graph.py index 1697c29..32e0f01 100644 --- a/examples/partition_graph.py +++ b/examples/partition_graph.py @@ -26,6 +26,7 @@ CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset, + RedditDataset ) SUPPORTED_DATASETS = {