From 0541f289329430abcb0870d9ceacc66c34fdc53f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 22 Sep 2024 13:18:24 -0700 Subject: [PATCH 1/3] Remove duplicate coordinates with indexes on sub-trees This is a _partial_ solution to the duplicate coordinates issue from #9475. Here we remove all duplicate coordinates between parent and child nodes with an index (these are already checked for equality via the alignment check). Other repeated coordinates (which we cannot automatically check for equality) are still allowed for now. We will need an alternative solution for these, as discussed in #9475, but it is less obvious what the right solution is so I'm holding off on it for now. --- xarray/core/datatree.py | 31 ++++++++++++++++++++++++++++++- xarray/tests/test_datatree.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index bd583ac86cb..27a2da88701 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -374,7 +374,7 @@ def map( # type: ignore[override] class DataTree( - NamedNode, + NamedNode["DataTree"], MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmeticMixin, @@ -486,6 +486,35 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) check_alignment(path, node_ds, parent_ds, self.children) + def _dedup_inherited_coordinates(self): + # This method removes repeated indexes (and correpsonding coordinates) + # that are repeated between a DataTree and its parents. + # + # TODO(shoyer): Decide how to handle repeated coordinates *without* an + # index. Should these be allowed, in which case we probably want to + # exclude them from inheritance, or should they be automatically + # dropped? + # https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 + removed_something = False + for name in self._indexes.parents: + if name in self._node_indexes: + # Indexes on a Dataset always have a corresponding coordinate + del self._node_indexes[name] + del self._node_coord_variables[name] + removed_something = True + + if removed_something: + self._node_dims = calculate_dimensions( + self._data_variables | self._node_coord_variables + ) + + for child in self._children.values(): + child._dedup_inherited_coordinates() + + def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: + super()._post_attach(parent, name) + self._dedup_inherited_coordinates() + @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: return ChainMap( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 4e22ba57e2e..e009e7c47ab 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1190,6 +1190,22 @@ def test_inherited_coords_override(self): xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) + def test_inherited_coords_with_index_are_deduplicated(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1, 2]}), + "/b": xr.Dataset(coords={"x": [1, 2]}), + } + ) + child_dataset = dt.children["b"].to_dataset(inherited=False) + expected = xr.Dataset() + assert_identical(child_dataset, expected) + + dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]}) + child_dataset = dt.children["c"].to_dataset(inherited=False) + expected = xr.Dataset({"foo": ("x", [4, 5])}) + assert_identical(child_dataset, expected) + def test_inconsistent_dims(self): expected_msg = _exact_match( """ @@ -1534,6 +1550,18 @@ def test_binary_op_on_datatree(self): result = dt * dt # type: ignore[operator] assert_equal(result, expected) + def test_arithmetic_inherited_coords(self): + tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]})) + tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])})) + actual = 2 * tree # type: ignore + + actual_dataset = actual.children["foo"].to_dataset(inherited=False) + assert "x" not in actual_dataset.coords + + expected = tree.copy() + expected["/foo/bar"].data = np.array([8, 10, 12]) + assert_identical(actual, expected) + class TestUFuncs: def test_tree(self, create_test_datatree): From b3d3eaafecef1600cc8aef508eafb91055e33de3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 22 Sep 2024 14:06:27 -0700 Subject: [PATCH 2/3] silence mypy error --- xarray/tests/test_datatree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index e009e7c47ab..9f51f5fd99f 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1553,7 +1553,7 @@ def test_binary_op_on_datatree(self): def test_arithmetic_inherited_coords(self): tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]})) tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])})) - actual = 2 * tree # type: ignore + actual: DataTree = 2 * tree # type: ignore[assignment,operator] actual_dataset = actual.children["foo"].to_dataset(inherited=False) assert "x" not in actual_dataset.coords From 9cde26459d8f890d618be764f424dfb9ccf0045f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 22 Sep 2024 16:18:00 -0700 Subject: [PATCH 3/3] edits per review --- xarray/core/datatree.py | 60 +++++++++++++-------------- xarray/tests/test_datatree_mapping.py | 13 +++++- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 27a2da88701..ff498a77480 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -156,6 +156,34 @@ def check_alignment( check_alignment(child_path, child_ds, base_ds, child.children) +def _deduplicate_inherited_coordinates(child: DataTree, parent: DataTree) -> None: + # This method removes repeated indexes (and corresponding coordinates) + # that are repeated between a DataTree and its parents. + # + # TODO(shoyer): Decide how to handle repeated coordinates *without* an + # index. Should these be allowed, in which case we probably want to + # exclude them from inheritance, or should they be automatically + # dropped? + # https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 + removed_something = False + for name in parent._indexes: + if name in child._node_indexes: + # Indexes on a Dataset always have a corresponding coordinate. + # We already verified that these coordinates match in the + # check_alignment() call from _pre_attach(). + del child._node_indexes[name] + del child._node_coord_variables[name] + removed_something = True + + if removed_something: + child._node_dims = calculate_dimensions( + child._data_variables | child._node_coord_variables + ) + + for grandchild in child._children.values(): + _deduplicate_inherited_coordinates(grandchild, child) + + def _check_for_slashes_in_names(variables: Iterable[Hashable]) -> None: offending_variable_names = [ name for name in variables if isinstance(name, str) and "/" in name @@ -485,35 +513,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: node_ds = self.to_dataset(inherited=False) parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) check_alignment(path, node_ds, parent_ds, self.children) - - def _dedup_inherited_coordinates(self): - # This method removes repeated indexes (and correpsonding coordinates) - # that are repeated between a DataTree and its parents. - # - # TODO(shoyer): Decide how to handle repeated coordinates *without* an - # index. Should these be allowed, in which case we probably want to - # exclude them from inheritance, or should they be automatically - # dropped? - # https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264 - removed_something = False - for name in self._indexes.parents: - if name in self._node_indexes: - # Indexes on a Dataset always have a corresponding coordinate - del self._node_indexes[name] - del self._node_coord_variables[name] - removed_something = True - - if removed_something: - self._node_dims = calculate_dimensions( - self._data_variables | self._node_coord_variables - ) - - for child in self._children.values(): - child._dedup_inherited_coordinates() - - def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: - super()._post_attach(parent, name) - self._dedup_inherited_coordinates() + _deduplicate_inherited_coordinates(self, parent) @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: @@ -1379,7 +1379,7 @@ def map_over_subtree( func: Callable, *args: Iterable[Any], **kwargs: Any, - ) -> DataTree | tuple[DataTree]: + ) -> DataTree | tuple[DataTree, ...]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index c7e0e93b89b..9a7d3009c3b 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -7,7 +7,7 @@ check_isomorphic, map_over_subtree, ) -from xarray.testing import assert_equal +from xarray.testing import assert_equal, assert_identical empty = xr.Dataset() @@ -306,6 +306,17 @@ def fail_on_specific_node(ds): ): dt.map_over_subtree(fail_on_specific_node) + def test_inherited_coordinates_with_index(self): + root = xr.Dataset(coords={"x": [1, 2]}) + child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates + tree = xr.DataTree.from_dict({"/": root, "/child": child}) + actual = tree.map_over_subtree(lambda ds: ds) # identity + assert isinstance(actual, xr.DataTree) + assert_identical(tree, actual) + + actual_child = actual.children["child"].to_dataset(inherited=False) + assert_identical(actual_child, child) + class TestMutableOperations: def test_construct_using_type(self):