Skip to content

Commit

Permalink
Merge pull request #422 from Pale-Blue-Dot-97/421-bug-in-caching-dataset
Browse files Browse the repository at this point in the history
421 bug in caching dataset
  • Loading branch information
Pale-Blue-Dot-97 committed Jan 23, 2024
2 parents c37c3c4 + 1893988 commit fe8c15f
Showing 1 changed file with 48 additions and 9 deletions.
57 changes: 48 additions & 9 deletions minerva/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,64 @@ def get_subdataset(
universal_path(data_directory), sub_dataset_params["paths"]
)

sub_dataset: GeoDataset
sub_dataset: Optional[GeoDataset]

if cache or sub_dataset_params.get("cache_dataset"):
this_hash = utils.make_hash(sub_dataset_params)

cached_dataset_path = Path(CACHE_DIR) / f"{this_hash}.obj"

if cached_dataset_path.exists():
print(f"\nLoad cached dataset {this_hash}")
sub_dataset = load_dataset_from_cache(cached_dataset_path)

else:
sub_dataset = create_subdataset(
_sub_dataset,
sub_dataset_paths,
sub_dataset_params,
transformations,
sample_pairs=sample_pairs,
)
# Ensure that no conflicts from caching datasets made in multiple processes arises.
if dist.is_available() and dist.is_initialized(): # pragma: no cover
# Get this process#s rank.
rank = dist.get_rank()

# Start a blocking action, ensuring only process 0 can create and cache the dataset.
# All other processes will wait till 0 is finished.
dist.barrier()

if rank == 0:
print(f"\nCreating dataset on {rank}...")
sub_dataset = create_subdataset(
_sub_dataset,
sub_dataset_paths,
sub_dataset_params,
transformations,
sample_pairs=sample_pairs,
)

print(f"\nSaving dataset {this_hash}")
cache_dataset(sub_dataset, cached_dataset_path)

# Other processes wait...
else:
sub_dataset = None

# End of blocking action.
dist.barrier()

# Now the other processes can load the newly created cached dataset from 0.
if rank != 0:
print(f"\nLoading dataset from cache {this_hash} on {rank}")
sub_dataset = load_dataset_from_cache(cached_dataset_path)

else:
print("\nCreating dataset...")
sub_dataset = create_subdataset(
_sub_dataset,
sub_dataset_paths,
sub_dataset_params,
transformations,
sample_pairs=sample_pairs,
)

cache_dataset(sub_dataset, cached_dataset_path)
print(f"\nSaving dataset {this_hash}")
cache_dataset(sub_dataset, cached_dataset_path)

else:
sub_dataset = create_subdataset(
Expand All @@ -199,6 +237,7 @@ def get_subdataset(
sample_pairs=sample_pairs,
)

assert sub_dataset is not None
return sub_dataset


Expand Down

0 comments on commit fe8c15f

Please sign in to comment.