Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplicate coordinates with indexes on sub-trees #9531

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def map( # type: ignore[override]


class DataTree(
NamedNode,
NamedNode["DataTree"],
MappedDatasetMethodsMixin,
MappedDataWithCoords,
DataTreeArithmeticMixin,
Expand Down Expand Up @@ -486,6 +486,35 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True)
check_alignment(path, node_ds, parent_ds, self.children)

def _dedup_inherited_coordinates(self):
# This method removes repeated indexes (and correpsonding coordinates)
shoyer marked this conversation as resolved.
Show resolved Hide resolved
# that are repeated between a DataTree and its parents.
#
# TODO(shoyer): Decide how to handle repeated coordinates *without* an
# index. Should these be allowed, in which case we probably want to
# exclude them from inheritance, or should they be automatically
# dropped?
# https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264
removed_something = False
for name in self._indexes.parents:
if name in self._node_indexes:
# Indexes on a Dataset always have a corresponding coordinate
del self._node_indexes[name]
del self._node_coord_variables[name]
removed_something = True
shoyer marked this conversation as resolved.
Show resolved Hide resolved

if removed_something:
self._node_dims = calculate_dimensions(
self._data_variables | self._node_coord_variables
)

for child in self._children.values():
child._dedup_inherited_coordinates()

def _post_attach(self: DataTree, parent: DataTree, name: str) -> None:
super()._post_attach(parent, name)
self._dedup_inherited_coordinates()

@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
Expand Down
28 changes: 28 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,22 @@ def test_inherited_coords_override(self):
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))

def test_inherited_coords_with_index_are_deduplicated(self):
dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2]}),
"/b": xr.Dataset(coords={"x": [1, 2]}),
}
)
child_dataset = dt.children["b"].to_dataset(inherited=False)
expected = xr.Dataset()
assert_identical(child_dataset, expected)

dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]})
child_dataset = dt.children["c"].to_dataset(inherited=False)
expected = xr.Dataset({"foo": ("x", [4, 5])})
assert_identical(child_dataset, expected)

def test_inconsistent_dims(self):
expected_msg = _exact_match(
"""
Expand Down Expand Up @@ -1534,6 +1550,18 @@ def test_binary_op_on_datatree(self):
result = dt * dt # type: ignore[operator]
assert_equal(result, expected)

def test_arithmetic_inherited_coords(self):
tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]}))
shoyer marked this conversation as resolved.
Show resolved Hide resolved
tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])}))
actual: DataTree = 2 * tree # type: ignore[assignment,operator]

actual_dataset = actual.children["foo"].to_dataset(inherited=False)
assert "x" not in actual_dataset.coords

expected = tree.copy()
expected["/foo/bar"].data = np.array([8, 10, 12])
assert_identical(actual, expected)


class TestUFuncs:
def test_tree(self, create_test_datatree):
Expand Down
Loading