Skip to content

Commit

Permalink
Merge branch 'main' into h5netcdf-new-features
Browse files Browse the repository at this point in the history
  • Loading branch information
kmuehlbauer committed Sep 22, 2024
2 parents 2d12d54 + 2a6212e commit 98e5071
Show file tree
Hide file tree
Showing 10 changed files with 343 additions and 62 deletions.
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ Bug fixes
<https://github.com/spencerkclark>`_.
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
By `Deepak Cherian <https://github.com/dcherian>`_.

- Fix the safe_chunks validation option on the to_zarr method
(:issue:`5511`, :pull:`9513`). By `Joseph Nowak
<https://github.com/josephnowak>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
9 changes: 9 additions & 0 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,12 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
roundtripped.columns.name = "cols" # why?
pd.testing.assert_frame_equal(df, roundtripped)
xr.testing.assert_identical(dataset, roundtripped.to_xarray())


def test_roundtrip_1d_pandas_extension_array() -> None:
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
arr = xr.Dataset.from_dataframe(df)["cat"]
roundtripped = arr.to_pandas()
assert (df["cat"] == roundtripped).all()
assert df["cat"].dtype == roundtripped.dtype
xr.testing.assert_identical(arr, roundtripped.to_xarray())
168 changes: 121 additions & 47 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def __getitem__(self, key):
# could possibly have a work-around for 0d data here


def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
def _determine_zarr_chunks(
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape
):
"""
Given encoding chunks (possibly None or []) and variable chunks
(possibly None or []).
Expand Down Expand Up @@ -163,7 +165,9 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):

if len(enc_chunks_tuple) != ndim:
# throw away encoding chunks, start over
return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks)
return _determine_zarr_chunks(
None, var_chunks, ndim, name, safe_chunks, region, mode, shape
)

for x in enc_chunks_tuple:
if not isinstance(x, int):
Expand All @@ -189,20 +193,59 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
# TODO: incorporate synchronizer to allow writes from multiple dask
# threads
if var_chunks and enc_chunks_tuple:
for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True):
for dchunk in dchunks[:-1]:
# If it is possible to write on partial chunks then it is not necessary to check
# the last one contained on the region
allow_partial_chunks = mode != "r+"

base_error = (
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r} "
f"on the region {region}. "
f"Writing this array in parallel with dask could lead to corrupted data."
f"Consider either rechunking using `chunk()`, deleting "
f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)

for zchunk, dchunks, interval, size in zip(
enc_chunks_tuple, var_chunks, region, shape, strict=True
):
if not safe_chunks:
continue

for dchunk in dchunks[1:-1]:
if dchunk % zchunk:
base_error = (
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
f"Writing this array in parallel with dask could lead to corrupted data."
)
if safe_chunks:
raise ValueError(
base_error
+ " Consider either rechunking using `chunk()`, deleting "
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)
raise ValueError(base_error)

region_start = interval.start if interval.start else 0

if len(dchunks) > 1:
# The first border size is the amount of data that needs to be updated on the
# first chunk taking into account the region slice.
first_border_size = zchunk
if allow_partial_chunks:
first_border_size = zchunk - region_start % zchunk

if (dchunks[0] - first_border_size) % zchunk:
raise ValueError(base_error)

if not allow_partial_chunks:
chunk_start = sum(dchunks[:-1]) + region_start
if chunk_start % zchunk:
# The last chunk which can also be the only one is a partial chunk
# if it is not aligned at the beginning
raise ValueError(base_error)

region_stop = interval.stop if interval.stop else size

if size - region_stop + 1 < zchunk:
# If the region is covering the last chunk then check
# if the reminder with the default chunk size
# is equal to the size of the last chunk
if dchunks[-1] % zchunk != size % zchunk:
raise ValueError(base_error)
elif dchunks[-1] % zchunk:
raise ValueError(base_error)

return enc_chunks_tuple

raise AssertionError("We should never get here. Function logic must be wrong.")
Expand Down Expand Up @@ -243,7 +286,14 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):


def extract_zarr_variable_encoding(
variable, raise_on_invalid=False, name=None, safe_chunks=True
variable,
raise_on_invalid=False,
name=None,
*,
safe_chunks=True,
region=None,
mode=None,
shape=None,
):
"""
Extract zarr encoding dictionary from xarray Variable
Expand All @@ -252,12 +302,18 @@ def extract_zarr_variable_encoding(
----------
variable : Variable
raise_on_invalid : bool, optional
name: str | Hashable, optional
safe_chunks: bool, optional
region: tuple[slice, ...], optional
mode: str, optional
shape: tuple[int, ...], optional
Returns
-------
encoding : dict
Zarr encoding for `variable`
"""

shape = shape if shape else variable.shape
encoding = variable.encoding.copy()

safe_to_drop = {"source", "original_shape"}
Expand Down Expand Up @@ -285,7 +341,14 @@ def extract_zarr_variable_encoding(
del encoding[k]

chunks = _determine_zarr_chunks(
encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks
enc_chunks=encoding.get("chunks"),
var_chunks=variable.chunks,
ndim=variable.ndim,
name=name,
safe_chunks=safe_chunks,
region=region,
mode=mode,
shape=shape,
)
encoding["chunks"] = chunks
return encoding
Expand Down Expand Up @@ -762,16 +825,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
)
zarr_array = None
zarr_shape = None
write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

if name in existing_keys:
# existing variable
Expand Down Expand Up @@ -801,7 +858,40 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
)
else:
zarr_array = self.zarr_group[name]
else:

if self._append_dim is not None and self._append_dim in dims:
# resize existing variable
append_axis = dims.index(self._append_dim)
assert write_region[self._append_dim] == slice(None)
write_region[self._append_dim] = slice(
zarr_array.shape[append_axis], None
)

new_shape = list(zarr_array.shape)
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

zarr_shape = zarr_array.shape

region = tuple(write_region[dim] for dim in dims)

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
# Note: Ideally there should be two functions, one for validating the chunks and
# another one for extracting the encoding.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
region=region,
mode=self._mode,
shape=zarr_shape,
)

if name not in existing_keys:
# new variable
encoded_attrs = {}
# the magic for storing the hidden dimension data
Expand Down Expand Up @@ -833,22 +923,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
)
zarr_array = _put_attrs(zarr_array, encoded_attrs)

write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

if self._append_dim is not None and self._append_dim in dims:
# resize existing variable
append_axis = dims.index(self._append_dim)
assert write_region[self._append_dim] == slice(None)
write_region[self._append_dim] = slice(
zarr_array.shape[append_axis], None
)

new_shape = list(zarr_array.shape)
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

region = tuple(write_region[dim] for dim in dims)
writer.add(v.data, zarr_array, region)

def close(self) -> None:
Expand Down Expand Up @@ -897,9 +971,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")
if any(v == "auto" for v in region.values()):
if self._mode != "r+":
if self._mode not in ["r+", "a"]:
raise ValueError(
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got {self._mode!r}"
)
region = self._auto_detect_regions(ds, region)

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def concat(
except StopIteration as err:
raise ValueError("must supply at least one object to concatenate") from err

if compat not in _VALID_COMPAT:
if compat not in set(_VALID_COMPAT) - {"minimal"}:
raise ValueError(
f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'"
)
Expand Down
24 changes: 22 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
create_coords_with_default_indexes,
)
from xarray.core.dataset import Dataset
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.formatting import format_item
from xarray.core.indexes import (
Index,
Expand Down Expand Up @@ -3857,7 +3858,11 @@ def to_pandas(self) -> Self | pd.Series | pd.DataFrame:
"""
# TODO: consolidate the info about pandas constructors and the
# attributes that correspond to their indexes into a separate module?
constructors = {0: lambda x: x, 1: pd.Series, 2: pd.DataFrame}
constructors: dict[int, Callable] = {
0: lambda x: x,
1: pd.Series,
2: pd.DataFrame,
}
try:
constructor = constructors[self.ndim]
except KeyError as err:
Expand All @@ -3866,7 +3871,14 @@ def to_pandas(self) -> Self | pd.Series | pd.DataFrame:
"pandas objects. Requires 2 or fewer dimensions."
) from err
indexes = [self.get_index(dim) for dim in self.dims]
return constructor(self.values, *indexes) # type: ignore[operator]
if isinstance(self._variable._data, PandasExtensionArray):
values = self._variable._data.array
else:
values = self.values
pandas_object = constructor(values, *indexes)
if isinstance(pandas_object, pd.Series):
pandas_object.name = self.name
return pandas_object

def to_dataframe(
self, name: Hashable | None = None, dim_order: Sequence[Hashable] | None = None
Expand Down Expand Up @@ -4310,6 +4322,14 @@ def to_zarr(
if Zarr arrays are written in parallel. This option may be useful in combination
with ``compute=False`` to initialize a Zarr store from an existing
DataArray with arbitrary chunk structure.
In addition to the many-to-one relationship validation, it also detects partial
chunks writes when using the region parameter,
these partial chunks are considered unsafe in the mode "r+" but safe in
the mode "a".
Note: Even with these validations it can still be unsafe to write
two or more chunked arrays in the same location in parallel if they are
not writing in independent regions, for those cases it is better to use
a synchronizer.
storage_options : dict, optional
Any additional parameters for the storage backend (ignored for local
paths).
Expand Down
8 changes: 8 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,6 +2515,14 @@ def to_zarr(
if Zarr arrays are written in parallel. This option may be useful in combination
with ``compute=False`` to initialize a Zarr from an existing
Dataset with arbitrary chunk structure.
In addition to the many-to-one relationship validation, it also detects partial
chunks writes when using the region parameter,
these partial chunks are considered unsafe in the mode "r+" but safe in
the mode "a".
Note: Even with these validations it can still be unsafe to write
two or more chunked arrays in the same location in parallel if they are
not writing in independent regions, for those cases it is better to use
a synchronizer.
storage_options : dict, optional
Any additional parameters for the storage backend (ignored for local
paths).
Expand Down
13 changes: 10 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,13 @@ def as_compatible_data(
if isinstance(data, DataArray):
return cast("T_DuckArray", data._variable._data)

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
def convert_non_numpy_type(data):
data = _possibly_convert_datetime_or_timedelta_index(data)
return cast("T_DuckArray", _maybe_wrap_data(data))

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
return convert_non_numpy_type(data)

if isinstance(data, tuple):
data = utils.to_0d_object_array(data)

Expand All @@ -303,7 +306,11 @@ def as_compatible_data(

# we don't want nested self-described arrays
if isinstance(data, pd.Series | pd.DataFrame):
data = data.values # type: ignore[assignment]
pandas_data = data.values
if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
return convert_non_numpy_type(pandas_data)
else:
data = pandas_data

if isinstance(data, np.ma.MaskedArray):
mask = np.ma.getmaskarray(data)
Expand Down Expand Up @@ -540,7 +547,7 @@ def _dask_finalize(self, results, array_func, *args, **kwargs):
return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding)

@property
def values(self):
def values(self) -> np.ndarray:
"""The variable's data as a numpy.ndarray"""
return _as_array_or_item(self._data)

Expand Down
Loading

0 comments on commit 98e5071

Please sign in to comment.