Skip to content

Commit

Permalink
Merge branch 'main' into force-scalars-to-ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
keewis authored Sep 3, 2024
2 parents 2fa2a2c + a8c9896 commit 42f7e8c
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pypi-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ jobs:
path: dist
- name: Publish package to TestPyPI
if: github.event_name == 'push'
uses: pypa/gh-action-pypi-publish@v1.9.0
uses: pypa/gh-action-pypi-publish@v1.10.0
with:
repository_url: https://test.pypi.org/legacy/
verbose: true
Expand All @@ -111,6 +111,6 @@ jobs:
name: releases
path: dist
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@v1.9.0
uses: pypa/gh-action-pypi-publish@v1.10.0
with:
verbose: true
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: mixed-line-ending
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: 'v0.6.2'
rev: 'v0.6.3'
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
Expand Down
25 changes: 21 additions & 4 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def check_name(name: Hashable):
check_name(k)


def _validate_attrs(dataset, invalid_netcdf=False):
def _validate_attrs(dataset, engine, invalid_netcdf=False):
"""`attrs` must have a string key and a value which is either: a number,
a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_.
Expand All @@ -177,8 +177,8 @@ def _validate_attrs(dataset, invalid_netcdf=False):
`invalid_netcdf=True`.
"""

valid_types = (str, Number, np.ndarray, np.number, list, tuple)
if invalid_netcdf:
valid_types = (str, Number, np.ndarray, np.number, list, tuple, bytes)
if invalid_netcdf and engine == "h5netcdf":
valid_types += (np.bool_,)

def check_attr(name, value, valid_types):
Expand All @@ -202,6 +202,23 @@ def check_attr(name, value, valid_types):
f"{', '.join([vtype.__name__ for vtype in valid_types])}"
)

if isinstance(value, bytes) and engine == "h5netcdf":
try:
value.decode("utf-8")
except UnicodeDecodeError as e:
raise ValueError(
f"Invalid value provided for attribute '{name!r}': {value!r}. "
"Only binary data derived from UTF-8 encoded strings is allowed "
f"for the '{engine}' engine. Consider using the 'netcdf4' engine."
) from e

if b"\x00" in value:
raise ValueError(
f"Invalid value provided for attribute '{name!r}': {value!r}. "
f"Null characters are not permitted for the '{engine}' engine. "
"Consider using the 'netcdf4' engine."
)

# Check attrs on the dataset itself
for k, v in dataset.attrs.items():
check_attr(k, v, valid_types)
Expand Down Expand Up @@ -1353,7 +1370,7 @@ def to_netcdf(

# validate Dataset keys, DataArray names, and attr keys/values
_validate_dataset_names(dataset)
_validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf")
_validate_attrs(dataset, engine, invalid_netcdf)

try:
store_open = WRITEABLE_STORES[engine]
Expand Down
16 changes: 13 additions & 3 deletions xarray/core/accessor_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
is_np_timedelta_like,
)
from xarray.core.types import T_DataArray
from xarray.core.variable import IndexVariable
from xarray.core.variable import IndexVariable, Variable
from xarray.namedarray.utils import is_duck_dask_array

if TYPE_CHECKING:
Expand Down Expand Up @@ -244,12 +244,22 @@ def _date_field(self, name: str, dtype: DTypeLike) -> T_DataArray:
if dtype is None:
dtype = self._obj.dtype
result = _get_date_field(_index_or_data(self._obj), name, dtype)
newvar = self._obj.variable.copy(data=result, deep=False)
newvar = Variable(
dims=self._obj.dims,
attrs=self._obj.attrs,
encoding=self._obj.encoding,
data=result,
)
return self._obj._replace(newvar, name=name)

def _tslib_round_accessor(self, name: str, freq: str) -> T_DataArray:
result = _round_field(_index_or_data(self._obj), name, freq)
newvar = self._obj.variable.copy(data=result, deep=False)
newvar = Variable(
dims=self._obj.dims,
attrs=self._obj.attrs,
encoding=self._obj.encoding,
data=result,
)
return self._obj._replace(newvar, name=name)

def floor(self, freq: str) -> T_DataArray:
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,7 @@ def _unstack_once(
dim: Hashable,
fill_value=dtypes.NA,
sparse: bool = False,
) -> Self:
) -> Variable:
"""
Unstacks this variable given an index to unstack and the name of the
dimension to which the index refers.
Expand Down Expand Up @@ -1551,10 +1551,10 @@ def _unstack_once(
# case the destinations will be NaN / zero.
data[(..., *indexer)] = reordered

return self._replace(dims=new_dims, data=data)
return self.to_base_variable()._replace(dims=new_dims, data=data)

@partial(deprecate_dims, old_name="dimensions")
def unstack(self, dim=None, **dim_kwargs):
def unstack(self, dim=None, **dim_kwargs) -> Variable:
"""
Unstack an existing dimension into multiple new dimensions.
Expand Down
4 changes: 3 additions & 1 deletion xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def assert_identical(a, b, from_root=True):
if isinstance(a, Variable):
assert a.identical(b), formatting.diff_array_repr(a, b, "identical")
elif isinstance(a, DataArray):
assert a.name == b.name
assert (
a.name == b.name
), f"DataArray names are different. L: {a.name}, R: {b.name}"
assert a.identical(b), formatting.diff_array_repr(a, b, "identical")
elif isinstance(a, Dataset | Variable):
assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical")
Expand Down
20 changes: 20 additions & 0 deletions xarray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,26 @@ def d(request, backend, type) -> DataArray | Dataset:
raise ValueError


@pytest.fixture
def byte_attrs_dataset():
"""For testing issue #9407"""
null_byte = b"\x00"
other_bytes = bytes(range(1, 256))
ds = Dataset({"x": 1}, coords={"x_coord": [1]})
ds["x"].attrs["null_byte"] = null_byte
ds["x"].attrs["other_bytes"] = other_bytes

expected = ds.copy()
expected["x"].attrs["null_byte"] = ""
expected["x"].attrs["other_bytes"] = other_bytes.decode(errors="replace")

return {
"input": ds,
"expected": expected,
"h5netcdf_error": r"Invalid value provided for attribute .*: .*\. Null characters .*",
}


@pytest.fixture(scope="module")
def create_test_datatree():
"""
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/test_accessor_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_field_access(self, field) -> None:
actual = getattr(self.data.time.dt, field)
else:
actual = getattr(self.data.time.dt, field)
assert not isinstance(actual.variable, xr.IndexVariable)

assert expected.dtype == actual.dtype
assert_identical(expected, actual)
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,13 @@ def test_refresh_from_disk(self) -> None:
a.close()
b.close()

def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None:
# test for issue #9407
input = byte_attrs_dataset["input"]
expected = byte_attrs_dataset["expected"]
with self.roundtrip(input) as actual:
assert_identical(actual, expected)


_counter = itertools.count()

Expand Down Expand Up @@ -3861,6 +3868,10 @@ def test_decode_utf8_warning(self) -> None:
assert ds.title == title
assert "attribute 'title' of h5netcdf object '/'" in str(w[0].message)

def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None:
with pytest.raises(ValueError, match=byte_attrs_dataset["h5netcdf_error"]):
super().test_byte_attrs(byte_attrs_dataset)


@requires_h5netcdf
@requires_netCDF4
Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7202,3 +7202,17 @@ def test_lazy_data_variable_not_loaded():
da = xr.DataArray(v)
# No data needs to be accessed, so no error should be raised
xr.DataArray(da)


def test_unstack_index_var() -> None:
source = xr.DataArray(range(2), dims=["x"], coords=[["a", "b"]])
da = source.x
da = da.assign_coords(y=("x", ["c", "d"]), z=("x", ["e", "f"]))
da = da.set_index(x=["y", "z"])
actual = da.unstack("x")
expected = xr.DataArray(
np.array([["a", np.nan], [np.nan, "b"]], dtype=object),
coords={"y": ["c", "d"], "z": ["e", "f"]},
name="x",
)
assert_identical(actual, expected)
2 changes: 1 addition & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ def test_chunk_by_frequency(self, freq: str, calendar: str, add_gap: bool) -> No
ΔN = 28
time = xr.date_range(
"2001-01-01", periods=N + ΔN, freq="D", calendar=calendar
).to_numpy()
).to_numpy(copy=True)
if add_gap:
# introduce an empty bin
time[31 : 31 + ΔN] = np.datetime64("NaT")
Expand Down

0 comments on commit 42f7e8c

Please sign in to comment.