Skip to content

Commit

Permalink
Improve safe chunk validation (#9559)
Browse files Browse the repository at this point in the history
* fix safe chunks validation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix safe chunks validation

* Update xarray/tests/test_backends.py

Co-authored-by: Maximilian Roos <[email protected]>

* The validation of the chunks now is able to detect full or partial chunk and raise a proper error based on the mode selected, it is also possible to use the auto region detection with the mode "a"

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* The test_extract_zarr_variable_encoding does not need to use the region parameter

* Inline the code of the allow_partial_chunks and end, document the parameter in order on the extract_zarr_variable_encoding method, raise the correct error if the border size is smaller than the zchunk on mode equal to r+

* Inline the code of the allow_partial_chunks and end, document the parameter in order on the extract_zarr_variable_encoding method, raise the correct error if the border size is smaller than the zchunk on mode equal to r+

* Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete"

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete"

* Add a typehint to the modes to avoid issues with mypy

* Fix the detection of the last chunk

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix the whats-new and add mode="w" to the new test case

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Maximilian Roos <[email protected]>
Co-authored-by: Maximilian Roos <[email protected]>
  • Loading branch information
4 people authored Sep 30, 2024
1 parent 7bdc6d4 commit 095d47f
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 54 deletions.
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ Bug fixes
<https://github.com/spencerkclark>`_.
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
By `Deepak Cherian <https://github.com/dcherian>`_.

- Fix the safe_chunks validation option on the to_zarr method
(:issue:`5511`, :pull:`9559`). By `Joseph Nowak
<https://github.com/josephnowak>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
167 changes: 120 additions & 47 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def __getitem__(self, key):
# could possibly have a work-around for 0d data here


def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
def _determine_zarr_chunks(
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape
):
"""
Given encoding chunks (possibly None or []) and variable chunks
(possibly None or []).
Expand Down Expand Up @@ -163,7 +165,9 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):

if len(enc_chunks_tuple) != ndim:
# throw away encoding chunks, start over
return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks)
return _determine_zarr_chunks(
None, var_chunks, ndim, name, safe_chunks, region, mode, shape
)

for x in enc_chunks_tuple:
if not isinstance(x, int):
Expand All @@ -189,20 +193,58 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
# TODO: incorporate synchronizer to allow writes from multiple dask
# threads
if var_chunks and enc_chunks_tuple:
for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True):
for dchunk in dchunks[:-1]:
# If it is possible to write on partial chunks then it is not necessary to check
# the last one contained on the region
allow_partial_chunks = mode != "r+"

base_error = (
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r} "
f"on the region {region}. "
f"Writing this array in parallel with dask could lead to corrupted data."
f"Consider either rechunking using `chunk()`, deleting "
f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)

for zchunk, dchunks, interval, size in zip(
enc_chunks_tuple, var_chunks, region, shape, strict=True
):
if not safe_chunks:
continue

for dchunk in dchunks[1:-1]:
if dchunk % zchunk:
base_error = (
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
f"Writing this array in parallel with dask could lead to corrupted data."
)
if safe_chunks:
raise ValueError(
base_error
+ " Consider either rechunking using `chunk()`, deleting "
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)
raise ValueError(base_error)

region_start = interval.start if interval.start else 0

if len(dchunks) > 1:
# The first border size is the amount of data that needs to be updated on the
# first chunk taking into account the region slice.
first_border_size = zchunk
if allow_partial_chunks:
first_border_size = zchunk - region_start % zchunk

if (dchunks[0] - first_border_size) % zchunk:
raise ValueError(base_error)

if not allow_partial_chunks:
region_stop = interval.stop if interval.stop else size

if region_start % zchunk:
# The last chunk which can also be the only one is a partial chunk
# if it is not aligned at the beginning
raise ValueError(base_error)

if np.ceil(region_stop / zchunk) == np.ceil(size / zchunk):
# If the region is covering the last chunk then check
# if the reminder with the default chunk size
# is equal to the size of the last chunk
if dchunks[-1] % zchunk != size % zchunk:
raise ValueError(base_error)
elif dchunks[-1] % zchunk:
raise ValueError(base_error)

return enc_chunks_tuple

raise AssertionError("We should never get here. Function logic must be wrong.")
Expand Down Expand Up @@ -243,7 +285,14 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):


def extract_zarr_variable_encoding(
variable, raise_on_invalid=False, name=None, safe_chunks=True
variable,
raise_on_invalid=False,
name=None,
*,
safe_chunks=True,
region=None,
mode=None,
shape=None,
):
"""
Extract zarr encoding dictionary from xarray Variable
Expand All @@ -252,12 +301,18 @@ def extract_zarr_variable_encoding(
----------
variable : Variable
raise_on_invalid : bool, optional
name: str | Hashable, optional
safe_chunks: bool, optional
region: tuple[slice, ...], optional
mode: str, optional
shape: tuple[int, ...], optional
Returns
-------
encoding : dict
Zarr encoding for `variable`
"""

shape = shape if shape else variable.shape
encoding = variable.encoding.copy()

safe_to_drop = {"source", "original_shape"}
Expand Down Expand Up @@ -285,7 +340,14 @@ def extract_zarr_variable_encoding(
del encoding[k]

chunks = _determine_zarr_chunks(
encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks
enc_chunks=encoding.get("chunks"),
var_chunks=variable.chunks,
ndim=variable.ndim,
name=name,
safe_chunks=safe_chunks,
region=region,
mode=mode,
shape=shape,
)
encoding["chunks"] = chunks
return encoding
Expand Down Expand Up @@ -762,16 +824,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
)
zarr_array = None
zarr_shape = None
write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

if name in existing_keys:
# existing variable
Expand Down Expand Up @@ -801,7 +857,40 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
)
else:
zarr_array = self.zarr_group[name]
else:

if self._append_dim is not None and self._append_dim in dims:
# resize existing variable
append_axis = dims.index(self._append_dim)
assert write_region[self._append_dim] == slice(None)
write_region[self._append_dim] = slice(
zarr_array.shape[append_axis], None
)

new_shape = list(zarr_array.shape)
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

zarr_shape = zarr_array.shape

region = tuple(write_region[dim] for dim in dims)

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
# Note: Ideally there should be two functions, one for validating the chunks and
# another one for extracting the encoding.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
region=region,
mode=self._mode,
shape=zarr_shape,
)

if name not in existing_keys:
# new variable
encoded_attrs = {}
# the magic for storing the hidden dimension data
Expand Down Expand Up @@ -833,22 +922,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
)
zarr_array = _put_attrs(zarr_array, encoded_attrs)

write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

if self._append_dim is not None and self._append_dim in dims:
# resize existing variable
append_axis = dims.index(self._append_dim)
assert write_region[self._append_dim] == slice(None)
write_region[self._append_dim] = slice(
zarr_array.shape[append_axis], None
)

new_shape = list(zarr_array.shape)
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

region = tuple(write_region[dim] for dim in dims)
writer.add(v.data, zarr_array, region)

def close(self) -> None:
Expand Down Expand Up @@ -897,9 +970,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")
if any(v == "auto" for v in region.values()):
if self._mode != "r+":
if self._mode not in ["r+", "a"]:
raise ValueError(
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got {self._mode!r}"
)
region = self._auto_detect_regions(ds, region)

Expand Down
8 changes: 8 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4316,6 +4316,14 @@ def to_zarr(
if Zarr arrays are written in parallel. This option may be useful in combination
with ``compute=False`` to initialize a Zarr store from an existing
DataArray with arbitrary chunk structure.
In addition to the many-to-one relationship validation, it also detects partial
chunks writes when using the region parameter,
these partial chunks are considered unsafe in the mode "r+" but safe in
the mode "a".
Note: Even with these validations it can still be unsafe to write
two or more chunked arrays in the same location in parallel if they are
not writing in independent regions, for those cases it is better to use
a synchronizer.
storage_options : dict, optional
Any additional parameters for the storage backend (ignored for local
paths).
Expand Down
8 changes: 8 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,6 +2509,14 @@ def to_zarr(
if Zarr arrays are written in parallel. This option may be useful in combination
with ``compute=False`` to initialize a Zarr from an existing
Dataset with arbitrary chunk structure.
In addition to the many-to-one relationship validation, it also detects partial
chunks writes when using the region parameter,
these partial chunks are considered unsafe in the mode "r+" but safe in
the mode "a".
Note: Even with these validations it can still be unsafe to write
two or more chunked arrays in the same location in parallel if they are
not writing in independent regions, for those cases it is better to use
a synchronizer.
storage_options : dict, optional
Any additional parameters for the storage backend (ignored for local
paths).
Expand Down
Loading

0 comments on commit 095d47f

Please sign in to comment.