Skip to content

Commit

Permalink
Added parameters and exception behaviour to pqdm (#792)
Browse files Browse the repository at this point in the history
Co-authored-by: Chuck Daniels <[email protected]>
Co-authored-by: Matt Fisher <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2024
1 parent 9784e4c commit 84be54e
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 11 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html)

## [Unreleased]

- Fix `earthaccess.download` to not ignore errors by default
([#581](https://github.com/nsidc/earthaccess/issues/581))
([**@Sherwin-14**](https://github.com/Sherwin-14),
[**@chuckwondo**](https://github.com/chuckwondo),
[**@mfisher87**](https://github.com/mfisher87))

### Changed

- Use built-in `assert` statements instead of `unittest` assertions in
Expand Down
27 changes: 24 additions & 3 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests
import s3fs
from fsspec import AbstractFileSystem
from typing_extensions import Any, Dict, List, Optional, Union, deprecated
from typing_extensions import Any, Dict, List, Mapping, Optional, Union, deprecated

import earthaccess
from earthaccess.services import DataServices
Expand Down Expand Up @@ -205,6 +205,7 @@ def download(
local_path: Optional[str],
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -217,6 +218,9 @@ def download(
local_path: local directory to store the remote data granules
provider: if we download a list of URLs, we need to specify the provider.
threads: parallel number of threads to use to download the files, adjust as necessary, default = 8
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
List of downloaded files
Expand All @@ -225,12 +229,19 @@ def download(
Exception: A file download failed.
"""
provider = _normalize_location(provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
"n_jobs": threads,
**(pqdm_kwargs or {}),
}
if isinstance(granules, DataGranule):
granules = [granules]
elif isinstance(granules, str):
granules = [granules]
try:
results = earthaccess.__store__.get(granules, local_path, provider, threads)
results = earthaccess.__store__.get(
granules, local_path, provider, threads, pqdm_kwargs
)
except AttributeError as err:
logger.error(
f"{err}: You must call earthaccess.login() before you can download data"
Expand All @@ -242,6 +253,7 @@ def download(
def open(
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[AbstractFileSystem]:
"""Returns a list of file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -250,12 +262,21 @@ def open(
granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`.
If a list of URLs is passed, we need to specify the data provider.
provider: e.g. POCLOUD, NSIDC_CPRD, etc.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
provider = _normalize_location(provider)
results = earthaccess.__store__.open(granules=granules, provider=provider)
pqdm_kwargs = {
"exception_behavior": "immediate",
**(pqdm_kwargs or {}),
}
results = earthaccess.__store__.open(
granules=granules, provider=provider, pqdm_kwargs=pqdm_kwargs
)
return results


Expand Down
55 changes: 47 additions & 8 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,20 @@ def _open_files(
url_mapping: Mapping[str, Union[DataGranule, None]],
fs: fsspec.AbstractFileSystem,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.spec.AbstractBufferedFile]:
def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile:
url, granule = data
return EarthAccessFile(fs.open(url), granule) # type: ignore

fileset = pqdm(url_mapping.items(), multi_thread_open, n_jobs=threads)
pqdm_kwargs = {
"exception_behavior": "immediate",
**(pqdm_kwargs or {}),
}

fileset = pqdm(
url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs
)
return fileset


Expand Down Expand Up @@ -336,6 +344,7 @@ def open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.spec.AbstractBufferedFile]:
"""Returns a list of file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -344,19 +353,23 @@ def open(
granules: a list of granule instances **or** list of URLs, e.g. `s3://some-granule`.
If a list of URLs is passed, we need to specify the data provider.
provider: e.g. POCLOUD, NSIDC_CPRD, etc.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
A list of "file pointers" to remote (i.e. s3 or https) files.
"""
if len(granules):
return self._open(granules, provider)
return self._open(granules, provider, pqdm_kwargs)
return []

@singledispatchmethod
def _open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
raise NotImplementedError("granules should be a list of DataGranule or URLs")

Expand Down Expand Up @@ -420,6 +433,7 @@ def _open_urls(
granules: List[str],
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
fileset: List = []

Expand Down Expand Up @@ -447,6 +461,7 @@ def _open_urls(
url_mapping,
fs=s3_fs,
threads=threads,
pqdm_kwargs=pqdm_kwargs,
)
except Exception as e:
raise RuntimeError(
Expand All @@ -466,7 +481,7 @@ def _open_urls(
raise ValueError(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
fileset = self._open_urls_https(url_mapping, threads)
fileset = self._open_urls_https(url_mapping, threads, pqdm_kwargs)
return fileset

def get(
Expand All @@ -475,6 +490,7 @@ def get(
local_path: Union[Path, str, None] = None,
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -491,6 +507,9 @@ def get(
provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions
threads: Parallel number of threads to use to download the files;
adjust as necessary, default = 8.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
List of downloaded files
Expand All @@ -503,7 +522,7 @@ def get(
local_path = Path(local_path)

if len(granules):
files = self._get(granules, local_path, provider, threads)
files = self._get(granules, local_path, provider, threads, pqdm_kwargs)
return files
else:
raise ValueError("List of URLs or DataGranule instances expected")
Expand All @@ -515,6 +534,7 @@ def _get(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -531,6 +551,9 @@ def _get(
provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions
threads: Parallel number of threads to use to download the files;
adjust as necessary, default = 8.
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
None
Expand All @@ -544,6 +567,7 @@ def _get_urls(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links = granules
downloaded_files: List = []
Expand All @@ -565,7 +589,9 @@ def _get_urls(

else:
# if we are not in AWS
return self._download_onprem_granules(data_links, local_path, threads)
return self._download_onprem_granules(
data_links, local_path, threads, pqdm_kwargs
)

@_get.register
def _get_granules(
Expand All @@ -574,6 +600,7 @@ def _get_granules(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[str]:
data_links: List = []
downloaded_files: List = []
Expand Down Expand Up @@ -614,7 +641,9 @@ def _get_granules(
else:
# if the data are cloud-based, but we are not in AWS,
# it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)
return self._download_onprem_granules(
data_links, local_path, threads, pqdm_kwargs
)

def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.
Expand Down Expand Up @@ -652,7 +681,11 @@ def _download_file(self, url: str, directory: Path) -> str:
return str(path)

def _download_onprem_granules(
self, urls: List[str], directory: Path, threads: int = 8
self,
urls: List[str],
directory: Path,
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[Any]:
"""Downloads a list of URLS into the data directory.
Expand All @@ -661,6 +694,9 @@ def _download_onprem_granules(
directory: local directory to store the downloaded files
threads: parallel number of threads to use to download the files;
adjust as necessary, default = 8
pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library.
See pqdm documentation for available options. Default is to use immediate exception behavior
and the number of jobs specified by the `threads` parameter.
Returns:
A list of local filepaths to which the files were downloaded.
Expand All @@ -674,23 +710,26 @@ def _download_onprem_granules(
directory.mkdir(parents=True, exist_ok=True)

arguments = [(url, directory) for url in urls]

results = pqdm(
arguments,
self._download_file,
n_jobs=threads,
argument_type="args",
**pqdm_kwargs,
)
return results

def _open_urls_https(
self,
url_mapping: Mapping[str, Union[DataGranule, None]],
threads: int = 8,
pqdm_kwargs: Optional[Mapping[str, Any]] = None,
) -> List[fsspec.AbstractFileSystem]:
https_fs = self.get_fsspec_session()

try:
return _open_files(url_mapping, https_fs, threads)
return _open_files(url_mapping, https_fs, threads, pqdm_kwargs)
except Exception:
logger.exception(
"An exception occurred while trying to access remote files via HTTPS"
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import Mock

import earthaccess
import pytest


def test_download_immediate_failure(monkeypatch):
earthaccess.login()

results = earthaccess.search_data(
short_name="ATL06",
bounding_box=(-10, 20, 10, 50),
temporal=("1999-02", "2019-03"),
count=10,
)

def mock_get(*args, **kwargs):
raise Exception("Download failed")

mock_store = Mock()
monkeypatch.setattr(earthaccess, "__store__", mock_store)
monkeypatch.setattr(mock_store, "get", mock_get)

with pytest.raises(Exception, match="Download failed"):
earthaccess.download(results, "/home/download-folder")


def test_download_deferred_failure(monkeypatch):
earthaccess.login()

results = earthaccess.search_data(
short_name="ATL06",
bounding_box=(-10, 20, 10, 50),
temporal=("1999-02", "2019-03"),
count=10,
)

def mock_get(*args, **kwargs):
return [Exception("Download failed")] * len(results)

mock_store = Mock()
monkeypatch.setattr(earthaccess, "__store__", mock_store)
monkeypatch.setattr(mock_store, "get", mock_get)

results = earthaccess.download(
results, "/home/download-folder", None, 8, {"exception_behavior": "deferred"}
)

assert all(isinstance(e, Exception) for e in results)
assert len(results) == 10

0 comments on commit 84be54e

Please sign in to comment.