diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bd583ac86cb..27a2da88701 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -374,7 +374,7 @@ def map( # type: ignore[override] class DataTree( - NamedNode, + NamedNode["DataTree"], MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmeticMixin, @@ -486,6 +486,35 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) check_alignment(path, node_ds, parent_ds, self.children) + def _dedup_inherited_coordinates(self): + # This method removes repeated indexes (and correpsonding coordinates) + # that are repeated between a DataTree and its parents. + # + # TODO(shoyer): Decide how to handle repeated coordinates *without* an + # index. Should these be allowed, in which case we probably want to + # exclude them from inheritance, or should they be automatically + # dropped? + # https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 + removed_something = False + for name in self._indexes.parents: + if name in self._node_indexes: + # Indexes on a Dataset always have a corresponding coordinate + del self._node_indexes[name] + del self._node_coord_variables[name] + removed_something = True + + if removed_something: + self._node_dims = calculate_dimensions( + self._data_variables | self._node_coord_variables + ) + + for child in self._children.values(): + child._dedup_inherited_coordinates() + + def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: + super()._post_attach(parent, name) + self._dedup_inherited_coordinates() + @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: return ChainMap( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 4e22ba57e2e..e009e7c47ab 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1190,6 +1190,22 @@ def test_inherited_coords_override(self): xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) + def test_inherited_coords_with_index_are_deduplicated(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2]}), + "/b": xr.Dataset(coords={"x": [1, 2]}), + } + ) + child_dataset = dt.children["b"].to_dataset(inherited=False) + expected = xr.Dataset() + assert_identical(child_dataset, expected) + + dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]}) + child_dataset = dt.children["c"].to_dataset(inherited=False) + expected = xr.Dataset({"foo": ("x", [4, 5])}) + assert_identical(child_dataset, expected) + def test_inconsistent_dims(self): expected_msg = _exact_match( """ @@ -1534,6 +1550,18 @@ def test_binary_op_on_datatree(self): result = dt * dt # type: ignore[operator] assert_equal(result, expected) + def test_arithmetic_inherited_coords(self): + tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]})) + tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])})) + actual = 2 * tree # type: ignore + + actual_dataset = actual.children["foo"].to_dataset(inherited=False) + assert "x" not in actual_dataset.coords + + expected = tree.copy() + expected["/foo/bar"].data = np.array([8, 10, 12]) + assert_identical(actual, expected) + class TestUFuncs: def test_tree(self, create_test_datatree):