Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tutorial for units and tests for writing unit region #12

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,80 @@ The following diagram illustrates the structure of the data for a concrete examp
</div>


### Linking to units table
One way to make the information stored in the `BinnedAlignedSpikes` object more useful is to indicate exactly which units or neurons the first dimension of the `data` attribute corresponds to. This is **optional but recommended** as it makes the data more interpretable and useful for future users. In NWB the units are usually stored in a `Units` [table](https://pynwb.readthedocs.io/en/stable/pynwb.misc.html#pynwb.misc.Units). To illustrate how to to create this link let's ffirst create a toy `Units` table:

```python
import numpy as np
from pynwb.misc import Units
from hdmf.common import DynamicTableRegion
from pynwb.testing.mock.file import mock_NWBFile

num_units = 5
max_spikes_per_unit = 10

units_table = Units(name="units")

rng = np.random.default_rng(seed=0)

times = rng.random(size=(num_units, max_spikes_per_unit)).cumsum(axis=1)
spikes_per_unit = rng.integers(1, max_spikes_per_unit, size=num_units)

spike_times = []
for unit_index in range(num_units):

# Not all units have the same number of spikes
spike_times = times[unit_index, : spikes_per_unit[unit_index]]
units_table.add_unit(spike_times=spike_times, unit_name=unit_name)


# We then create a mock NWB file and add the units table
nwbfile = mock_NWBFile()
nwbfile.units = units_table
```

This will create a `Units` table with 5 units. We can then link the `BinnedAlignedSpikes` object to this table by creating a `DynamicTableRegion` object. This allows to be very specific about which units the data in the `BinnedAlignedSpikes` object corresponds to. The following code illustrates how to create the `DynamicTableRegion` object and link it to the `BinnedAlignedSpikes` object:
```python

region_indices = [0, 1, 2]
units_region = DynamicTableRegion(
data=region_indices, table=units_table, description="region of units table", name="units_region"
)


# Now we create the BinnedAlignedSpikes object and link it to the units table
data = np.array(
[
[ # Data of the first unit
[5, 1, 3, 2], # First timestamp bins
[6, 3, 4, 3], # Second timestamp bins
[4, 2, 1, 4], # Third timestamp bins
],
[ # Data of the second unit
[8, 4, 0, 2], # First timestamp bins
[3, 3, 4, 2], # Second timestamp bins
[2, 7, 4, 1], # Third timestamp bins
],
],
)

event_timestamps = np.array([0.25, 5.0, 12.25])
milliseconds_from_event_to_first_bin = -50.0 # The first bin is 50 ms before the event
bin_width_in_milliseconds = 100.0
name = "BinnedAignedSpikesForMyPurpose"
description = "Spike counts that is binned and aligned to events."
binned_aligned_spikes = BinnedAlignedSpikes(
data=data,
event_timestamps=event_timestamps,
bin_width_in_milliseconds=bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin,
description=description,
name=name,
units_region=units_region,
)


```


---
Expand Down
26 changes: 14 additions & 12 deletions spec/ndx-binned-spikes.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,17 @@ groups:
A negative value indicatesthat the first bin is before the event whereas a positive
value indicates that the first bin is after the event.
required: false
- name: units
dtype:
target_type: DynamicTableRegion
reftype: region
doc: A reference to the Units table region that contains the units of the data.
required: false
datasets:
- name: data
dtype: numeric
dims:
- - num_units
- number_of_events
- number_of_bins
- num_units
- number_of_events
- number_of_bins
shape:
- - null
- null
- null
- null
- null
- null
doc: The binned data. It should be an array whose first dimension is the number
of units, the second dimension is the number of events, and the third dimension
is the number of bins.
Expand All @@ -50,3 +44,11 @@ groups:
shape:
- null
doc: The timestamps at which the events occurred.
- name: units_region
neurodata_type_inc: DynamicTableRegion
dims:
- number_of_units
shape:
- null
doc: A reference to the Units table region that contains the units of the data.
quantity: '?'
7 changes: 4 additions & 3 deletions src/pynwb/ndx_binned_spikes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pynwb import register_class
from pynwb.core import NWBDataInterface
from hdmf.utils import docval
from hdmf.common import DynamicTableRegion

try:
from importlib.resources import files
Expand Down Expand Up @@ -34,7 +35,7 @@ class BinnedAlignedSpikes(NWBDataInterface):
"milliseconds_from_event_to_first_bin",
"data",
"event_timestamps",
"units",
{"name":"units_region", "child":True},
)

DEFAULT_NAME = "BinnedAlignedSpikes"
Expand Down Expand Up @@ -84,8 +85,8 @@ class BinnedAlignedSpikes(NWBDataInterface):
"shape": (None,),
},
{
"name": "units",
"type": "DynamicTableRegion",
"name": "units_region",
"type": DynamicTableRegion,
"doc": "A reference to the Units table region that contains the units of the data.",
"default": None,
},
Expand Down
39 changes: 38 additions & 1 deletion src/pynwb/ndx_binned_spikes/testing/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from ndx_binned_spikes import BinnedAlignedSpikes
import numpy as np
from pynwb import NWBFile
from pynwb.misc import Units
from hdmf.common import DynamicTableRegion


def mock_BinnedAlignedSpikes(
Expand All @@ -13,6 +16,7 @@ def mock_BinnedAlignedSpikes(
seed: int = 0,
event_timestamps: Optional[np.ndarray] = None,
data: Optional[np.ndarray] = None,
units_region: Optional[DynamicTableRegion] = None,
) -> "BinnedAlignedSpikes":
"""
Generate a mock BinnedAlignedSpikes object with specified parameters or from given data.
Expand All @@ -22,7 +26,7 @@ def mock_BinnedAlignedSpikes(
number_of_units : int, optional
The number of different units (channels, neurons, etc.) to simulate.
number_of_events : int, optional
The number of times an event is repeated.
The number of timestamps of the event that the data is aligned to.
number_of_bins : int, optional
The number of bins.
bin_width_in_milliseconds : float, optional
Expand All @@ -38,6 +42,8 @@ def mock_BinnedAlignedSpikes(
A 3D array of shape (number_of_units, number_of_events, number_of_bins) representing
the binned spike data. If provided, it overrides the generation of mock data based on other parameters.
Its shape should match the expected number of units, event repetitions, and bins.
units_region: DynamicTableRegion, optional
A reference to the Units table region that contains the units of the data.

Returns
-------
Expand Down Expand Up @@ -83,5 +89,36 @@ def mock_BinnedAlignedSpikes(
milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin,
data=data,
event_timestamps=event_timestamps,
units_region=units_region
)
return binned_aligned_spikes


#TODO: Remove once pynwb 2.7.0 is released and use the mock class there
def mock_Units(
num_units: int = 10,
max_spikes_per_unit: int = 10,
seed: int = 0,
nwbfile: Optional[NWBFile] = None,
) -> Units:

units_table = Units(name="units") # This is for nwbfile.units= mock_Units() to work
units_table.add_column(name="unit_name", description="a readable identifier for the unit")

rng = np.random.default_rng(seed=seed)

times = rng.random(size=(num_units, max_spikes_per_unit)).cumsum(axis=1)
spikes_per_unit = rng.integers(1, max_spikes_per_unit, size=num_units)

spike_times = []
for unit_index in range(num_units):

# Not all units have the same number of spikes
spike_times = times[unit_index, : spikes_per_unit[unit_index]]
unit_name = f"unit_{unit_index}"
units_table.add_unit(spike_times=spike_times, unit_name=unit_name)

if nwbfile is not None:
nwbfile.units = units_table

return units_table
44 changes: 35 additions & 9 deletions src/pynwb/tests/test_binned_aligned_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile
from pynwb.testing import TestCase, remove_test_file

from hdmf.common import DynamicTableRegion
from pynwb.misc import Units
from ndx_binned_spikes import BinnedAlignedSpikes
from ndx_binned_spikes.testing.mock import mock_BinnedAlignedSpikes
from ndx_binned_spikes.testing.mock import mock_BinnedAlignedSpikes, mock_Units


class TestBinnedAlignedSpikesConstructor(TestCase):
Expand Down Expand Up @@ -62,8 +63,7 @@ def test_constructor(self):
self.assertEqual(binned_aligned_spikes.data.shape[2], self.number_of_bins)

def test_constructor_units_region(self):
from pynwb.misc import Units
from hdmf.common import DynamicTableRegion


units_table = Units()
units_table.add_column(name="unit_name", description="a readable identifier for the units")
Expand All @@ -90,11 +90,11 @@ def test_constructor_units_region(self):
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=self.data,
event_timestamps=self.event_timestamps,
units=units_region,
units_region=units_region,
)

unit_table_indices = binned_aligned_spikes.units.data
unit_table_names = binned_aligned_spikes.units.table["unit_name"][unit_table_indices]
unit_table_indices = binned_aligned_spikes.units_region.data
unit_table_names = binned_aligned_spikes.units_region.table["unit_name"][unit_table_indices]

expected_names = [unit_name_a, unit_name_c]
self.assertListEqual(unit_table_names, expected_names)
Expand All @@ -109,14 +109,14 @@ def test_constructor_inconsistent_timestamps_and_data_error(self):
data=self.data,
event_timestamps=shorter_timestamps,
)


class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase):
"""Simple roundtrip test for BinnedAlignedSpikes."""

def setUp(self):
self.nwbfile = mock_NWBFile()

self.binned_aligned_spikes = mock_BinnedAlignedSpikes()

self.path = "test.nwb"

Expand All @@ -128,6 +128,7 @@ def test_roundtrip_acquisition(self):
Add a BinnedAlignedSpikes to an NWBFile, write it to file, read the file
and test that the BinnedAlignedSpikes from the file matches the original BinnedAlignedSpikes.
"""
self.binned_aligned_spikes = mock_BinnedAlignedSpikes()

self.nwbfile.add_acquisition(self.binned_aligned_spikes)

Expand All @@ -136,9 +137,12 @@ def test_roundtrip_acquisition(self):

with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io:
read_nwbfile = io.read()
self.assertContainerEqual(self.binned_aligned_spikes, read_nwbfile.acquisition["BinnedAlignedSpikes"])
read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"]
self.assertContainerEqual(self.binned_aligned_spikes, read_container)

def test_roundtrip_processing_module(self):
self.binned_aligned_spikes = mock_BinnedAlignedSpikes()

ecephys_processinng_module = self.nwbfile.create_processing_module(name="ecephys", description="a description")
ecephys_processinng_module.add(self.binned_aligned_spikes)

Expand All @@ -149,3 +153,25 @@ def test_roundtrip_processing_module(self):
read_nwbfile = io.read()
read_container = read_nwbfile.processing["ecephys"]["BinnedAlignedSpikes"]
self.assertContainerEqual(self.binned_aligned_spikes, read_container)

def test_roundtrip_with_units_table(self):

units = mock_Units(num_units=3)
self.nwbfile.units = units
region_indices = [0, 3]
units_region = DynamicTableRegion(
data=region_indices, table=units, description="region of units table", name="units_region"
)

binned_aligned_spikes_with_region = mock_BinnedAlignedSpikes(units_region=units_region)
self.nwbfile.add_acquisition(binned_aligned_spikes_with_region)


with NWBHDF5IO(self.path, mode="w") as io:
io.write(self.nwbfile)

with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io:
read_nwbfile = io.read()
read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"]
self.assertContainerEqual(binned_aligned_spikes_with_region, read_container)

34 changes: 19 additions & 15 deletions src/spec/create_extension_spec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import os.path

from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, NWBAttributeSpec, NWBRefSpec, NWBDatasetSpec
from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, NWBAttributeSpec, NWBDatasetSpec


def main():
Expand Down Expand Up @@ -32,24 +32,34 @@ def main():
"is the number of events, and the third dimension is the number of bins."
),
dtype="numeric", # TODO should this be a uint64?
shape=[(None, None, None)],
dims=[("num_units", "number_of_events", "number_of_bins")],
shape=[None, None, None],
dims=["num_units", "number_of_events", "number_of_bins"],
)

event_timestamps = NWBDatasetSpec(
name="event_timestamps",
doc="The timestamps at which the events occurred.",
dtype="float64",
shape=(None,),
dims=("number_of_events",),
shape=[None],
dims=["number_of_events"],
)


units_region = NWBDatasetSpec(
name="units_region",
neurodata_type_inc="DynamicTableRegion",
doc="A reference to the Units table region that contains the units of the data.",
shape=[None],
dims=["number_of_units"],
quantity="?",

)

binned_aligned_spikes = NWBGroupSpec(
neurodata_type_def="BinnedAlignedSpikes",
neurodata_type_inc="NWBDataInterface",
default_name="BinnedAlignedSpikes",
doc="A data interface for binned spike data aligned to an event (e.g. a stimuli or the beginning of a trial).",
datasets=[binned_aligned_spikes_data, event_timestamps],
datasets=[binned_aligned_spikes_data, event_timestamps, units_region],
attributes=[
NWBAttributeSpec(
name="name",
Expand All @@ -74,16 +84,10 @@ def main():
"The time in milliseconds from the event to the beginning of the first bin. A negative value indicates"
"that the first bin is before the event whereas a positive value indicates that the first bin is "
"after the event."
),
),
dtype="float64",
default_value=0.0,
),
NWBAttributeSpec(
name="units",
doc="A reference to the Units table region that contains the units of the data.",
required=False,
dtype=NWBRefSpec(target_type="DynamicTableRegion", reftype="region"),
),
)
],
)

Expand Down
Loading