From 2a6212e1255ea56065ec1bfad8d484fbdad33945 Mon Sep 17 00:00:00 2001 From: joseph nowak Date: Sat, 21 Sep 2024 21:30:32 -0400 Subject: [PATCH] Improve safe chunk validation (#9527) * fix safe chunks validation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix safe chunks validation * Update xarray/tests/test_backends.py Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * The validation of the chunks now is able to detect full or partial chunk and raise a proper error based on the mode selected, it is also possible to use the auto region detection with the mode "a" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * The test_extract_zarr_variable_encoding does not need to use the region parameter * Inline the code of the allow_partial_chunks and end, document the parameter in order on the extract_zarr_variable_encoding method, raise the correct error if the border size is smaller than the zchunk on mode equal to r+ * Inline the code of the allow_partial_chunks and end, document the parameter in order on the extract_zarr_variable_encoding method, raise the correct error if the border size is smaller than the zchunk on mode equal to r+ * Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete" * Add a typehint to the modes to avoid issues with mypy --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- doc/whats-new.rst | 4 +- xarray/backends/zarr.py | 168 +++++++++++++++++++++++---------- xarray/core/dataarray.py | 8 ++ xarray/core/dataset.py | 8 ++ xarray/tests/test_backends.py | 169 ++++++++++++++++++++++++++++++++-- 5 files changed, 303 insertions(+), 54 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e4b2a06a3e7..89c8d3b4599 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -58,7 +58,9 @@ Bug fixes `_. - Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`). By `Deepak Cherian `_. - +- Fix the safe_chunks validation option on the to_zarr method + (:issue:`5511`, :pull:`9513`). By `Joseph Nowak + `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 5a6b043eef8..2c6b50b3589 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -112,7 +112,9 @@ def __getitem__(self, key): # could possibly have a work-around for 0d data here -def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): +def _determine_zarr_chunks( + enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape +): """ Given encoding chunks (possibly None or []) and variable chunks (possibly None or []). @@ -163,7 +165,9 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): if len(enc_chunks_tuple) != ndim: # throw away encoding chunks, start over - return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks) + return _determine_zarr_chunks( + None, var_chunks, ndim, name, safe_chunks, region, mode, shape + ) for x in enc_chunks_tuple: if not isinstance(x, int): @@ -189,20 +193,59 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): # TODO: incorporate synchronizer to allow writes from multiple dask # threads if var_chunks and enc_chunks_tuple: - for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True): - for dchunk in dchunks[:-1]: + # If it is possible to write on partial chunks then it is not necessary to check + # the last one contained on the region + allow_partial_chunks = mode != "r+" + + base_error = ( + f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for " + f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r} " + f"on the region {region}. " + f"Writing this array in parallel with dask could lead to corrupted data." + f"Consider either rechunking using `chunk()`, deleting " + f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`." + ) + + for zchunk, dchunks, interval, size in zip( + enc_chunks_tuple, var_chunks, region, shape, strict=True + ): + if not safe_chunks: + continue + + for dchunk in dchunks[1:-1]: if dchunk % zchunk: - base_error = ( - f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for " - f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. " - f"Writing this array in parallel with dask could lead to corrupted data." - ) - if safe_chunks: - raise ValueError( - base_error - + " Consider either rechunking using `chunk()`, deleting " - "or modifying `encoding['chunks']`, or specify `safe_chunks=False`." - ) + raise ValueError(base_error) + + region_start = interval.start if interval.start else 0 + + if len(dchunks) > 1: + # The first border size is the amount of data that needs to be updated on the + # first chunk taking into account the region slice. + first_border_size = zchunk + if allow_partial_chunks: + first_border_size = zchunk - region_start % zchunk + + if (dchunks[0] - first_border_size) % zchunk: + raise ValueError(base_error) + + if not allow_partial_chunks: + chunk_start = sum(dchunks[:-1]) + region_start + if chunk_start % zchunk: + # The last chunk which can also be the only one is a partial chunk + # if it is not aligned at the beginning + raise ValueError(base_error) + + region_stop = interval.stop if interval.stop else size + + if size - region_stop + 1 < zchunk: + # If the region is covering the last chunk then check + # if the reminder with the default chunk size + # is equal to the size of the last chunk + if dchunks[-1] % zchunk != size % zchunk: + raise ValueError(base_error) + elif dchunks[-1] % zchunk: + raise ValueError(base_error) + return enc_chunks_tuple raise AssertionError("We should never get here. Function logic must be wrong.") @@ -243,7 +286,14 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr): def extract_zarr_variable_encoding( - variable, raise_on_invalid=False, name=None, safe_chunks=True + variable, + raise_on_invalid=False, + name=None, + *, + safe_chunks=True, + region=None, + mode=None, + shape=None, ): """ Extract zarr encoding dictionary from xarray Variable @@ -252,12 +302,18 @@ def extract_zarr_variable_encoding( ---------- variable : Variable raise_on_invalid : bool, optional - + name: str | Hashable, optional + safe_chunks: bool, optional + region: tuple[slice, ...], optional + mode: str, optional + shape: tuple[int, ...], optional Returns ------- encoding : dict Zarr encoding for `variable` """ + + shape = shape if shape else variable.shape encoding = variable.encoding.copy() safe_to_drop = {"source", "original_shape"} @@ -285,7 +341,14 @@ def extract_zarr_variable_encoding( del encoding[k] chunks = _determine_zarr_chunks( - encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks + enc_chunks=encoding.get("chunks"), + var_chunks=variable.chunks, + ndim=variable.ndim, + name=name, + safe_chunks=safe_chunks, + region=region, + mode=mode, + shape=shape, ) encoding["chunks"] = chunks return encoding @@ -762,16 +825,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No if v.encoding == {"_FillValue": None} and fill_value is None: v.encoding = {} - # We need to do this for both new and existing variables to ensure we're not - # writing to a partial chunk, even though we don't use the `encoding` value - # when writing to an existing variable. See - # https://github.com/pydata/xarray/issues/8371 for details. - encoding = extract_zarr_variable_encoding( - v, - raise_on_invalid=vn in check_encoding_set, - name=vn, - safe_chunks=self._safe_chunks, - ) + zarr_array = None + zarr_shape = None + write_region = self._write_region if self._write_region is not None else {} + write_region = {dim: write_region.get(dim, slice(None)) for dim in dims} if name in existing_keys: # existing variable @@ -801,7 +858,40 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No ) else: zarr_array = self.zarr_group[name] - else: + + if self._append_dim is not None and self._append_dim in dims: + # resize existing variable + append_axis = dims.index(self._append_dim) + assert write_region[self._append_dim] == slice(None) + write_region[self._append_dim] = slice( + zarr_array.shape[append_axis], None + ) + + new_shape = list(zarr_array.shape) + new_shape[append_axis] += v.shape[append_axis] + zarr_array.resize(new_shape) + + zarr_shape = zarr_array.shape + + region = tuple(write_region[dim] for dim in dims) + + # We need to do this for both new and existing variables to ensure we're not + # writing to a partial chunk, even though we don't use the `encoding` value + # when writing to an existing variable. See + # https://github.com/pydata/xarray/issues/8371 for details. + # Note: Ideally there should be two functions, one for validating the chunks and + # another one for extracting the encoding. + encoding = extract_zarr_variable_encoding( + v, + raise_on_invalid=vn in check_encoding_set, + name=vn, + safe_chunks=self._safe_chunks, + region=region, + mode=self._mode, + shape=zarr_shape, + ) + + if name not in existing_keys: # new variable encoded_attrs = {} # the magic for storing the hidden dimension data @@ -833,22 +923,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No ) zarr_array = _put_attrs(zarr_array, encoded_attrs) - write_region = self._write_region if self._write_region is not None else {} - write_region = {dim: write_region.get(dim, slice(None)) for dim in dims} - - if self._append_dim is not None and self._append_dim in dims: - # resize existing variable - append_axis = dims.index(self._append_dim) - assert write_region[self._append_dim] == slice(None) - write_region[self._append_dim] = slice( - zarr_array.shape[append_axis], None - ) - - new_shape = list(zarr_array.shape) - new_shape[append_axis] += v.shape[append_axis] - zarr_array.resize(new_shape) - - region = tuple(write_region[dim] for dim in dims) writer.add(v.data, zarr_array, region) def close(self) -> None: @@ -897,9 +971,9 @@ def _validate_and_autodetect_region(self, ds) -> None: if not isinstance(region, dict): raise TypeError(f"``region`` must be a dict, got {type(region)}") if any(v == "auto" for v in region.values()): - if self._mode != "r+": + if self._mode not in ["r+", "a"]: raise ValueError( - f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}" + f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got {self._mode!r}" ) region = self._auto_detect_regions(ds, region) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b5441fc273a..bcc57acd316 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4316,6 +4316,14 @@ def to_zarr( if Zarr arrays are written in parallel. This option may be useful in combination with ``compute=False`` to initialize a Zarr store from an existing DataArray with arbitrary chunk structure. + In addition to the many-to-one relationship validation, it also detects partial + chunks writes when using the region parameter, + these partial chunks are considered unsafe in the mode "r+" but safe in + the mode "a". + Note: Even with these validations it can still be unsafe to write + two or more chunked arrays in the same location in parallel if they are + not writing in independent regions, for those cases it is better to use + a synchronizer. storage_options : dict, optional Any additional parameters for the storage backend (ignored for local paths). diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7b9b4819245..b1ce264cbc8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2509,6 +2509,14 @@ def to_zarr( if Zarr arrays are written in parallel. This option may be useful in combination with ``compute=False`` to initialize a Zarr from an existing Dataset with arbitrary chunk structure. + In addition to the many-to-one relationship validation, it also detects partial + chunks writes when using the region parameter, + these partial chunks are considered unsafe in the mode "r+" but safe in + the mode "a". + Note: Even with these validations it can still be unsafe to write + two or more chunked arrays in the same location in parallel if they are + not writing in independent regions, for those cases it is better to use + a synchronizer. storage_options : dict, optional Any additional parameters for the storage backend (ignored for local paths). diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index cbc0b9e019d..ccf1bc73dd6 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5989,9 +5989,10 @@ def test_zarr_region_append(self, tmp_path): } ) - # Don't allow auto region detection in append mode due to complexities in - # implementing the overlap logic and lack of safety with parallel writes - with pytest.raises(ValueError): + # Now it is valid to use auto region detection with the append mode, + # but it is still unsafe to modify dimensions or metadata using the region + # parameter. + with pytest.raises(KeyError): ds_new.to_zarr( tmp_path / "test.zarr", mode="a", append_dim="x", region="auto" ) @@ -6096,6 +6097,162 @@ def test_zarr_region_chunk_partial_offset(tmp_path): store, safe_chunks=False, region="auto" ) - # This write is unsafe, and should raise an error, but does not. - # with pytest.raises(ValueError): - # da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto") + with pytest.raises(ValueError): + da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto") + + +@requires_zarr +@requires_dask +def test_zarr_safe_chunk_append_dim(tmp_path): + store = tmp_path / "foo.zarr" + data = np.ones((20,)) + da = xr.DataArray(data, dims=["x"], coords={"x": range(20)}, name="foo").chunk(x=5) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + with pytest.raises(ValueError): + # If the first chunk is smaller than the border size then raise an error + da.isel(x=slice(7, 11)).chunk(x=(2, 2)).to_zarr( + store, append_dim="x", safe_chunks=True + ) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + # If the first chunk is of the size of the border size then it is valid + da.isel(x=slice(7, 11)).chunk(x=(3, 1)).to_zarr( + store, safe_chunks=True, append_dim="x" + ) + assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 11))) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + # If the first chunk is of the size of the border size + N * zchunk then it is valid + da.isel(x=slice(7, 17)).chunk(x=(8, 2)).to_zarr( + store, safe_chunks=True, append_dim="x" + ) + assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 17))) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + with pytest.raises(ValueError): + # If the first chunk is valid but the other are not then raise an error + da.isel(x=slice(7, 14)).chunk(x=(3, 3, 1)).to_zarr( + store, append_dim="x", safe_chunks=True + ) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + with pytest.raises(ValueError): + # If the first chunk have a size bigger than the border size but not enough + # to complete the size of the next chunk then an error must be raised + da.isel(x=slice(7, 14)).chunk(x=(4, 3)).to_zarr( + store, append_dim="x", safe_chunks=True + ) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + # Append with a single chunk it's totally valid, + # and it does not matter the size of the chunk + da.isel(x=slice(7, 19)).chunk(x=-1).to_zarr(store, append_dim="x", safe_chunks=True) + assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 19))) + + +@requires_zarr +@requires_dask +def test_zarr_safe_chunk_region(tmp_path): + store = tmp_path / "foo.zarr" + + arr = xr.DataArray( + list(range(11)), dims=["a"], coords={"a": list(range(11))}, name="foo" + ).chunk(a=3) + arr.to_zarr(store, mode="w") + + modes: list[Literal["r+", "a"]] = ["r+", "a"] + for mode in modes: + with pytest.raises(ValueError): + # There are two Dask chunks on the same Zarr chunk, + # which means that it is unsafe in any mode + arr.isel(a=slice(0, 3)).chunk(a=(2, 1)).to_zarr( + store, region="auto", mode=mode + ) + + with pytest.raises(ValueError): + # the first chunk is covering the border size, but it is not + # completely covering the second chunk, which means that it is + # unsafe in any mode + arr.isel(a=slice(1, 5)).chunk(a=(3, 1)).to_zarr( + store, region="auto", mode=mode + ) + + with pytest.raises(ValueError): + # The first chunk is safe but the other two chunks are overlapping with + # the same Zarr chunk + arr.isel(a=slice(0, 5)).chunk(a=(3, 1, 1)).to_zarr( + store, region="auto", mode=mode + ) + + # Fully update two contiguous chunks is safe in any mode + arr.isel(a=slice(3, 9)).to_zarr(store, region="auto", mode=mode) + + # The last chunk is considered full based on their current size (2) + arr.isel(a=slice(9, 11)).to_zarr(store, region="auto", mode=mode) + arr.isel(a=slice(6, None)).chunk(a=-1).to_zarr(store, region="auto", mode=mode) + + # Write the last chunk of a region partially is safe in "a" mode + arr.isel(a=slice(3, 8)).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + # with "r+" mode it is invalid to write partial chunk + arr.isel(a=slice(3, 8)).to_zarr(store, region="auto", mode="r+") + + # This is safe with mode "a", the border size is covered by the first chunk of Dask + arr.isel(a=slice(1, 4)).chunk(a=(2, 1)).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + # This is considered unsafe in mode "r+" because it is writing in a partial chunk + arr.isel(a=slice(1, 4)).chunk(a=(2, 1)).to_zarr(store, region="auto", mode="r+") + + # This is safe on mode "a" because there is a single dask chunk + arr.isel(a=slice(1, 5)).chunk(a=(4,)).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + # This is unsafe on mode "r+", because the Dask chunk is partially writing + # in the first chunk of Zarr + arr.isel(a=slice(1, 5)).chunk(a=(4,)).to_zarr(store, region="auto", mode="r+") + + # The first chunk is completely covering the first Zarr chunk + # and the last chunk is a partial one + arr.isel(a=slice(0, 5)).chunk(a=(3, 2)).to_zarr(store, region="auto", mode="a") + + with pytest.raises(ValueError): + # The last chunk is partial, so it is considered unsafe on mode "r+" + arr.isel(a=slice(0, 5)).chunk(a=(3, 2)).to_zarr(store, region="auto", mode="r+") + + # The first chunk is covering the border size (2 elements) + # and also the second chunk (3 elements), so it is valid + arr.isel(a=slice(1, 8)).chunk(a=(5, 2)).to_zarr(store, region="auto", mode="a") + + with pytest.raises(ValueError): + # The first chunk is not fully covering the first zarr chunk + arr.isel(a=slice(1, 8)).chunk(a=(5, 2)).to_zarr(store, region="auto", mode="r+") + + with pytest.raises(ValueError): + # Validate that the border condition is not affecting the "r+" mode + arr.isel(a=slice(1, 9)).to_zarr(store, region="auto", mode="r+") + + arr.isel(a=slice(10, 11)).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + # Validate that even if we write with a single Dask chunk on the last Zarr + # chunk it is still unsafe if it is not fully covering it + # (the last Zarr chunk has size 2) + arr.isel(a=slice(10, 11)).to_zarr(store, region="auto", mode="r+") + + # Validate the same as the above test but in the beginning of the last chunk + arr.isel(a=slice(9, 10)).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + arr.isel(a=slice(9, 10)).to_zarr(store, region="auto", mode="r+") + + arr.isel(a=slice(7, None)).chunk(a=-1).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + # Test that even a Dask chunk that covers the last Zarr chunk can be unsafe + # if it is partial covering other Zarr chunks + arr.isel(a=slice(7, None)).chunk(a=-1).to_zarr(store, region="auto", mode="r+") + + with pytest.raises(ValueError): + # If the chunk is of size equal to the one in the Zarr encoding, but + # it is partially writing in the first chunk then raise an error + arr.isel(a=slice(8, None)).chunk(a=3).to_zarr(store, region="auto", mode="r+") + + with pytest.raises(ValueError): + arr.isel(a=slice(5, -1)).chunk(a=5).to_zarr(store, region="auto", mode="r+")