Skip to content

Commit

Permalink
Stop inheriting non-indexed coordinates for DataTree
Browse files Browse the repository at this point in the history
This is option (4) from #9475 (comment)
  • Loading branch information
shoyer committed Sep 28, 2024
1 parent cde720f commit e8a3c35
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
8 changes: 7 additions & 1 deletion xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ class DataTree(
_cache: dict[str, Any] # used by _CachedAccessor
_data_variables: dict[Hashable, Variable]
_node_coord_variables: dict[Hashable, Variable]
_node_indexed_coord_variables: dict[Hashable, Variable]
_node_dims: dict[Hashable, int]
_node_indexes: dict[Hashable, Index]
_attrs: dict[Hashable, Any] | None
Expand All @@ -451,6 +452,7 @@ class DataTree(
"_cache", # used by _CachedAccessor
"_data_variables",
"_node_coord_variables",
"_node_indexed_coord_variables",
"_node_dims",
"_node_indexes",
"_attrs",
Expand Down Expand Up @@ -498,6 +500,9 @@ def _set_node_data(self, dataset: Dataset):
data_vars, coord_vars = _collect_data_and_coord_variables(dataset)
self._data_variables = data_vars
self._node_coord_variables = coord_vars
self._node_indexed_coord_variables = {
k: coord_vars[k] for k in dataset._indexes
}
self._node_dims = dataset._dims
self._node_indexes = dataset._indexes
self._encoding = dataset._encoding
Expand All @@ -519,7 +524,8 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
self._node_coord_variables, *(p._node_coord_variables for p in self.parents)
self._node_coord_variables,
*(p._node_indexed_coord_variables for p in self.parents),
)

@property
Expand Down
36 changes: 19 additions & 17 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,16 @@ def test_is_hollow(self):


class TestToDataset:
def test_to_dataset(self):
base = xr.Dataset(coords={"a": 1})
sub = xr.Dataset(coords={"b": 2})
def test_to_dataset_inherited(self):
base = xr.Dataset(coords={"a": [1], "b": 2})
sub = xr.Dataset(coords={"c": [3]})
tree = DataTree.from_dict({"/": base, "/sub": sub})
subtree = typing.cast(DataTree, tree["sub"])

assert_identical(tree.to_dataset(inherited=False), base)
assert_identical(subtree.to_dataset(inherited=False), sub)

sub_and_base = xr.Dataset(coords={"a": 1, "b": 2})
sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]}) # no "b"
assert_identical(tree.to_dataset(inherited=True), base)
assert_identical(subtree.to_dataset(inherited=True), sub_and_base)

Expand Down Expand Up @@ -714,7 +714,8 @@ def test_inherited(self):
dt["child"] = DataTree()
child = dt["child"]

assert set(child.coords) == {"x", "y", "a", "b"}
assert set(dt.coords) == {"x", "y", "a", "b"}
assert set(child.coords) == {"x", "y"}

actual = child.copy(deep=True)
actual.coords["x"] = ("x", ["a", "b"])
Expand All @@ -729,7 +730,7 @@ def test_inherited(self):

with pytest.raises(KeyError):
# cannot delete inherited coordinate from child node
del child["b"]
del child["x"]

# TODO requires a fix for #9472
# actual = child.copy(deep=True)
Expand Down Expand Up @@ -1278,22 +1279,23 @@ def test_inherited_coords_index(self):
assert "x" in dt["/b"].coords
xr.testing.assert_identical(dt["/x"], dt["/b/x"])

def test_inherited_coords_override(self):
def test_inherit_only_index_coords(self):
dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 1, "y": 2}),
"/b": xr.Dataset(coords={"x": 4, "z": 3}),
"/": xr.Dataset(coords={"x": [1], "y": 2}),
"/b": xr.Dataset(coords={"z": 3}),
}
)
assert dt.coords.keys() == {"x", "y"}
root_coords = {"x": 1, "y": 2}
sub_coords = {"x": 4, "y": 2, "z": 3}
xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords))
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords))
assert dt["/b"].coords.keys() == {"x", "y", "z"}
xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords))
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))
xr.testing.assert_equal(
dt["/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "y": 2})
)
xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords={"y": 2}))
assert dt["/b"].coords.keys() == {"x", "z"}
xr.testing.assert_equal(
dt["/b/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "z": 3})
)
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords={"z": 3}))

def test_inherited_coords_with_index_are_deduplicated(self):
dt = DataTree.from_dict(
Expand Down

0 comments on commit e8a3c35

Please sign in to comment.