From e8a3c358bb9e607f7de98f7e7b47a1dda0be3c21 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 28 Sep 2024 15:12:30 -0700 Subject: [PATCH 1/7] Stop inheriting non-indexed coordinates for DataTree This is option (4) from https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 --- xarray/core/datatree.py | 8 +++++++- xarray/tests/test_datatree.py | 36 ++++++++++++++++++----------------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 52d44bec96f..6d9a5a6d8d0 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -438,6 +438,7 @@ class DataTree( _cache: dict[str, Any] # used by _CachedAccessor _data_variables: dict[Hashable, Variable] _node_coord_variables: dict[Hashable, Variable] + _node_indexed_coord_variables: dict[Hashable, Variable] _node_dims: dict[Hashable, int] _node_indexes: dict[Hashable, Index] _attrs: dict[Hashable, Any] | None @@ -451,6 +452,7 @@ class DataTree( "_cache", # used by _CachedAccessor "_data_variables", "_node_coord_variables", + "_node_indexed_coord_variables", "_node_dims", "_node_indexes", "_attrs", @@ -498,6 +500,9 @@ def _set_node_data(self, dataset: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(dataset) self._data_variables = data_vars self._node_coord_variables = coord_vars + self._node_indexed_coord_variables = { + k: coord_vars[k] for k in dataset._indexes + } self._node_dims = dataset._dims self._node_indexes = dataset._indexes self._encoding = dataset._encoding @@ -519,7 +524,8 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: return ChainMap( - self._node_coord_variables, *(p._node_coord_variables for p in self.parents) + self._node_coord_variables, + *(p._node_indexed_coord_variables for p in self.parents), ) @property diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 30934f83c63..3365e493090 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -216,16 +216,16 @@ def test_is_hollow(self): class TestToDataset: - def test_to_dataset(self): - base = xr.Dataset(coords={"a": 1}) - sub = xr.Dataset(coords={"b": 2}) + def test_to_dataset_inherited(self): + base = xr.Dataset(coords={"a": [1], "b": 2}) + sub = xr.Dataset(coords={"c": [3]}) tree = DataTree.from_dict({"/": base, "/sub": sub}) subtree = typing.cast(DataTree, tree["sub"]) assert_identical(tree.to_dataset(inherited=False), base) assert_identical(subtree.to_dataset(inherited=False), sub) - sub_and_base = xr.Dataset(coords={"a": 1, "b": 2}) + sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]}) # no "b" assert_identical(tree.to_dataset(inherited=True), base) assert_identical(subtree.to_dataset(inherited=True), sub_and_base) @@ -714,7 +714,8 @@ def test_inherited(self): dt["child"] = DataTree() child = dt["child"] - assert set(child.coords) == {"x", "y", "a", "b"} + assert set(dt.coords) == {"x", "y", "a", "b"} + assert set(child.coords) == {"x", "y"} actual = child.copy(deep=True) actual.coords["x"] = ("x", ["a", "b"]) @@ -729,7 +730,7 @@ def test_inherited(self): with pytest.raises(KeyError): # cannot delete inherited coordinate from child node - del child["b"] + del child["x"] # TODO requires a fix for #9472 # actual = child.copy(deep=True) @@ -1278,22 +1279,23 @@ def test_inherited_coords_index(self): assert "x" in dt["/b"].coords xr.testing.assert_identical(dt["/x"], dt["/b/x"]) - def test_inherited_coords_override(self): + def test_inherit_only_index_coords(self): dt = DataTree.from_dict( { - "/": xr.Dataset(coords={"x": 1, "y": 2}), - "/b": xr.Dataset(coords={"x": 4, "z": 3}), + "/": xr.Dataset(coords={"x": [1], "y": 2}), + "/b": xr.Dataset(coords={"z": 3}), } ) assert dt.coords.keys() == {"x", "y"} - root_coords = {"x": 1, "y": 2} - sub_coords = {"x": 4, "y": 2, "z": 3} - xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords)) - xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords)) - assert dt["/b"].coords.keys() == {"x", "y", "z"} - xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords)) - xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) - xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) + xr.testing.assert_equal( + dt["/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "y": 2}) + ) + xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords={"y": 2})) + assert dt["/b"].coords.keys() == {"x", "z"} + xr.testing.assert_equal( + dt["/b/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "z": 3}) + ) + xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords={"z": 3})) def test_inherited_coords_with_index_are_deduplicated(self): dt = DataTree.from_dict( From 2709120f05e20ae8c806cb6f05ba57a23dc010b7 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 5 Oct 2024 16:32:09 -0700 Subject: [PATCH 2/7] alternative implementaiton --- xarray/core/dataarray.py | 7 +++---- xarray/core/dataset.py | 6 +++--- xarray/core/datatree.py | 18 ++++++++---------- xarray/core/utils.py | 17 +++++++++++------ 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2adf862f1fd..ab99044ab98 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -69,7 +69,7 @@ ) from xarray.core.utils import ( Default, - HybridMappingProxy, + FilteredMapping, ReprObject, _default, either_dict_or_kwargs, @@ -929,11 +929,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: @property def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: """Places to look-up items for key-completion""" - yield HybridMappingProxy(keys=self._coords, mapping=self.coords) + yield FilteredMapping(keys=self._coords, mapping=self.coords) # virtual coordinates - # uses empty dict -- everything here can already be found in self.coords. - yield HybridMappingProxy(keys=self.dims, mapping={}) + yield FilteredMapping(keys=self.dims, mapping=self.coords) def __contains__(self, key: Any) -> bool: return key in self.data diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 82b60d7abc8..59eb05f0b07 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -103,9 +103,9 @@ ) from xarray.core.utils import ( Default, + FilteredMapping, Frozen, FrozenMappingWarningOnValuesAccess, - HybridMappingProxy, OrderedSet, _default, decode_numpy_dict_values, @@ -1507,10 +1507,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: """Places to look-up items for key-completion""" yield self.data_vars - yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) + yield FilteredMapping(keys=self._coord_names, mapping=self.coords) # virtual coordinates - yield HybridMappingProxy(keys=self.sizes, mapping=self) + yield FilteredMapping(keys=self.sizes, mapping=self) def __contains__(self, key: object) -> bool: """The 'in' operator will return true or false depending on whether diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 6d9a5a6d8d0..bebfb6324b1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -40,8 +40,8 @@ from xarray.core.treenode import NamedNode, NodePath from xarray.core.utils import ( Default, + FilteredMapping, Frozen, - HybridMappingProxy, _default, either_dict_or_kwargs, maybe_wrap_array, @@ -438,7 +438,6 @@ class DataTree( _cache: dict[str, Any] # used by _CachedAccessor _data_variables: dict[Hashable, Variable] _node_coord_variables: dict[Hashable, Variable] - _node_indexed_coord_variables: dict[Hashable, Variable] _node_dims: dict[Hashable, int] _node_indexes: dict[Hashable, Index] _attrs: dict[Hashable, Any] | None @@ -452,7 +451,6 @@ class DataTree( "_cache", # used by _CachedAccessor "_data_variables", "_node_coord_variables", - "_node_indexed_coord_variables", "_node_dims", "_node_indexes", "_attrs", @@ -500,9 +498,6 @@ def _set_node_data(self, dataset: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(dataset) self._data_variables = data_vars self._node_coord_variables = coord_vars - self._node_indexed_coord_variables = { - k: coord_vars[k] for k in dataset._indexes - } self._node_dims = dataset._dims self._node_indexes = dataset._indexes self._encoding = dataset._encoding @@ -525,7 +520,10 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: def _coord_variables(self) -> ChainMap[Hashable, Variable]: return ChainMap( self._node_coord_variables, - *(p._node_indexed_coord_variables for p in self.parents), + *( + FilteredMapping(keys=p._node_indexes, mapping=p._node_coord_variables) + for p in self.parents + ), ) @property @@ -665,7 +663,7 @@ def variables(self) -> Mapping[Hashable, Variable]: Dataset invariants. It contains all variable objects constituting this DataTree node, including both data variables and coordinates. """ - return Frozen(self._data_variables | self._coord_variables) + return Frozen(dict(**self._data_variables, **self._coord_variables)) @property def attrs(self) -> dict[Hashable, Any]: @@ -726,10 +724,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: def _item_sources(self) -> Iterable[Mapping[Any, Any]]: """Places to look-up items for key-completion""" yield self.data_vars - yield HybridMappingProxy(keys=self._coord_variables, mapping=self.coords) + yield FilteredMapping(keys=self._coord_variables, mapping=self.coords) # virtual coordinates - yield HybridMappingProxy(keys=self.dims, mapping=self) + yield FilteredMapping(keys=self.dims, mapping=self) # immediate child nodes yield self.children diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 3c1dee7a36d..ff8bceb2b32 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -465,7 +465,7 @@ def values(self) -> ValuesView[V]: return super().values() -class HybridMappingProxy(Mapping[K, V]): +class FilteredMapping(Mapping[K, V]): """Implements the Mapping interface. Uses the wrapped mapping for item lookup and a separate wrapped keys collection for iteration. @@ -473,25 +473,30 @@ class HybridMappingProxy(Mapping[K, V]): eagerly accessing its items or when a mapping object is expected but only iteration over keys is actually used. - Note: HybridMappingProxy does not validate consistency of the provided `keys` + Note: FilteredMapping does not validate consistency of the provided `keys` and `mapping`. It is the caller's responsibility to ensure that they are suitable for the task at hand. """ - __slots__ = ("_keys", "mapping") + __slots__ = ("keys", "mapping") def __init__(self, keys: Collection[K], mapping: Mapping[K, V]): - self._keys = keys + self.keys = keys self.mapping = mapping def __getitem__(self, key: K) -> V: + if key not in self.keys: + raise KeyError(key) return self.mapping[key] def __iter__(self) -> Iterator[K]: - return iter(self._keys) + return iter(self.keys) def __len__(self) -> int: - return len(self._keys) + return len(self.keys) + + def __repr__(self) -> str: + return f"{type(self).__name__}(keys={self.keys!r}, mapping={self.mapping!r})" class OrderedSet(MutableSet[T]): From 4962ccae0e58779a0db7a24e56156754bd442654 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 5 Oct 2024 16:32:59 -0700 Subject: [PATCH 3/7] docstring --- xarray/core/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ff8bceb2b32..d3504f4400c 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -473,9 +473,10 @@ class FilteredMapping(Mapping[K, V]): eagerly accessing its items or when a mapping object is expected but only iteration over keys is actually used. - Note: FilteredMapping does not validate consistency of the provided `keys` - and `mapping`. It is the caller's responsibility to ensure that they are - suitable for the task at hand. + Note: keys should be a subset of mapping, but FilteredMapping does not + validate consistency of the provided `keys` and `mapping`. It is the + caller's responsibility to ensure that they are suitable for the task at + hand. """ __slots__ = ("keys", "mapping") From 8ee900146a2fa0ebbcdbc25ae99037fe06b0289a Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 5 Oct 2024 17:50:42 -0700 Subject: [PATCH 4/7] fix type error --- xarray/core/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index d3504f4400c..e5168342e1e 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -479,25 +479,25 @@ class FilteredMapping(Mapping[K, V]): hand. """ - __slots__ = ("keys", "mapping") + __slots__ = ("keys_", "mapping") def __init__(self, keys: Collection[K], mapping: Mapping[K, V]): - self.keys = keys + self.keys_ = keys # .keys is already a property on Mapping self.mapping = mapping def __getitem__(self, key: K) -> V: - if key not in self.keys: + if key not in self.keys_: raise KeyError(key) return self.mapping[key] def __iter__(self) -> Iterator[K]: - return iter(self.keys) + return iter(self.keys_) def __len__(self) -> int: - return len(self.keys) + return len(self.keys_) def __repr__(self) -> str: - return f"{type(self).__name__}(keys={self.keys!r}, mapping={self.mapping!r})" + return f"{type(self).__name__}(keys={self.keys_!r}, mapping={self.mapping!r})" class OrderedSet(MutableSet[T]): From d5f9cb3c84c87b7110747359435490d3f9cc2ea2 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 5 Oct 2024 17:59:52 -0700 Subject: [PATCH 5/7] refactor --- xarray/core/datatree.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bebfb6324b1..bff5c2f8d7f 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -516,14 +516,17 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: check_alignment(path, node_ds, parent_ds, self.children) _deduplicate_inherited_coordinates(self, parent) + @property + def _node_coord_variables_with_index(self) -> Mapping[Hashable, Variable]: + return FilteredMapping( + keys=self._node_indexes, mapping=self._node_coord_variables + ) + @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: return ChainMap( self._node_coord_variables, - *( - FilteredMapping(keys=p._node_indexes, mapping=p._node_coord_variables) - for p in self.parents - ), + *(p._node_coord_variables_with_index for p in self.parents), ) @property From 568a97252c1d475865b02dbdd2a5c38735bcb8b3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 5 Oct 2024 18:55:55 -0700 Subject: [PATCH 6/7] revert variables --- xarray/core/datatree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bff5c2f8d7f..82769fcdda1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -666,7 +666,7 @@ def variables(self) -> Mapping[Hashable, Variable]: Dataset invariants. It contains all variable objects constituting this DataTree node, including both data variables and coordinates. """ - return Frozen(dict(**self._data_variables, **self._coord_variables)) + return Frozen(self._data_variables | self._coord_variables) @property def attrs(self) -> dict[Hashable, Any]: From 3bb6a47dd88340da5bd11e4075650e8e4954175e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 5 Oct 2024 19:01:56 -0700 Subject: [PATCH 7/7] tests for FilteredMapping --- xarray/tests/test_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 86e34d151a8..9ef4a688302 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -139,6 +139,16 @@ def test_frozen(self): "Frozen({'b': 'B', 'a': 'A'})", ) + def test_filtered(self): + x = utils.FilteredMapping(keys={"a"}, mapping={"a": 1, "b": 2}) + assert "a" in x + assert "b" not in x + assert x["a"] == 1 + assert list(x) == ["a"] + assert len(x) == 1 + assert repr(x) == "FilteredMapping(keys={'a'}, mapping={'a': 1, 'b': 2})" + assert dict(x) == {"a": 1} + def test_repr_object(): obj = utils.ReprObject("foo")