Skip to content

Commit

Permalink
update position validation
Browse files Browse the repository at this point in the history
  • Loading branch information
stephprince committed Sep 4, 2024
1 parent dd5da88 commit 33110b7
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/pynwb/ecephys.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import numpy as np
from collections.abc import Iterable

from hdmf.common import DynamicTableRegion
Expand Down Expand Up @@ -27,25 +28,29 @@ class ElectrodeGroup(NWBContainer):
{'name': 'device', 'type': Device, 'doc': 'the device that was used to record from this electrode group'},
{'name': 'position', 'type': 'array_data',
'doc': 'Compound dataset with stereotaxic position of this electrode group (x, y, z). '
'Each element of the data array must have three elements or the dtype of the '
'The data array must have three elements or the dtype of the '
'array must be ``(float, float, float)``', 'default': None})
def __init__(self, **kwargs):
args_to_set = popargs_to_dict(('description', 'location', 'device', 'position'), kwargs)
super().__init__(**kwargs)

# position is a compound dataset, i.e., this must be an array with a compound data type of three floats
# or is a list/tuple of three entries or a list with a single element with three entries
if args_to_set['position'] is not None and len(args_to_set['position']) > 0:
# If we have a dtype, then check that it is valid length
is_valid_dtype = not hasattr(args_to_set['position'], 'dtype') or len(args_to_set['position'].dtype) == 3

# check data is a valid shape
is_valid_shape = len(args_to_set['position']) == 3 or \
(len(args_to_set['position']) == 1 and len(args_to_set['position'][0]) == 3)

if not is_valid_dtype or not is_valid_shape:
raise ValueError('ElectrodeGroup position dataset must have three components (x, y, z) or [(x, y, z)] '
'but received: %s' % str(args_to_set['position']))
# position is a compound dataset, i.e., this must be a scalar with a
# compound data type of three floats or a list/tuple of three entries
position = args_to_set['position']
if position:
# check position argument is valid
position_dtype_invalid = (
(hasattr(position, 'dtype') and len(position.dtype) != 3) or
(not hasattr(position, 'dtype') and len(position) != 3) or
(len(np.shape(position)) > 1)
)
if position_dtype_invalid:
raise ValueError(f"ElectrodeGroup position argument must have three elements: x, y, z,"
f"but received: {position}")

# convert position to scalar with compound data type if needed
if not hasattr(position, 'dtype'):
args_to_set['position'] = np.array(tuple(position), dtype=[('x', float), ('y', float), ('z', float)])

for key, val in args_to_set.items():
setattr(self, key, val)
Expand Down

0 comments on commit 33110b7

Please sign in to comment.