Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop inheriting non-indexed coordinates for DataTree #9555

Merged
merged 8 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
)
from xarray.core.utils import (
Default,
HybridMappingProxy,
FilteredMapping,
ReprObject,
_default,
either_dict_or_kwargs,
Expand Down Expand Up @@ -929,11 +929,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
@property
def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for key-completion"""
yield HybridMappingProxy(keys=self._coords, mapping=self.coords)
yield FilteredMapping(keys=self._coords, mapping=self.coords)

# virtual coordinates
# uses empty dict -- everything here can already be found in self.coords.
yield HybridMappingProxy(keys=self.dims, mapping={})
yield FilteredMapping(keys=self.dims, mapping=self.coords)

def __contains__(self, key: Any) -> bool:
return key in self.data
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@
)
from xarray.core.utils import (
Default,
FilteredMapping,
Frozen,
FrozenMappingWarningOnValuesAccess,
HybridMappingProxy,
OrderedSet,
_default,
decode_numpy_dict_values,
Expand Down Expand Up @@ -1507,10 +1507,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for key-completion"""
yield self.data_vars
yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords)
yield FilteredMapping(keys=self._coord_names, mapping=self.coords)

# virtual coordinates
yield HybridMappingProxy(keys=self.sizes, mapping=self)
yield FilteredMapping(keys=self.sizes, mapping=self)

def __contains__(self, key: object) -> bool:
"""The 'in' operator will return true or false depending on whether
Expand Down
15 changes: 11 additions & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from xarray.core.treenode import NamedNode, NodePath
from xarray.core.utils import (
Default,
FilteredMapping,
Frozen,
HybridMappingProxy,
_default,
either_dict_or_kwargs,
maybe_wrap_array,
Expand Down Expand Up @@ -516,10 +516,17 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
check_alignment(path, node_ds, parent_ds, self.children)
_deduplicate_inherited_coordinates(self, parent)

@property
def _node_coord_variables_with_index(self) -> Mapping[Hashable, Variable]:
return FilteredMapping(
keys=self._node_indexes, mapping=self._node_coord_variables
)

@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
self._node_coord_variables, *(p._node_coord_variables for p in self.parents)
self._node_coord_variables,
*(p._node_coord_variables_with_index for p in self.parents),
Comment on lines +519 to +529
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a much neater implementation, nice!!

)

@property
Expand Down Expand Up @@ -720,10 +727,10 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
def _item_sources(self) -> Iterable[Mapping[Any, Any]]:
"""Places to look-up items for key-completion"""
yield self.data_vars
yield HybridMappingProxy(keys=self._coord_variables, mapping=self.coords)
yield FilteredMapping(keys=self._coord_variables, mapping=self.coords)

# virtual coordinates
yield HybridMappingProxy(keys=self.dims, mapping=self)
yield FilteredMapping(keys=self.dims, mapping=self)

# immediate child nodes
yield self.children
Expand Down
22 changes: 14 additions & 8 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,33 +465,39 @@ def values(self) -> ValuesView[V]:
return super().values()


class HybridMappingProxy(Mapping[K, V]):
class FilteredMapping(Mapping[K, V]):
"""Implements the Mapping interface. Uses the wrapped mapping for item lookup
and a separate wrapped keys collection for iteration.

Can be used to construct a mapping object from another dict-like object without
eagerly accessing its items or when a mapping object is expected but only
iteration over keys is actually used.

Note: HybridMappingProxy does not validate consistency of the provided `keys`
and `mapping`. It is the caller's responsibility to ensure that they are
suitable for the task at hand.
Note: keys should be a subset of mapping, but FilteredMapping does not
validate consistency of the provided `keys` and `mapping`. It is the
caller's responsibility to ensure that they are suitable for the task at
hand.
"""

__slots__ = ("_keys", "mapping")
__slots__ = ("keys_", "mapping")

def __init__(self, keys: Collection[K], mapping: Mapping[K, V]):
self._keys = keys
self.keys_ = keys # .keys is already a property on Mapping
self.mapping = mapping

def __getitem__(self, key: K) -> V:
if key not in self.keys_:
raise KeyError(key)
return self.mapping[key]

def __iter__(self) -> Iterator[K]:
return iter(self._keys)
return iter(self.keys_)

def __len__(self) -> int:
return len(self._keys)
return len(self.keys_)

def __repr__(self) -> str:
return f"{type(self).__name__}(keys={self.keys_!r}, mapping={self.mapping!r})"


class OrderedSet(MutableSet[T]):
Expand Down
36 changes: 19 additions & 17 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,16 @@ def test_is_hollow(self):


class TestToDataset:
def test_to_dataset(self):
base = xr.Dataset(coords={"a": 1})
sub = xr.Dataset(coords={"b": 2})
def test_to_dataset_inherited(self):
base = xr.Dataset(coords={"a": [1], "b": 2})
sub = xr.Dataset(coords={"c": [3]})
tree = DataTree.from_dict({"/": base, "/sub": sub})
subtree = typing.cast(DataTree, tree["sub"])

assert_identical(tree.to_dataset(inherited=False), base)
assert_identical(subtree.to_dataset(inherited=False), sub)

sub_and_base = xr.Dataset(coords={"a": 1, "b": 2})
sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]}) # no "b"
assert_identical(tree.to_dataset(inherited=True), base)
assert_identical(subtree.to_dataset(inherited=True), sub_and_base)

Expand Down Expand Up @@ -714,7 +714,8 @@ def test_inherited(self):
dt["child"] = DataTree()
child = dt["child"]

assert set(child.coords) == {"x", "y", "a", "b"}
assert set(dt.coords) == {"x", "y", "a", "b"}
assert set(child.coords) == {"x", "y"}

actual = child.copy(deep=True)
actual.coords["x"] = ("x", ["a", "b"])
Expand All @@ -729,7 +730,7 @@ def test_inherited(self):

with pytest.raises(KeyError):
# cannot delete inherited coordinate from child node
del child["b"]
del child["x"]

# TODO requires a fix for #9472
# actual = child.copy(deep=True)
Expand Down Expand Up @@ -1278,22 +1279,23 @@ def test_inherited_coords_index(self):
assert "x" in dt["/b"].coords
xr.testing.assert_identical(dt["/x"], dt["/b/x"])

def test_inherited_coords_override(self):
def test_inherit_only_index_coords(self):
dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1, "y": 2}),
"/b": xr.Dataset(coords={"x": 4, "z": 3}),
"/": xr.Dataset(coords={"x": [1], "y": 2}),
"/b": xr.Dataset(coords={"z": 3}),
Comment on lines +1285 to +1286
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we're planning to eventually change the API of the Dataset constructor to make it more explicit which coordinates have indexes, it might be nice to at least add comments here to describe the intention.

(sort of related to #8959)

}
)
assert dt.coords.keys() == {"x", "y"}
root_coords = {"x": 1, "y": 2}
sub_coords = {"x": 4, "y": 2, "z": 3}
xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords))
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords))
assert dt["/b"].coords.keys() == {"x", "y", "z"}
xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords))
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))
xr.testing.assert_equal(
dt["/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "y": 2})
)
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords={"y": 2}))
assert dt["/b"].coords.keys() == {"x", "z"}
xr.testing.assert_equal(
dt["/b/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "z": 3})
)
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords={"z": 3}))

def test_inherited_coords_with_index_are_deduplicated(self):
dt = DataTree.from_dict(
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ def test_frozen(self):
"Frozen({'b': 'B', 'a': 'A'})",
)

def test_filtered(self):
x = utils.FilteredMapping(keys={"a"}, mapping={"a": 1, "b": 2})
assert "a" in x
assert "b" not in x
assert x["a"] == 1
assert list(x) == ["a"]
assert len(x) == 1
assert repr(x) == "FilteredMapping(keys={'a'}, mapping={'a': 1, 'b': 2})"
assert dict(x) == {"a": 1}


def test_repr_object():
obj = utils.ReprObject("foo")
Expand Down
Loading