Skip to content

Commit

Permalink
Merge pull request IntelLabs#3 from Kacper-Pietkun/partition_data_man…
Browse files Browse the repository at this point in the history
…ager_fix

Partition Data Manager fix
  • Loading branch information
seanmcpherson authored Aug 22, 2023
2 parents 708612f + 4b39994 commit 355e27b
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 153 deletions.
145 changes: 7 additions & 138 deletions examples/train_homogeneous_graph_basic_single-node.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,116 +105,6 @@ def forward(self, graph: sar.GraphShardManager, features: torch.Tensor):

return features

class PartitionDataManager:
def __init__(self, idx, folder_name_prefix="partition_rank_"):
self.idx = idx
self.folder_name = f"{folder_name_prefix}{idx}"
self._features = None
self._masks = None
self._labels = None
self._partition_data = None

@property
def partition_data(self):
if self._partition_data is None:
raise ValueError("partition_data not set")
return self._partition_data

@partition_data.setter
def partition_data(self, data):
self._partition_data = data

@partition_data.deleter
def partition_data(self):
del self._partition_data

@property
def features(self):
#if self._features is None:
# raise ValueError("features not set")
return self._features

@features.setter
def features(self, feats):
self._features = feats

@features.deleter
def features(self):
del self._features
self._features = None

@property
def labels(self):
if self._labels is None:
raise ValueError("labels not set")
return self._labels

@labels.setter
def labels(self, labels):
self._labels = labels

@property
def masks(self):
if self._masks is None:
raise ValueError("masks not set")
return self._masks

@masks.setter
def masks(self, masks):
self._masks = masks

def save(self):
if not os.path.exists(self.folder_name):
os.makedirs(self.folder_name)
#if self._features is not None:
# torch.save(self._features, os.path.join(self.folder_name, "features.pt"))
if self._masks is not None:
torch.save(self._masks, os.path.join(self.folder_name, "masks.pt"))
if self._labels is not None:
torch.save(self._labels, os.path.join(self.folder_name, "labels.pt"))
if self._partition_data is not None:
torch.save(self._partition_data, os.path.join(self.folder_name, "partition_data.pt"))

def save_tensor(self, tensor, tensor_name):
if not os.path.exists(self.folder_name):
os.makedirs(self.folder_name)
if tensor is not None:
torch.save(tensor, os.path.join(self.folder_name, tensor_name + ".pt"))

def delete(self):
#del self._features
del self._masks
del self._labels
del self._partition_data

def load(self):
if not os.path.exists(self.folder_name):
raise FileNotFoundError("No partition data saved")
#if os.path.exists(os.path.join(self.folder_name, "features.pt")):
# self._features = torch.load(os.path.join(self.folder_name, "features.pt"))
#else:
# print("features not loaded, no file saved")
if os.path.exists(os.path.join(self.folder_name, "masks.pt")):
self._masks = torch.load(os.path.join(self.folder_name, "masks.pt"))
else:
print("masks not loaded, no file saved")
if os.path.exists(os.path.join(self.folder_name, "labels.pt")):
self._labels = torch.load(os.path.join(self.folder_name, "labels.pt"))
else:
print("labels not loaded, no file saved")
if os.path.exists(os.path.join(self.folder_name, "partition_data.pt")):
self._partition_data = torch.load(os.path.join(self.folder_name, "partition_data.pt"))
else:
print("partition_data not loaded, no file saved")

def load_tensor(self, tensor_name):
if not os.path.exists(self.folder_name):
raise FileNotFoundError("No partition data saved")
if os.path.exists(os.path.join(self.folder_name, tensor_name + ".pt")):
return torch.load(os.path.join(self.folder_name, tensor_name + ".pt"))
else:
return None


def main():
args = parser.parse_args()
Expand Down Expand Up @@ -251,8 +141,11 @@ def run(args, rank, lock, barrier):
print("Node {} Lock Acquired".format(rank))

sar.start_comm_thread()

# Instantiate PartitionDataManager for managing data saving and loading from disk
partition_data_manager = sar.PartitionDataManager(rank, lock)

# Load DGL partition data
partition_data_manager = PartitionDataManager(rank)
partition_data_manager.partition_data = sar.load_dgl_partition_data(
args.partitioning_json_file, rank, device)

Expand All @@ -275,49 +168,27 @@ def run(args, rank, lock, barrier):

# Obtain the number of classes by finding the max label across all workers
num_labels = partition_data_manager.labels.max() + 1

def precall_func():
partition_data_manager.save()
if isinstance(partition_data_manager.features, torch.Tensor):
partition_data_manager.save_tensor(partition_data_manager.features, 'features')
del partition_data_manager.features
partition_data_manager.delete()
lock.release()
print("Node {} Lock Released".format(rank))

def callback_func(handle):
handle.wait()
lock.acquire()
print("Node {} Lock Acquired".format(rank))
partition_data_manager.load()
features = partition_data_manager.load_tensor('features')
if features is not None:
partition_data_manager.features = features


sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True,
precall_func=precall_func, callback_func=callback_func)
precall_func=partition_data_manager.precall_func, callback_func=partition_data_manager.callback_func)
print("Node {} Num Labels {}".format(rank, num_labels))

num_labels = num_labels.item()

partition_data_manager.features = sar.suffix_key_lookup(partition_data_manager.partition_data.node_features, 'features').to(device)


gnn_model = GNNModel(partition_data_manager.features.size(1),
args.hidden_layer_dim,
num_labels).to(device)
print('model', gnn_model)

# Synchronize the model parmeters across all workers
sar.sync_params(gnn_model, precall_func=precall_func, callback_func=callback_func)
sar.sync_params(gnn_model, precall_func=partition_data_manager.precall_func, callback_func=partition_data_manager.callback_func)

# Obtain the number of labeled nodes in the training
# This will be needed to properly obtain a cross entropy loss
# normalized by the number of training examples
n_train_points = torch.LongTensor([partition_data_manager.masks['train_indices'].numel()])
sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True,
precall_func=precall_func, callback_func=callback_func)
precall_func=partition_data_manager.precall_func, callback_func=partition_data_manager.callback_func)
n_train_points = n_train_points.item()

full_graph_manager = sar.construct_full_graph(partition_data_manager.partition_data,
Expand All @@ -327,8 +198,6 @@ def callback_func(handle):
del partition_data_manager.partition_data
partition_data_manager.partition_data = None

#full_graph_manager.partition_data_manager = partition_data_manager

optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr)
for train_iter_idx in range(args.train_iters):
logger.debug(f'{rank} : starting training iteration {train_iter_idx}')
Expand Down
2 changes: 1 addition & 1 deletion sar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .core import GraphShardManager, message_has_parameters, DistributedBlock,\
DistNeighborSampler, DataLoader, start_comm_thread
from .construct_shard_manager import construct_mfgs, construct_full_graph, convert_dist_graph
from .data_loading import load_dgl_partition_data, suffix_key_lookup
from .data_loading import load_dgl_partition_data, suffix_key_lookup, PartitionDataManager
from .distributed_bn import DistributedBN1D
from .config import Config
from .edge_softmax import edge_softmax
Expand Down
22 changes: 8 additions & 14 deletions sar/core/graphshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,27 +350,21 @@ def _pause_process(self):
self.metric_dict['pre-pause'].append(p.memory_full_info().uss)
self.partition_data_manager.save()
self.partition_data_manager.delete()
try:
#TODO fix this so the first exception doesn't stop saving everything else.
#only delete if successfully saved
self.partition_data_manager.save_tensor(self.dstdata, 'dstdata')

if hasattr(self, 'dstdata') and self.partition_data_manager.save_tensor(self.dstdata, 'dstdata'):
del self.dstdata
self.partition_data_manager.save_tensor(self.srcdata, 'srcdata')
if hasattr(self, 'srcdata') and self.partition_data_manager.save_tensor(self.srcdata, 'srcdata'):
del self.srcdata
self.partition_data_manager.save_tensor(self.edata, 'edata')
if hasattr(self, 'edata') and self.partition_data_manager.save_tensor(self.edata, 'edata'):
del self.edata
self.partition_data_manager.save_tensor(self.graph_shards, 'graph_shards')
if hasattr(self, 'graph_shards') and self.partition_data_manager.save_tensor(self.graph_shards, 'graph_shards'):
del self.graph_shards
self.partition_data_manager.save_tensor(self.input_nodes, 'input_nodes')
if hasattr(self, 'input_nodes') and self.partition_data_manager.save_tensor(self.input_nodes, 'input_nodes'):
del self.input_nodes
self.partition_data_manager.save_tensor(self.seeds, 'seeds')
if hasattr(self, 'seeds') and self.partition_data_manager.save_tensor(self.seeds, 'seeds'):
del self.seeds
self.partition_data_manager.save_tensor(self.indices_required_from_me, 'indicies_required_from_me')
if hasattr(self, 'indices_required_from_me') and self.partition_data_manager.save_tensor(self.indices_required_from_me, 'indicies_required_from_me'):
del self.indices_required_from_me
except Exception as e:
logger.debug("_pause_process Exception: {}".format(e))



for idx, tens in enumerate(self.pointer_list):
#if idx >=4:
Expand Down
146 changes: 146 additions & 0 deletions sar/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,158 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import os
from typing import List, Tuple, Dict, Optional
import torch
from torch import Tensor
import dgl # type: ignore
from dgl.distributed.partition import load_partition # type: ignore
from .common_tuples import PartitionData, ShardEdgesAndFeatures
import logging


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
logger.setLevel(logging.DEBUG)


class PartitionDataManager:
"""
Manages loading and saving data to disk, when SAR is running in single-node mode.
:param idx: rank of the worker
:type idx: int
:param lock: lock is used to synchronize processes, so that we can be sure only one partition is\
loaded at the same time
:type lock: Lock
:param folder_name_prefix: prefix of the name of the folder where the data will be saved on the disk
:type folder_name_prefix: string
"""
def __init__(self, idx, lock, folder_name_prefix="partition_rank_"):
self.idx = idx
self.lock = lock
self.folder_name = f"{folder_name_prefix}{idx}"
self._features = None
self._masks = None
self._labels = None
self._partition_data = None

@property
def partition_data(self):
if self._partition_data is None:
raise ValueError("partition_data not set")
return self._partition_data

@partition_data.setter
def partition_data(self, data):
self._partition_data = data

@partition_data.deleter
def partition_data(self):
del self._partition_data

@property
def features(self):
# if self._features is None:
# raise ValueError("features not set")
return self._features

@features.setter
def features(self, feats):
self._features = feats

@features.deleter
def features(self):
del self._features
self._features = None

@property
def labels(self):
if self._labels is None:
raise ValueError("labels not set")
return self._labels

@labels.setter
def labels(self, labels):
self._labels = labels

@property
def masks(self):
if self._masks is None:
raise ValueError("masks not set")
return self._masks

@masks.setter
def masks(self, masks):
self._masks = masks

def precall_func(self):
self.save()
if isinstance(self.features, torch.Tensor):
self.save_tensor(self.features, 'features')
del self.features
self.delete()
self.lock.release()
print("Node {} Lock Released".format(self.idx))

def callback_func(self, handle):
handle.wait()
self.lock.acquire()
print("Node {} Lock Acquired".format(self.idx))
self.load()
features = self.load_tensor('features')
if features is not None:
self.features = features

def save(self):
try:
if not os.path.exists(self.folder_name):
os.makedirs(self.folder_name)
if self._masks is not None:
torch.save(self._masks, os.path.join(self.folder_name, "masks.pt"))
if self._labels is not None:
torch.save(self._labels, os.path.join(self.folder_name, "labels.pt"))
if self._partition_data is not None:
torch.save(self._partition_data, os.path.join(self.folder_name, "partition_data.pt"))
except Exception as e:
logger.debug("_pause_process Exception: {}".format(e))
return False
return True

def save_tensor(self, tensor, tensor_name):
if not os.path.exists(self.folder_name):
os.makedirs(self.folder_name)
if tensor is not None:
torch.save(tensor, os.path.join(self.folder_name, tensor_name + ".pt"))

def delete(self):
del self._masks
del self._labels
del self._partition_data

def load(self):
if not os.path.exists(self.folder_name):
raise FileNotFoundError("No partition data saved")
if os.path.exists(os.path.join(self.folder_name, "masks.pt")):
self._masks = torch.load(os.path.join(self.folder_name, "masks.pt"))
else:
print("masks not loaded, no file saved")
if os.path.exists(os.path.join(self.folder_name, "labels.pt")):
self._labels = torch.load(os.path.join(self.folder_name, "labels.pt"))
else:
print("labels not loaded, no file saved")
if os.path.exists(os.path.join(self.folder_name, "partition_data.pt")):
self._partition_data = torch.load(os.path.join(self.folder_name, "partition_data.pt"))
else:
print("partition_data not loaded, no file saved")

def load_tensor(self, tensor_name):
if not os.path.exists(self.folder_name):
raise FileNotFoundError("No partition data saved")
if os.path.exists(os.path.join(self.folder_name, tensor_name + ".pt")):
return torch.load(os.path.join(self.folder_name, tensor_name + ".pt"))
else:
return None


def suffix_key_lookup(feature_dict: Dict[str, Tensor], key: str,
Expand Down

0 comments on commit 355e27b

Please sign in to comment.