Skip to content

Commit

Permalink
Merge pull request #6 from catalystneuro/reference_units_table_region
Browse files Browse the repository at this point in the history
Make units a `DynamicTableRegion`
  • Loading branch information
h-mayorquin authored Mar 20, 2024
2 parents 34724b6 + 618adf0 commit 5f1f9cc
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 49 deletions.
6 changes: 3 additions & 3 deletions spec/ndx-binned-spikes.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ groups:
required: false
- name: units
dtype:
target_type: Units
reftype: object
doc: A link to the units Table that contains the units of the data.
target_type: DynamicTableRegion
reftype: region
doc: A reference to the Units table region that contains the units of the data.
required: false
datasets:
- name: data
Expand Down
13 changes: 7 additions & 6 deletions src/pynwb/ndx_binned_spikes/testing/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ndx_binned_spikes import BinnedAlignedSpikes
import numpy as np


def mock_BinnedAlignedSpikes(
number_of_units: int = 2,
number_of_event_repetitions: int = 4,
Expand All @@ -12,7 +13,7 @@ def mock_BinnedAlignedSpikes(
seed: int = 0,
event_timestamps: Optional[np.ndarray] = None,
data: Optional[np.ndarray] = None,
) -> 'BinnedAlignedSpikes':
) -> "BinnedAlignedSpikes":
"""
Generate a mock BinnedAlignedSpikes object with specified parameters or from given data.
Expand Down Expand Up @@ -65,13 +66,13 @@ def mock_BinnedAlignedSpikes(
else:
rng = np.random.default_rng(seed=seed)
data = rng.integers(low=0, high=100, size=(number_of_units, number_of_event_repetitions, number_of_bins))

if event_timestamps is None:
event_timestamps = np.arange(number_of_event_repetitions, dtype="float64")
else:
assert event_timestamps.shape[0] == number_of_event_repetitions, (
"The shape of `event_timestamps` does not match `number_of_event_repetitions`."
)
assert (
event_timestamps.shape[0] == number_of_event_repetitions
), "The shape of `event_timestamps` does not match `number_of_event_repetitions`."
event_timestamps = np.array(event_timestamps, dtype="float64")

if event_timestamps.shape[0] != data.shape[1]:
Expand All @@ -83,4 +84,4 @@ def mock_BinnedAlignedSpikes(
data=data,
event_timestamps=event_timestamps,
)
return binned_aligned_spikes
return binned_aligned_spikes
104 changes: 72 additions & 32 deletions src/pynwb/tests/test_binned_aligned_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,46 +18,90 @@ class TestBinnedAlignedSpikesConstructor(TestCase):

def setUp(self):
"""Set up an NWB file. Necessary because BinnedAlignedSpikes requires references to electrodes."""

self.number_of_units = 2
self.number_of_bins = 3
self.number_of_event_repetitions = 4
self.bin_width_in_milliseconds = 20.0
self.milliseconds_from_event_to_first_bin = -100.0
rng = np.random.default_rng(seed=0)

self.data = rng.integers(
low=0,
high=100,
size=(
self.number_of_units,
self.number_of_event_repetitions,
self.number_of_bins,
),
)

self.event_timestamps = np.arange(self.number_of_event_repetitions, dtype="float64")

self.nwbfile = mock_NWBFile()

def test_constructor(self):
"""Test that the constructor for BinnedAlignedSpikes sets values as expected."""

number_of_units = 2
number_of_bins = 3
number_of_event_repetitions = 4
bin_width_in_milliseconds = 20.0
milliseconds_from_event_to_first_bin = 1.0

rng = np.random.default_rng(seed=0)
data = rng.integers(low=0, high=100, size=(number_of_units, number_of_event_repetitions, number_of_bins))
event_timestamps = np.arange(number_of_event_repetitions, dtype="float64")

binned_aligned_spikes = BinnedAlignedSpikes(
bin_width_in_milliseconds=bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin,
data=data,
event_timestamps=event_timestamps
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=self.data,
event_timestamps=self.event_timestamps,
)
np.testing.assert_array_equal(binned_aligned_spikes.data, data)
np.testing.assert_array_equal(binned_aligned_spikes.event_timestamps, event_timestamps)
self.assertEqual(binned_aligned_spikes.bin_width_in_milliseconds, bin_width_in_milliseconds)

np.testing.assert_array_equal(binned_aligned_spikes.data, self.data)
np.testing.assert_array_equal(binned_aligned_spikes.event_timestamps, self.event_timestamps)
self.assertEqual(binned_aligned_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds)
self.assertEqual(
binned_aligned_spikes.milliseconds_from_event_to_first_bin, milliseconds_from_event_to_first_bin
binned_aligned_spikes.milliseconds_from_event_to_first_bin, self.milliseconds_from_event_to_first_bin
)

self.assertEqual(binned_aligned_spikes.data.shape[0], number_of_units)
self.assertEqual(binned_aligned_spikes.data.shape[1], number_of_event_repetitions)
self.assertEqual(binned_aligned_spikes.data.shape[2], number_of_bins)

self.assertEqual(binned_aligned_spikes.data.shape[0], self.number_of_units)
self.assertEqual(binned_aligned_spikes.data.shape[1], self.number_of_event_repetitions)
self.assertEqual(binned_aligned_spikes.data.shape[2], self.number_of_bins)

def test_constructor_units_region(self):
from pynwb.misc import Units
from hdmf.common import DynamicTableRegion

units_table = Units()
units_table.add_column(name="unit_name", description="a readable identifier for the units")

class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase):
"""Simple roundtrip test for BinnedAlignedSpikes."""
unit_name_a = "a"
spike_times_a = [1.1, 2.2, 3.3]
units_table.add_row(spike_times=spike_times_a, unit_name=unit_name_a)

unit_name_b = "b"
spike_times_b = [4.4, 5.5, 6.6]
units_table.add_row(spike_times=spike_times_b, unit_name=unit_name_b)

unit_name_c = "c"
spike_times_c = [7.7, 8.8, 9.9]
units_table.add_row(spike_times=spike_times_c, unit_name=unit_name_c)

region_indices = [0, 2]
units_region = DynamicTableRegion(
data=region_indices, table=units_table, description="region of units table", name="units_region"
)

binned_aligned_spikes = BinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=self.data,
event_timestamps=self.event_timestamps,
units=units_region,
)

unit_table_indices = binned_aligned_spikes.units.data
unit_table_names = binned_aligned_spikes.units.table["unit_name"][unit_table_indices]

expected_names = [unit_name_a, unit_name_c]
self.assertListEqual(unit_table_names, expected_names)


class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase):
"""Simple roundtrip test for BinnedAlignedSpikes."""

def setUp(self):
self.nwbfile = mock_NWBFile()
Expand Down Expand Up @@ -85,17 +129,13 @@ def test_roundtrip_acquisition(self):
self.assertContainerEqual(self.binned_aligned_spikes, read_nwbfile.acquisition["BinnedAlignedSpikes"])

def test_roundtrip_processing_module(self):


ecephys_processinng_module = self.nwbfile.create_processing_module(
name="ecephys", description="a description"
)
ecephys_processinng_module = self.nwbfile.create_processing_module(name="ecephys", description="a description")
ecephys_processinng_module.add(self.binned_aligned_spikes)

with NWBHDF5IO(self.path, mode="w") as io:
io.write(self.nwbfile)

with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io:
read_nwbfile = io.read()
read_container = read_nwbfile.processing["ecephys"]["BinnedAlignedSpikes"]
self.assertContainerEqual(self.binned_aligned_spikes, read_container)
self.assertContainerEqual(self.binned_aligned_spikes, read_container)
12 changes: 4 additions & 8 deletions src/spec/create_extension_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, NWBAttributeSpec, NWBRefSpec, NWBDatasetSpec


def main():
# these arguments were auto-generated from your cookiecutter inputs
ns_builder = NWBNamespaceBuilder(
Expand All @@ -24,19 +25,14 @@ def main():
# of the other extension below
# ns_builder.include_namespace("ndx-other-extension")

# TODO: define your new data types
# see https://pynwb.readthedocs.io/en/stable/tutorials/general/extensions.html
# for more information


binned_aligned_spikes_data = NWBDatasetSpec(
name="data",
doc="TODO",
dtype="numeric", # TODO should this be a uint64?
shape=[(None, None, None)],
dims=[("num_units", "number_of_event_repetitions", "number_of_bins")],
)

event_timestamps = NWBDatasetSpec(
name="event_timestamps",
doc="The timestamps at which the event occurred.",
Expand Down Expand Up @@ -68,9 +64,9 @@ def main():
),
NWBAttributeSpec(
name="units",
doc="A link to the Units table that contains the units of the data.",
doc="A reference to the Units table region that contains the units of the data.",
required=False,
dtype=NWBRefSpec(target_type="Units", reftype="object"),
dtype=NWBRefSpec(target_type="DynamicTableRegion", reftype="region"),
),
],
)
Expand Down

0 comments on commit 5f1f9cc

Please sign in to comment.