From 0541f289329430abcb0870d9ceacc66c34fdc53f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 22 Sep 2024 13:18:24 -0700 Subject: [PATCH] Remove duplicate coordinates with indexes on sub-trees This is a _partial_ solution to the duplicate coordinates issue from #9475. Here we remove all duplicate coordinates between parent and child nodes with an index (these are already checked for equality via the alignment check). Other repeated coordinates (which we cannot automatically check for equality) are still allowed for now. We will need an alternative solution for these, as discussed in #9475, but it is less obvious what the right solution is so I'm holding off on it for now. --- xarray/core/datatree.py | 31 ++++++++++++++++++++++++++++++- xarray/tests/test_datatree.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bd583ac86cb..27a2da88701 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -374,7 +374,7 @@ def map( # type: ignore[override] class DataTree( - NamedNode, + NamedNode["DataTree"], MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmeticMixin, @@ -486,6 +486,35 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) check_alignment(path, node_ds, parent_ds, self.children) + def _dedup_inherited_coordinates(self): + # This method removes repeated indexes (and correpsonding coordinates) + # that are repeated between a DataTree and its parents. + # + # TODO(shoyer): Decide how to handle repeated coordinates *without* an + # index. Should these be allowed, in which case we probably want to + # exclude them from inheritance, or should they be automatically + # dropped? + # https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 + removed_something = False + for name in self._indexes.parents: + if name in self._node_indexes: + # Indexes on a Dataset always have a corresponding coordinate + del self._node_indexes[name] + del self._node_coord_variables[name] + removed_something = True + + if removed_something: + self._node_dims = calculate_dimensions( + self._data_variables | self._node_coord_variables + ) + + for child in self._children.values(): + child._dedup_inherited_coordinates() + + def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: + super()._post_attach(parent, name) + self._dedup_inherited_coordinates() + @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: return ChainMap( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 4e22ba57e2e..e009e7c47ab 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1190,6 +1190,22 @@ def test_inherited_coords_override(self): xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) + def test_inherited_coords_with_index_are_deduplicated(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2]}), + "/b": xr.Dataset(coords={"x": [1, 2]}), + } + ) + child_dataset = dt.children["b"].to_dataset(inherited=False) + expected = xr.Dataset() + assert_identical(child_dataset, expected) + + dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]}) + child_dataset = dt.children["c"].to_dataset(inherited=False) + expected = xr.Dataset({"foo": ("x", [4, 5])}) + assert_identical(child_dataset, expected) + def test_inconsistent_dims(self): expected_msg = _exact_match( """ @@ -1534,6 +1550,18 @@ def test_binary_op_on_datatree(self): result = dt * dt # type: ignore[operator] assert_equal(result, expected) + def test_arithmetic_inherited_coords(self): + tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]})) + tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])})) + actual = 2 * tree # type: ignore + + actual_dataset = actual.children["foo"].to_dataset(inherited=False) + assert "x" not in actual_dataset.coords + + expected = tree.copy() + expected["/foo/bar"].data = np.array([8, 10, 12]) + assert_identical(actual, expected) + class TestUFuncs: def test_tree(self, create_test_datatree):