diff --git a/minerva/datasets/factory.py b/minerva/datasets/factory.py index f6fd98c65..75a589084 100644 --- a/minerva/datasets/factory.py +++ b/minerva/datasets/factory.py @@ -185,26 +185,64 @@ def get_subdataset( universal_path(data_directory), sub_dataset_params["paths"] ) - sub_dataset: Union[GeoDataset, NonGeoDataset] + sub_dataset: Optional[Union[GeoDataset, NonGeoDataset]] if cache or sub_dataset_params.get("cache_dataset"): this_hash = utils.make_hash(sub_dataset_params) cached_dataset_path = Path(CACHE_DIR) / f"{this_hash}.obj" + print(f"{cached_dataset_path=}") + print(f"{cached_dataset_path.exists()=}") if cached_dataset_path.exists(): + print("\nLoad cached dataset") sub_dataset = load_dataset_from_cache(cached_dataset_path) else: - sub_dataset = create_subdataset( - _sub_dataset, - sub_dataset_paths, - sub_dataset_params, - transformations, - sample_pairs=sample_pairs, - ) + # Ensure that no conflicts from caching datasets made in multiple processes arises. + if dist.is_available() and dist.is_initialized(): + # Get this process#s rank. + rank = dist.get_rank() + + # Start a blocking action, ensuring only process 0 can create and cache the dataset. + # All other processes will wait till 0 is finished. + dist.barrier() + + if rank == 0: + print(f"\nCreating dataset on {rank}...") + sub_dataset = create_subdataset( + _sub_dataset, + sub_dataset_paths, + sub_dataset_params, + transformations, + sample_pairs=sample_pairs, + ) + + cache_dataset(sub_dataset, cached_dataset_path) + + # Other processes wait... + else: + sub_dataset = None + + # End of blocking action. + dist.barrier() + + # Now the other processes can load the newly created cached dataset from 0. + if rank != 0: + print(f"\nLoading dataset from cache on {rank}") + sub_dataset = load_dataset_from_cache(cached_dataset_path) + + else: + print("\nCreating dataset...") + sub_dataset = create_subdataset( + _sub_dataset, + sub_dataset_paths, + sub_dataset_params, + transformations, + sample_pairs=sample_pairs, + ) - cache_dataset(sub_dataset, cached_dataset_path) + cache_dataset(sub_dataset, cached_dataset_path) else: sub_dataset = create_subdataset( @@ -215,6 +253,7 @@ def get_subdataset( sample_pairs=sample_pairs, ) + assert sub_dataset is not None return sub_dataset