Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplicate coordinates with indexes on sub-trees #9531

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,34 @@ def check_alignment(
check_alignment(child_path, child_ds, base_ds, child.children)


def _deduplicate_inherited_coordinates(child: DataTree, parent: DataTree) -> None:
# This method removes repeated indexes (and corresponding coordinates)
# that are repeated between a DataTree and its parents.
#
# TODO(shoyer): Decide how to handle repeated coordinates *without* an
# index. Should these be allowed, in which case we probably want to
# exclude them from inheritance, or should they be automatically
# dropped?
# https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264
removed_something = False
for name in parent._indexes:
if name in child._node_indexes:
# Indexes on a Dataset always have a corresponding coordinate.
# We already verified that these coordinates match in the
# check_alignment() call from _pre_attach().
del child._node_indexes[name]
del child._node_coord_variables[name]
removed_something = True

if removed_something:
child._node_dims = calculate_dimensions(
child._data_variables | child._node_coord_variables
)

for grandchild in child._children.values():
_deduplicate_inherited_coordinates(grandchild, child)


def _check_for_slashes_in_names(variables: Iterable[Hashable]) -> None:
offending_variable_names = [
name for name in variables if isinstance(name, str) and "/" in name
Expand Down Expand Up @@ -374,7 +402,7 @@ def map( # type: ignore[override]


class DataTree(
NamedNode,
NamedNode["DataTree"],
MappedDatasetMethodsMixin,
MappedDataWithCoords,
DataTreeArithmeticMixin,
Expand Down Expand Up @@ -485,6 +513,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
node_ds = self.to_dataset(inherited=False)
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True)
check_alignment(path, node_ds, parent_ds, self.children)
_deduplicate_inherited_coordinates(self, parent)

@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
Expand Down Expand Up @@ -1350,7 +1379,7 @@ def map_over_subtree(
func: Callable,
*args: Iterable[Any],
**kwargs: Any,
) -> DataTree | tuple[DataTree]:
) -> DataTree | tuple[DataTree, ...]:
"""
Apply a function to every dataset in this subtree, returning a new tree which stores the results.

Expand Down
28 changes: 28 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,22 @@ def test_inherited_coords_override(self):
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))

def test_inherited_coords_with_index_are_deduplicated(self):
dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2]}),
"/b": xr.Dataset(coords={"x": [1, 2]}),
}
)
child_dataset = dt.children["b"].to_dataset(inherited=False)
expected = xr.Dataset()
assert_identical(child_dataset, expected)

dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]})
child_dataset = dt.children["c"].to_dataset(inherited=False)
expected = xr.Dataset({"foo": ("x", [4, 5])})
assert_identical(child_dataset, expected)

def test_inconsistent_dims(self):
expected_msg = _exact_match(
"""
Expand Down Expand Up @@ -1534,6 +1550,18 @@ def test_binary_op_on_datatree(self):
result = dt * dt # type: ignore[operator]
assert_equal(result, expected)

def test_arithmetic_inherited_coords(self):
tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]}))
shoyer marked this conversation as resolved.
Show resolved Hide resolved
tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])}))
actual: DataTree = 2 * tree # type: ignore[assignment,operator]

actual_dataset = actual.children["foo"].to_dataset(inherited=False)
assert "x" not in actual_dataset.coords

expected = tree.copy()
expected["/foo/bar"].data = np.array([8, 10, 12])
assert_identical(actual, expected)


class TestUFuncs:
def test_tree(self, create_test_datatree):
Expand Down
13 changes: 12 additions & 1 deletion xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
check_isomorphic,
map_over_subtree,
)
from xarray.testing import assert_equal
from xarray.testing import assert_equal, assert_identical

empty = xr.Dataset()

Expand Down Expand Up @@ -306,6 +306,17 @@ def fail_on_specific_node(ds):
):
dt.map_over_subtree(fail_on_specific_node)

def test_inherited_coordinates_with_index(self):
root = xr.Dataset(coords={"x": [1, 2]})
child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates
tree = xr.DataTree.from_dict({"/": root, "/child": child})
actual = tree.map_over_subtree(lambda ds: ds) # identity
assert isinstance(actual, xr.DataTree)
assert_identical(tree, actual)

actual_child = actual.children["child"].to_dataset(inherited=False)
assert_identical(actual_child, child)


class TestMutableOperations:
def test_construct_using_type(self):
Expand Down
Loading