Skip to content

Commit

Permalink
fix safe chunks validation
Browse files Browse the repository at this point in the history
  • Loading branch information
josephnowak committed Sep 18, 2024
1 parent 9cb9958 commit 85c49d3
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 51 deletions.
111 changes: 67 additions & 44 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __getitem__(self, key):
# could possibly have a work-around for 0d data here


def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks, region):
"""
Given encoding chunks (possibly None or []) and variable chunks
(possibly None or []).
Expand Down Expand Up @@ -163,7 +163,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):

if len(enc_chunks_tuple) != ndim:
# throw away encoding chunks, start over
return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks)
return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks, region)

for x in enc_chunks_tuple:
if not isinstance(x, int):
Expand All @@ -189,20 +189,36 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
# TODO: incorporate synchronizer to allow writes from multiple dask
# threads
if var_chunks and enc_chunks_tuple:
for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True):
for dchunk in dchunks[:-1]:
base_error = (
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
f"Writing this array in parallel with dask could lead to corrupted data."
f"Consider either rechunking using `chunk()`, deleting "
f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)

for zchunk, dchunks, interval in zip(enc_chunks_tuple, var_chunks, region, strict=True):
if not safe_chunks or len(dchunks) <= 1:
# It is not necessary to perform any additional validation if the
# safe_chunks is False, or there are less than two dchunks
continue

start = 0
if interval.start:
# If the start of the interval is not None or 0, it means that the data
# is being appended or updated, and in both cases it is mandatory that
# the residue of the division between the first dchunk and the zchunk
# being equal to the border size
border_size = zchunk - interval.start % zchunk
if dchunks[0] % zchunk != border_size:
raise ValueError(base_error)
# Avoid validating the first chunk inside the loop
start = 1

for dchunk in dchunks[start:-1]:
if dchunk % zchunk:
base_error = (
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
f"Writing this array in parallel with dask could lead to corrupted data."
)
if safe_chunks:
raise ValueError(
base_error
+ " Consider either rechunking using `chunk()`, deleting "
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)
raise ValueError(base_error)

return enc_chunks_tuple

raise AssertionError("We should never get here. Function logic must be wrong.")
Expand Down Expand Up @@ -243,14 +259,15 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):


def extract_zarr_variable_encoding(
variable, raise_on_invalid=False, name=None, safe_chunks=True
variable, region, raise_on_invalid=False, name=None, safe_chunks=True
):
"""
Extract zarr encoding dictionary from xarray Variable
Parameters
----------
variable : Variable
region: tuple[slice]
raise_on_invalid : bool, optional
Returns
Expand Down Expand Up @@ -285,7 +302,7 @@ def extract_zarr_variable_encoding(
del encoding[k]

chunks = _determine_zarr_chunks(
encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks
encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks, region
)
encoding["chunks"] = chunks
return encoding
Expand Down Expand Up @@ -762,16 +779,9 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
)
zarr_array = None
write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

if name in existing_keys:
# existing variable
Expand Down Expand Up @@ -801,7 +811,36 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
)
else:
zarr_array = self.zarr_group[name]
else:

if self._append_dim is not None and self._append_dim in dims:
# resize existing variable
append_axis = dims.index(self._append_dim)
assert write_region[self._append_dim] == slice(None)
write_region[self._append_dim] = slice(
zarr_array.shape[append_axis], None
)

new_shape = list(zarr_array.shape)
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

region = tuple(write_region[dim] for dim in dims)

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
# Note: Ideally there should be two functions, one for validating the chunks and
# another one for extracting the encoding.
encoding = extract_zarr_variable_encoding(
v,
region=region,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
)

if name not in existing_keys:
# new variable
encoded_attrs = {}
# the magic for storing the hidden dimension data
Expand Down Expand Up @@ -833,22 +872,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
)
zarr_array = _put_attrs(zarr_array, encoded_attrs)

write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

if self._append_dim is not None and self._append_dim in dims:
# resize existing variable
append_axis = dims.index(self._append_dim)
assert write_region[self._append_dim] == slice(None)
write_region[self._append_dim] = slice(
zarr_array.shape[append_axis], None
)

new_shape = list(zarr_array.shape)
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

region = tuple(write_region[dim] for dim in dims)
writer.add(v.data, zarr_array, region)

def close(self) -> None:
Expand Down
68 changes: 61 additions & 7 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5496,24 +5496,26 @@ def test_encode_zarr_attr_value() -> None:

@requires_zarr
def test_extract_zarr_variable_encoding() -> None:
# The region is not useful in these cases, but I still think that it must be mandatory
# because the validation of the chunks is in the same function
var = xr.Variable("x", [1, 2])
actual = backends.zarr.extract_zarr_variable_encoding(var)
actual = backends.zarr.extract_zarr_variable_encoding(var, region=tuple())
assert "chunks" in actual
assert actual["chunks"] is None

var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)})
actual = backends.zarr.extract_zarr_variable_encoding(var)
actual = backends.zarr.extract_zarr_variable_encoding(var, region=tuple())
assert actual["chunks"] == (1,)

# does not raise on invalid
var = xr.Variable("x", [1, 2], encoding={"foo": (1,)})
actual = backends.zarr.extract_zarr_variable_encoding(var)
actual = backends.zarr.extract_zarr_variable_encoding(var, region=tuple())

# raises on invalid
var = xr.Variable("x", [1, 2], encoding={"foo": (1,)})
with pytest.raises(ValueError, match=r"unexpected encoding parameters"):
actual = backends.zarr.extract_zarr_variable_encoding(
var, raise_on_invalid=True
var, raise_on_invalid=True, region=tuple()
)


Expand Down Expand Up @@ -6096,6 +6098,58 @@ def test_zarr_region_chunk_partial_offset(tmp_path):
store, safe_chunks=False, region="auto"
)

# This write is unsafe, and should raise an error, but does not.
# with pytest.raises(ValueError):
# da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto")
with pytest.raises(ValueError):
da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto")


@requires_zarr
@requires_dask
def test_zarr_safe_chunk(tmp_path):
# https://github.com/pydata/xarray/pull/8459#issuecomment-1819417545
store = tmp_path / "foo.zarr"
data = np.ones((20,))
da = xr.DataArray(data, dims=["x"], coords={"x": range(20)}, name="foo").chunk(x=5)

da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w")
with pytest.raises(ValueError):
# If the first chunk is smaller than the border size then raise an error
da.isel(x=slice(7, 11)).chunk(x=(2, 2)).to_zarr(
store, append_dim="x", safe_chunks=True
)

da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w")
# If the first chunk is of the size of the border size then it is valid
da.isel(x=slice(7, 11)).chunk(x=(3, 1)).to_zarr(
store, safe_chunks=True, append_dim="x"
)
assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 11)))

da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w")
# If the first chunk is of the size of the border size + N * zchunk then it is valid
da.isel(x=slice(7, 17)).chunk(x=(8, 2)).to_zarr(
store, safe_chunks=True, append_dim="x"
)
assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 17)))

da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w")
with pytest.raises(ValueError):
# If the first chunk is valid but the other are not then raise an error
da.isel(x=slice(7, 14)).chunk(x=(3, 3, 1)).to_zarr(
store, append_dim="x", safe_chunks=True
)

da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w")
with pytest.raises(ValueError):
# If the first chunk have a size bigger than the border size but not enough
# to complete the size of the next chunk then an error must be raised
da.isel(x=slice(7, 14)).chunk(x=(4, 3)).to_zarr(
store, append_dim="x", safe_chunks=True
)

da.isel(x=slice(0, 7)).to_zarr(store, safe_chunks=True, mode="w")
# Append with a single chunk it's totally valid,
# and it does not matter the size of the chunk
da.isel(x=slice(7, 19)).chunk(x=-1).to_zarr(
store, append_dim="x", safe_chunks=True
)
assert xr.open_zarr(store)["foo"].equals(da.isel(x=slice(0, 19)))

0 comments on commit 85c49d3

Please sign in to comment.