-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
De-duplicate inherited coordinates & dimensions #9510
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import contextlib | ||
import itertools | ||
import textwrap | ||
from collections import ChainMap | ||
|
@@ -373,6 +374,10 @@ def map( # type: ignore[override] | |
return Dataset(variables) | ||
|
||
|
||
CONFLICTING_COORDS_MODES = {"error", "ignore"} | ||
ConflictingCoordsMode = Literal["ignore", "raise"] | ||
|
||
|
||
class DataTree( | ||
NamedNode, | ||
MappedDatasetMethodsMixin, | ||
|
@@ -414,6 +419,7 @@ class DataTree( | |
_attrs: dict[Hashable, Any] | None | ||
_encoding: dict[Hashable, Any] | None | ||
_close: Callable[[], None] | None | ||
_conflicting_coords_mode: ConflictingCoordsMode | ||
|
||
__slots__ = ( | ||
"_name", | ||
|
@@ -427,6 +433,7 @@ class DataTree( | |
"_attrs", | ||
"_encoding", | ||
"_close", | ||
"_conflicting_coords_mode", | ||
) | ||
|
||
def __init__( | ||
|
@@ -459,9 +466,10 @@ def __init__( | |
-------- | ||
DataTree.from_dict | ||
""" | ||
self._conflicting_coords_mode: str = "ignore" | ||
self._set_node_data(_to_new_dataset(dataset)) | ||
|
||
# comes after setting node data as this will check for clashes between child names and existing variable names | ||
# comes after setting node data as this will check for clashes between | ||
# child names and existing variable names | ||
super().__init__(name=name, children=children) | ||
|
||
def _set_node_data(self, dataset: Dataset): | ||
|
@@ -486,6 +494,55 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: | |
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) | ||
check_alignment(path, node_ds, parent_ds, self.children) | ||
|
||
@contextlib.contextmanager | ||
def _with_conflicting_coords_mode(self, mode: ConflictingCoordsMode): | ||
"""Set handling of duplicated inherited coordinates on this object. | ||
|
||
This private context manager is an indirect way to control the handling | ||
of duplicated inherited coordinates, which is done inside _post_attach() | ||
and hence is difficult to directly parameterize. | ||
""" | ||
if mode not in CONFLICTING_COORDS_MODES: | ||
raise ValueError( | ||
f"conflicting coordinates mode must be in {CONFLICTING_COORDS_MODES}, got {mode!r}" | ||
) | ||
assert self.parent is None | ||
original_mode = self._conflicting_coords_mode | ||
self._conflicting_coords_mode = mode | ||
try: | ||
yield | ||
finally: | ||
self._conflicting_coords_mode = original_mode | ||
|
||
def _dedup_inherited_coordinates(self): | ||
removed_something = False | ||
for name in self._coord_variables.parents: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It took me a moment here to realise that |
||
if name in self._node_coord_variables: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the condition for considering two coordinates to be duplicates is just that their names are the same? Isn't that going to cause problems when you do want to override a parent's coordinate with a new coordinate that has the same name but different values? |
||
indexed = name in self._node_indexes | ||
if indexed: | ||
del self._node_indexes[name] | ||
elif self.root._conflicting_coords_mode == "error": | ||
# error only for non-indexed coordinates, which are not | ||
# checked for equality | ||
raise ValueError( | ||
f"coordinate {name!r} on node {self.path!r} is also found on a parent node" | ||
) | ||
|
||
del self._node_coord_variables[name] | ||
removed_something = True | ||
|
||
if removed_something: | ||
self._node_dims = calculate_dimensions( | ||
self._data_variables | self._node_coord_variables | ||
) | ||
|
||
for child in self._children.values(): | ||
child._dedup_inherited_coordinates() | ||
|
||
def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: | ||
super()._post_attach(parent, name) | ||
self._dedup_inherited_coordinates() | ||
|
||
@property | ||
def _coord_variables(self) -> ChainMap[Hashable, Variable]: | ||
return ChainMap( | ||
|
@@ -1055,6 +1112,7 @@ def from_dict( | |
d: Mapping[str, Dataset | DataTree | None], | ||
/, | ||
name: str | None = None, | ||
conflicting_coords: ConflictingCoordsMode = "ignore", | ||
) -> DataTree: | ||
""" | ||
Create a datatree from a dictionary of data objects, organised by paths into the tree. | ||
|
@@ -1070,6 +1128,11 @@ def from_dict( | |
To assign data to the root node of the tree use "/" as the path. | ||
name : Hashable | None, optional | ||
Name for the root node of the tree. Default is None. | ||
conflicting_coords : "ignore" or "error" | ||
How to handle repeated coordinates without an associated index | ||
between parent and child nodes. By default, such coordinates are | ||
dropped from child nodes. Alternatively, an error can be raise when | ||
they are encountered. | ||
|
||
Returns | ||
------- | ||
|
@@ -1097,23 +1160,27 @@ def depth(item) -> int: | |
return len(NodePath(pathstr).parts) | ||
|
||
if d_cast: | ||
# Populate tree with children determined from data_objects mapping | ||
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276) | ||
for path, data in sorted(d_cast.items(), key=depth): | ||
# Create and set new node | ||
node_name = NodePath(path).name | ||
if isinstance(data, DataTree): | ||
new_node = data.copy() | ||
elif isinstance(data, Dataset) or data is None: | ||
new_node = cls(name=node_name, dataset=data) | ||
else: | ||
raise TypeError(f"invalid values: {data}") | ||
obj._set_item( | ||
path, | ||
new_node, | ||
allow_overwrite=False, | ||
new_nodes_along_path=True, | ||
) | ||
# Populate tree with children determined from data_objects mapping. | ||
# Ensure the result object has the right handling of conflicting | ||
# coordinates. | ||
with obj._with_conflicting_coords_mode(conflicting_coords): | ||
# Sort keys by depth so as to insert nodes from root first (see | ||
# GH issue #9276) | ||
for path, data in sorted(d_cast.items(), key=depth): | ||
# Create and set new node | ||
node_name = NodePath(path).name | ||
if isinstance(data, DataTree): | ||
new_node = data.copy() | ||
elif isinstance(data, Dataset) or data is None: | ||
new_node = cls(name=node_name, dataset=data) | ||
else: | ||
raise TypeError(f"invalid values: {data}") | ||
obj._set_item( | ||
path, | ||
new_node, | ||
allow_overwrite=False, | ||
new_nodes_along_path=True, | ||
) | ||
|
||
return obj | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want to allow two modes for this? My understanding from the conversation in #9475 (comment) is that we would add a flag for use once during initial construction by
open_datatree
, but not a global flag or one that sets state of the tree object. Like I would have thought that this could just be a function/method parameter rather than a contextmanager or an attribute.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summarizing discussion from the meeting just now: this is an internal implementation detail, not something that is meant to be publicly accessible. The aim is still that all public
DataTree
objects are de-duplicated.The issue is that the natural place to put de-duplication in is in
._post_attach
, but._post_attach
doesn't accept any parameters (so we can't passmode
). So this global flag & associated context manager are just ways to pass state into the._post_attach
method.@shoyer maybe the least-worst thing to do here is just to change the signature of
._post_attach
towhere
TreeNode._post_attach
doesn't usemode
?