Skip to content

Commit

Permalink
edits per review
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Sep 22, 2024
1 parent b3d3eaa commit 9cde264
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 31 deletions.
60 changes: 30 additions & 30 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,34 @@ def check_alignment(
check_alignment(child_path, child_ds, base_ds, child.children)


def _deduplicate_inherited_coordinates(child: DataTree, parent: DataTree) -> None:
# This method removes repeated indexes (and corresponding coordinates)
# that are repeated between a DataTree and its parents.
#
# TODO(shoyer): Decide how to handle repeated coordinates *without* an
# index. Should these be allowed, in which case we probably want to
# exclude them from inheritance, or should they be automatically
# dropped?
# https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264
removed_something = False
for name in parent._indexes:
if name in child._node_indexes:
# Indexes on a Dataset always have a corresponding coordinate.
# We already verified that these coordinates match in the
# check_alignment() call from _pre_attach().
del child._node_indexes[name]
del child._node_coord_variables[name]
removed_something = True

if removed_something:
child._node_dims = calculate_dimensions(
child._data_variables | child._node_coord_variables
)

for grandchild in child._children.values():
_deduplicate_inherited_coordinates(grandchild, child)


def _check_for_slashes_in_names(variables: Iterable[Hashable]) -> None:
offending_variable_names = [
name for name in variables if isinstance(name, str) and "/" in name
Expand Down Expand Up @@ -485,35 +513,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
node_ds = self.to_dataset(inherited=False)
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True)
check_alignment(path, node_ds, parent_ds, self.children)

def _dedup_inherited_coordinates(self):
# This method removes repeated indexes (and correpsonding coordinates)
# that are repeated between a DataTree and its parents.
#
# TODO(shoyer): Decide how to handle repeated coordinates *without* an
# index. Should these be allowed, in which case we probably want to
# exclude them from inheritance, or should they be automatically
# dropped?
# https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264
removed_something = False
for name in self._indexes.parents:
if name in self._node_indexes:
# Indexes on a Dataset always have a corresponding coordinate
del self._node_indexes[name]
del self._node_coord_variables[name]
removed_something = True

if removed_something:
self._node_dims = calculate_dimensions(
self._data_variables | self._node_coord_variables
)

for child in self._children.values():
child._dedup_inherited_coordinates()

def _post_attach(self: DataTree, parent: DataTree, name: str) -> None:
super()._post_attach(parent, name)
self._dedup_inherited_coordinates()
_deduplicate_inherited_coordinates(self, parent)

@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
Expand Down Expand Up @@ -1379,7 +1379,7 @@ def map_over_subtree(
func: Callable,
*args: Iterable[Any],
**kwargs: Any,
) -> DataTree | tuple[DataTree]:
) -> DataTree | tuple[DataTree, ...]:
"""
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
Expand Down
13 changes: 12 additions & 1 deletion xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
check_isomorphic,
map_over_subtree,
)
from xarray.testing import assert_equal
from xarray.testing import assert_equal, assert_identical

empty = xr.Dataset()

Expand Down Expand Up @@ -306,6 +306,17 @@ def fail_on_specific_node(ds):
):
dt.map_over_subtree(fail_on_specific_node)

def test_inherited_coordinates_with_index(self):
root = xr.Dataset(coords={"x": [1, 2]})
child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates
tree = xr.DataTree.from_dict({"/": root, "/child": child})
actual = tree.map_over_subtree(lambda ds: ds) # identity
assert isinstance(actual, xr.DataTree)
assert_identical(tree, actual)

actual_child = actual.children["child"].to_dataset(inherited=False)
assert_identical(actual_child, child)


class TestMutableOperations:
def test_construct_using_type(self):
Expand Down

0 comments on commit 9cde264

Please sign in to comment.