Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forbid slashes in datatree coordinate names #9492

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,10 +846,11 @@ def to_dataset(self) -> Dataset:
def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
) -> None:
from xarray.core.datatree import check_alignment
from xarray.core.datatree import check_alignment, validate_variable_names

# create updated node (`.to_dataset` makes a copy so this doesn't modify in-place)
node_ds = self._data.to_dataset(inherited=False)
validate_variable_names(list(coords.keys()))
node_ds.coords._update_coords(coords, indexes)

# check consistency *before* modifying anything in-place
Expand Down
12 changes: 3 additions & 9 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,7 @@
either_dict_or_kwargs,
maybe_wrap_array,
)
from xarray.core.variable import Variable

try:
from xarray.core.variable import calculate_dimensions
except ImportError:
# for xarray versions 2022.03.0 and earlier
from xarray.core.dataset import calculate_dimensions
from xarray.core.variable import Variable, calculate_dimensions

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -156,7 +150,7 @@ def check_alignment(
check_alignment(child_path, child_ds, base_ds, child.children)


def _check_for_slashes_in_names(variables: Iterable[Hashable]) -> None:
def validate_variable_names(variables: Iterable[Hashable]) -> None:
offending_variable_names = [
name for name in variables if isinstance(name, str) and "/" in name
]
Expand Down Expand Up @@ -465,7 +459,7 @@ def __init__(
super().__init__(name=name, children=children)

def _set_node_data(self, dataset: Dataset):
_check_for_slashes_in_names(dataset.variables)
validate_variable_names(dataset.variables)
data_vars, coord_vars = _collect_data_and_coord_variables(dataset)
self._data_variables = data_vars
self._node_coord_variables = coord_vars
Expand Down
24 changes: 13 additions & 11 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def __init__(self, *pathsegments):
# TODO should we also forbid suffixes to avoid node names with dots in them?


def validate_name(name: str | None) -> None:
if name is not None:
if not isinstance(name, str):
raise TypeError("node name must be a string or None")
if "/" in name:
raise ValueError("node names cannot contain forward slashes")


Tree = TypeVar("Tree", bound="TreeNode")


Expand Down Expand Up @@ -205,6 +213,8 @@ def _check_children(children: Mapping[str, Tree]) -> None:

seen = set()
for name, child in children.items():
validate_name(name)

if not isinstance(child, TreeNode):
raise TypeError(
f"Cannot add object {name}. It is of type {type(child)}, "
Expand Down Expand Up @@ -640,14 +650,6 @@ def same_tree(self, other: Tree) -> bool:
AnyNamedNode = TypeVar("AnyNamedNode", bound="NamedNode")


def _validate_name(name: str | None) -> None:
if name is not None:
if not isinstance(name, str):
raise TypeError("node name must be a string or None")
if "/" in name:
raise ValueError("node names cannot contain forward slashes")


class NamedNode(TreeNode, Generic[Tree]):
"""
A TreeNode which knows its own name.
Expand All @@ -661,7 +663,7 @@ class NamedNode(TreeNode, Generic[Tree]):

def __init__(self, name=None, children=None):
super().__init__(children=children)
_validate_name(name)
validate_name(name)
self._name = name

@property
Expand All @@ -677,7 +679,7 @@ def name(self, name: str | None) -> None:
"Consider creating a detached copy of this node via .copy() "
"on the parent node."
)
_validate_name(name)
validate_name(name)
self._name = name

def __repr__(self, level=0):
Expand All @@ -692,7 +694,7 @@ def __str__(self) -> str:

def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None:
"""Ensures child has name attribute corresponding to key under which it has been stored."""
_validate_name(name) # is this check redundant?
# we have already validated `name`, since it has already been set as the name of another node (the parent)
self._name = name

def _copy_node(
Expand Down
6 changes: 6 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,12 @@ def test_inherited(self):
# expected = child.assign_coords({"c": 11})
# assert_identical(expected, actual)

def test_forbid_paths_as_names(self):
# regression test for GH issue #9485
dt = DataTree(Dataset(coords={"x": 0}), children={"child": DataTree()})
with pytest.raises(ValueError, match="cannot have names containing"):
dt.coords["/child/y"] = 2


def test_delitem():
ds = Dataset({"a": 0}, coords={"x": ("x", [1, 2]), "z": "a"})
Expand Down
39 changes: 39 additions & 0 deletions xarray/tests/test_treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,45 @@ def test_del_child(self):
del john["Mary"]


class TestValidNames:
def test_child_keys(self):
parent: TreeNode = TreeNode()
with pytest.raises(ValueError, match="cannot contain forward slashes"):
parent.children = {"a/b": TreeNode()}

with pytest.raises(TypeError, match="must be a string or None"):
parent.children = {0: TreeNode()} # type: ignore[dict-item]

def test_node_names(self):
with pytest.raises(ValueError, match="cannot contain forward slashes"):
NamedNode(name="a/b")

with pytest.raises(TypeError, match="must be a string or None"):
NamedNode(name=0)

def test_names(self):
nn: NamedNode = NamedNode()
assert nn.name is None

nn = NamedNode(name="foo")
assert nn.name == "foo"

nn.name = "bar"
assert nn.name == "bar"

nn = NamedNode(children={"foo": NamedNode()})
assert nn.children["foo"].name == "foo"
with pytest.raises(
ValueError, match="cannot set the name of a node which already has a parent"
):
nn.children["foo"].name = "bar"

detached = nn.children["foo"].copy()
assert detached.name == "foo"
detached.name = "bar"
assert detached.name == "bar"


def create_test_tree() -> tuple[NamedNode, NamedNode]:
# a
# ├── b
Expand Down
Loading