diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 27a2da88701..ff498a77480 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -156,6 +156,34 @@ def check_alignment( check_alignment(child_path, child_ds, base_ds, child.children) +def _deduplicate_inherited_coordinates(child: DataTree, parent: DataTree) -> None: + # This method removes repeated indexes (and corresponding coordinates) + # that are repeated between a DataTree and its parents. + # + # TODO(shoyer): Decide how to handle repeated coordinates *without* an + # index. Should these be allowed, in which case we probably want to + # exclude them from inheritance, or should they be automatically + # dropped? + # https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 + removed_something = False + for name in parent._indexes: + if name in child._node_indexes: + # Indexes on a Dataset always have a corresponding coordinate. + # We already verified that these coordinates match in the + # check_alignment() call from _pre_attach(). + del child._node_indexes[name] + del child._node_coord_variables[name] + removed_something = True + + if removed_something: + child._node_dims = calculate_dimensions( + child._data_variables | child._node_coord_variables + ) + + for grandchild in child._children.values(): + _deduplicate_inherited_coordinates(grandchild, child) + + def _check_for_slashes_in_names(variables: Iterable[Hashable]) -> None: offending_variable_names = [ name for name in variables if isinstance(name, str) and "/" in name @@ -485,35 +513,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: node_ds = self.to_dataset(inherited=False) parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) check_alignment(path, node_ds, parent_ds, self.children) - - def _dedup_inherited_coordinates(self): - # This method removes repeated indexes (and correpsonding coordinates) - # that are repeated between a DataTree and its parents. - # - # TODO(shoyer): Decide how to handle repeated coordinates *without* an - # index. Should these be allowed, in which case we probably want to - # exclude them from inheritance, or should they be automatically - # dropped? - # https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 - removed_something = False - for name in self._indexes.parents: - if name in self._node_indexes: - # Indexes on a Dataset always have a corresponding coordinate - del self._node_indexes[name] - del self._node_coord_variables[name] - removed_something = True - - if removed_something: - self._node_dims = calculate_dimensions( - self._data_variables | self._node_coord_variables - ) - - for child in self._children.values(): - child._dedup_inherited_coordinates() - - def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: - super()._post_attach(parent, name) - self._dedup_inherited_coordinates() + _deduplicate_inherited_coordinates(self, parent) @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: @@ -1379,7 +1379,7 @@ def map_over_subtree( func: Callable, *args: Iterable[Any], **kwargs: Any, - ) -> DataTree | tuple[DataTree]: + ) -> DataTree | tuple[DataTree, ...]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index c7e0e93b89b..9a7d3009c3b 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -7,7 +7,7 @@ check_isomorphic, map_over_subtree, ) -from xarray.testing import assert_equal +from xarray.testing import assert_equal, assert_identical empty = xr.Dataset() @@ -306,6 +306,17 @@ def fail_on_specific_node(ds): ): dt.map_over_subtree(fail_on_specific_node) + def test_inherited_coordinates_with_index(self): + root = xr.Dataset(coords={"x": [1, 2]}) + child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates + tree = xr.DataTree.from_dict({"/": root, "/child": child}) + actual = tree.map_over_subtree(lambda ds: ds) # identity + assert isinstance(actual, xr.DataTree) + assert_identical(tree, actual) + + actual_child = actual.children["child"].to_dataset(inherited=False) + assert_identical(actual_child, child) + class TestMutableOperations: def test_construct_using_type(self):