diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 52d44bec96..6d9a5a6d8d 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -438,6 +438,7 @@ class DataTree( _cache: dict[str, Any] # used by _CachedAccessor _data_variables: dict[Hashable, Variable] _node_coord_variables: dict[Hashable, Variable] + _node_indexed_coord_variables: dict[Hashable, Variable] _node_dims: dict[Hashable, int] _node_indexes: dict[Hashable, Index] _attrs: dict[Hashable, Any] | None @@ -451,6 +452,7 @@ class DataTree( "_cache", # used by _CachedAccessor "_data_variables", "_node_coord_variables", + "_node_indexed_coord_variables", "_node_dims", "_node_indexes", "_attrs", @@ -498,6 +500,9 @@ def _set_node_data(self, dataset: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(dataset) self._data_variables = data_vars self._node_coord_variables = coord_vars + self._node_indexed_coord_variables = { + k: coord_vars[k] for k in dataset._indexes + } self._node_dims = dataset._dims self._node_indexes = dataset._indexes self._encoding = dataset._encoding @@ -519,7 +524,8 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: return ChainMap( - self._node_coord_variables, *(p._node_coord_variables for p in self.parents) + self._node_coord_variables, + *(p._node_indexed_coord_variables for p in self.parents), ) @property diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 30934f83c6..3365e49309 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -216,16 +216,16 @@ def test_is_hollow(self): class TestToDataset: - def test_to_dataset(self): - base = xr.Dataset(coords={"a": 1}) - sub = xr.Dataset(coords={"b": 2}) + def test_to_dataset_inherited(self): + base = xr.Dataset(coords={"a": [1], "b": 2}) + sub = xr.Dataset(coords={"c": [3]}) tree = DataTree.from_dict({"/": base, "/sub": sub}) subtree = typing.cast(DataTree, tree["sub"]) assert_identical(tree.to_dataset(inherited=False), base) assert_identical(subtree.to_dataset(inherited=False), sub) - sub_and_base = xr.Dataset(coords={"a": 1, "b": 2}) + sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]}) # no "b" assert_identical(tree.to_dataset(inherited=True), base) assert_identical(subtree.to_dataset(inherited=True), sub_and_base) @@ -714,7 +714,8 @@ def test_inherited(self): dt["child"] = DataTree() child = dt["child"] - assert set(child.coords) == {"x", "y", "a", "b"} + assert set(dt.coords) == {"x", "y", "a", "b"} + assert set(child.coords) == {"x", "y"} actual = child.copy(deep=True) actual.coords["x"] = ("x", ["a", "b"]) @@ -729,7 +730,7 @@ def test_inherited(self): with pytest.raises(KeyError): # cannot delete inherited coordinate from child node - del child["b"] + del child["x"] # TODO requires a fix for #9472 # actual = child.copy(deep=True) @@ -1278,22 +1279,23 @@ def test_inherited_coords_index(self): assert "x" in dt["/b"].coords xr.testing.assert_identical(dt["/x"], dt["/b/x"]) - def test_inherited_coords_override(self): + def test_inherit_only_index_coords(self): dt = DataTree.from_dict( { - "/": xr.Dataset(coords={"x": 1, "y": 2}), - "/b": xr.Dataset(coords={"x": 4, "z": 3}), + "/": xr.Dataset(coords={"x": [1], "y": 2}), + "/b": xr.Dataset(coords={"z": 3}), } ) assert dt.coords.keys() == {"x", "y"} - root_coords = {"x": 1, "y": 2} - sub_coords = {"x": 4, "y": 2, "z": 3} - xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords)) - xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords)) - assert dt["/b"].coords.keys() == {"x", "y", "z"} - xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords)) - xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) - xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) + xr.testing.assert_equal( + dt["/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "y": 2}) + ) + xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords={"y": 2})) + assert dt["/b"].coords.keys() == {"x", "z"} + xr.testing.assert_equal( + dt["/b/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "z": 3}) + ) + xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords={"z": 3})) def test_inherited_coords_with_index_are_deduplicated(self): dt = DataTree.from_dict(