Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

De-duplicate inherited coordinates & dimensions #9510

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 86 additions & 19 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import itertools
import textwrap
from collections import ChainMap
Expand Down Expand Up @@ -373,6 +374,10 @@ def map( # type: ignore[override]
return Dataset(variables)


CONFLICTING_COORDS_MODES = {"error", "ignore"}
ConflictingCoordsMode = Literal["ignore", "raise"]


class DataTree(
NamedNode,
MappedDatasetMethodsMixin,
Expand Down Expand Up @@ -414,6 +419,7 @@ class DataTree(
_attrs: dict[Hashable, Any] | None
_encoding: dict[Hashable, Any] | None
_close: Callable[[], None] | None
_conflicting_coords_mode: ConflictingCoordsMode

__slots__ = (
"_name",
Expand All @@ -427,6 +433,7 @@ class DataTree(
"_attrs",
"_encoding",
"_close",
"_conflicting_coords_mode",
)

def __init__(
Expand Down Expand Up @@ -459,9 +466,10 @@ def __init__(
--------
DataTree.from_dict
"""
self._conflicting_coords_mode: str = "ignore"
self._set_node_data(_to_new_dataset(dataset))

# comes after setting node data as this will check for clashes between child names and existing variable names
# comes after setting node data as this will check for clashes between
# child names and existing variable names
super().__init__(name=name, children=children)

def _set_node_data(self, dataset: Dataset):
Expand All @@ -486,6 +494,55 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True)
check_alignment(path, node_ds, parent_ds, self.children)

@contextlib.contextmanager
def _with_conflicting_coords_mode(self, mode: ConflictingCoordsMode):
"""Set handling of duplicated inherited coordinates on this object.

This private context manager is an indirect way to control the handling
of duplicated inherited coordinates, which is done inside _post_attach()
and hence is difficult to directly parameterize.
"""
Comment on lines +497 to +504
Copy link
Member

@TomNicholas TomNicholas Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to allow two modes for this? My understanding from the conversation in #9475 (comment) is that we would add a flag for use once during initial construction by open_datatree, but not a global flag or one that sets state of the tree object. Like I would have thought that this could just be a function/method parameter rather than a contextmanager or an attribute.

Copy link
Member

@TomNicholas TomNicholas Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summarizing discussion from the meeting just now: this is an internal implementation detail, not something that is meant to be publicly accessible. The aim is still that all public DataTree objects are de-duplicated.

The issue is that the natural place to put de-duplication in is in ._post_attach, but ._post_attach doesn't accept any parameters (so we can't pass mode). So this global flag & associated context manager are just ways to pass state into the ._post_attach method.

@shoyer maybe the least-worst thing to do here is just to change the signature of ._post_attach to

def _post_attach(self: Tree, parent: Tree, name: str, mode: bool | None) -> None:

where TreeNode._post_attach doesn't use mode?

if mode not in CONFLICTING_COORDS_MODES:
raise ValueError(
f"conflicting coordinates mode must be in {CONFLICTING_COORDS_MODES}, got {mode!r}"
)
assert self.parent is None
original_mode = self._conflicting_coords_mode
self._conflicting_coords_mode = mode
try:
yield
finally:
self._conflicting_coords_mode = original_mode

def _dedup_inherited_coordinates(self):
removed_something = False
for name in self._coord_variables.parents:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a moment here to realise that .parents was ChainMap.parents, not DataTree.parents.

if name in self._node_coord_variables:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the condition for considering two coordinates to be duplicates is just that their names are the same? Isn't that going to cause problems when you do want to override a parent's coordinate with a new coordinate that has the same name but different values?

indexed = name in self._node_indexes
if indexed:
del self._node_indexes[name]
elif self.root._conflicting_coords_mode == "error":
# error only for non-indexed coordinates, which are not
# checked for equality
raise ValueError(
f"coordinate {name!r} on node {self.path!r} is also found on a parent node"
)

del self._node_coord_variables[name]
removed_something = True

if removed_something:
self._node_dims = calculate_dimensions(
self._data_variables | self._node_coord_variables
)

for child in self._children.values():
child._dedup_inherited_coordinates()

def _post_attach(self: DataTree, parent: DataTree, name: str) -> None:
super()._post_attach(parent, name)
self._dedup_inherited_coordinates()

@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
Expand Down Expand Up @@ -1055,6 +1112,7 @@ def from_dict(
d: Mapping[str, Dataset | DataTree | None],
/,
name: str | None = None,
conflicting_coords: ConflictingCoordsMode = "ignore",
) -> DataTree:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.
Expand All @@ -1070,6 +1128,11 @@ def from_dict(
To assign data to the root node of the tree use "/" as the path.
name : Hashable | None, optional
Name for the root node of the tree. Default is None.
conflicting_coords : "ignore" or "error"
How to handle repeated coordinates without an associated index
between parent and child nodes. By default, such coordinates are
dropped from child nodes. Alternatively, an error can be raise when
they are encountered.

Returns
-------
Expand Down Expand Up @@ -1097,23 +1160,27 @@ def depth(item) -> int:
return len(NodePath(pathstr).parts)

if d_cast:
# Populate tree with children determined from data_objects mapping
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
for path, data in sorted(d_cast.items(), key=depth):
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
path,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)
# Populate tree with children determined from data_objects mapping.
# Ensure the result object has the right handling of conflicting
# coordinates.
with obj._with_conflicting_coords_mode(conflicting_coords):
# Sort keys by depth so as to insert nodes from root first (see
# GH issue #9276)
for path, data in sorted(d_cast.items(), key=depth):
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
path,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)

return obj

Expand Down
44 changes: 42 additions & 2 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,37 @@ def test_array_values(self):
with pytest.raises(TypeError):
DataTree.from_dict(data) # type: ignore[arg-type]

def test_conflicting_coords(self):
tree = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1}),
"/sub": xr.Dataset(coords={"x": 2}),
},
conflicting_coords="ignore",
)
expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1}),
"/sub": xr.Dataset(),
}
)
assert_identical(tree, expected)

with pytest.raises(
ValueError,
match="coordinate 'x' on node '/sub' is also found on a parent node",
):
DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1}),
"/sub": xr.Dataset(coords={"x": 2}),
},
conflicting_coords="error",
)

with pytest.raises(ValueError, match="conflicting coordinates mode must be in"):
DataTree.from_dict({"/sub": xr.Dataset()}, conflicting_coords="invalid") # type: ignore


class TestDatasetView:
def test_view_contents(self):
Expand Down Expand Up @@ -1184,11 +1215,11 @@ def test_inherited_coords_override(self):
)
assert dt.coords.keys() == {"x", "y"}
root_coords = {"x": 1, "y": 2}
sub_coords = {"x": 4, "y": 2, "z": 3}
sub_coords = {"x": 1, "y": 2, "z": 3}
xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords))
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords))
assert dt["/b"].coords.keys() == {"x", "y", "z"}
xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords))
xr.testing.assert_equal(dt["/b/x"], xr.DataArray(1, coords=sub_coords))
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))

Expand Down Expand Up @@ -1536,6 +1567,15 @@ def test_binary_op_on_datatree(self):
result = dt * dt # type: ignore[operator]
assert_equal(result, expected)

def test_arithmetic_inherited_coords(self):
tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]}))
tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])}))
actual = 2 * tree
expected = tree.copy()
expected["/foo/bar"].data = np.array([8, 10, 12])
assert "x" not in tree.children["foo"].to_dataset(inherited=False).coords
assert_identical(actual, tree)


class TestUFuncs:
def test_tree(self, create_test_datatree):
Expand Down
Loading