Skip to content

Commit

Permalink
Fix resolution of extension classes that have references (#1183)
Browse files Browse the repository at this point in the history
* Fix resolution of extension classes that have references

* Update changelog

* Remove unnecessary if

* Update CHANGELOG.md

Co-authored-by: Oliver Ruebel <[email protected]>

---------

Co-authored-by: Oliver Ruebel <[email protected]>
  • Loading branch information
rly and oruebel authored Aug 30, 2024
1 parent d378dec commit 1fc6212
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
### Bug fixes
- Fixed issue where scalar datasets with a compound data type were being written as non-scalar datasets @stephprince [#1176](https://github.com/hdmf-dev/hdmf/pull/1176)
- Fixed H5DataIO not exposing `maxshape` on non-dci dsets. @cboulay [#1149](https://github.com/hdmf-dev/hdmf/pull/1149)
- Fixed generation of classes in an extension that contain attributes or datasets storing references to other types defined in the extension.
@rly [#1183](https://github.com/hdmf-dev/hdmf/pull/1183)

## HDMF 3.14.3 (July 29, 2024)

Expand Down
17 changes: 15 additions & 2 deletions src/hdmf/build/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator
from ..container import AbstractContainer, Container, Data
from ..term_set import TypeConfigurator
from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog
from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog, RefSpec
from ..spec.spec import BaseStorageSpec
from ..utils import docval, getargs, ExtenderMeta, get_docval

Expand Down Expand Up @@ -480,6 +480,7 @@ def load_namespaces(self, **kwargs):
load_namespaces here has the advantage of being able to keep track of type dependencies across namespaces.
'''
deps = self.__ns_catalog.load_namespaces(**kwargs)
# register container types for each dependent type in each dependent namespace
for new_ns, ns_deps in deps.items():
for src_ns, types in ns_deps.items():
for dt in types:
Expand Down Expand Up @@ -529,7 +530,7 @@ def get_dt_container_cls(self, **kwargs):
namespace = ns_key
break
if namespace is None:
raise ValueError("Namespace could not be resolved.")
raise ValueError(f"Namespace could not be resolved for data type '{data_type}'.")

cls = self.__get_container_cls(namespace, data_type)

Expand All @@ -549,6 +550,8 @@ def get_dt_container_cls(self, **kwargs):

def __check_dependent_types(self, spec, namespace):
"""Ensure that classes for all types used by this type exist in this namespace and generate them if not.
`spec` should be a GroupSpec or DatasetSpec in the `namespace`
"""
def __check_dependent_types_helper(spec, namespace):
if isinstance(spec, (GroupSpec, DatasetSpec)):
Expand All @@ -564,6 +567,16 @@ def __check_dependent_types_helper(spec, namespace):

if spec.data_type_inc is not None:
self.get_dt_container_cls(spec.data_type_inc, namespace)

# handle attributes that have a reference dtype
for attr_spec in spec.attributes:
if isinstance(attr_spec.dtype, RefSpec):
self.get_dt_container_cls(attr_spec.dtype.target_type, namespace)
# handle datasets that have a reference dtype
if isinstance(spec, DatasetSpec):
if isinstance(spec.dtype, RefSpec):
self.get_dt_container_cls(spec.dtype.target_type, namespace)
# recurse into nested types
if isinstance(spec, GroupSpec):
for child_spec in (spec.groups + spec.datasets + spec.links):
__check_dependent_types_helper(child_spec, namespace)
Expand Down
180 changes: 178 additions & 2 deletions tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from hdmf.build import TypeMap, CustomClassGenerator
from hdmf.build.classgenerator import ClassGenerator, MCIClassGenerator
from hdmf.container import Container, Data, MultiContainerInterface, AbstractContainer
from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, LinkSpec
from hdmf.spec import (
GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, LinkSpec, RefSpec
)
from hdmf.testing import TestCase
from hdmf.utils import get_docval, docval

Expand Down Expand Up @@ -734,9 +736,18 @@ def _build_separate_namespaces(self):
GroupSpec(data_type_inc='Bar', doc='a bar', quantity='?')
]
)
moo_spec = DatasetSpec(
doc='A test dataset that is a 1D array of object references of Baz',
data_type_def='Moo',
shape=(None,),
dtype=RefSpec(
reftype='object',
target_type='Baz'
)
)
create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[baz_spec],
specs=[baz_spec, moo_spec],
output_dir=self.test_dir,
incl_types={
CORE_NAMESPACE: ['Bar'],
Expand Down Expand Up @@ -828,6 +839,171 @@ def test_get_class_include_from_separate_ns_4(self):

self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2)

class TestGetClassObjectReferences(TestCase):

def setUp(self):
self.test_dir = tempfile.mkdtemp()
if os.path.exists(self.test_dir): # start clean
self.tearDown()
os.mkdir(self.test_dir)
self.type_map = TypeMap()

def tearDown(self):
shutil.rmtree(self.test_dir)

def test_get_class_include_dataset_of_references(self):
"""Test that get_class resolves datasets of object references."""
qux_spec = DatasetSpec(
doc='A test extension',
data_type_def='Qux'
)
moo_spec = DatasetSpec(
doc='A test dataset that is a 1D array of object references of Qux',
data_type_def='Moo',
shape=(None,),
dtype=RefSpec(
reftype='object',
target_type='Qux'
),
)

create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[qux_spec, moo_spec],
output_dir=self.test_dir,
incl_types={},
type_map=self.type_map
)
# no types should be resolved to start
assert self.type_map.get_container_classes('ndx-test') == []

self.type_map.get_dt_container_cls('Moo', 'ndx-test')
# now, Moo and Qux should be resolved
assert len(self.type_map.get_container_classes('ndx-test')) == 2
assert "Moo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]

def test_get_class_include_attribute_object_reference(self):
"""Test that get_class resolves data types with an attribute that is an object reference."""
qux_spec = DatasetSpec(
doc='A test extension',
data_type_def='Qux'
)
woo_spec = DatasetSpec(
doc='A test dataset that has a scalar object reference to a Qux',
data_type_def='Woo',
attributes=[
AttributeSpec(
name='attr1',
doc='a string attribute',
dtype=RefSpec(reftype='object', target_type='Qux')
),
]
)
create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[qux_spec, woo_spec],
output_dir=self.test_dir,
incl_types={},
type_map=self.type_map
)
# no types should be resolved to start
assert self.type_map.get_container_classes('ndx-test') == []

self.type_map.get_dt_container_cls('Woo', 'ndx-test')
# now, Woo and Qux should be resolved
assert len(self.type_map.get_container_classes('ndx-test')) == 2
assert "Woo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]

def test_get_class_include_nested_object_reference(self):
"""Test that get_class resolves nested datasets that are object references."""
qux_spec = DatasetSpec(
doc='A test extension',
data_type_def='Qux'
)
spam_spec = DatasetSpec(
doc='A test extension',
data_type_def='Spam',
shape=(None,),
dtype=RefSpec(
reftype='object',
target_type='Qux'
),
)
goo_spec = GroupSpec(
doc='A test dataset that has a nested dataset (Spam) that has a scalar object reference to a Qux',
data_type_def='Goo',
datasets=[
DatasetSpec(
doc='a dataset',
data_type_inc='Spam',
),
],
)

create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[qux_spec, spam_spec, goo_spec],
output_dir=self.test_dir,
incl_types={},
type_map=self.type_map
)
# no types should be resolved to start
assert self.type_map.get_container_classes('ndx-test') == []

self.type_map.get_dt_container_cls('Goo', 'ndx-test')
# now, Goo, Spam, and Qux should be resolved
assert len(self.type_map.get_container_classes('ndx-test')) == 3
assert "Goo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Spam" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]

def test_get_class_include_nested_attribute_object_reference(self):
"""Test that get_class resolves nested datasets that have an attribute that is an object reference."""
qux_spec = DatasetSpec(
doc='A test extension',
data_type_def='Qux'
)
bam_spec = DatasetSpec(
doc='A test extension',
data_type_def='Bam',
attributes=[
AttributeSpec(
name='attr1',
doc='a string attribute',
dtype=RefSpec(reftype='object', target_type='Qux')
),
],
)
boo_spec = GroupSpec(
doc='A test dataset that has a nested dataset (Spam) that has a scalar object reference to a Qux',
data_type_def='Boo',
datasets=[
DatasetSpec(
doc='a dataset',
data_type_inc='Bam',
),
],
)

create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[qux_spec, bam_spec, boo_spec],
output_dir=self.test_dir,
incl_types={},
type_map=self.type_map
)
# no types should be resolved to start
assert self.type_map.get_container_classes('ndx-test') == []

self.type_map.get_dt_container_cls('Boo', 'ndx-test')
# now, Boo, Bam, and Qux should be resolved
assert len(self.type_map.get_container_classes('ndx-test')) == 3
assert "Boo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Bam" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]


class EmptyBar(Container):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/build_tests/test_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_get_dt_container_cls(self):
self.assertIs(ret, Foo)

def test_get_dt_container_cls_no_namespace(self):
with self.assertRaisesWith(ValueError, "Namespace could not be resolved."):
with self.assertRaisesWith(ValueError, "Namespace could not be resolved for data type 'Unknown'."):
self.type_map.get_dt_container_cls(data_type="Unknown")


Expand Down

0 comments on commit 1fc6212

Please sign in to comment.