From 85c49d3d6995a78f2cb337bf017307d5050d19d8 Mon Sep 17 00:00:00 2001 From: Joseph Gonzalez Date: Wed, 18 Sep 2024 08:19:59 -0400 Subject: [PATCH 1/4] fix safe chunks validation --- xarray/backends/zarr.py | 111 ++++++++++++++++++++-------------- xarray/tests/test_backends.py | 68 ++++++++++++++++++--- 2 files changed, 128 insertions(+), 51 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 52d2175621f..52de392e85d 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -112,7 +112,7 @@ def __getitem__(self, key): # could possibly have a work-around for 0d data here -def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): +def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks, region): """ Given encoding chunks (possibly None or []) and variable chunks (possibly None or []). @@ -163,7 +163,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): if len(enc_chunks_tuple) != ndim: # throw away encoding chunks, start over - return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks) + return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks, region) for x in enc_chunks_tuple: if not isinstance(x, int): @@ -189,20 +189,36 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): # TODO: incorporate synchronizer to allow writes from multiple dask # threads if var_chunks and enc_chunks_tuple: - for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True): - for dchunk in dchunks[:-1]: + base_error = ( + f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for " + f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. " + f"Writing this array in parallel with dask could lead to corrupted data." + f"Consider either rechunking using `chunk()`, deleting " + f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`." + ) + + for zchunk, dchunks, interval in zip(enc_chunks_tuple, var_chunks, region, strict=True): + if not safe_chunks or len(dchunks) <= 1: + # It is not necessary to perform any additional validation if the + # safe_chunks is False, or there are less than two dchunks + continue + + start = 0 + if interval.start: + # If the start of the interval is not None or 0, it means that the data + # is being appended or updated, and in both cases it is mandatory that + # the residue of the division between the first dchunk and the zchunk + # being equal to the border size + border_size = zchunk - interval.start % zchunk + if dchunks[0] % zchunk != border_size: + raise ValueError(base_error) + # Avoid validating the first chunk inside the loop + start = 1 + + for dchunk in dchunks[start:-1]: if dchunk % zchunk: - base_error = ( - f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for " - f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. " - f"Writing this array in parallel with dask could lead to corrupted data." - ) - if safe_chunks: - raise ValueError( - base_error - + " Consider either rechunking using `chunk()`, deleting " - "or modifying `encoding['chunks']`, or specify `safe_chunks=False`." - ) + raise ValueError(base_error) + return enc_chunks_tuple raise AssertionError("We should never get here. Function logic must be wrong.") @@ -243,7 +259,7 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr): def extract_zarr_variable_encoding( - variable, raise_on_invalid=False, name=None, safe_chunks=True + variable, region, raise_on_invalid=False, name=None, safe_chunks=True ): """ Extract zarr encoding dictionary from xarray Variable @@ -251,6 +267,7 @@ def extract_zarr_variable_encoding( Parameters ---------- variable : Variable + region: tuple[slice] raise_on_invalid : bool, optional Returns @@ -285,7 +302,7 @@ def extract_zarr_variable_encoding( del encoding[k] chunks = _determine_zarr_chunks( - encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks + encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks, region ) encoding["chunks"] = chunks return encoding @@ -762,16 +779,9 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No if v.encoding == {"_FillValue": None} and fill_value is None: v.encoding = {} - # We need to do this for both new and existing variables to ensure we're not - # writing to a partial chunk, even though we don't use the `encoding` value - # when writing to an existing variable. See - # https://github.com/pydata/xarray/issues/8371 for details. - encoding = extract_zarr_variable_encoding( - v, - raise_on_invalid=vn in check_encoding_set, - name=vn, - safe_chunks=self._safe_chunks, - ) + zarr_array = None + write_region = self._write_region if self._write_region is not None else {} + write_region = {dim: write_region.get(dim, slice(None)) for dim in dims} if name in existing_keys: # existing variable @@ -801,7 +811,36 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No ) else: zarr_array = self.zarr_group[name] - else: + + if self._append_dim is not None and self._append_dim in dims: + # resize existing variable + append_axis = dims.index(self._append_dim) + assert write_region[self._append_dim] == slice(None) + write_region[self._append_dim] = slice( + zarr_array.shape[append_axis], None + ) + + new_shape = list(zarr_array.shape) + new_shape[append_axis] += v.shape[append_axis] + zarr_array.resize(new_shape) + + region = tuple(write_region[dim] for dim in dims) + + # We need to do this for both new and existing variables to ensure we're not + # writing to a partial chunk, even though we don't use the `encoding` value + # when writing to an existing variable. See + # https://github.com/pydata/xarray/issues/8371 for details. + # Note: Ideally there should be two functions, one for validating the chunks and + # another one for extracting the encoding. + encoding = extract_zarr_variable_encoding( + v, + region=region, + raise_on_invalid=vn in check_encoding_set, + name=vn, + safe_chunks=self._safe_chunks, + ) + + if name not in existing_keys: # new variable encoded_attrs = {} # the magic for storing the hidden dimension data @@ -833,22 +872,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No ) zarr_array = _put_attrs(zarr_array, encoded_attrs) - write_region = self._write_region if self._write_region is not None else {} - write_region = {dim: write_region.get(dim, slice(None)) for dim in dims} - - if self._append_dim is not None and self._append_dim in dims: - # resize existing variable - append_axis = dims.index(self._append_dim) - assert write_region[self._append_dim] == slice(None) - write_region[self._append_dim] = slice( - zarr_array.shape[append_axis], None - ) - - new_shape = list(zarr_array.shape) - new_shape[append_axis] += v.shape[append_axis] - zarr_array.resize(new_shape) - - region = tuple(write_region[dim] for dim in dims) writer.add(v.data, zarr_array, region) def close(self) -> None: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 13258fcf6ea..a78b583598b 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5496,24 +5496,26 @@ def test_encode_zarr_attr_value() -> None: @requires_zarr def test_extract_zarr_variable_encoding() -> None: + # The region is not useful in these cases, but I still think that it must be mandatory + # because the validation of the chunks is in the same function var = xr.Variable("x", [1, 2]) - actual = backends.zarr.extract_zarr_variable_encoding(var) + actual = backends.zarr.extract_zarr_variable_encoding(var, region=tuple()) assert "chunks" in actual assert actual["chunks"] is None var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)}) - actual = backends.zarr.extract_zarr_variable_encoding(var) + actual = backends.zarr.extract_zarr_variable_encoding(var, region=tuple()) assert actual["chunks"] == (1,) # does not raise on invalid var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) - actual = backends.zarr.extract_zarr_variable_encoding(var) + actual = backends.zarr.extract_zarr_variable_encoding(var, region=tuple()) # raises on invalid var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) with pytest.raises(ValueError, match=r"unexpected encoding parameters"): actual = backends.zarr.extract_zarr_variable_encoding( - var, raise_on_invalid=True + var, raise_on_invalid=True, region=tuple() ) @@ -6096,6 +6098,58 @@ def test_zarr_region_chunk_partial_offset(tmp_path): store, safe_chunks=False, region="auto" ) - # This write is unsafe, and should raise an error, but does not. - # with pytest.raises(ValueError): - # da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto") + with pytest.raises(ValueError): + da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto") + + +@requires_zarr +@requires_dask +def test_zarr_safe_chunk(tmp_path): + # https://github.com/pydata/xarray/pull/8459#issuecomment-1819417545 + store = tmp_path / "foo.zarr" + data = np.ones((20,)) + da = xr.DataArray(data, dims=["x"], coords={"x": range(20)}, name="foo").chunk(x=5) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + with pytest.raises(ValueError): + # If the first chunk is smaller than the border size then raise an error + da.isel(x=slice(7, 11)).chunk(x=(2, 2)).to_zarr( + store, append_dim="x", safe_chunks=True + ) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + # If the first chunk is of the size of the border size then it is valid + da.isel(x=slice(7, 11)).chunk(x=(3, 1)).to_zarr( + store, safe_chunks=True, append_dim="x" + ) + assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 11))) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + # If the first chunk is of the size of the border size + N * zchunk then it is valid + da.isel(x=slice(7, 17)).chunk(x=(8, 2)).to_zarr( + store, safe_chunks=True, append_dim="x" + ) + assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 17))) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + with pytest.raises(ValueError): + # If the first chunk is valid but the other are not then raise an error + da.isel(x=slice(7, 14)).chunk(x=(3, 3, 1)).to_zarr( + store, append_dim="x", safe_chunks=True + ) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + with pytest.raises(ValueError): + # If the first chunk have a size bigger than the border size but not enough + # to complete the size of the next chunk then an error must be raised + da.isel(x=slice(7, 14)).chunk(x=(4, 3)).to_zarr( + store, append_dim="x", safe_chunks=True + ) + + da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") + # Append with a single chunk it's totally valid, + # and it does not matter the size of the chunk + da.isel(x=slice(7, 19)).chunk(x=-1).to_zarr( + store, append_dim="x", safe_chunks=True + ) + assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 19))) From 0160d48ee35153f26e96515f887affca61a89348 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:25:29 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/zarr.py | 11 +++++++++-- xarray/tests/test_backends.py | 4 +--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 52de392e85d..c4099f1f5fe 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -197,7 +197,9 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks, regi f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`." ) - for zchunk, dchunks, interval in zip(enc_chunks_tuple, var_chunks, region, strict=True): + for zchunk, dchunks, interval in zip( + enc_chunks_tuple, var_chunks, region, strict=True + ): if not safe_chunks or len(dchunks) <= 1: # It is not necessary to perform any additional validation if the # safe_chunks is False, or there are less than two dchunks @@ -302,7 +304,12 @@ def extract_zarr_variable_encoding( del encoding[k] chunks = _determine_zarr_chunks( - encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks, region + encoding.get("chunks"), + variable.chunks, + variable.ndim, + name, + safe_chunks, + region, ) encoding["chunks"] = chunks return encoding diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a78b583598b..06646e6ec4a 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -6149,7 +6149,5 @@ def test_zarr_safe_chunk(tmp_path): da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w") # Append with a single chunk it's totally valid, # and it does not matter the size of the chunk - da.isel(x=slice(7, 19)).chunk(x=-1).to_zarr( - store, append_dim="x", safe_chunks=True - ) + da.isel(x=slice(7, 19)).chunk(x=-1).to_zarr(store, append_dim="x", safe_chunks=True) assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 19))) From 60a7a3f18e2b450590d311141ca2ee4b79df6dc8 Mon Sep 17 00:00:00 2001 From: Joseph Gonzalez Date: Wed, 18 Sep 2024 08:55:26 -0400 Subject: [PATCH 3/4] fix safe chunks validation --- doc/whats-new.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 264c07f562b..56f4dda4cca 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -51,7 +51,9 @@ Bug fixes the non-missing times could in theory be encoded with integers (:issue:`9488`, :pull:`9497`). By `Spencer Clark `_. - +- Fix the safe_chunks validation option on the to_zarr method + (:issue:`5511`, :pull:`9513`). By `Joseph Nowak + `_. Documentation ~~~~~~~~~~~~~ From 6c41f4beb059d4ac0a8c04cda117177284e3fd62 Mon Sep 17 00:00:00 2001 From: joseph nowak Date: Wed, 18 Sep 2024 15:26:00 -0400 Subject: [PATCH 4/4] Update xarray/tests/test_backends.py Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- xarray/tests/test_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 06646e6ec4a..a2419cf9145 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -6104,7 +6104,7 @@ def test_zarr_region_chunk_partial_offset(tmp_path): @requires_zarr @requires_dask -def test_zarr_safe_chunk(tmp_path): +def test_zarr_safe_chunk_append_dim(tmp_path): # https://github.com/pydata/xarray/pull/8459#issuecomment-1819417545 store = tmp_path / "foo.zarr" data = np.ones((20,))