Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pyramid reproject map blocks #12

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0d1aaec
factor pyramid_reproject into its own module
Nov 23, 2021
4b3543d
[wip] update pyramid reproject to use map_blocks
Nov 23, 2021
4927891
fix pytest invocation in gha
Nov 29, 2021
aa49e99
updated implementation of make_grid_ds
Dec 7, 2021
bbddf14
fix numpy test
Dec 7, 2021
7e7b2cd
fix xy ordering
Dec 7, 2021
dffcd88
fix test comparison
Dec 7, 2021
8115415
update to latest datatree syntax
Dec 17, 2021
0cf2e60
update test env
Dec 17, 2021
806013e
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jan 3, 2022
9887f2e
Bump actions/checkout from 2.4.0 to 3 (#19)
dependabot[bot] Mar 24, 2022
682a6bb
Fix CI: use micromamba & gh action concurrency feature (#21)
andersy005 Mar 24, 2022
50df185
Update demo notebook (#20)
andersy005 Mar 24, 2022
ac2abf1
Update demo.ipynb (#22)
andersy005 Mar 24, 2022
1408888
Update pre-commit hooks (#23)
andersy005 Mar 25, 2022
787903a
factor pyramid_reproject into its own module
Nov 23, 2021
33d8b67
[wip] update pyramid reproject to use map_blocks
Nov 23, 2021
92f0ae7
enable pyupgrade
andersy005 Mar 25, 2022
a057963
set periodic=True as default in regridders (#16)
Mar 25, 2022
7f681ce
[wip] update pyramid reproject to use map_blocks
Nov 23, 2021
45d2ed4
Merge branch 'main' into feature/pyramid-reproject-map-blocks
andersy005 Mar 25, 2022
8646539
Merge branch 'main' into feature/pyramid-reproject-map-blocks
andersy005 Mar 25, 2022
dcf2e2d
Remove unused imports
andersy005 Mar 25, 2022
48cfee0
Merge branch 'main' into feature/pyramid-reproject-map-blocks
andersy005 Mar 31, 2022
5107983
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ndpyramid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pkg_resources import DistributionNotFound, get_distribution

from .core import pyramid_coarsen, pyramid_reproject
from .core import pyramid_coarsen
from .regrid import pyramid_regrid
from .reproject import pyramid_reproject

try:
__version__ = get_distribution(__name__).version
Expand Down
60 changes: 0 additions & 60 deletions ndpyramid/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections import defaultdict
from typing import List

import datatree as dt
Expand Down Expand Up @@ -35,62 +34,3 @@ def pyramid_coarsen(ds, factors: List[int], dims: List[str], **kwargs) -> dt.Dat
pyramid[skey] = ds.coarsen(**kwargs).mean()

return pyramid


def pyramid_reproject(
ds, levels: int = None, pixels_per_tile=128, resampling='average', extra_dim=None
) -> dt.DataTree:
import rioxarray # noqa: F401
from rasterio.transform import Affine
from rasterio.warp import Resampling

# multiscales spec
save_kwargs = {'levels': levels, 'pixels_per_tile': pixels_per_tile}
attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(levels)],
type='reduce',
method='pyramid_reproject',
version=get_version(),
kwargs=save_kwargs,
)
}

if isinstance(resampling, str):
resampling_dict = defaultdict(lambda: resampling)
else:
resampling_dict = resampling

# set up pyramid
root = xr.Dataset(attrs=attrs)
pyramid = dt.DataTree(data_objects={"root": root})

# pyramid data
for level in range(levels):
lkey = str(level)
dim = 2 ** level * pixels_per_tile
dst_transform = Affine.translation(-20026376.39, 20048966.10) * Affine.scale(
(20026376.39 * 2) / dim, -(20048966.10 * 2) / dim
)

def reproject(da, var):
return da.rio.reproject(
'EPSG:3857',
resampling=Resampling[resampling_dict[var]],
shape=(dim, dim),
transform=dst_transform,
)

pyramid[lkey] = xr.Dataset(attrs=ds.attrs)
for k, da in ds.items():
if len(da.shape) == 4:
if extra_dim is None:
raise ValueError("must specify 'extra_dim' to iterate over 4d data")
da_all = []
for index in ds[extra_dim]:
da_reprojected = reproject(da.sel({extra_dim: index}), k)
da_all.append(da_reprojected)
pyramid[lkey].ds[k] = xr.concat(da_all, ds[extra_dim])
else:
pyramid[lkey].ds[k] = reproject(da, k)
return pyramid
2 changes: 1 addition & 1 deletion ndpyramid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def pyramid_regrid(

attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i) for i in range(levels)}],
datasets=[{'path': str(i)} for i in range(levels)],
type='reduce',
method='pyramid_regrid',
version=get_version(),
Expand Down
131 changes: 131 additions & 0 deletions ndpyramid/reproject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from collections import defaultdict
from typing import Dict, Tuple, TypeVar

import dask
import datatree as dt
import numpy as np
import xarray as xr

from .utils import get_version, multiscales_template

ResamplingType = TypeVar('ResamplingType', str, Dict[str, str])


def _add_x_y_coords(da: xr.DataArray, shape: Tuple[int], transform) -> xr.DataArray:
'''helper function to add x/y coordinates to xr.DataArray'''

bounds_shape = tuple(s + 1 for s in shape)

xs = np.empty(shape)
ys = np.empty(shape)
for i in range(bounds_shape[0]):
for j in range(bounds_shape[1]):
if i < shape[0] and j < shape[1]:
x, y = transform * [j + 0.5, i + 0.5]
xs[i, j] = x
ys[i, j] = y

da = da.assign_coords(
{"x": xr.DataArray(xs[0, :], dims=["x"]), "y": xr.DataArray(ys[:, 0], dims=["y"])}
)

return da


def _make_template(shape: Tuple[int], dst_transform, attrs: dict) -> xr.DataArray:
'''helper function to make a xr.DataArray template'''

template = xr.DataArray(
data=dask.array.empty(shape, chunks=shape), dims=("y", "x"), attrs=attrs
)
template = _add_x_y_coords(template, shape, dst_transform)
template.coords["spatial_ref"] = xr.DataArray(np.array(1.0))
return template


def _reproject(da: xr.DataArray, shape=None, dst_transform=None, resampling="average"):
'''helper function to reproject xr.DataArray objects'''
from rasterio.warp import Resampling

return da.rio.reproject(
"EPSG:3857",
resampling=Resampling[resampling],
shape=shape,
transform=dst_transform,
)


def pyramid_reproject(
ds,
levels: int = None,
pixels_per_tile=128,
resampling: ResamplingType = 'average',
) -> dt.DataTree:
"""[summary]

Parameters
----------
ds : xr.Dataset
Input dataset
levels : int, optional
Number of levels in pyramid, by default None
pixels_per_tile : int, optional
Number of pixels to include along each axis in individual tiles, by default 128
resampling : str or dict, optional
Rasterio resampling method. Can be provided as a string or a per-variable
dict, by default 'average'

Returns
-------
dt.DataTree
Multiscale data pyramid
"""
import rioxarray # noqa: F401
from rasterio.transform import Affine

# multiscales spec
save_kwargs = {"levels": levels, "pixels_per_tile": pixels_per_tile}
attrs = {
"multiscales": multiscales_template(
datasets=[{"path": str(i)} for i in range(levels)],
type="reduce",
method="pyramid_reproject",
version=get_version(),
kwargs=save_kwargs,
)
}

resampling_dict: ResamplingType
if isinstance(resampling, str):
resampling_dict = defaultdict(lambda: resampling)
else:
resampling_dict = resampling

# set up pyramid
root = xr.Dataset(attrs=attrs)
pyramid = dt.DataTree(data_objects={"root": root})

for level in range(levels):
lkey = str(level)
dim = 2 ** level * pixels_per_tile

dst_transform = Affine.translation(-20026376.39, 20048966.10) * Affine.scale(
(20026376.39 * 2) / dim, -(20048966.10 * 2) / dim
)

pyramid[lkey] = xr.Dataset(attrs=ds.attrs)
shape = (dim, dim)
for k, da in ds.items():
template_shape = (chunked_dim_sizes) + shape # TODO: pick up here.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is chunked_dim_sizes supposed to come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Clearly I didn't finish this part.

If I recall correctly, this is mean to support cases where non-spatial dims are chunked. The most common examples would be chunks along the band or time dimensions.

To do this, we need to construct a template that includes the correct shape for each block.

template = _make_template(template_shape, dst_transform, ds[k].attrs)
print(resampling_dict[k])
pyramid[lkey].ds[k] = xr.map_blocks(
_reproject,
da,
kwargs=dict(
shape=(dim, dim), dst_transform=dst_transform, resampling=resampling_dict[k]
),
template=template,
)

return pyramid
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ per-file-ignores = __init__.py:F401

[isort]
known_first_party=carbonplan
known_third_party=datatree,numpy,pkg_resources,pytest,setuptools,xarray,zarr
known_third_party=dask,datatree,numpy,pkg_resources,pytest,setuptools,xarray,zarr
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
14 changes: 14 additions & 0 deletions tests/test_pyramids.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def test_reprojected_pyramid(temperature):
pyramid.to_zarr(MemoryStore())


def test_reprojected_pyramid_dask(temperature):
rioxarray = pytest.importorskip("rioxarray") # noqa: F841
levels = 2
temperature = temperature.rio.write_crs('EPSG:4326')
print(temperature)
pyramid = pyramid_reproject(temperature.chunk({'time': 1}), levels=2)
for child in pyramid.children:
child.ds = child.ds.chunk({"x": 128, "y": 128})
print(pyramid['0'].ds)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels
pyramid.to_zarr(MemoryStore())


def test_regridded_pyramid(temperature):
xesmf = pytest.importorskip("xesmf") # noqa: F841
pyramid = pyramid_regrid(temperature, levels=2)
Expand Down