From ebcb5993d546f60ea82574b231895d4ada51b963 Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Tue, 27 Jun 2023 16:16:28 -0700 Subject: [PATCH 1/9] preliminary working version of multiproc single-node --- examples/train_homogeneous_graph_basic.py | 232 +++++++++++++++++++--- sar/__init__.py | 2 +- sar/comm.py | 88 +++++--- sar/construct_shard_manager.py | 4 +- sar/core/__init__.py | 2 +- sar/core/graphshard.py | 35 +++- sar/core/sar_aggregation.py | 46 +++-- 7 files changed, 334 insertions(+), 75 deletions(-) diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic.py index 8f39f25..2dd7ccc 100644 --- a/examples/train_homogeneous_graph_basic.py +++ b/examples/train_homogeneous_graph_basic.py @@ -56,6 +56,9 @@ help="Run on CPUs if set, otherwise run on GPUs " ) +parser.add_argument('--log-level', default='INFO', type=str, + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='SAR log level ') parser.add_argument('--train-iters', default=100, type=int, help='number of training iterations ') @@ -99,98 +102,271 @@ def forward(self, graph: sar.GraphShardManager, features: torch.Tensor): return features +class PartitionDataManager: + def __init__(self, idx, folder_name_prefix="partition_rank_"): + self.idx = idx + self.folder_name = f"{folder_name_prefix}{idx}" + self._features = None + self._masks = None + self._labels = None + self._partition_data = None + + @property + def partition_data(self): + if self._partition_data is None: + raise ValueError("partition_data not set") + return self._partition_data + + @partition_data.setter + def partition_data(self, data): + self._partition_data = data + + @partition_data.deleter + def partition_data(self): + del self._partition_data + + @property + def features(self): + if self._features is None: + raise ValueError("features not set") + return self._features + + @features.setter + def features(self, feats): + self._features = feats + + @property + def labels(self): + if self._labels is None: + raise ValueError("labels not set") + return self._labels + + @labels.setter + def labels(self, labels): + self._labels = labels + + @property + def masks(self): + if self._masks is None: + raise ValueError("masks not set") + return self._masks + + @masks.setter + def masks(self, masks): + self._masks = masks + + def save(self): + if not os.path.exists(self.folder_name): + os.makedirs(self.folder_name) + if not self._features is None: + torch.save(self._features, os.path.join(self.folder_name, "features.pt")) + if not self._masks is None: + torch.save(self._masks, os.path.join(self.folder_name, "masks.pt")) + if not self._labels is None: + torch.save(self._labels, os.path.join(self.folder_name, "labels.pt")) + if not self._partition_data is None: + torch.save(self._partition_data, os.path.join(self.folder_name, "partition_data.pt")) + + def delete(self): + del self._features + del self._masks + del self._labels + del self._partition_data + + def load(self): + if not os.path.exists(self.folder_name): + raise FileNotFoundError("No partition data saved") + return + if os.path.exists(os.path.join(self.folder_name, "features.pt")): + self._features = torch.load(os.path.join(self.folder_name, "features.pt")) + else: + print("features not loaded, no file saved") + if os.path.exists(os.path.join(self.folder_name, "masks.pt")): + self._masks = torch.load(os.path.join(self.folder_name, "masks.pt")) + else: + print("masks not loaded, no file saved") + if os.path.exists(os.path.join(self.folder_name, "labels.pt")): + self._labels = torch.load(os.path.join(self.folder_name, "labels.pt")) + else: + print("labels not loaded, no file saved") + if os.path.exists(os.path.join(self.folder_name, "partition_data.pt")): + self._partition_data = torch.load(os.path.join(self.folder_name, "partition_data.pt")) + else: + print("partition_data not loaded, no file saved") + +class PartitionThreadManager: + + def __init__(self, num_partitions): + self.num_partitions = num_partitions + + def pause_thread(self): + ''' + for pointer in self.pointer_list: + torch.save(pointer, p_name) + pointer.resize_(0) + + graph_shard_manager.save() + barrier.wait() + self.resume_thread() + ''' + pass + + def resume_thread(self): + ''' + for pointer in self.pointer_list: + t = torch.load(p_name) + pointer.resize_(t.size()) + + graph_shard_manager.load() + ''' + pass + def main(): args = parser.parse_args() + + from multiprocessing import Process, Lock, Barrier + + lock = Lock() + barrier = Barrier(args.world_size) + + for rank_idx in range(args.world_size): + p = Process(target=run, args=(args,rank_idx, lock, barrier)) + p.start() + +def run(args, rank, lock, barrier): print('args', args) + print('rank', rank) + + #lock.acquire() + #print("Node {} Lock Acquired".format(rank)) use_gpu = torch.cuda.is_available() and not args.cpu_run device = torch.device('cuda' if use_gpu else 'cpu') + sar.logging_setup(logging.getLevelName(args.log_level), + rank, args.world_size) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + logger.setLevel(args.log_level) + # Obtain the ip address of the master through the network file system - master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) - sar.initialize_comms(args.rank, + master_ip_address = sar.nfs_ip_init(rank, args.ip_file) + #lock.release() + #print("Node {} Lock Released".format(rank)) + sar.initialize_comms(rank, args.world_size, master_ip_address, args.backend) + #barrier.wait() + #print("Node {} Barrier Passed".format(rank)) + lock.acquire() + print("Node {} Lock Acquired".format(rank)) + + sar.start_comm_thread() # Load DGL partition data - partition_data = sar.load_dgl_partition_data( - args.partitioning_json_file, args.rank, device) + partition_data_manager = PartitionDataManager(rank) + partition_data_manager.partition_data = sar.load_dgl_partition_data( + args.partitioning_json_file, rank, device) + # Obtain train,validation, and test masks # These are stored as node features. Partitioning may prepend # the node type to the mask names. So we use the convenience function # suffix_key_lookup to look up the mask name while ignoring the # arbitrary node type - masks = {} + partition_data_manager.masks = {} for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'], ['train_indices', 'val_indices', 'test_indices']): - boolean_mask = sar.suffix_key_lookup(partition_data.node_features, + boolean_mask = sar.suffix_key_lookup(partition_data_manager.partition_data.node_features, mask_name) - masks[indices_name] = boolean_mask.nonzero( + partition_data_manager.masks[indices_name] = boolean_mask.nonzero( as_tuple=False).view(-1).to(device) - labels = sar.suffix_key_lookup(partition_data.node_features, + partition_data_manager.labels = sar.suffix_key_lookup(partition_data_manager.partition_data.node_features, 'labels').long().to(device) # Obtain the number of classes by finding the max label across all workers - num_labels = labels.max() + 1 - sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True) - num_labels = num_labels.item() + num_labels = partition_data_manager.labels.max() + 1 - features = sar.suffix_key_lookup(partition_data.node_features, 'features').to(device) - full_graph_manager = sar.construct_full_graph(partition_data).to(device) + def precall_func(): + partition_data_manager.save() + partition_data_manager.delete() + lock.release() + print("Node {} Lock Released".format(rank)) + + def callback_func(handle): + handle.wait() + lock.acquire() + print("Node {} Lock Acquired".format(rank)) + partition_data_manager.load() + + sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True, + precall_func=precall_func, callback_func=callback_func) + print("Node {} Num Labels {}".format(rank, num_labels)) + + num_labels = num_labels.item() + + partition_data_manager.features = sar.suffix_key_lookup(partition_data_manager.partition_data.node_features, 'features').to(device) + full_graph_manager = sar.construct_full_graph(partition_data_manager.partition_data, lock=lock).to(device) + #We do not need the partition data anymore - del partition_data + del partition_data_manager.partition_data + partition_data_manager.partition_data = None - gnn_model = GNNModel(features.size(1), + gnn_model = GNNModel(partition_data_manager.features.size(1), args.hidden_layer_dim, num_labels).to(device) print('model', gnn_model) # Synchronize the model parmeters across all workers - sar.sync_params(gnn_model) + sar.sync_params(gnn_model, precall_func=precall_func, callback_func=callback_func) # Obtain the number of labeled nodes in the training # This will be needed to properly obtain a cross entropy loss # normalized by the number of training examples - n_train_points = torch.LongTensor([masks['train_indices'].numel()]) - sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True) + n_train_points = torch.LongTensor([partition_data_manager.masks['train_indices'].numel()]) + sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True, + precall_func=precall_func, callback_func=callback_func) n_train_points = n_train_points.item() optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr) for train_iter_idx in range(args.train_iters): + logger.debug(f'{rank} : starting training iteration {train_iter_idx}') # Train t_1 = time.time() - logits = gnn_model(full_graph_manager, features) - loss = F.cross_entropy(logits[masks['train_indices']], - labels[masks['train_indices']], reduction='sum')/n_train_points + logits = gnn_model(full_graph_manager, partition_data_manager.features) + logger.debug(f'{rank} : training iteration complete {train_iter_idx}') + loss = F.cross_entropy(logits[partition_data_manager.masks['train_indices']], + partition_data_manager.labels[partition_data_manager.masks['train_indices']], reduction='sum')/n_train_points optimizer.zero_grad() loss.backward() # Do not forget to gather the parameter gradients from all workers - sar.gather_grads(gnn_model) + sar.gather_grads(gnn_model, + precall_func=full_graph_manager.pause_process, callback_func=full_graph_manager.resume_process) optimizer.step() train_time = time.time() - t_1 # Calculate accuracy for train/validation/test results = [] for indices_name in ['train_indices', 'val_indices', 'test_indices']: - n_correct = (logits[masks[indices_name]].argmax(1) == - labels[masks[indices_name]]).float().sum() - results.extend([n_correct, masks[indices_name].numel()]) + n_correct = (logits[partition_data_manager.masks[indices_name]].argmax(1) == + partition_data_manager.labels[partition_data_manager.masks[indices_name]]).float().sum() + results.extend([n_correct, partition_data_manager.masks[indices_name].numel()]) acc_vec = torch.FloatTensor(results) # Sum the n_correct, and number of mask elements across all workers - sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True) + sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True, + precall_func=full_graph_manager.pause_process, callback_func=full_graph_manager.resume_process) (train_acc, val_acc, test_acc) = \ (acc_vec[0] / acc_vec[1], acc_vec[2] / acc_vec[3], acc_vec[4] / acc_vec[5]) result_message = ( - f"iteration [{train_iter_idx}/{args.train_iters}] | " + f"iteration [{train_iter_idx + 1}/{args.train_iters}] | " ) result_message += ', '.join([ f"train loss={loss:.4f}, " @@ -202,7 +378,7 @@ def main(): f" |" ]) print(result_message, flush=True) - + lock.release() if __name__ == '__main__': main() diff --git a/sar/__init__.py b/sar/__init__.py index e519173..c435f4e 100644 --- a/sar/__init__.py +++ b/sar/__init__.py @@ -26,7 +26,7 @@ from .comm import initialize_comms, rank, world_size, comm_device,\ nfs_ip_init, sync_params, gather_grads from .core import GraphShardManager, message_has_parameters, DistributedBlock,\ - DistNeighborSampler, DataLoader + DistNeighborSampler, DataLoader, start_comm_thread from .construct_shard_manager import construct_mfgs, construct_full_graph from .data_loading import load_dgl_partition_data, suffix_key_lookup from .distributed_bn import DistributedBN1D diff --git a/sar/comm.py b/sar/comm.py index 7421353..e6ad39f 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -168,7 +168,8 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, if backend == 'ccl': # pylint: disable=unused-import - import torch_ccl # type: ignore + # import torch_ccl # type: ignore + import oneccl_bindings_for_pytorch os.environ['MASTER_ADDR'] = master_ip_address os.environ['MASTER_PORT'] = str(master_port_number) @@ -194,8 +195,10 @@ def initialize_comms(_rank: int, _world_size: int, master_ip_address: str, os.environ['I_MPI_COMM_WORLD'] = str(_world_size) os.environ['I_MPI_COMM_RANK'] = str(_rank) + print("init_process_group: ", backend, _rank, _world_size) dist.init_process_group( - backend=backend, rank=_rank, world_size=_world_size) + backend=backend, rank=_rank, world_size=_world_size, + init_method='file:///home/nervana/graph_neural_networks/shared_data/data_share') _CommData.rank = _rank _CommData.world_size = _world_size @@ -266,7 +269,7 @@ def comm_device() -> torch.device: def all_to_all(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor], - move_to_comm_device: bool = False) -> None: + move_to_comm_device: bool = False, precall_func = None, callback_func = None) -> None: ''' wrapper around dist.all_to_all ''' @@ -274,21 +277,24 @@ def all_to_all(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor == 0 else x for x in recv_tensors] send_tensors = [x.new(1, *x.size()[1:]) if x.numel() == 0 else x for x in send_tensors] - + if move_to_comm_device: recv_tensors_cd = [recv_tensor.to(comm_device()) for recv_tensor in recv_tensors] send_tensors_cd = [send_tensor.to(comm_device()) for send_tensor in send_tensors] - all_to_all_rounds(recv_tensors_cd, send_tensors_cd) + all_to_all_rounds(recv_tensors_cd, send_tensors_cd, + precall_func = precall_func, callback_func = callback_func) for recv_tensor, recv_tensor_cd in zip(recv_tensors, recv_tensors_cd): recv_tensor.copy_(recv_tensor_cd) else: - all_to_all_rounds(recv_tensors, send_tensors) + all_to_all_rounds(recv_tensors, send_tensors, + precall_func = precall_func, callback_func = callback_func) def all_reduce(red_tensor: torch.Tensor, op: dist.ReduceOp, - move_to_comm_device: bool = False): # pylint: disable=invalid-name + move_to_comm_device: bool = False, precall_func = None, + callback_func = None): # pylint: disable=invalid-name """ wrapper around dist.all_reduce :param red_tensor: reduction tensor @@ -300,26 +306,44 @@ def all_reduce(red_tensor: torch.Tensor, op: dist.ReduceOp, """ - if move_to_comm_device: red_tensor_cd = red_tensor.to(comm_device()) - dist.all_reduce(red_tensor_cd, op) + if precall_func: + precall_func() + handle = dist.all_reduce(red_tensor_cd, op, async_op=True) + if callback_func: + callback_func(handle) + else: + handle.wait() red_tensor.copy_(red_tensor_cd) else: - dist.all_reduce(red_tensor, op) - - -def all_to_all_rounds(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor]): + if precall_func: + precall_func() + handle = dist.all_reduce(red_tensor, op, async_op=True) + if callback_func: + callback_func(handle) + else: + handle.wait() + +def all_to_all_rounds(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor], + precall_func = None, callback_func = None): if Config.max_collective_size == 0: #print('all to all', recv_tensors, send_tensors, flush=True) - dist.all_to_all(recv_tensors, send_tensors) + if precall_func: + precall_func() + handle = dist.all_to_all(recv_tensors, send_tensors, async_op=True) #print('all to all complete', recv_tensors, send_tensors, flush=True) + if callback_func: + callback_func(handle) + else: + handle.wait() else: max_n_elems = Config.max_collective_size total_elems = sum(r_tensor.numel() for r_tensor in recv_tensors) + \ sum(s_tensor.numel() for s_tensor in send_tensors) n_rounds_t = torch.tensor(max(1, total_elems // max_n_elems)) - all_reduce(n_rounds_t, dist.ReduceOp.MAX, move_to_comm_device=True) + all_reduce(n_rounds_t, dist.ReduceOp.MAX, move_to_comm_device=True, + precall_func=precall_func, callback_func=callback_func) n_rounds = int(n_rounds_t.item()) logger.debug(f'all to all using {n_rounds}') for round_idx in range(n_rounds): @@ -327,7 +351,11 @@ def all_to_all_rounds(recv_tensors: List[torch.Tensor], send_tensors: List[torch s_tensor in send_tensors] recv_tensors_slices = [_get_tensor_slice(r_tensor, n_rounds, round_idx) for r_tensor in recv_tensors] - dist.all_to_all(recv_tensors_slices, send_tensors_slices) + handle = dist.all_to_all(recv_tensors_slices, send_tensors_slices, async_op=True) + if callback_func: + callback_func(handle) + else: + handle.wait() def _get_tensor_slice(tens: Tensor, n_splits: int, split_idx: int) -> Tensor: @@ -343,7 +371,8 @@ def _get_tensor_slice(tens: Tensor, n_splits: int, split_idx: int) -> Tensor: def exchange_single_tensor(recv_idx: int, send_idx: int, - recv_tensor: Tensor, send_tensor: Tensor) -> None: + recv_tensor: Tensor, send_tensor: Tensor, + precall_func = None, callback_func = None) -> None: """ Sends send_tensor to worker send_idx and fills recv_tensor with data received from worker recv_idx. @@ -380,7 +409,8 @@ def exchange_single_tensor(recv_idx: int, send_idx: int, recv_tensors_list[recv_idx] = active_recv_tensor send_tensors_list[send_idx] = active_send_tensor - all_to_all(recv_tensors_list, send_tensors_list) + all_to_all(recv_tensors_list, send_tensors_list, + precall_func = precall_func, callback_func = callback_func) if active_recv_tensor is not recv_tensor and recv_tensor.size(0) > 0: recv_tensor.copy_(active_recv_tensor) @@ -389,7 +419,8 @@ def exchange_single_tensor(recv_idx: int, send_idx: int, f'{rank()} : done exchange_single_tensor : {recv_idx}, {send_idx},{recv_tensor.size()},{send_tensor.size()}') -def exchange_tensors(tensors: List[torch.Tensor], recv_sizes: Optional[List[int]] = None) -> List[torch.Tensor]: +def exchange_tensors(tensors: List[torch.Tensor], recv_sizes: Optional[List[int]] = None, + precall_func = None, callback_func = None) -> List[torch.Tensor]: """ tensors is a list of size WORLD_SIZE. tensors[i] is sent to worker i. Returns a list of tensors recv_tensors, where recv_tensors[i] is the tensor received from worker i. Optionally, you can provide recv_sizes to specify the @@ -407,7 +438,6 @@ def exchange_tensors(tensors: List[torch.Tensor], recv_sizes: Optional[List[int] received from worker i. """ - trailing_dimensions = tensors[0].size()[1:] dtype = tensors[0].dtype assert all(x.size()[ @@ -423,7 +453,8 @@ def exchange_tensors(tensors: List[torch.Tensor], recv_sizes: Optional[List[int] all_their_sizes = [torch.Tensor([-1]).long().to( comm_device()) for _ in range(len(tensors))] - all_to_all(all_their_sizes, all_my_sizes) + all_to_all(all_their_sizes, all_my_sizes, + precall_func = precall_func, callback_func = callback_func) #print('all my sizes', all_my_sizes) #print('all their sizes', all_their_sizes) @@ -436,12 +467,13 @@ def exchange_tensors(tensors: List[torch.Tensor], recv_sizes: Optional[List[int] recv_tensors = [torch.empty(x, *trailing_dimensions, dtype=dtype).to(comm_device()).fill_(-1) for x in all_their_sizes_aug] - all_to_all(recv_tensors, tensors_comm_device) + all_to_all(recv_tensors, tensors_comm_device, + precall_func = precall_func, callback_func = callback_func) return [x[:s].to(tensors[0].device) for s, x in zip(all_their_sizes_i, recv_tensors)] -def sync_params(model: torch.nn.Module): +def sync_params(model: torch.nn.Module, precall_func = None, callback_func = None): """Synchronize the model parameters across all workers. The model parameters of worker 0 (the master worker) are copied to all workers @@ -454,10 +486,11 @@ def sync_params(model: torch.nn.Module): for _, s_v in state_dict.items(): if rank() != 0: s_v.data.zero_() - all_reduce(s_v.data, op=dist.ReduceOp.SUM, move_to_comm_device=True) + all_reduce(s_v.data, op=dist.ReduceOp.SUM, move_to_comm_device=True, + precall_func=precall_func, callback_func=callback_func) -def gather_grads(model: torch.nn.Module): +def gather_grads(model: torch.nn.Module, precall_func = None, callback_func = None): """Sum the parameter gradients from all workers. This should be called before optimizer.step @@ -466,11 +499,10 @@ def gather_grads(model: torch.nn.Module): :type model: torch.nn.Module """ - for param in model.parameters(): if param.grad is not None: all_reduce(param.grad, op=dist.ReduceOp.SUM, - move_to_comm_device=True) + move_to_comm_device=True, precall_func = precall_func, callback_func = callback_func) class CommThread: @@ -511,4 +543,4 @@ def _fetch_tasks(self) -> None: self.result_queue.put(result) -comm_thread = CommThread() +#comm_thread = CommThread() diff --git a/sar/construct_shard_manager.py b/sar/construct_shard_manager.py index 03e6acd..d00c445 100644 --- a/sar/construct_shard_manager.py +++ b/sar/construct_shard_manager.py @@ -183,7 +183,7 @@ def construct_mfgs(partition_data: PartitionData, return graph_shard_manager_list[::-1] -def construct_full_graph(partition_data: PartitionData) -> GraphShardManager: +def construct_full_graph(partition_data: PartitionData ,lock = None) -> GraphShardManager: """ Constructs a GraphShardManager object from the partition data. The GraphShardManager object can serve as a drop-in replacemet to DGL's native graph in most GNN layers @@ -202,4 +202,4 @@ def construct_full_graph(partition_data: PartitionData) -> GraphShardManager: seed_nodes = torch.arange(partition_data.node_ranges[rank()][1] - partition_data.node_ranges[rank()][0]) return GraphShardManager(graph_shard_list, - seed_nodes, seed_nodes) + seed_nodes, seed_nodes, lock=lock) diff --git a/sar/core/__init__.py b/sar/core/__init__.py index 1c8cfc7..5f218fe 100644 --- a/sar/core/__init__.py +++ b/sar/core/__init__.py @@ -22,7 +22,7 @@ Modules for sharded data representation and management ''' from .graphshard import GraphShard, GraphShardManager -from .sar_aggregation import message_has_parameters +from .sar_aggregation import message_has_parameters, start_comm_thread from .full_partition_block import DistributedBlock from .sampling import DistNeighborSampler, DataLoader diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index f819cff..dbe965f 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -191,9 +191,16 @@ class GraphShardManager: """ - def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, local_tgt_seeds: Tensor) -> None: + def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, local_tgt_seeds: Tensor, lock = None) -> None: super().__init__() self.graph_shards = graph_shards + self.set_lock(lock) + if lock is None: + self.resume_process = None + self.pause_process = None + else: + self.resume_process = self._resume_process + self.pause_process = self._pause_process assert all(self.tgt_node_range == x.tgt_range for x in self.graph_shards[1:]) @@ -283,11 +290,31 @@ def sampling_graph(self): self._sampling_graph = sampling_graph return sampling_graph + def get_lock(self): + return self._lock + def set_lock(self, lock): + self._lock = lock + def acquire_lock(self): + self._lock.acquire() + logger.debug("Node {} Lock Acquired".format(rank())) + def release_lock(self): + self._lock.release() + logger.debug("Node {} Lock Released".format(rank())) + + def _pause_process(self): + logger.debug("Node {} pause_process called".format(rank())) + self.release_lock() + def _resume_process(self, handle = None): + logger.debug("Node {} resume_process called".format(rank())) + if handle: + handle.wait() + self.acquire_lock() + def update_boundary_nodes_indices(self) -> List[Tensor]: all_my_sources_indices = [ x.unique_src_nodes for x in self.graph_shards] - - indices_required_from_me = exchange_tensors(all_my_sources_indices) + indices_required_from_me = exchange_tensors(all_my_sources_indices, + precall_func = self.pause_process, callback_func = self.resume_process) for ind in indices_required_from_me: ind.sub_(self.tgt_node_range[0]) return indices_required_from_me @@ -496,6 +523,8 @@ def update_all(self, assert isinstance(reduce_func, dgl.function.reducer.SimpleReduceFunction), \ 'only simple reduce functions: sum, min, max, and mean are supported' + logger.debug('in update_all') + if reduce_func.name == 'mean': reduce_func = fn.sum(reduce_func.msg_field, # pylint: disable=no-member reduce_func.out_field) diff --git a/sar/core/sar_aggregation.py b/sar/core/sar_aggregation.py index 1ed44f8..1caaf70 100644 --- a/sar/core/sar_aggregation.py +++ b/sar/core/sar_aggregation.py @@ -33,7 +33,7 @@ from torch import Tensor from ..config import Config -from ..comm import exchange_single_tensor, rank, comm_thread, world_size +from ..comm import exchange_single_tensor, rank, CommThread, world_size from ..common_tuples import AggregationData, TensorPlace, ShardInfo if TYPE_CHECKING: @@ -43,6 +43,10 @@ logger.addHandler(logging.NullHandler()) logger.setLevel(logging.DEBUG) +def start_comm_thread(): + global comm_thread + comm_thread = CommThread() + def message_has_parameters(param_foo: Callable[[Any], Tuple[Tensor, ...]]): """A decorator for message functions that use learnable parameters. @@ -126,11 +130,13 @@ def remote_grad_sent(self): if self.n_grads_remaining == 0: self.backward_event.set() - def wait_for_all_grads(self): + def wait_for_all_grads(self, return_handle = False): if self.n_grads_remaining == 0: return - - self.backward_event.wait() + if return_handle: + return self.backward_event + else: + self.backward_event.wait() def update_grad(self, tensor_name: str, tensor: Tensor, indices: Tensor): self.grad_dict[tensor_name][indices] += tensor @@ -158,7 +164,9 @@ def exchange_grads(send_grad: Tensor, recv_grad = send_grad.new(indices_required_from_me.size(0), *send_grad.size()[1:]).zero_() - exchange_single_tensor(recv_idx, send_idx, recv_grad, send_grad) + exchange_single_tensor(recv_idx, send_idx, recv_grad, send_grad, + precall_func = graph_shard_manager.pause_process, + callback_func = graph_shard_manager.resume_process) backward_manager.update_grad( tensor_name, recv_grad, indices_required_from_me) backward_manager.remote_grad_sent() @@ -168,10 +176,15 @@ def grad_hook(grad: Tensor, graph_shard_manager: "GraphShardManager", backward_manager: BackwardManager, tensor_name: str, remote_idx: int) -> None: - comm_thread.submit_task(task_id=f'grad_{remote_idx}', task=functools.partial( - exchange_grads, send_grad=grad, graph_shard_manager=graph_shard_manager, - backward_manager=backward_manager, tensor_name=tensor_name, - send_idx=remote_idx)) + if graph_shard_manager.get_lock(): + exchange_grads(send_grad=grad, graph_shard_manager=graph_shard_manager, + backward_manager=backward_manager, tensor_name=tensor_name, + send_idx=remote_idx) + else: + comm_thread.submit_task(task_id=f'grad_{remote_idx}', task=functools.partial( + exchange_grads, send_grad=grad, graph_shard_manager=graph_shard_manager, + backward_manager=backward_manager, tensor_name=tensor_name, + send_idx=remote_idx)) def exchange_features(graph_shard_manager: "GraphShardManager", @@ -195,7 +208,9 @@ def exchange_features(graph_shard_manager: "GraphShardManager", recv_features = send_features.new( n_recv_nodes, *send_features.size()[1:]) exchange_single_tensor(recv_idx, send_idx, recv_features, - send_features) + send_features, + precall_func = graph_shard_manager.pause_process, + callback_func = graph_shard_manager.resume_process) logger.debug('recv features %s', recv_features.size()) if grad_enabled and detached_input_tensors[tensor_idx].requires_grad: @@ -423,7 +438,6 @@ def do_aggregation(aggregation_data: AggregationData, comm_round=pipeline_stage, grad_enabled=torch.is_grad_enabled()) ) pipeline_stage += 1 - with profiler.record_function("COMM_FETCH"): recv_idx, recv_dict = comm_thread.get_result() else: @@ -465,6 +479,7 @@ class SAROp(torch.autograd.Function): # pylint: disable = abstract-method def forward(ctx, aggregation_data: AggregationData, # type: ignore *all_input_tensors: Tensor) -> Tensor: # type: ignore + logger.debug('in sar_op.forward') logger.debug('aggregation_data %s', aggregation_data) # Do not pass the parameter tensors to aggregation routines. They @@ -517,6 +532,7 @@ def forward(ctx, aggregation_data: AggregationData, # type: ignore # pylint: disable = arguments-differ # type: ignore def backward(ctx, output_grad) -> Tuple[Optional[Tensor], ...]: + logger.debug('in sar_op.backwards') logger.debug('backward aggregation data %s', ctx.aggregation_data) aggregation_data = ctx.aggregation_data backward_manager = ctx.backward_manager @@ -533,7 +549,13 @@ def backward(ctx, output_grad) -> Tuple[Optional[Tensor], ...]: logger.debug('backward with successive rematerialization') t1 = time.time() - backward_manager.wait_for_all_grads() + if aggregation_data.graph_shard_manager.get_lock(): + aggregation_data.graph_shard_manager.pause_process() + handle = backward_manager.wait_for_all_grads(return_handle=True) + aggregation_data.graph_shard_manager.resume_process(handle) + else: + backward_manager.wait_for_all_grads() + logger.debug('backward event wait done in %s', time.time() - t1) input_grads = [] for tensor_idx, (tensor_place, tensor_name) in enumerate(aggregation_data.all_input_names): From daf369ec2a98bb38f6e2e1f18ea76081a762fdff Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Wed, 19 Jul 2023 17:04:26 -0700 Subject: [PATCH 2/9] merging running partitions in processes branch with saving intermediate tensors branch, still issue with saving and loading tensors that require grads --- examples/train_homogeneous_graph_basic.py | 112 +++++------ sar/__init__.py | 4 +- sar/comm.py | 3 +- sar/construct_shard_manager.py | 5 +- sar/core/graphshard.py | 221 +++++++++++++++++++++- sar/tensor_utils.py | 34 ++++ 6 files changed, 317 insertions(+), 62 deletions(-) create mode 100644 sar/tensor_utils.py diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic.py index 2dd7ccc..df805a6 100644 --- a/examples/train_homogeneous_graph_basic.py +++ b/examples/train_homogeneous_graph_basic.py @@ -31,6 +31,7 @@ import sar +from memory_profiler import profile parser = ArgumentParser( description="GNN training on node classification tasks in homogeneous graphs") @@ -127,14 +128,19 @@ def partition_data(self): @property def features(self): - if self._features is None: - raise ValueError("features not set") + #if self._features is None: + # raise ValueError("features not set") return self._features @features.setter def features(self, feats): self._features = feats + @features.deleter + def features(self): + del self._features + self._features = None + @property def labels(self): if self._labels is None: @@ -158,17 +164,23 @@ def masks(self, masks): def save(self): if not os.path.exists(self.folder_name): os.makedirs(self.folder_name) - if not self._features is None: - torch.save(self._features, os.path.join(self.folder_name, "features.pt")) - if not self._masks is None: + #if self._features is not None: + # torch.save(self._features, os.path.join(self.folder_name, "features.pt")) + if self._masks is not None: torch.save(self._masks, os.path.join(self.folder_name, "masks.pt")) - if not self._labels is None: + if self._labels is not None: torch.save(self._labels, os.path.join(self.folder_name, "labels.pt")) - if not self._partition_data is None: + if self._partition_data is not None: torch.save(self._partition_data, os.path.join(self.folder_name, "partition_data.pt")) + def save_tensor(self, tensor, tensor_name): + if not os.path.exists(self.folder_name): + os.makedirs(self.folder_name) + if tensor is not None: + torch.save(tensor, os.path.join(self.folder_name, tensor_name + ".pt")) + def delete(self): - del self._features + #del self._features del self._masks del self._labels del self._partition_data @@ -176,11 +188,10 @@ def delete(self): def load(self): if not os.path.exists(self.folder_name): raise FileNotFoundError("No partition data saved") - return - if os.path.exists(os.path.join(self.folder_name, "features.pt")): - self._features = torch.load(os.path.join(self.folder_name, "features.pt")) - else: - print("features not loaded, no file saved") + #if os.path.exists(os.path.join(self.folder_name, "features.pt")): + # self._features = torch.load(os.path.join(self.folder_name, "features.pt")) + #else: + # print("features not loaded, no file saved") if os.path.exists(os.path.join(self.folder_name, "masks.pt")): self._masks = torch.load(os.path.join(self.folder_name, "masks.pt")) else: @@ -194,32 +205,13 @@ def load(self): else: print("partition_data not loaded, no file saved") -class PartitionThreadManager: - - def __init__(self, num_partitions): - self.num_partitions = num_partitions - - def pause_thread(self): - ''' - for pointer in self.pointer_list: - torch.save(pointer, p_name) - pointer.resize_(0) - - graph_shard_manager.save() - barrier.wait() - self.resume_thread() - ''' - pass - - def resume_thread(self): - ''' - for pointer in self.pointer_list: - t = torch.load(p_name) - pointer.resize_(t.size()) - - graph_shard_manager.load() - ''' - pass + def load_tensor(self, tensor_name): + if not os.path.exists(self.folder_name): + raise FileNotFoundError("No partition data saved") + if os.path.exists(os.path.join(self.folder_name, tensor_name + ".pt")): + return torch.load(os.path.join(self.folder_name, tensor_name + ".pt")) + else: + return None def main(): @@ -238,9 +230,6 @@ def run(args, rank, lock, barrier): print('args', args) print('rank', rank) - #lock.acquire() - #print("Node {} Lock Acquired".format(rank)) - use_gpu = torch.cuda.is_available() and not args.cpu_run device = torch.device('cuda' if use_gpu else 'cpu') @@ -252,14 +241,10 @@ def run(args, rank, lock, barrier): # Obtain the ip address of the master through the network file system master_ip_address = sar.nfs_ip_init(rank, args.ip_file) - #lock.release() - #print("Node {} Lock Released".format(rank)) sar.initialize_comms(rank, args.world_size, master_ip_address, args.backend) - #barrier.wait() - #print("Node {} Barrier Passed".format(rank)) lock.acquire() print("Node {} Lock Acquired".format(rank)) @@ -291,6 +276,9 @@ def run(args, rank, lock, barrier): def precall_func(): partition_data_manager.save() + if isinstance(partition_data_manager.features, torch.Tensor): + partition_data_manager.save_tensor(partition_data_manager.features, 'features') + del partition_data_manager.features partition_data_manager.delete() lock.release() print("Node {} Lock Released".format(rank)) @@ -300,6 +288,9 @@ def callback_func(handle): lock.acquire() print("Node {} Lock Acquired".format(rank)) partition_data_manager.load() + features = partition_data_manager.load_tensor('features') + if features is not None: + partition_data_manager.features = features sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True, @@ -309,11 +300,7 @@ def callback_func(handle): num_labels = num_labels.item() partition_data_manager.features = sar.suffix_key_lookup(partition_data_manager.partition_data.node_features, 'features').to(device) - full_graph_manager = sar.construct_full_graph(partition_data_manager.partition_data, lock=lock).to(device) - #We do not need the partition data anymore - del partition_data_manager.partition_data - partition_data_manager.partition_data = None gnn_model = GNNModel(partition_data_manager.features.size(1), args.hidden_layer_dim, @@ -331,9 +318,19 @@ def callback_func(handle): precall_func=precall_func, callback_func=callback_func) n_train_points = n_train_points.item() + full_graph_manager = sar.construct_full_graph(partition_data_manager.partition_data, + partition_data_manager=partition_data_manager, lock=lock).to(device) + + #We do not need the partition data anymore + del partition_data_manager.partition_data + partition_data_manager.partition_data = None + + #full_graph_manager.partition_data_manager = partition_data_manager + optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr) for train_iter_idx in range(args.train_iters): logger.debug(f'{rank} : starting training iteration {train_iter_idx}') + partition_data_manager.features = sar.PointerTensor(partition_data_manager.features, pointer=full_graph_manager.pointer_list) # Train t_1 = time.time() logits = gnn_model(full_graph_manager, partition_data_manager.features) @@ -347,6 +344,8 @@ def callback_func(handle): sar.gather_grads(gnn_model, precall_func=full_graph_manager.pause_process, callback_func=full_graph_manager.resume_process) optimizer.step() + + train_time = time.time() - t_1 # Calculate accuracy for train/validation/test @@ -369,15 +368,20 @@ def callback_func(handle): f"iteration [{train_iter_idx + 1}/{args.train_iters}] | " ) result_message += ', '.join([ - f"train loss={loss:.4f}, " + f"train loss={loss}, " f"Accuracy: " - f"train={train_acc:.4f} " - f"valid={val_acc:.4f} " - f"test={test_acc:.4f} " + f"train={train_acc} " + f"valid={val_acc} " + f"test={test_acc} " f" | train time = {train_time} " f" |" ]) print(result_message, flush=True) + print("Pointer List Length: {}".format(len(full_graph_manager.pointer_list))) + full_graph_manager.pointer_list = [] + + full_graph_manager.print_metrics() + lock.release() if __name__ == '__main__': diff --git a/sar/__init__.py b/sar/__init__.py index c435f4e..dcf80c9 100644 --- a/sar/__init__.py +++ b/sar/__init__.py @@ -34,6 +34,7 @@ from .edge_softmax import edge_softmax from .patch_dgl import patch_dgl from .logging_setup import logging_setup, logger +from .tensor_utils import PointerTensor __all__ = ['initialize_comms', 'rank', 'world_size', 'nfs_ip_init', @@ -41,4 +42,5 @@ 'construct_mfgs', 'construct_full_graph', 'GraphShardManager', 'load_dgl_partition_data', 'suffix_key_lookup', 'Config', 'edge_softmax', 'message_has_parameters', 'DistributedBlock', 'DistNeighborSampler', 'DataLoader', - 'logging_setup', 'logger', 'sync_params', 'gather_grads', 'patch_dgl'] + 'logging_setup', 'logger', 'sync_params', 'gather_grads', 'patch_dgl', + 'PointerTensor'] diff --git a/sar/comm.py b/sar/comm.py index e6ad39f..b252dbc 100644 --- a/sar/comm.py +++ b/sar/comm.py @@ -291,7 +291,6 @@ def all_to_all(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor all_to_all_rounds(recv_tensors, send_tensors, precall_func = precall_func, callback_func = callback_func) - def all_reduce(red_tensor: torch.Tensor, op: dist.ReduceOp, move_to_comm_device: bool = False, precall_func = None, callback_func = None): # pylint: disable=invalid-name @@ -324,7 +323,7 @@ def all_reduce(red_tensor: torch.Tensor, op: dist.ReduceOp, callback_func(handle) else: handle.wait() - + def all_to_all_rounds(recv_tensors: List[torch.Tensor], send_tensors: List[torch.Tensor], precall_func = None, callback_func = None): if Config.max_collective_size == 0: diff --git a/sar/construct_shard_manager.py b/sar/construct_shard_manager.py index d00c445..cfd3bc4 100644 --- a/sar/construct_shard_manager.py +++ b/sar/construct_shard_manager.py @@ -183,7 +183,7 @@ def construct_mfgs(partition_data: PartitionData, return graph_shard_manager_list[::-1] -def construct_full_graph(partition_data: PartitionData ,lock = None) -> GraphShardManager: +def construct_full_graph(partition_data: PartitionData, partition_data_manager = None, lock = None) -> GraphShardManager: """ Constructs a GraphShardManager object from the partition data. The GraphShardManager object can serve as a drop-in replacemet to DGL's native graph in most GNN layers @@ -202,4 +202,5 @@ def construct_full_graph(partition_data: PartitionData ,lock = None) -> GraphSha seed_nodes = torch.arange(partition_data.node_ranges[rank()][1] - partition_data.node_ranges[rank()][0]) return GraphShardManager(graph_shard_list, - seed_nodes, seed_nodes, lock=lock) + seed_nodes, seed_nodes, + partition_data_manager=partition_data_manager,lock=lock) diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index dbe965f..83a7309 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -26,7 +26,7 @@ from typing import Tuple, Dict, List, Optional import inspect -import os +import os, gc, sys import itertools import logging from collections.abc import MutableMapping @@ -39,6 +39,8 @@ from torch import Tensor import torch.distributed as dist +#from memory_profiler import profile +import psutil from ..common_tuples import ShardEdgesAndFeatures, AggregationData, TensorPlace, ShardInfo from ..comm import exchange_tensors, rank, all_reduce @@ -49,6 +51,56 @@ logger.addHandler(logging.NullHandler()) logger.setLevel(logging.DEBUG) +import inspect +def get_size(obj, seen=None): + """Recursively finds size of objects in bytes""" + size = sys.getsizeof(obj) + if seen is None: + seen = set() + obj_id = id(obj) + if obj_id in seen: + return 0 + # Important mark as seen *before* entering recursion to gracefully handle + # self-referential objects + seen.add(obj_id) + if hasattr(obj, '__dict__'): + for cls in obj.__class__.__mro__: + if '__dict__' in cls.__dict__: + d = cls.__dict__['__dict__'] + if inspect.isgetsetdescriptor(d) or inspect.ismemberdescriptor(d): + size += get_size(obj.__dict__, seen) + break + if isinstance(obj, dict): + size += sum((get_size(v, seen) for v in obj.values())) + size += sum((get_size(k, seen) for k in obj.keys())) + elif isinstance(obj, torch.Tensor): + size += obj.size(dim=0) + elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): + try: + size += sum((get_size(i, seen) for i in obj)) + except TypeError: + logging.exception("Unable to get size of %r. This may lead to incorrect sizes. Please report this error.", obj) + if hasattr(obj, '__slots__'): # can have __slots__ with __dict__ + size += sum(get_size(getattr(obj, s), seen) for s in obj.__slots__ if hasattr(obj, s)) + + return size + + +def bytes2human(n): + # http://code.activestate.com/recipes/578019 + # >>> bytes2human(10000) + # '9.8K' + # >>> bytes2human(100001221) + # '95.4M' + symbols = ('K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y') + prefix = {} + for i, s in enumerate(symbols): + prefix[s] = 1 << (i + 1) * 10 + for s in reversed(symbols): + if abs(n) >= prefix[s]: + value = float(n) / prefix[s] + return '%.1f%s' % (value, s) + return "%sB" % n class GraphShard: """ @@ -191,9 +243,11 @@ class GraphShardManager: """ - def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, local_tgt_seeds: Tensor, lock = None) -> None: + def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, local_tgt_seeds: Tensor, + partition_data_manager=None, lock = None) -> None: super().__init__() self.graph_shards = graph_shards + self.pointer_list = [] self.set_lock(lock) if lock is None: self.resume_process = None @@ -201,6 +255,8 @@ def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, loca else: self.resume_process = self._resume_process self.pause_process = self._pause_process + self.partition_data_manager = partition_data_manager + self.metric_dict = None assert all(self.tgt_node_range == x.tgt_range for x in self.graph_shards[1:]) @@ -290,6 +346,18 @@ def sampling_graph(self): self._sampling_graph = sampling_graph return sampling_graph + @property + def partition_data_manager(self): + try: + partition_data_manager = self._partition_data_manager + return partition_data_manager + except: + raise ValueError("partition_data_manager not set") + + @partition_data_manager.setter + def partition_data_manager(self, partition_data_manager): + self._partition_data_manager = partition_data_manager + def get_lock(self): return self._lock def set_lock(self, lock): @@ -303,12 +371,155 @@ def release_lock(self): def _pause_process(self): logger.debug("Node {} pause_process called".format(rank())) + p = psutil.Process() + if self.partition_data_manager: + if self.metric_dict is None: + self.metric_dict = {'pre-pause': list(), + 'post-pause': list(), + 'pre-resume': list(), + 'post-resume': list()} + self.metric_dict['pre-pause'].append(p.memory_full_info().uss) + self.partition_data_manager.save() + self.partition_data_manager.delete() + try: + #TODO fix this so the first exception doesn't stop saving everything else. + #only delete if successfully saved + self.partition_data_manager.save_tensor(self.dstdata, 'dstdata') + del self.dstdata + self.partition_data_manager.save_tensor(self.srcdata, 'srcdata') + del self.srcdata + self.partition_data_manager.save_tensor(self.edata, 'edata') + del self.edata + self.partition_data_manager.save_tensor(self.graph_shards, 'graph_shards') + del self.graph_shards + self.partition_data_manager.save_tensor(self.input_nodes, 'input_nodes') + del self.input_nodes + self.partition_data_manager.save_tensor(self.seeds, 'seeds') + del self.seeds + self.partition_data_manager.save_tensor(self.indices_required_from_me, 'indicies_required_from_me') + del self.indices_required_from_me + except Exception as e: + logger.debug("_pause_process Exception: {}".format(e)) + + for idx, tens in enumerate(self.pointer_list): + try: + fname = "rank{}_tensor{}.pt".format(rank(), idx) + self.partition_data_manager.save_tensor(tens, fname) + tens.resize_(0) + logger.info("pointer tensor idx: {} saved with .resize_".format(idx)) + except RuntimeError as e: + # TODO check with tens error to here. + logger.info("pointer tensor idx: {} saved with storage().resize_".format(idx)) + import ipdb + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + finally: + sys.stdin = _stdin + tens.storage().resize_(0) + + gc.collect() + if self.metric_dict: + self.metric_dict['post-pause'].append(p.memory_full_info().uss) self.release_lock() + def _resume_process(self, handle = None): logger.debug("Node {} resume_process called".format(rank())) if handle: handle.wait() self.acquire_lock() + p = psutil.Process() + if self.metric_dict: + self.metric_dict['pre-resume'].append(p.memory_full_info().uss) + if self.partition_data_manager: + self.partition_data_manager.load() + try: + #only reset member variable if successfully loaded. + dstdata = self.partition_data_manager.load_tensor('dstdata') + if dstdata is not None: + self.dstdata = dstdata + srcdata = self.partition_data_manager.load_tensor('srcdata') + if srcdata is not None: + self.srcdata = srcdata + edata = self.partition_data_manager.load_tensor('edata') + if edata is not None: + self.edata = edata + graph_shards = self.partition_data_manager.load_tensor('graph_shards') + if graph_shards is not None: + self.graph_shards = graph_shards + input_nodes = self.partition_data_manager.load_tensor('input_nodes') + if input_nodes is not None: + self.input_nodes = input_nodes + seeds = self.partition_data_manager.load_tensor('seeds') + if seeds is not None: + self.seeds = seeds + indices_required_from_me = self.partition_data_manager.load_tensor('indicies_required_from_me') + if indices_required_from_me is not None: + self.indices_required_from_me = indices_required_from_me + except Exception as e: + logger.debug("_resume_process Exception: {}".format(e)) + for idx, tens in enumerate(self.pointer_list): + fname = "rank{}_tensor{}.pt".format(rank(), idx) + try: + tmp_tens = self.partition_data_manager.load_tensor(fname) + except: + try: + del tmp_tens + tmp_tens = self.partition_data_manager.load_tensor(fname) + except Exception as e: + logger.info("Exception: {}".format(e)) + import ipdb + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + finally: + sys.stdin = _stdin + try: + tens.resize_(tmp_tens.size()) + except: + import numpy as np + try: + tens.storage().resize_(int(np.prod(tmp_tens.size()))) + except Exception as e: + logger.info("Exception: {}".format(e)) + logger.info("tmp_tens: {}".format(tmp_tens)) + logger.info("tmp_tens: {} - func: {} - id: {}".format(tmp_tens, tmp_tens._func, id(tmp_tens))) + logger.info("tens: {} - func: {} - id: {}".format(tens, tens._func, id(tens))) + + try: + tens.copy_(tmp_tens) + except Exception as e: + logger.info("Exception: {}".format(e)) + try: + logger.info("tmp_tens: {} - func: {} - id: {}".format(tmp_tens, tmp_tens._func, id(tmp_tens))) + logger.info("tens: {} - func: {} - id: {}".format(tens, tens._func, id(tens))) + + except: + import ipdb + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + #ipdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + #logger.debug("Memory used post-resume: {}".format(p.memory_full_info())) + if self.metric_dict: + self.metric_dict['post-resume'].append(p.memory_full_info().uss) + + def print_metrics(self): + if self.metric_dict: + for metric in ['pre-pause', 'post-pause', 'pre-resume', 'post-resume']: + logger.info("Avg. Memory {}: {}".format(metric, bytes2human(sum(self.metric_dict[metric])/len(self.metric_dict[metric])))) + + avg_diff = sum([pre - post for pre, post in zip(self.metric_dict['pre-pause'], self.metric_dict['post-pause'])])/len(self.metric_dict['pre-pause']) + logger.info("Avg. Memory Diff pre/post pause: {}".format(bytes2human(avg_diff))) + + avg_diff = sum([post - pre for pre, post in zip(self.metric_dict['pre-resume'], self.metric_dict['post-resume'])])/len(self.metric_dict['pre-resume']) + logger.info("Avg. Memory Diff pre/post resume: {}".format(bytes2human(avg_diff))) def update_boundary_nodes_indices(self) -> List[Tensor]: all_my_sources_indices = [ @@ -457,7 +668,7 @@ def out_degrees(self, vertices=dgl.ALL, etype=None) -> Tensor: out_degrees[shard.unique_src_nodes - shard.src_range[0] ] = shard.graph.out_degrees(etype=etype) all_reduce(out_degrees, op=dist.ReduceOp.SUM, - move_to_comm_device=True) + move_to_comm_device=True, precall_func=self.pause_process, callback_func=self.resume_process) if comm_round == rank(): out_degrees[out_degrees == 0] = 1 self.out_degrees_cache[etype] = out_degrees.to( @@ -523,6 +734,10 @@ def update_all(self, assert isinstance(reduce_func, dgl.function.reducer.SimpleReduceFunction), \ 'only simple reduce functions: sum, min, max, and mean are supported' + #import ipdb; ipdb.set_trace() + #logger.info("update_all - Pointer List: {}".format(self.pointer_list)) + #logger.info("update_all - Pointer: {}".format(self.pointer_list[0]._pointer)) + logger.debug('in update_all') if reduce_func.name == 'mean': diff --git a/sar/tensor_utils.py b/sar/tensor_utils.py new file mode 100644 index 0000000..b36d771 --- /dev/null +++ b/sar/tensor_utils.py @@ -0,0 +1,34 @@ +import torch + +INPLACE_FUNCTIONS = [ + torch.Tensor.resize_, + torch.Tensor.copy_, + torch.Tensor.storage +] + +class PointerTensor(torch.Tensor): + # Is data even needed? + def __init__(self, data, pointer=[], func="", **kwargs): + self._pointer = pointer + self._pointer.append(self) + self._func = func + + @staticmethod + def __new__(cls, x, pointer=[], *args, **kwargs): + return super().__new__(cls, x, *args, **kwargs) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + #if func is torch.Tensor.__repr__: # or func is torch.Tensor.__format__: + # args = [a.tensor() if hasattr(a, 'tensor') else a for a in args] + # return func(*args, **kwargs) + pointers = tuple(a._pointer for a in args if hasattr(a, '_pointer')) + if len(pointers) == 0: + pointers = [[]] + #import ipdb; ipdb.set_trace() + parent = super().__torch_function__(func, types, args, kwargs) + if func not in INPLACE_FUNCTIONS and not hasattr(parent, '_pointer'): + parent.__init__([], pointer=pointers[0], func=func) + return parent From 68bc78a5143af45f6a68c2d673c24980dc2da2ff Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Fri, 21 Jul 2023 10:34:27 -0700 Subject: [PATCH 3/9] improving the way tensors requiring grad updates are saved and loaded --- sar/core/graphshard.py | 73 +++++++++++++++--------------------------- sar/tensor_utils.py | 5 ++- 2 files changed, 30 insertions(+), 48 deletions(-) diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index 83a7309..9588b50 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -401,15 +401,21 @@ def _pause_process(self): except Exception as e: logger.debug("_pause_process Exception: {}".format(e)) + + for idx, tens in enumerate(self.pointer_list): try: fname = "rank{}_tensor{}.pt".format(rank(), idx) - self.partition_data_manager.save_tensor(tens, fname) - tens.resize_(0) - logger.info("pointer tensor idx: {} saved with .resize_".format(idx)) + if tens.requires_grad: + tens_d = tens.detach() + self.partition_data_manager.save_tensor(tens_d, fname) + tens.storage().resize_(0) + else: + self.partition_data_manager.save_tensor(tens, fname) + tens.resize_(0) + except RuntimeError as e: # TODO check with tens error to here. - logger.info("pointer tensor idx: {} saved with storage().resize_".format(idx)) import ipdb _stdin = sys.stdin try: @@ -417,8 +423,7 @@ def _pause_process(self): ipdb.set_trace() finally: sys.stdin = _stdin - tens.storage().resize_(0) - + gc.collect() if self.metric_dict: self.metric_dict['post-pause'].append(p.memory_full_info().uss) @@ -463,48 +468,22 @@ def _resume_process(self, handle = None): fname = "rank{}_tensor{}.pt".format(rank(), idx) try: tmp_tens = self.partition_data_manager.load_tensor(fname) - except: - try: - del tmp_tens - tmp_tens = self.partition_data_manager.load_tensor(fname) - except Exception as e: - logger.info("Exception: {}".format(e)) - import ipdb - _stdin = sys.stdin - try: - sys.stdin = open('/dev/stdin') - ipdb.set_trace() - finally: - sys.stdin = _stdin - try: - tens.resize_(tmp_tens.size()) - except: - import numpy as np - try: - tens.storage().resize_(int(np.prod(tmp_tens.size()))) - except Exception as e: - logger.info("Exception: {}".format(e)) - logger.info("tmp_tens: {}".format(tmp_tens)) - logger.info("tmp_tens: {} - func: {} - id: {}".format(tmp_tens, tmp_tens._func, id(tmp_tens))) - logger.info("tens: {} - func: {} - id: {}".format(tens, tens._func, id(tens))) - - try: - tens.copy_(tmp_tens) + if tens.requires_grad: + import numpy as np + tens.storage().resize_(int(np.prod(tens.size()))) + #tens.set_(tmp_tens) + else: + tens.resize_(tens.size()) + tens.set_(tmp_tens) except Exception as e: - logger.info("Exception: {}".format(e)) - try: - logger.info("tmp_tens: {} - func: {} - id: {}".format(tmp_tens, tmp_tens._func, id(tmp_tens))) - logger.info("tens: {} - func: {} - id: {}".format(tens, tens._func, id(tens))) - - except: - import ipdb - _stdin = sys.stdin - try: - sys.stdin = open('/dev/stdin') - ipdb.set_trace() - #ipdb.Pdb.interaction(self, *args, **kwargs) - finally: - sys.stdin = _stdin + logger.info("Exception: {}".format(e)) + import ipdb + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + finally: + sys.stdin = _stdin #logger.debug("Memory used post-resume: {}".format(p.memory_full_info())) if self.metric_dict: diff --git a/sar/tensor_utils.py b/sar/tensor_utils.py index b36d771..8cda42c 100644 --- a/sar/tensor_utils.py +++ b/sar/tensor_utils.py @@ -3,7 +3,10 @@ INPLACE_FUNCTIONS = [ torch.Tensor.resize_, torch.Tensor.copy_, - torch.Tensor.storage + torch.Tensor.storage, + torch.Tensor.detach, + torch.Tensor.set_, + torch.Tensor.requires_grad ] class PointerTensor(torch.Tensor): From 23bfe2e1acc4cc6bdc8e5e9d3762f1dbe5966ab9 Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Wed, 2 Aug 2023 14:59:41 -0700 Subject: [PATCH 4/9] saving and loading now working, issues with backprop --- examples/train_homogeneous_graph_basic.py | 4 +- sar/core/graphshard.py | 45 +++++++++++++++++++++-- sar/tensor_utils.py | 42 ++++++++++++++++++--- 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic.py index df805a6..1f4a6ca 100644 --- a/examples/train_homogeneous_graph_basic.py +++ b/examples/train_homogeneous_graph_basic.py @@ -330,7 +330,9 @@ def callback_func(handle): optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr) for train_iter_idx in range(args.train_iters): logger.debug(f'{rank} : starting training iteration {train_iter_idx}') - partition_data_manager.features = sar.PointerTensor(partition_data_manager.features, pointer=full_graph_manager.pointer_list) + partition_data_manager.features = sar.PointerTensor(partition_data_manager.features, + pointer=full_graph_manager.pointer_list, + linked=full_graph_manager.linked_list) # Train t_1 = time.time() logits = gnn_model(full_graph_manager, partition_data_manager.features) diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index 9588b50..73e79c9 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -248,6 +248,7 @@ def __init__(self, graph_shards: List[GraphShard], local_src_seeds: Tensor, loca super().__init__() self.graph_shards = graph_shards self.pointer_list = [] + self.linked_list = [] self.set_lock(lock) if lock is None: self.resume_process = None @@ -404,12 +405,26 @@ def _pause_process(self): for idx, tens in enumerate(self.pointer_list): + #if idx >=4: + ''' + import ipdb + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + finally: + sys.stdin = _stdin + ''' + logger.debug("presave tens version - {}".format(tens._version)) + version = tens._version try: fname = "rank{}_tensor{}.pt".format(rank(), idx) if tens.requires_grad: tens_d = tens.detach() self.partition_data_manager.save_tensor(tens_d, fname) - tens.storage().resize_(0) + with torch.no_grad(): + tens.set_() + #tens.storage().resize_(0) else: self.partition_data_manager.save_tensor(tens, fname) tens.resize_(0) @@ -423,7 +438,8 @@ def _pause_process(self): ipdb.set_trace() finally: sys.stdin = _stdin - + tens._version = version + logger.debug("post-save tens version - {}".format(tens._version)) gc.collect() if self.metric_dict: self.metric_dict['post-pause'].append(p.memory_full_info().uss) @@ -465,7 +481,9 @@ def _resume_process(self, handle = None): except Exception as e: logger.debug("_resume_process Exception: {}".format(e)) for idx, tens in enumerate(self.pointer_list): + logger.debug("preload tens version - {}".format(tens._version)) fname = "rank{}_tensor{}.pt".format(rank(), idx) + version = tens._version try: tmp_tens = self.partition_data_manager.load_tensor(fname) if tens.requires_grad: @@ -473,8 +491,9 @@ def _resume_process(self, handle = None): tens.storage().resize_(int(np.prod(tens.size()))) #tens.set_(tmp_tens) else: - tens.resize_(tens.size()) - tens.set_(tmp_tens) + tens.resize_(tens.size()) + with torch.no_grad(): + tens.set_(tmp_tens) except Exception as e: logger.info("Exception: {}".format(e)) import ipdb @@ -484,6 +503,24 @@ def _resume_process(self, handle = None): ipdb.set_trace() finally: sys.stdin = _stdin + tens._version = version + logger.debug("post-load tens version - {}".format(tens._version)) + for ref_tens, linked_tens in self.linked_list: + try: + with torch.no_grad(): + linked_tens.set_(ref_tens.storage(), + storage_offset=linked_tens.storage_offset(), + size=linked_tens.size(), + stride=linked_tens.stride()) + except Exception as e: + + import ipdb + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + finally: + sys.stdin = _stdin #logger.debug("Memory used post-resume: {}".format(p.memory_full_info())) if self.metric_dict: diff --git a/sar/tensor_utils.py b/sar/tensor_utils.py index 8cda42c..609764b 100644 --- a/sar/tensor_utils.py +++ b/sar/tensor_utils.py @@ -6,18 +6,23 @@ torch.Tensor.storage, torch.Tensor.detach, torch.Tensor.set_, - torch.Tensor.requires_grad + torch.Tensor.requires_grad, + torch.Tensor.data_ptr ] class PointerTensor(torch.Tensor): # Is data even needed? - def __init__(self, data, pointer=[], func="", **kwargs): + def __init__(self, data, pointer=[], linked = [], base_tensor = None, func="", **kwargs): self._pointer = pointer - self._pointer.append(self) + self._linked = linked + if base_tensor is not None: + self._linked.append((base_tensor, self)) + else: + self._pointer.append(self) self._func = func @staticmethod - def __new__(cls, x, pointer=[], *args, **kwargs): + def __new__(cls, x, pointer=[], linked = [], base_tensor = None, *args, **kwargs): return super().__new__(cls, x, *args, **kwargs) @classmethod @@ -30,8 +35,35 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): pointers = tuple(a._pointer for a in args if hasattr(a, '_pointer')) if len(pointers) == 0: pointers = [[]] + links = tuple(a._linked for a in args if hasattr(a, '_linked')) + if len(links) == 0: + links = [[]] #import ipdb; ipdb.set_trace() parent = super().__torch_function__(func, types, args, kwargs) + if not type(parent) in [torch.Tensor, cls]: + #print("parent_type: {}".format(type(parent))) + return parent + + ''' + import ipdb, sys + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + finally: + sys.stdin = _stdin + ''' + base_tensor = None + for pointer in pointers[0]: + if hasattr(parent, 'data_ptr'): + if parent.storage().data_ptr() == pointer.storage().data_ptr(): + base_tensor = pointer + break + if func not in INPLACE_FUNCTIONS and not hasattr(parent, '_pointer'): - parent.__init__([], pointer=pointers[0], func=func) + #if hasattr(parent, 'data_ptr') and len(pointers[0]) > 0: + # if not parent.data_ptr() == pointers[-1][-1].data_ptr() and not hasattr(parent, '_pointer'): + parent.__init__([], pointer=pointers[0], linked = links[0], + base_tensor= base_tensor, func=func) + return parent From 7a22bb6cadd7c5d8f8127df8f3f312a40f176e6d Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Wed, 2 Aug 2023 15:00:40 -0700 Subject: [PATCH 5/9] saving and loading now working, issues with backprop --- sar/core/graphshard.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index 73e79c9..565bcc3 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -416,7 +416,6 @@ def _pause_process(self): sys.stdin = _stdin ''' logger.debug("presave tens version - {}".format(tens._version)) - version = tens._version try: fname = "rank{}_tensor{}.pt".format(rank(), idx) if tens.requires_grad: @@ -438,7 +437,6 @@ def _pause_process(self): ipdb.set_trace() finally: sys.stdin = _stdin - tens._version = version logger.debug("post-save tens version - {}".format(tens._version)) gc.collect() if self.metric_dict: @@ -483,7 +481,6 @@ def _resume_process(self, handle = None): for idx, tens in enumerate(self.pointer_list): logger.debug("preload tens version - {}".format(tens._version)) fname = "rank{}_tensor{}.pt".format(rank(), idx) - version = tens._version try: tmp_tens = self.partition_data_manager.load_tensor(fname) if tens.requires_grad: @@ -503,7 +500,6 @@ def _resume_process(self, handle = None): ipdb.set_trace() finally: sys.stdin = _stdin - tens._version = version logger.debug("post-load tens version - {}".format(tens._version)) for ref_tens, linked_tens in self.linked_list: try: From afe5378759bd9396ca4c6c9d3da30809f8909464 Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Mon, 7 Aug 2023 15:00:16 -0700 Subject: [PATCH 6/9] working end-to-end single node training --- examples/train_homogeneous_graph_basic.py | 8 ++- sar/core/graphshard.py | 74 +++++++++++++++++------ sar/tensor_utils.py | 4 +- 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic.py index 1f4a6ca..f8c57dc 100644 --- a/examples/train_homogeneous_graph_basic.py +++ b/examples/train_homogeneous_graph_basic.py @@ -347,6 +347,10 @@ def callback_func(handle): precall_func=full_graph_manager.pause_process, callback_func=full_graph_manager.resume_process) optimizer.step() + logits = torch.Tensor(logits) + partition_data_manager.features = torch.Tensor(partition_data_manager.features) + full_graph_manager.pointer_list = [] + full_graph_manager.linked_list = [] train_time = time.time() - t_1 @@ -379,9 +383,7 @@ def callback_func(handle): f" |" ]) print(result_message, flush=True) - print("Pointer List Length: {}".format(len(full_graph_manager.pointer_list))) - full_graph_manager.pointer_list = [] - + full_graph_manager.print_metrics() lock.release() diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index 565bcc3..b72ba9d 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -415,20 +415,35 @@ def _pause_process(self): finally: sys.stdin = _stdin ''' - logger.debug("presave tens version - {}".format(tens._version)) try: fname = "rank{}_tensor{}.pt".format(rank(), idx) if tens.requires_grad: tens_d = tens.detach() self.partition_data_manager.save_tensor(tens_d, fname) - with torch.no_grad(): - tens.set_() - #tens.storage().resize_(0) + #with torch.no_grad(): + # tens.set_() + tens.storage().resize_(0) else: self.partition_data_manager.save_tensor(tens, fname) tens.resize_(0) - except RuntimeError as e: + except Exception as e: + # TODO check with tens error to here. + import ipdb + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + finally: + sys.stdin = _stdin + + for _, linked_tens in self.linked_list: + try: + if linked_tens.requires_grad: + linked_tens.storage().resize_(0) + else: + linked_tens.resize_(0) + except Exception as e: # TODO check with tens error to here. import ipdb _stdin = sys.stdin @@ -437,7 +452,6 @@ def _pause_process(self): ipdb.set_trace() finally: sys.stdin = _stdin - logger.debug("post-save tens version - {}".format(tens._version)) gc.collect() if self.metric_dict: self.metric_dict['post-pause'].append(p.memory_full_info().uss) @@ -479,18 +493,37 @@ def _resume_process(self, handle = None): except Exception as e: logger.debug("_resume_process Exception: {}".format(e)) for idx, tens in enumerate(self.pointer_list): - logger.debug("preload tens version - {}".format(tens._version)) fname = "rank{}_tensor{}.pt".format(rank(), idx) try: tmp_tens = self.partition_data_manager.load_tensor(fname) if tens.requires_grad: import numpy as np - tens.storage().resize_(int(np.prod(tens.size()))) + #tens.storage().resize_(int(np.prod(tmp_tens.size()))) + tens.storage().resize_(tmp_tens.numel()) + + ''' + import ipdb + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + ipdb.set_trace() + finally: + sys.stdin = _stdin + ''' #tens.set_(tmp_tens) + #tens.data.copy_(tmp_tens.data) + else: - tens.resize_(tens.size()) - with torch.no_grad(): - tens.set_(tmp_tens) + tens.resize_(tmp_tens.size()) + #tens.data.copy_(tmp_tens.data) + #tens.copy_(tmp_tens) + #with torch.no_grad(): + # tens.set_(tmp_tens) + #tens.copy_(tmp_tens) + ref_tens = tens.new(0) + ref_tens.set_(tens) + ref_tens.copy_(tmp_tens) + #tens[:] = tmp_tens except Exception as e: logger.info("Exception: {}".format(e)) import ipdb @@ -500,14 +533,20 @@ def _resume_process(self, handle = None): ipdb.set_trace() finally: sys.stdin = _stdin - logger.debug("post-load tens version - {}".format(tens._version)) for ref_tens, linked_tens in self.linked_list: try: - with torch.no_grad(): - linked_tens.set_(ref_tens.storage(), - storage_offset=linked_tens.storage_offset(), - size=linked_tens.size(), - stride=linked_tens.stride()) + if linked_tens.requires_grad: + linked_tens.storage().resize_(ref_tens.numel()) + else: + linked_tens.resize_(ref_tens.size()) + + if not(linked_tens.storage().data_ptr() == ref_tens.storage().data_ptr()): + raise ValueError("linked_tens and ref_tens data_ptr not equal") + #with torch.no_grad(): + # linked_tens.set_(ref_tens.storage(), + # storage_offset=linked_tens.storage_offset(), + # size=linked_tens.size(), + # stride=linked_tens.stride()) except Exception as e: import ipdb @@ -517,7 +556,6 @@ def _resume_process(self, handle = None): ipdb.set_trace() finally: sys.stdin = _stdin - #logger.debug("Memory used post-resume: {}".format(p.memory_full_info())) if self.metric_dict: self.metric_dict['post-resume'].append(p.memory_full_info().uss) diff --git a/sar/tensor_utils.py b/sar/tensor_utils.py index 609764b..fadde2e 100644 --- a/sar/tensor_utils.py +++ b/sar/tensor_utils.py @@ -7,7 +7,8 @@ torch.Tensor.detach, torch.Tensor.set_, torch.Tensor.requires_grad, - torch.Tensor.data_ptr + torch.Tensor.data_ptr, + torch.Tensor.new ] class PointerTensor(torch.Tensor): @@ -16,6 +17,7 @@ def __init__(self, data, pointer=[], linked = [], base_tensor = None, func="", * self._pointer = pointer self._linked = linked if base_tensor is not None: + #pass self._linked.append((base_tensor, self)) else: self._pointer.append(self) From 0de1cf1487a735e865f5fde893df0765e8d1b581 Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Tue, 8 Aug 2023 09:16:32 -0700 Subject: [PATCH 7/9] renaming example script --- ...raph_basic.py => train_homogeneous_graph_basic_single-node.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{train_homogeneous_graph_basic.py => train_homogeneous_graph_basic_single-node.py} (100%) diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic_single-node.py similarity index 100% rename from examples/train_homogeneous_graph_basic.py rename to examples/train_homogeneous_graph_basic_single-node.py From d406629d7cdb488740b16a891284c501f9aa2b8c Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Tue, 8 Aug 2023 09:19:27 -0700 Subject: [PATCH 8/9] resetting training script --- examples/train_homogeneous_graph_basic.py | 208 ++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 examples/train_homogeneous_graph_basic.py diff --git a/examples/train_homogeneous_graph_basic.py b/examples/train_homogeneous_graph_basic.py new file mode 100644 index 0000000..8f39f25 --- /dev/null +++ b/examples/train_homogeneous_graph_basic.py @@ -0,0 +1,208 @@ +# Copyright (c) 2022 Intel Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import List, Union, Dict +from argparse import ArgumentParser +import os +import logging +import time +import torch +import torch.nn.functional as F +from torch import nn +import torch.distributed as dist +import dgl # type: ignore + +import sar + + +parser = ArgumentParser( + description="GNN training on node classification tasks in homogeneous graphs") + + +parser.add_argument( + "--partitioning-json-file", + type=str, + default="", + help="Path to the .json file containing partitioning information " +) + +parser.add_argument('--ip-file', default='./ip_file', type=str, + help='File with ip-address. Worker 0 creates this file and all others read it ') + + +parser.add_argument('--backend', default='nccl', type=str, choices=['ccl', 'nccl', 'mpi'], + help='Communication backend to use ' + ) + +parser.add_argument( + "--cpu-run", action="store_true", + help="Run on CPUs if set, otherwise run on GPUs " +) + + +parser.add_argument('--train-iters', default=100, type=int, + help='number of training iterations ') + +parser.add_argument( + "--lr", + type=float, + default=1e-2, + help="learning rate" +) + + +parser.add_argument('--rank', default=0, type=int, + help='Rank of the current worker ') + +parser.add_argument('--world-size', default=2, type=int, + help='Number of workers ') + +parser.add_argument('--hidden-layer-dim', default=256, type=int, + help='Dimension of GNN hidden layer') + + +class GNNModel(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): + super().__init__() + + self.convs = nn.ModuleList([ + # pylint: disable=no-member + dgl.nn.SAGEConv(in_dim, hidden_dim, aggregator_type='mean'), + # pylint: disable=no-member + dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean'), + # pylint: disable=no-member + dgl.nn.SAGEConv(hidden_dim, out_dim, aggregator_type='mean'), + ]) + + def forward(self, graph: sar.GraphShardManager, features: torch.Tensor): + for idx, conv in enumerate(self.convs): + features = conv(graph, features) + if idx < len(self.convs) - 1: + features = F.relu(features, inplace=True) + + return features + + +def main(): + args = parser.parse_args() + print('args', args) + + use_gpu = torch.cuda.is_available() and not args.cpu_run + device = torch.device('cuda' if use_gpu else 'cpu') + + # Obtain the ip address of the master through the network file system + master_ip_address = sar.nfs_ip_init(args.rank, args.ip_file) + sar.initialize_comms(args.rank, + args.world_size, master_ip_address, + args.backend) + + # Load DGL partition data + partition_data = sar.load_dgl_partition_data( + args.partitioning_json_file, args.rank, device) + + # Obtain train,validation, and test masks + # These are stored as node features. Partitioning may prepend + # the node type to the mask names. So we use the convenience function + # suffix_key_lookup to look up the mask name while ignoring the + # arbitrary node type + masks = {} + for mask_name, indices_name in zip(['train_mask', 'val_mask', 'test_mask'], + ['train_indices', 'val_indices', 'test_indices']): + boolean_mask = sar.suffix_key_lookup(partition_data.node_features, + mask_name) + masks[indices_name] = boolean_mask.nonzero( + as_tuple=False).view(-1).to(device) + + labels = sar.suffix_key_lookup(partition_data.node_features, + 'labels').long().to(device) + + # Obtain the number of classes by finding the max label across all workers + num_labels = labels.max() + 1 + sar.comm.all_reduce(num_labels, dist.ReduceOp.MAX, move_to_comm_device=True) + num_labels = num_labels.item() + + features = sar.suffix_key_lookup(partition_data.node_features, 'features').to(device) + full_graph_manager = sar.construct_full_graph(partition_data).to(device) + + #We do not need the partition data anymore + del partition_data + + gnn_model = GNNModel(features.size(1), + args.hidden_layer_dim, + num_labels).to(device) + print('model', gnn_model) + + # Synchronize the model parmeters across all workers + sar.sync_params(gnn_model) + + # Obtain the number of labeled nodes in the training + # This will be needed to properly obtain a cross entropy loss + # normalized by the number of training examples + n_train_points = torch.LongTensor([masks['train_indices'].numel()]) + sar.comm.all_reduce(n_train_points, op=dist.ReduceOp.SUM, move_to_comm_device=True) + n_train_points = n_train_points.item() + + optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.lr) + for train_iter_idx in range(args.train_iters): + # Train + t_1 = time.time() + logits = gnn_model(full_graph_manager, features) + loss = F.cross_entropy(logits[masks['train_indices']], + labels[masks['train_indices']], reduction='sum')/n_train_points + + optimizer.zero_grad() + loss.backward() + # Do not forget to gather the parameter gradients from all workers + sar.gather_grads(gnn_model) + optimizer.step() + train_time = time.time() - t_1 + + # Calculate accuracy for train/validation/test + results = [] + for indices_name in ['train_indices', 'val_indices', 'test_indices']: + n_correct = (logits[masks[indices_name]].argmax(1) == + labels[masks[indices_name]]).float().sum() + results.extend([n_correct, masks[indices_name].numel()]) + + acc_vec = torch.FloatTensor(results) + # Sum the n_correct, and number of mask elements across all workers + sar.comm.all_reduce(acc_vec, op=dist.ReduceOp.SUM, move_to_comm_device=True) + (train_acc, val_acc, test_acc) = \ + (acc_vec[0] / acc_vec[1], + acc_vec[2] / acc_vec[3], + acc_vec[4] / acc_vec[5]) + + result_message = ( + f"iteration [{train_iter_idx}/{args.train_iters}] | " + ) + result_message += ', '.join([ + f"train loss={loss:.4f}, " + f"Accuracy: " + f"train={train_acc:.4f} " + f"valid={val_acc:.4f} " + f"test={test_acc:.4f} " + f" | train time = {train_time} " + f" |" + ]) + print(result_message, flush=True) + + +if __name__ == '__main__': + main() From 503579f4310fe35889f520a66f111867b8678439 Mon Sep 17 00:00:00 2001 From: seanmcpherson Date: Tue, 8 Aug 2023 09:50:16 -0700 Subject: [PATCH 9/9] cleaning up, removing some debug messages --- sar/core/graphshard.py | 41 ------------------------------------- sar/core/sar_aggregation.py | 2 -- 2 files changed, 43 deletions(-) diff --git a/sar/core/graphshard.py b/sar/core/graphshard.py index b72ba9d..d9ce0ac 100644 --- a/sar/core/graphshard.py +++ b/sar/core/graphshard.py @@ -51,41 +51,6 @@ logger.addHandler(logging.NullHandler()) logger.setLevel(logging.DEBUG) -import inspect -def get_size(obj, seen=None): - """Recursively finds size of objects in bytes""" - size = sys.getsizeof(obj) - if seen is None: - seen = set() - obj_id = id(obj) - if obj_id in seen: - return 0 - # Important mark as seen *before* entering recursion to gracefully handle - # self-referential objects - seen.add(obj_id) - if hasattr(obj, '__dict__'): - for cls in obj.__class__.__mro__: - if '__dict__' in cls.__dict__: - d = cls.__dict__['__dict__'] - if inspect.isgetsetdescriptor(d) or inspect.ismemberdescriptor(d): - size += get_size(obj.__dict__, seen) - break - if isinstance(obj, dict): - size += sum((get_size(v, seen) for v in obj.values())) - size += sum((get_size(k, seen) for k in obj.keys())) - elif isinstance(obj, torch.Tensor): - size += obj.size(dim=0) - elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): - try: - size += sum((get_size(i, seen) for i in obj)) - except TypeError: - logging.exception("Unable to get size of %r. This may lead to incorrect sizes. Please report this error.", obj) - if hasattr(obj, '__slots__'): # can have __slots__ with __dict__ - size += sum(get_size(getattr(obj, s), seen) for s in obj.__slots__ if hasattr(obj, s)) - - return size - - def bytes2human(n): # http://code.activestate.com/recipes/578019 # >>> bytes2human(10000) @@ -784,12 +749,6 @@ def update_all(self, assert isinstance(reduce_func, dgl.function.reducer.SimpleReduceFunction), \ 'only simple reduce functions: sum, min, max, and mean are supported' - #import ipdb; ipdb.set_trace() - #logger.info("update_all - Pointer List: {}".format(self.pointer_list)) - #logger.info("update_all - Pointer: {}".format(self.pointer_list[0]._pointer)) - - logger.debug('in update_all') - if reduce_func.name == 'mean': reduce_func = fn.sum(reduce_func.msg_field, # pylint: disable=no-member reduce_func.out_field) diff --git a/sar/core/sar_aggregation.py b/sar/core/sar_aggregation.py index 1caaf70..7617422 100644 --- a/sar/core/sar_aggregation.py +++ b/sar/core/sar_aggregation.py @@ -479,7 +479,6 @@ class SAROp(torch.autograd.Function): # pylint: disable = abstract-method def forward(ctx, aggregation_data: AggregationData, # type: ignore *all_input_tensors: Tensor) -> Tensor: # type: ignore - logger.debug('in sar_op.forward') logger.debug('aggregation_data %s', aggregation_data) # Do not pass the parameter tensors to aggregation routines. They @@ -532,7 +531,6 @@ def forward(ctx, aggregation_data: AggregationData, # type: ignore # pylint: disable = arguments-differ # type: ignore def backward(ctx, output_grad) -> Tuple[Optional[Tensor], ...]: - logger.debug('in sar_op.backwards') logger.debug('backward aggregation data %s', ctx.aggregation_data) aggregation_data = ctx.aggregation_data backward_manager = ctx.backward_manager