Skip to content

Commit

Permalink
Now the mode r+ is able to update the last chunk of Zarr even if it i…
Browse files Browse the repository at this point in the history
…s not "complete"
  • Loading branch information
josephnowak committed Sep 21, 2024
1 parent c454cfe commit cc585d0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 26 deletions.
55 changes: 35 additions & 20 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __getitem__(self, key):


def _determine_zarr_chunks(
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape
):
"""
Given encoding chunks (possibly None or []) and variable chunks
Expand Down Expand Up @@ -166,7 +166,7 @@ def _determine_zarr_chunks(
if len(enc_chunks_tuple) != ndim:
# throw away encoding chunks, start over
return _determine_zarr_chunks(
None, var_chunks, ndim, name, safe_chunks, region, mode
None, var_chunks, ndim, name, safe_chunks, region, mode, shape
)

for x in enc_chunks_tuple:
Expand Down Expand Up @@ -208,29 +208,38 @@ def _determine_zarr_chunks(
f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)

for zchunk, dchunks, interval in zip(
enc_chunks_tuple, var_chunks, region, strict=True
for zchunk, dchunks, interval, size in zip(
enc_chunks_tuple, var_chunks, region, shape, strict=True
):
if not safe_chunks:
continue

# The first border size is the amount of data that needs to be updated on the
# first chunk taking into account the region slice.
first_border_size = zchunk
if interval.start:
first_border_size = zchunk - interval.start % zchunk
for dchunk in dchunks[1:-1]:
if dchunk % zchunk:
raise ValueError(base_error)

region_start = interval.start if interval.start else 0

if not allow_partial_chunks and first_border_size < zchunk:
# If the border is smaller than zchunk, then it is a partial chunk write
raise ValueError(base_error)
if len(dchunks) > 1:
# The first border size is the amount of data that needs to be updated on the
# first chunk taking into account the region slice.
first_border_size = zchunk
if allow_partial_chunks:
first_border_size = zchunk - region_start % zchunk

for dchunk in dchunks[:end]:
if (dchunk - first_border_size) % zchunk:
if (dchunks[0] - first_border_size) % zchunk:
raise ValueError(base_error)

# The first border is only useful during the first iteration,
# so ignore it in the next validations
first_border_size = 0
if not allow_partial_chunks:
region_stop = interval.stop if interval.stop else size
cover_last_chunk = region_stop > size - size % zchunk

if not cover_last_chunk:
if dchunks[-1] % zchunk:
raise ValueError(base_error)
elif dchunks[-1] % zchunk != size % zchunk:
# The remainder must be equal to the size of the last Zarr chunk
raise ValueError(base_error)

return enc_chunks_tuple

Expand Down Expand Up @@ -279,6 +288,7 @@ def extract_zarr_variable_encoding(
safe_chunks=True,
region=None,
mode=None,
shape=None
):
"""
Extract zarr encoding dictionary from xarray Variable
Expand All @@ -289,9 +299,9 @@ def extract_zarr_variable_encoding(
raise_on_invalid : bool, optional
name: str | Hashable, optional
safe_chunks: bool, optional
region: tuple[slice], optional
region: tuple[slice, ...], optional
mode: str, optional
shape: tuple[int, ...], optional
Returns
-------
encoding : dict
Expand Down Expand Up @@ -331,6 +341,7 @@ def extract_zarr_variable_encoding(
safe_chunks=safe_chunks,
region=region,
mode=mode,
shape=shape
)
encoding["chunks"] = chunks
return encoding
Expand Down Expand Up @@ -808,6 +819,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
v.encoding = {}

zarr_array = None
zarr_shape = None
write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

Expand Down Expand Up @@ -852,6 +864,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

zarr_shape = zarr_array.shape

region = tuple(write_region[dim] for dim in dims)

# We need to do this for both new and existing variables to ensure we're not
Expand All @@ -862,11 +876,12 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
# another one for extracting the encoding.
encoding = extract_zarr_variable_encoding(
v,
region=region,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
region=region,
mode=self._mode,
shape=zarr_shape
)

if name not in existing_keys:
Expand Down
41 changes: 35 additions & 6 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6157,7 +6157,7 @@ def test_zarr_safe_chunk_region(tmp_path):
store = tmp_path / "foo.zarr"

arr = xr.DataArray(
list(range(10)), dims=["a"], coords={"a": list(range(10))}, name="foo"
list(range(11)), dims=["a"], coords={"a": list(range(11))}, name="foo"
).chunk(a=3)
arr.to_zarr(store, mode="w")

Expand Down Expand Up @@ -6187,10 +6187,14 @@ def test_zarr_safe_chunk_region(tmp_path):
# Fully update two contiguous chunks is safe in any mode
arr.isel(a=slice(3, 9)).to_zarr(store, region="auto", mode=mode)

# Write the last chunk partially is safe in "a" mode
# The last chunk is considered full based on their current size (2)
arr.isel(a=slice(9, 11)).to_zarr(store, region="auto", mode=mode)
arr.isel(a=slice(6, None)).chunk(a=-1).to_zarr(store, region="auto", mode=mode)

# Write the last chunk of a region partially is safe in "a" mode
arr.isel(a=slice(3, 8)).to_zarr(store, region="auto", mode="a")
with pytest.raises(ValueError):
# with "r+" mode it is invalid to write partial chunk even on the last one
# with "r+" mode it is invalid to write partial chunk
arr.isel(a=slice(3, 8)).to_zarr(store, region="auto", mode="r+")

# This is safe with mode "a", the border size is covered by the first chunk of Dask
Expand All @@ -6204,12 +6208,12 @@ def test_zarr_safe_chunk_region(tmp_path):
arr.isel(a=slice(1, 5)).chunk(a=(4,)).to_zarr(store, region="auto", mode="a")

with pytest.raises(ValueError):
# This is unsafe on mode "r+", because there is a single dask
# chunk smaller than the Zarr chunk
# This is unsafe on mode "r+", because the Dask chunk is partially writing
# in the first chunk of Zarr
arr.isel(a=slice(1, 5)).chunk(a=(4,)).to_zarr(store, region="auto", mode="r+")

# The first chunk is completely covering the first Zarr chunk
# and the last chunk is a partial chunk
# and the last chunk is a partial one
arr.isel(a=slice(0, 5)).chunk(a=(3, 2)).to_zarr(store, region="auto", mode="a")

with pytest.raises(ValueError):
Expand All @@ -6227,3 +6231,28 @@ def test_zarr_safe_chunk_region(tmp_path):
with pytest.raises(ValueError):
# Validate that the border condition is not affecting the "r+" mode
arr.isel(a=slice(1, 9)).to_zarr(store, region="auto", mode="r+")

arr.isel(a=slice(10, 11)).to_zarr(store, region="auto", mode="a")
with pytest.raises(ValueError):
# Validate that even if we write with a single Dask chunk on the last Zarr
# chunk it is still unsafe if it is not fully covering it
# (the last Zarr chunk has size 2)
arr.isel(a=slice(10, 11)).to_zarr(store, region="auto", mode="r+")

# Validate the same than the above test but in the beginning of the last chunk
arr.isel(a=slice(9, 10)).to_zarr(store, region="auto", mode="a")
with pytest.raises(ValueError):
arr.isel(a=slice(9, 10)).to_zarr(store, region="auto", mode="r+")

arr.isel(a=slice(7, None)).chunk(a=-1).to_zarr(store, region="auto", mode="a")
with pytest.raises(ValueError):
# Test that even a Dask chunk that covers the last Zarr chunk can be unsafe
# if it is partial covering other Zarr chunks
arr.isel(a=slice(7, None)).chunk(a=-1).to_zarr(store, region="auto", mode="r+")

with pytest.raises(ValueError):
# If the chunk is of size equal to the one in the Zarr encoding, but
# it is partially writing in the last chunk then raise an error
arr.isel(a=slice(8, None)).chunk(a=3).to_zarr(store, region="auto", mode="r+")


0 comments on commit cc585d0

Please sign in to comment.