From cc585d0fb4b7003822d75c8199b30a0afb47b278 Mon Sep 17 00:00:00 2001 From: Joseph Gonzalez Date: Sat, 21 Sep 2024 18:14:24 -0400 Subject: [PATCH] Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete" --- xarray/backends/zarr.py | 55 ++++++++++++++++++++++------------- xarray/tests/test_backends.py | 41 ++++++++++++++++++++++---- 2 files changed, 70 insertions(+), 26 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index b10f3c8da94..e6fe93a398a 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -113,7 +113,7 @@ def __getitem__(self, key): def _determine_zarr_chunks( - enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode + enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape ): """ Given encoding chunks (possibly None or []) and variable chunks @@ -166,7 +166,7 @@ def _determine_zarr_chunks( if len(enc_chunks_tuple) != ndim: # throw away encoding chunks, start over return _determine_zarr_chunks( - None, var_chunks, ndim, name, safe_chunks, region, mode + None, var_chunks, ndim, name, safe_chunks, region, mode, shape ) for x in enc_chunks_tuple: @@ -208,29 +208,38 @@ def _determine_zarr_chunks( f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`." ) - for zchunk, dchunks, interval in zip( - enc_chunks_tuple, var_chunks, region, strict=True + for zchunk, dchunks, interval, size in zip( + enc_chunks_tuple, var_chunks, region, shape, strict=True ): if not safe_chunks: continue - # The first border size is the amount of data that needs to be updated on the - # first chunk taking into account the region slice. - first_border_size = zchunk - if interval.start: - first_border_size = zchunk - interval.start % zchunk + for dchunk in dchunks[1:-1]: + if dchunk % zchunk: + raise ValueError(base_error) + + region_start = interval.start if interval.start else 0 - if not allow_partial_chunks and first_border_size < zchunk: - # If the border is smaller than zchunk, then it is a partial chunk write - raise ValueError(base_error) + if len(dchunks) > 1: + # The first border size is the amount of data that needs to be updated on the + # first chunk taking into account the region slice. + first_border_size = zchunk + if allow_partial_chunks: + first_border_size = zchunk - region_start % zchunk - for dchunk in dchunks[:end]: - if (dchunk - first_border_size) % zchunk: + if (dchunks[0] - first_border_size) % zchunk: raise ValueError(base_error) - # The first border is only useful during the first iteration, - # so ignore it in the next validations - first_border_size = 0 + if not allow_partial_chunks: + region_stop = interval.stop if interval.stop else size + cover_last_chunk = region_stop > size - size % zchunk + + if not cover_last_chunk: + if dchunks[-1] % zchunk: + raise ValueError(base_error) + elif dchunks[-1] % zchunk != size % zchunk: + # The remainder must be equal to the size of the last Zarr chunk + raise ValueError(base_error) return enc_chunks_tuple @@ -279,6 +288,7 @@ def extract_zarr_variable_encoding( safe_chunks=True, region=None, mode=None, + shape=None ): """ Extract zarr encoding dictionary from xarray Variable @@ -289,9 +299,9 @@ def extract_zarr_variable_encoding( raise_on_invalid : bool, optional name: str | Hashable, optional safe_chunks: bool, optional - region: tuple[slice], optional + region: tuple[slice, ...], optional mode: str, optional - + shape: tuple[int, ...], optional Returns ------- encoding : dict @@ -331,6 +341,7 @@ def extract_zarr_variable_encoding( safe_chunks=safe_chunks, region=region, mode=mode, + shape=shape ) encoding["chunks"] = chunks return encoding @@ -808,6 +819,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No v.encoding = {} zarr_array = None + zarr_shape = None write_region = self._write_region if self._write_region is not None else {} write_region = {dim: write_region.get(dim, slice(None)) for dim in dims} @@ -852,6 +864,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No new_shape[append_axis] += v.shape[append_axis] zarr_array.resize(new_shape) + zarr_shape = zarr_array.shape + region = tuple(write_region[dim] for dim in dims) # We need to do this for both new and existing variables to ensure we're not @@ -862,11 +876,12 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No # another one for extracting the encoding. encoding = extract_zarr_variable_encoding( v, - region=region, raise_on_invalid=vn in check_encoding_set, name=vn, safe_chunks=self._safe_chunks, + region=region, mode=self._mode, + shape=zarr_shape ) if name not in existing_keys: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 78d50fcbdac..c04f71ae61c 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -6157,7 +6157,7 @@ def test_zarr_safe_chunk_region(tmp_path): store = tmp_path / "foo.zarr" arr = xr.DataArray( - list(range(10)), dims=["a"], coords={"a": list(range(10))}, name="foo" + list(range(11)), dims=["a"], coords={"a": list(range(11))}, name="foo" ).chunk(a=3) arr.to_zarr(store, mode="w") @@ -6187,10 +6187,14 @@ def test_zarr_safe_chunk_region(tmp_path): # Fully update two contiguous chunks is safe in any mode arr.isel(a=slice(3, 9)).to_zarr(store, region="auto", mode=mode) - # Write the last chunk partially is safe in "a" mode + # The last chunk is considered full based on their current size (2) + arr.isel(a=slice(9, 11)).to_zarr(store, region="auto", mode=mode) + arr.isel(a=slice(6, None)).chunk(a=-1).to_zarr(store, region="auto", mode=mode) + + # Write the last chunk of a region partially is safe in "a" mode arr.isel(a=slice(3, 8)).to_zarr(store, region="auto", mode="a") with pytest.raises(ValueError): - # with "r+" mode it is invalid to write partial chunk even on the last one + # with "r+" mode it is invalid to write partial chunk arr.isel(a=slice(3, 8)).to_zarr(store, region="auto", mode="r+") # This is safe with mode "a", the border size is covered by the first chunk of Dask @@ -6204,12 +6208,12 @@ def test_zarr_safe_chunk_region(tmp_path): arr.isel(a=slice(1, 5)).chunk(a=(4,)).to_zarr(store, region="auto", mode="a") with pytest.raises(ValueError): - # This is unsafe on mode "r+", because there is a single dask - # chunk smaller than the Zarr chunk + # This is unsafe on mode "r+", because the Dask chunk is partially writing + # in the first chunk of Zarr arr.isel(a=slice(1, 5)).chunk(a=(4,)).to_zarr(store, region="auto", mode="r+") # The first chunk is completely covering the first Zarr chunk - # and the last chunk is a partial chunk + # and the last chunk is a partial one arr.isel(a=slice(0, 5)).chunk(a=(3, 2)).to_zarr(store, region="auto", mode="a") with pytest.raises(ValueError): @@ -6227,3 +6231,28 @@ def test_zarr_safe_chunk_region(tmp_path): with pytest.raises(ValueError): # Validate that the border condition is not affecting the "r+" mode arr.isel(a=slice(1, 9)).to_zarr(store, region="auto", mode="r+") + + arr.isel(a=slice(10, 11)).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + # Validate that even if we write with a single Dask chunk on the last Zarr + # chunk it is still unsafe if it is not fully covering it + # (the last Zarr chunk has size 2) + arr.isel(a=slice(10, 11)).to_zarr(store, region="auto", mode="r+") + + # Validate the same than the above test but in the beginning of the last chunk + arr.isel(a=slice(9, 10)).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + arr.isel(a=slice(9, 10)).to_zarr(store, region="auto", mode="r+") + + arr.isel(a=slice(7, None)).chunk(a=-1).to_zarr(store, region="auto", mode="a") + with pytest.raises(ValueError): + # Test that even a Dask chunk that covers the last Zarr chunk can be unsafe + # if it is partial covering other Zarr chunks + arr.isel(a=slice(7, None)).chunk(a=-1).to_zarr(store, region="auto", mode="r+") + + with pytest.raises(ValueError): + # If the chunk is of size equal to the one in the Zarr encoding, but + # it is partially writing in the last chunk then raise an error + arr.isel(a=slice(8, None)).chunk(a=3).to_zarr(store, region="auto", mode="r+") + +