Skip to content

Commit

Permalink
De-duplicate inherited coordinates & dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Sep 17, 2024
1 parent 18e5c87 commit 2ba1da6
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 21 deletions.
105 changes: 86 additions & 19 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import itertools
import textwrap
from collections import ChainMap
Expand Down Expand Up @@ -373,6 +374,10 @@ def map( # type: ignore[override]
return Dataset(variables)


CONFLICTING_COORDS_MODES = {"error", "ignore"}
ConflictingCoordsMode = Literal["ignore", "raise"]


class DataTree(
NamedNode,
MappedDatasetMethodsMixin,
Expand Down Expand Up @@ -414,6 +419,7 @@ class DataTree(
_attrs: dict[Hashable, Any] | None
_encoding: dict[Hashable, Any] | None
_close: Callable[[], None] | None
_conflicting_coords_mode: ConflictingCoordsMode

__slots__ = (
"_name",
Expand All @@ -427,6 +433,7 @@ class DataTree(
"_attrs",
"_encoding",
"_close",
"_conflicting_coords_mode",
)

def __init__(
Expand Down Expand Up @@ -459,9 +466,10 @@ def __init__(
--------
DataTree.from_dict
"""
self._conflicting_coords_mode: str = "ignore"
self._set_node_data(_to_new_dataset(dataset))

# comes after setting node data as this will check for clashes between child names and existing variable names
# comes after setting node data as this will check for clashes between
# child names and existing variable names
super().__init__(name=name, children=children)

def _set_node_data(self, dataset: Dataset):
Expand All @@ -486,6 +494,55 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True)
check_alignment(path, node_ds, parent_ds, self.children)

@contextlib.contextmanager
def _with_conflicting_coords_mode(self, mode: ConflictingCoordsMode):
"""Set handling of duplicated inherited coordinates on this object.
This private context manager is an indirect way to control the handling
of duplicated inherited coordinates, which is done inside _post_attach()
and hence is difficult to directly parameterize.
"""
if mode not in CONFLICTING_COORDS_MODES:
raise ValueError(
f"conflicting coordinates mode must be in {CONFLICTING_COORDS_MODES}, got {mode!r}"
)
assert self.parent is None
original_mode = self._conflicting_coords_mode
self._conflicting_coords_mode = mode
try:
yield
finally:
self._conflicting_coords_mode = original_mode

def _dedup_inherited_coordinates(self):
removed_something = False
for name in self._coord_variables.parents:
if name in self._node_coord_variables:
indexed = name in self._node_indexes
if indexed:
del self._node_indexes[name]
elif self.root._conflicting_coords_mode == "error":
# error only for non-indexed coordinates, which are not
# checked for equality
raise ValueError(
f"coordinate {name!r} on node {self.path!r} is also found on a parent node"
)

del self._node_coord_variables[name]
removed_something = True

if removed_something:
self._node_dims = calculate_dimensions(
self._data_variables | self._node_coord_variables
)

for child in self._children.values():
child._dedup_inherited_coordinates()

def _post_attach(self: DataTree, parent: DataTree, name: str) -> None:
super()._post_attach(parent, name)
self._dedup_inherited_coordinates()

@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
Expand Down Expand Up @@ -1055,6 +1112,7 @@ def from_dict(
d: Mapping[str, Dataset | DataTree | None],
/,
name: str | None = None,
conflicting_coords: ConflictingCoordsMode = "ignore",
) -> DataTree:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.
Expand All @@ -1070,6 +1128,11 @@ def from_dict(
To assign data to the root node of the tree use "/" as the path.
name : Hashable | None, optional
Name for the root node of the tree. Default is None.
conflicting_coords : "ignore" or "error"
How to handle repeated coordinates without an associated index
between parent and child nodes. By default, such coordinates are
dropped from child nodes. Alternatively, an error can be raise when
they are encountered.
Returns
-------
Expand Down Expand Up @@ -1097,23 +1160,27 @@ def depth(item) -> int:
return len(NodePath(pathstr).parts)

if d_cast:
# Populate tree with children determined from data_objects mapping
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
for path, data in sorted(d_cast.items(), key=depth):
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
path,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)
# Populate tree with children determined from data_objects mapping.
# Ensure the result object has the right handling of conflicting
# coordinates.
with obj._with_conflicting_coords_mode(conflicting_coords):
# Sort keys by depth so as to insert nodes from root first (see
# GH issue #9276)
for path, data in sorted(d_cast.items(), key=depth):
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
path,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)

return obj

Expand Down
44 changes: 42 additions & 2 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,37 @@ def test_array_values(self):
with pytest.raises(TypeError):
DataTree.from_dict(data) # type: ignore[arg-type]

def test_conflicting_coords(self):
tree = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1}),
"/sub": xr.Dataset(coords={"x": 2}),
},
conflicting_coords="ignore",
)
expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1}),
"/sub": xr.Dataset(),
}
)
assert_identical(tree, expected)

with pytest.raises(
ValueError,
match="coordinate 'x' on node '/sub' is also found on a parent node",
):
DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1}),
"/sub": xr.Dataset(coords={"x": 2}),
},
conflicting_coords="error",
)

with pytest.raises(ValueError, match="conflicting coordinates mode must be in"):
DataTree.from_dict({"/sub": xr.Dataset()}, conflicting_coords="invalid") # type: ignore


class TestDatasetView:
def test_view_contents(self):
Expand Down Expand Up @@ -1166,11 +1197,11 @@ def test_inherited_coords_override(self):
)
assert dt.coords.keys() == {"x", "y"}
root_coords = {"x": 1, "y": 2}
sub_coords = {"x": 4, "y": 2, "z": 3}
sub_coords = {"x": 1, "y": 2, "z": 3}
xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords))
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords))
assert dt["/b"].coords.keys() == {"x", "y", "z"}
xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords))
xr.testing.assert_equal(dt["/b/x"], xr.DataArray(1, coords=sub_coords))
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))

Expand Down Expand Up @@ -1518,6 +1549,15 @@ def test_binary_op_on_datatree(self):
result = dt * dt # type: ignore[operator]
assert_equal(result, expected)

def test_arithmetic_inherited_coords(self):
tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]}))
tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])}))
actual = 2 * tree
expected = tree.copy()
expected["/foo/bar"].data = np.array([8, 10, 12])
assert "x" not in tree.children["foo"].to_dataset(inherited=False).coords
assert_identical(actual, tree)


class TestUFuncs:
def test_tree(self, create_test_datatree):
Expand Down

0 comments on commit 2ba1da6

Please sign in to comment.