diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bd583ac86cb..ff498a77480 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -156,6 +156,34 @@ def check_alignment( check_alignment(child_path, child_ds, base_ds, child.children) +def _deduplicate_inherited_coordinates(child: DataTree, parent: DataTree) -> None: + # This method removes repeated indexes (and corresponding coordinates) + # that are repeated between a DataTree and its parents. + # + # TODO(shoyer): Decide how to handle repeated coordinates *without* an + # index. Should these be allowed, in which case we probably want to + # exclude them from inheritance, or should they be automatically + # dropped? + # https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 + removed_something = False + for name in parent._indexes: + if name in child._node_indexes: + # Indexes on a Dataset always have a corresponding coordinate. + # We already verified that these coordinates match in the + # check_alignment() call from _pre_attach(). + del child._node_indexes[name] + del child._node_coord_variables[name] + removed_something = True + + if removed_something: + child._node_dims = calculate_dimensions( + child._data_variables | child._node_coord_variables + ) + + for grandchild in child._children.values(): + _deduplicate_inherited_coordinates(grandchild, child) + + def _check_for_slashes_in_names(variables: Iterable[Hashable]) -> None: offending_variable_names = [ name for name in variables if isinstance(name, str) and "/" in name @@ -374,7 +402,7 @@ def map( # type: ignore[override] class DataTree( - NamedNode, + NamedNode["DataTree"], MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmeticMixin, @@ -485,6 +513,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: node_ds = self.to_dataset(inherited=False) parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) check_alignment(path, node_ds, parent_ds, self.children) + _deduplicate_inherited_coordinates(self, parent) @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: @@ -1350,7 +1379,7 @@ def map_over_subtree( func: Callable, *args: Iterable[Any], **kwargs: Any, - ) -> DataTree | tuple[DataTree]: + ) -> DataTree | tuple[DataTree, ...]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 4e22ba57e2e..9f51f5fd99f 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1190,6 +1190,22 @@ def test_inherited_coords_override(self): xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) + def test_inherited_coords_with_index_are_deduplicated(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2]}), + "/b": xr.Dataset(coords={"x": [1, 2]}), + } + ) + child_dataset = dt.children["b"].to_dataset(inherited=False) + expected = xr.Dataset() + assert_identical(child_dataset, expected) + + dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]}) + child_dataset = dt.children["c"].to_dataset(inherited=False) + expected = xr.Dataset({"foo": ("x", [4, 5])}) + assert_identical(child_dataset, expected) + def test_inconsistent_dims(self): expected_msg = _exact_match( """ @@ -1534,6 +1550,18 @@ def test_binary_op_on_datatree(self): result = dt * dt # type: ignore[operator] assert_equal(result, expected) + def test_arithmetic_inherited_coords(self): + tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]})) + tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])})) + actual: DataTree = 2 * tree # type: ignore[assignment,operator] + + actual_dataset = actual.children["foo"].to_dataset(inherited=False) + assert "x" not in actual_dataset.coords + + expected = tree.copy() + expected["/foo/bar"].data = np.array([8, 10, 12]) + assert_identical(actual, expected) + class TestUFuncs: def test_tree(self, create_test_datatree): diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index c7e0e93b89b..9a7d3009c3b 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -7,7 +7,7 @@ check_isomorphic, map_over_subtree, ) -from xarray.testing import assert_equal +from xarray.testing import assert_equal, assert_identical empty = xr.Dataset() @@ -306,6 +306,17 @@ def fail_on_specific_node(ds): ): dt.map_over_subtree(fail_on_specific_node) + def test_inherited_coordinates_with_index(self): + root = xr.Dataset(coords={"x": [1, 2]}) + child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates + tree = xr.DataTree.from_dict({"/": root, "/child": child}) + actual = tree.map_over_subtree(lambda ds: ds) # identity + assert isinstance(actual, xr.DataTree) + assert_identical(tree, actual) + + actual_child = actual.children["child"].to_dataset(inherited=False) + assert_identical(actual_child, child) + class TestMutableOperations: def test_construct_using_type(self):