Skip to content

Commit

Permalink
Merge branch 'pr-staging-branch' into single-node-merge-branches
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmcpherson authored Aug 8, 2023
2 parents 503579f + a190ce4 commit 5716356
Show file tree
Hide file tree
Showing 37 changed files with 2,664 additions and 345 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/sar_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: SAR tests

on:
pull_request:
branches: [main]
workflow_dispatch:

jobs:
sar_tests:
runs-on: ubuntu-latest
steps:
- name: Pull SAR
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Install requirements
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu
python setup.py install
- name: Run pytest
run: |
set +e
python -m pytest tests/ -sv
55 changes: 55 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Python cache
__pycache__/
*.pyc

# Jupyter notebook checkpoints
.ipynb_checkpoints/

# Compiled Python files
*.pyc
*.pyo
*.pyd
__pycache__/

# Build directories
build/
dist/
*.egg-info/


# Package distribution
*.egg
*.egg-info

# IDE and editor files
.vscode/
.idea/
*.iml
*.iws
*.ipr

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Datasets and partitions
dataset/
datasets/
partition_data/
12 changes: 7 additions & 5 deletions docs/source/comm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

SAR's communication routines
=============================
SAR uses only two types of collective communication calls: ``all_to_all`` and ``all_reduce``. This choice was made to improve scalability by avoiding any point-to-point communication. Currently, the only backends in `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ that support ``all_to_all`` are ``nccl``, ``ccl``, or ``mpi``. Nvidia's ``nccl`` is already included in the PyTorch distribution and it is the natural choice when training on GPUs.
SAR uses only two types of collective communication calls: ``all_to_all`` and ``all_reduce``. This choice was made to improve scalability by avoiding any point-to-point communication. SAR supports four backends, which are ``ccl``, ``nccl``, ``mpi`` and ``gloo``. (Note: Using ``gloo`` backend may not be as optimal as using other backends, because it doesn't support ``all_to_all`` routine - SAR must use its own implementation, which uses multiple asynchronous sends (torch.dist.isend) between workers). Nvidia's ``nccl`` is already included in the PyTorch distribution and it is the natural choice when training on GPUs.

The ``ccl`` backend uses `Intel's OneCCL <https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html>`_ library. You can install the PyTorch bindings for OneCCL `here <https://github.com/intel/torch-ccl>`_ . ``ccl`` is the preferred backend when training on CPUs.

Expand All @@ -16,14 +16,16 @@ In an environment with a networked file system, initializing ``torch.distributed
comm_device = torch.device('cuda')
else:
comm_device = torch.device('cpu')
master_ip_address = sar.nfs_ip_init(rank,path_to_ip_file)
sar.initialize_comms(rank,world_size, master_ip_address,backend_name,comm_device)

..
master_ip_address = sar.nfs_ip_init(rank, path_to_ip_file)
sar.initialize_comms(rank, world_size, master_ip_address, backend_name, comm_device)

..
:func:`sar.initialize_comms` tries to initialize the torch.distributed process group, but only if it has not been initialized. User can initialize process group on his own before calling :func:`sar.initialize_comms`.
:func:`sar.nfs_ip_init` communicates the master's ip address to the workers through the file system. In the absence of a networked file system, you should develop your own mechanism to communicate the master's ip address.

You can specify the name of the socket that will be used for communication with `SAR_SOCKET_NAME` environment variable (if not specified, the first available socket will be selected).



Relevant methods
Expand Down
61 changes: 49 additions & 12 deletions docs/source/data_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

Data loading and graph construction
==========================================================
After partitioning the graph using DGL's `partition_graph <https://docs.dgl.ai/en/0.6.x/generated/dgl.distributed.partition.partition_graph.html>`_ function, SAR can load the graph data using :func:`sar.load_dgl_partition_data`. This yields a :class:`sar.common_tuples.PartitionData` object. The ``PartitionData`` object can then be used to construct various types of graph-like objects that can be passed to GNN models. You can construct graph objects to use for distributed full-batch training or graph objects to use for distributed training as follows:
After partitioning the graph using DGL's `partition_graph <https://docs.dgl.ai/en/0.6.x/generated/dgl.distributed.partition.partition_graph.html>`_ function, SAR can load the graph data using :func:`sar.load_dgl_partition_data`. This yields a :class:`sar.common_tuples.PartitionData` object. The ``PartitionData`` object can then be used to construct various types of graph-like objects that can be passed to GNN models. You can construct graph objects to use for distributed full-batch training or graph objects to use for distributed training as follows:

.. contents:: :local:
:depth: 3


Full-batch training
---------------------------------------------------------------------------------------

Constructing the full graph for sequential aggregation and rematerialization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Construct a single distributed graph object of type :class:`sar.core.GraphShardManager`::
Expand All @@ -20,7 +20,7 @@ Construct a single distributed graph object of type :class:`sar.core.GraphShardM

..
The ``GraphShardManager`` object encapsulates N DGL graph objects (where N is the number of workers). Each graph object represents the edges incoming from one partition (including the local partition). ``GraphShardManager`` implements the ``update_all`` and ``apply_edges`` methods in addition to several other methods from the standard ``dgl.heterograph.DGLHeterograph`` API. The ``update_all`` and ``apply_edges`` methods implement the sequential aggregation and rematerialization scheme to realize the distributed forward and backward passes. ``GraphShardManager`` can usually be passed to GNN layers instead of ``dgl.heterograph.DGLHeterograph``. See the :ref:`the distributed graph limitations section<shard-limitations>` for some exceptions.
The ``GraphShardManager`` object encapsulates N DGL graph objects (where N is the number of workers). Each graph object represents the edges incoming from one partition (including the local partition). ``GraphShardManager`` implements the ``update_all`` and ``apply_edges`` methods in addition to several other methods from the standard ``dgl.heterograph.DGLGraph`` API. The ``update_all`` and ``apply_edges`` methods implement the sequential aggregation and rematerialization scheme to realize the distributed forward and backward passes. ``GraphShardManager`` can usually be passed to GNN layers instead of ``dgl.heterograph.DGLGraph``. See the :ref:`the distributed graph limitations section<shard-limitations>` for some exceptions.

Constructing Message Flow Graphs (MFGs) for sequential aggregation and rematerialization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -86,7 +86,7 @@ For sampling-based training, use the dataloader provided by SAR: :func:`sar.Data

::

shard_manager = sar.construct_full_graph(partition_data)
shard_manager = sar.construct_full_graph(partition_data)

neighbor_sampler = sar.DistNeighborSampler(
[15, 10, 5], #Fanout for every layer
Expand All @@ -103,11 +103,48 @@ For sampling-based training, use the dataloader provided by SAR: :func:`sar.Data
for blocks in dataloader:
output = gnn_model(blocks)
...
..

..

Full-graph inference
---------------------------------------------------------------------------------------
SAR might also be utilized just for model evaluation. It is preferable to evaluate the model on the entire graph while performing mini-batch distributed training with the DGL package. To accomplish this, SAR can turn a `DistGraph <https://docs.dgl.ai/api/python/dgl.distributed.html#dgl.distributed.DistGraph>`_ object into a GraphShardManager object, allowing for distributed full-graph inference. The procedure is simple since no further steps are required because the model parameters are already synchronized during inference. You can use :func:`sar.convert_dist_graph` in the following way to perform full-graph inference:
::

class GNNModel(nn.Module):
def __init__(n_layers: int):
super().__init__()
self.convs = nn.ModuleList([
dgl.nn.SAGEConv(100, 100)
for _ in range(n_layers)
])

# forward function prepared for mini-batch training
def forward(blocks: List[DGLBlock], features: torch.Tensor):
h = features
for idx, (layer, block) in enumerate(zip(self.convs, blocks)):
h = self.convs[idx](blocks[idx], h)
return h
# implement inference function for full-graph input
def full_graph_inference(graph: sar.GraphShardManager, featues: torch.Tensor):
h = features
for idx, layer in enumerate(self.convs):
h = layer(graph, h)
return h

# model wrapped in pytorch DistributedDataParallel
gnn_model = th.nn.parallel.DistributedDataParallel(GNNModel(3))

# Convert DistGraph into GraphShardManager
gsm = sar.convert_dist_graph(g).to(device)

# Access to model through DistributedDataParallel module field
model_out = gnn_model.module.full_graph_inference(gsm, local_node_features)
..

Relevant methods
---------------------------------------------------------------------------------------

Expand All @@ -117,11 +154,11 @@ Relevant methods
.. autosummary::
:toctree: Data loading and graph construction
:template: distneighborsampler


load_dgl_partition_data

load_dgl_partition_data
construct_full_graph
construct_mfgs
construct_mfgs
convert_dist_graph
DataLoader
DistNeighborSampler

12 changes: 6 additions & 6 deletions docs/source/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Follow the following steps to enable distributed training in your DGL code:

Partition the graph
----------------------------------
Partition the graph using DGL's `partition_graph <https://docs.dgl.ai/en/0.6.x/generated/dgl.distributed.partition.partition_graph.html>`_ function. See `here <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/dist/partition_graph.py>`_ for an example. The number of partitions should be the same as the number of training machines/workers that will be used. SAR requires consecutive node indices in each partition, and requires that the partition information include the one-hop neighborhoods of all nodes in the partition. Setting ``num_hops = 1`` and ``reshuffle = True`` in the call to ``partition_graph`` takes care of these requirements. ``partition_graph`` yields a directory structure with the partition information and a .json file ``graph_name.json``.
Partition the graph using DGL's `partition_graph <https://docs.dgl.ai/en/0.6.x/generated/dgl.distributed.partition.partition_graph.html>`_ function. See `here <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/dist/partition_graph.py>`_ for an example. The number of partitions should be the same as the number of training machines/workers that will be used. SAR requires consecutive node indices in each partition, and requires that the partition information include the one-hop neighborhoods of all nodes in the partition. Setting ``num_hops = 1`` and ``reshuffle = True`` (in DGL < 1.0) in the call to ``partition_graph`` takes care of these requirements. ``partition_graph`` yields a directory structure with the partition information and a .json file ``graph_name.json``.


An example of partitioning the ogbn-arxiv graph in two parts: ::
Expand Down Expand Up @@ -44,8 +44,8 @@ An example of partitioning the ogbn-arxiv graph in two parts: ::
graph.ndata[name] = val

dgl.distributed.partition_graph(
graph, 'arxiv', 2, './test_partition_data/', num_hops=1, reshuffle=True)
graph, 'arxiv', 2, './test_partition_data/', num_hops=1) # use reshuffle=True in DGL < 1.0

..
Note that we add the labels, and the train/test/validation masks as node features so that they get split into multiple parts alongside the graph.
Expand All @@ -61,12 +61,12 @@ Initialize the communication through a call to :func:`sar.initialize_comms` , sp
comm_device = torch.device('cuda')
else:
comm_device = torch.device('cpu')
master_ip_address = sar.nfs_ip_init(rank,path_to_ip_file)
sar.initialize_comms(rank,world_size, master_ip_address,backend_name,comm_device)
master_ip_address = sar.nfs_ip_init(rank, path_to_ip_file)
sar.initialize_comms(rank, world_size, master_ip_address, backend_name, comm_device)
..
``backend_name`` can be ``nccl``, ``ccl``, or ``mpi``.
``backend_name`` can be ``ccl``, ``nccl``, ``mpi`` or ``gloo``.



Expand Down
2 changes: 1 addition & 1 deletion docs/source/shards.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ In the distributed implementation of the sequential backward pass in ``update_a

Limitations of the distributed graph objects
------------------------------------------------------------------------------------
Keep in mind that the distributed graph class :class:`sar.core.GraphShardManager` does not implement all the functionality of DGL's native graph class. For example, it does not impelement the ``successors`` and ``predecessors`` methods. It supports primarily the methods of DGL's native graphs that are relevant to GNNs such as ``update_all``, ``apply_edges``, and ``local_scope``. It also supports setting graph node and edge features through the dictionaries ``srcdata``, ``dstdata``, and ``edata``. Note that :class:`sar.core.GraphShardManager` does not support the ``ndata`` member dictionary.
Keep in mind that the distributed graph class :class:`sar.core.GraphShardManager` does not implement all the functionality of DGL's native graph class. For example, it does not impelement the ``successors`` and ``predecessors`` methods. It supports primarily the methods of DGL's native graphs that are relevant to GNNs such as ``update_all``, ``apply_edges``, and ``local_scope``. It also supports setting graph node and edge features through the dictionaries ``srcdata``, ``dstdata``, and ``edata``. To remain compatible with DGLGraph :class:`sar.core.GraphShardManager` provides also access to the ``ndata`` member, which works as alias to ``srcdata``, however it is not accessible when working with MFGs.

:class:`sar.core.GraphShardManager` also supports the ``in_degrees`` and ``out_degrees`` members and supports querying the number of nodes and edges in the graph.

Expand Down
47 changes: 45 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
## Graph partitioning

``partition_arxiv_products.py`` partitions the ogbn-arxiv and ogbn-products graphs from the [Open Graph Benchmarks](https://ogb.stanford.edu/) using DGL's metis-based partitioning. The general technique there can be used to partition arbitrary homogeneous graphs. Note that all node-related information must be included in the graph's ``ndata`` dictionary so that they are correctly partitioned with the graph. Similarly, edge-related information must be included in the graph's ``edata`` dictionary
The ``partition_graph.py`` script can be used to partition both homogeneous and heterogeneous graphs. It utilizes DGL's metis-based partitioning algorithm to divide the graphs into smaller partitions. Note that all node-related information must be included in the graph's ``ndata`` dictionary so that they are correctly partitioned with the graph.
Similarly, edge-related information must be included in the graph's ``edata`` dictionary

``partition_mag.py`` partitions the [ogbn-mag](https://ogb.stanford.edu/docs/nodeprop/#ogbn-mag) heterogeneous graph. Again, all node-related information are included in the graph's ``ndata`` for the relevant node types
### Supported datasets:
- ogbn-products, ogbn-arxiv, ogb-mag from [Open Graph Benchmarks](https://ogb.stanford.edu/)
- cora, citeseer, pubmed

## Full-batch Training

Expand All @@ -29,3 +32,43 @@ python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/pa
python3 train_homogeneous_sampling_basic.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2

```

## Distributed Mini-Batch Training with Full-Graph inference
The script ``train_distdgl_with_sar_inference.py`` showcases how SAR can be effectively combined with native DGL distributed training. In this particular example, the training process utilizes a sampling approach, while the evaluation phase leverages the SAR library to perform computations on the entire graph.
```shell
python /home/ubuntu/workspace/dgl/tools/launch.py \
--workspace /home/ubuntu/workspace/SAR/examples \
--num_trainers 1 \
--num_samplers 2 \
--num_servers 1 \
--part_config partition_data/ogbn-products.json \
--ip_config ip_config.txt \
"/home/ubuntu/miniconda3/bin/python train_distdgl_with_sar_inference.py --graph_name ogbn-products --ip_config ip_config.txt --num_epochs 2 --batch_size 1000 --part_config partition_data/ogbn-products.json"
```

## Correct and Smooth
Example taken from [DGL implemenetation](https://github.com/dmlc/dgl/tree/master/examples/pytorch/correct_and_smooth) of C&S. Code is adjusted to perform distributed training with SAR. Introduced modifications change the way data normalization is performed - workers need to communicate with each other to calculate mean and standard deviation for the entire dataset (not just their partition). Moreover, workers need to be synchronized with each other to calculate sigma value required during "correct" phase.

For instance, you can run the example with following commands (2 machines scenario):

* **Plain MLP + C&S**
* Rank 0 machine:
```shell
python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale
```

* Rank 1 machine:
```shell
python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 --dropout 0.5 --correction-adj DA --smoothing-adj AD --autoscale
```

* **Plain Linear + C&S**
* Rank 0 machine:
```shell
python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 0 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale
```

* Rank 1 machine:
```shell
python correct_and_smooth.py --partitioning-json-file /path/to/partitioning/graph_name.json --ip-file /path/to/ip_file --rank 1 --world-size 2 --model linear --dropout 0.5 --epochs 1000 --correction-alpha 0.87 --smoothing-alpha 0.81 --correction-adj AD --autoscale
```
18 changes: 18 additions & 0 deletions examples/SIGN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## SIGN: Scalable Inception Graph Neural Networks

Original script: https://github.com/dmlc/dgl/tree/master/examples/pytorch/sign

Provided `train_sign_with_sar.py` script is an example how to intergrate SAR to preprocess graph data for training.

### Results
Obtained results for two partitions:
- ogbn-products: 0.7832
- reddit: 0.9639

### Run command:

```
python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 0 --world-size 2
python train_sign_with_sar.py --partitioning-json-file partition_data/reddit.json --ip-file ip_file --backend ccl --rank 1 --world-size 2
```
Loading

0 comments on commit 5716356

Please sign in to comment.