Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop inheriting non-indexed coordinates for DataTree #9555

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ class DataTree(
_cache: dict[str, Any] # used by _CachedAccessor
_data_variables: dict[Hashable, Variable]
_node_coord_variables: dict[Hashable, Variable]
_node_indexed_coord_variables: dict[Hashable, Variable]
_node_dims: dict[Hashable, int]
_node_indexes: dict[Hashable, Index]
_attrs: dict[Hashable, Any] | None
Expand All @@ -451,6 +452,7 @@ class DataTree(
"_cache", # used by _CachedAccessor
"_data_variables",
"_node_coord_variables",
"_node_indexed_coord_variables",
"_node_dims",
"_node_indexes",
"_attrs",
Expand Down Expand Up @@ -498,6 +500,9 @@ def _set_node_data(self, dataset: Dataset):
data_vars, coord_vars = _collect_data_and_coord_variables(dataset)
self._data_variables = data_vars
self._node_coord_variables = coord_vars
self._node_indexed_coord_variables = {
k: coord_vars[k] for k in dataset._indexes
}
self._node_dims = dataset._dims
self._node_indexes = dataset._indexes
self._encoding = dataset._encoding
Expand All @@ -519,7 +524,8 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
self._node_coord_variables, *(p._node_coord_variables for p in self.parents)
self._node_coord_variables,
*(p._node_indexed_coord_variables for p in self.parents),
)

@property
Expand Down
36 changes: 19 additions & 17 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,16 @@ def test_is_hollow(self):


class TestToDataset:
def test_to_dataset(self):
base = xr.Dataset(coords={"a": 1})
sub = xr.Dataset(coords={"b": 2})
def test_to_dataset_inherited(self):
base = xr.Dataset(coords={"a": [1], "b": 2})
sub = xr.Dataset(coords={"c": [3]})
tree = DataTree.from_dict({"/": base, "/sub": sub})
subtree = typing.cast(DataTree, tree["sub"])

assert_identical(tree.to_dataset(inherited=False), base)
assert_identical(subtree.to_dataset(inherited=False), sub)

sub_and_base = xr.Dataset(coords={"a": 1, "b": 2})
sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]}) # no "b"
assert_identical(tree.to_dataset(inherited=True), base)
assert_identical(subtree.to_dataset(inherited=True), sub_and_base)

Expand Down Expand Up @@ -714,7 +714,8 @@ def test_inherited(self):
dt["child"] = DataTree()
child = dt["child"]

assert set(child.coords) == {"x", "y", "a", "b"}
assert set(dt.coords) == {"x", "y", "a", "b"}
assert set(child.coords) == {"x", "y"}

actual = child.copy(deep=True)
actual.coords["x"] = ("x", ["a", "b"])
Expand All @@ -729,7 +730,7 @@ def test_inherited(self):

with pytest.raises(KeyError):
# cannot delete inherited coordinate from child node
del child["b"]
del child["x"]

# TODO requires a fix for #9472
# actual = child.copy(deep=True)
Expand Down Expand Up @@ -1278,22 +1279,23 @@ def test_inherited_coords_index(self):
assert "x" in dt["/b"].coords
xr.testing.assert_identical(dt["/x"], dt["/b/x"])

def test_inherited_coords_override(self):
def test_inherit_only_index_coords(self):
dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1, "y": 2}),
"/b": xr.Dataset(coords={"x": 4, "z": 3}),
"/": xr.Dataset(coords={"x": [1], "y": 2}),
"/b": xr.Dataset(coords={"z": 3}),
}
)
assert dt.coords.keys() == {"x", "y"}
root_coords = {"x": 1, "y": 2}
sub_coords = {"x": 4, "y": 2, "z": 3}
xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords))
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords))
assert dt["/b"].coords.keys() == {"x", "y", "z"}
xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords))
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))
xr.testing.assert_equal(
dt["/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "y": 2})
)
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords={"y": 2}))
assert dt["/b"].coords.keys() == {"x", "z"}
xr.testing.assert_equal(
dt["/b/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "z": 3})
)
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords={"z": 3}))

def test_inherited_coords_with_index_are_deduplicated(self):
dt = DataTree.from_dict(
Expand Down
Loading