Skip to content

Commit

Permalink
Add structure for aggregated data AggregatedBinnedAlignedSpikes (#16)
Browse files Browse the repository at this point in the history
* test aggregated binned aligned spikes

* add method for getting the data corresponding to single event

* ruff

* mistake

* add timestamps

* event_timestamps to timestamps

* add round trip test

* add method for sorting

* add auxiliar method, not sort by default

* documentation first draft

* readme review
  • Loading branch information
h-mayorquin authored Aug 21, 2024
1 parent 78fa053 commit 3247ad5
Show file tree
Hide file tree
Showing 7 changed files with 717 additions and 16 deletions.
130 changes: 130 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,135 @@ binned_aligned_spikes = BinnedAlignedSpikes(

As with the previous example this can be then added to a processing module in an NWB file and written to disk using exactly the same code as before.

### Storing data from multiple events together
In experiments where multiple stimuli are presented to a subject within a single session, it is often useful to store the aggregated spike counts from all events in a single object. For such cases, the `AggregatedBinnedAlignedSpikes` object is ideal. This object functions similarly to the `BinnedAlignedSpikes` object but is designed to store data from multiple events (e.g., different stimuli) together.

Since events may not occur the same number of times, an homogeneous data structure is not possible. Therefore the `AggregatedBinnedAlignedSpikes` object includes an additional variable, event_indices, to indicate which event each set of counts corresponds to. You can create this object as follows:


```python
from ndx_binned_spikes import AggregatedBinnedAlignedSpikes

aggregated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes(
bin_width_in_milliseconds=bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin,
data=data, # Shape (number_of_units, aggregated_events_counts, number_of_bins)
timestamps=timestamps, # As many timestamps as the second dimension of data
event_indices=event_indices, # An index indicating which event each of the counts corresponds to
)
```

The `aggregated_events_counts` represents the total number of repetitions for all the events being aggregated. For example, if data is being aggregated from two stimuli where the first stimulus appeared twice and the second appeared three times, the aggregated_events_counts would be 5.

The `event_indices` is an indicator vector that should be constructed so that `data[:, event_indices == event_index, :]` corresponds to the binned spike counts around the event with the specified event_index. You can retrieve the same data using the convenience method `aggregated_binned_aligned_spikes.get_data_for_event(event_index)`.

It's important to note that the timestamps must be in ascending order and must correspond positionally to the event indices and the second dimension of the data. If they are not, a ValueError will be raised. To help organize the data correctly, you can use the convenience method `AggregatedBinnedAlignedSpikes.sort_data_by_timestamps(data=data, timestamps=timestamps, event_indices=event_indices)`, which ensures the data is properly sorted. Here’s how it can be used:

```python
sorted_data, sorted_timestamps, sorted_event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_timestamps(data=data, timestamps=timestamps, event_indices=event_indices)

aggregated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes(
bin_width_in_milliseconds=bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin,
data=sorted_data,
timestamps=sorted_timestamps,
event_indices=sorted_event_indices,
)
```

The same can be achieved by using the following script:

```python
sorted_indices = np.argsort(timestamps)
sorted_data = data[:, sorted_indices, :]
sorted_timestamps = timestamps[sorted_indices]
sorted_event_indices = event_indices[sorted_indices]

aggregated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes(
bin_width_in_milliseconds=bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin,
data=sorted_data,
timestamps=sorted_timestamps,
event_indices=sorted_event_indices,
)
```

#### Example of building an `AggregatedBinnedAlignedSpikes` object from scratch

To better understand how this object works, let's consider a specific example. Suppose we have data for two distinct events (such as two different stimuli) and their associated timestamps, similar to the `BinnedAlignedSpikes` examples mentioned earlier:

```python
import numpy as np

# Two units and 4 bins
data_for_first_event = np.array(
[
# Unit 1
[
[0, 1, 2, 3], # Bin counts around the first timestamp
[4, 5, 6, 7], # Bin counts around the second timestamp
],
# Unit 2
[
[8, 9, 10, 11], # Bin counts around the first timestamp
[12, 13, 14, 15], # Bin counts around the second timestamp
],
],
)

# Also two units and 4 bins but this event appeared three times
data_for_second_event = np.array(
[
# Unit 1
[
[0, 1, 2, 3], # Bin counts around the first timestamp
[4, 5, 6, 7], # Bin counts around the second timestamp
[8, 9, 10, 11], # Bin counts around the third timestamp
],
# Unit 2
[
[12, 13, 14, 15], # Bin counts around the first timestamp
[16, 17, 18, 19], # Bin counts around the second timestamp
[20, 21, 22, 23], # Bin counts around the third timestamp
],
]
)

timestamps_first_event = [5.0, 15.0]
timestamps_second_event = [1.0, 10.0, 20.0]
```

The way that we would build the data for the `AggregatedBinnedAlignedSpikes` object is as follows:

```python
from ndx_binned_spikes import AggregatedBinnedAlignedSpikes

bin_width_in_milliseconds = 100.0
milliseconds_from_event_to_first_bin = -50.0

data = np.concatenate([data_for_first_event, data_for_second_event], axis=1)
timestamps = np.concatenate([timestamps_first_event, timestamps_second_event])
event_indices = np.concatenate([np.zeros(2), np.ones(3)])

sorted_data, sorted_timestamps, sorted_event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_timestamps(data=data, timestamps=timestamps, event_indices=event_indices)

aggregated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes(
bin_width_in_milliseconds=bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin,
data=sorted_data,
timestamps=sorted_timestamps,
event_indices=sorted_event_indices,
)
```

Then we can recover the original data by calling the `get_data_for_event` method:

```python
retrieved_data_for_first_event = aggregated_binned_aligned_spikes.get_data_for_stimuli(event_index=0)
np.testing.assert_array_equal(retrieved_data_for_first_event, data_for_first_event)
```

The `AggregatedBinnedAlignedSpikes` object can be added to a processing module in an NWB file and written to disk using the same code as before. Plus, a region of the `Units` table can be linked to the `AggregatedBinnedAlignedSpikes` object in the same way as it was done for the `BinnedAlignedSpikes` object.

---
This extension was created using [ndx-template](https://github.com/nwb-extensions/ndx-template).
58 changes: 58 additions & 0 deletions spec/ndx-binned-spikes.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,61 @@ groups:
neurodata_type_inc: DynamicTableRegion
doc: A reference to the Units table region that contains the units of the data.
quantity: '?'
- neurodata_type_def: AggregatedBinnedAlignedSpikes
neurodata_type_inc: NWBDataInterface
default_name: AggregatedBinnedAlignedSpikes
doc: A data interface for aggregated binned spike data aligned to multiple events.
The data for all the events is concatenated along the second dimension and a second
array, event_indices, is used to keep track of which event each row of the data
corresponds to.
attributes:
- name: name
dtype: text
value: BinnedAlignedSpikes
doc: The name of this container
- name: description
dtype: text
value: Spikes data binned and aligned to the timestamps of multiple events.
doc: A description of what the data represents
- name: bin_width_in_milliseconds
dtype: float64
doc: The length in milliseconds of the bins
- name: milliseconds_from_event_to_first_bin
dtype: float64
default_value: 0.0
doc: The time in milliseconds from the event to the beginning of the first bin.
A negative value indicatesthat the first bin is before the event whereas a positive
value indicates that the first bin is after the event.
required: false
datasets:
- name: data
dtype: numeric
dims:
- num_units
- number_of_events
- number_of_bins
shape:
- null
- null
- null
doc: The binned data. It should be an array whose first dimension is the number
of units, the second dimension is the total number of events of all stimuli,
and the third dimension is the number of bins.
- name: event_indices
dtype: int64
dims:
- number_of_events
shape:
- null
doc: The index of the event that each row of the data corresponds to.
- name: timestamps
dtype: float64
dims:
- number_of_events
shape:
- null
doc: The timestamps at which the events occurred.
- name: units_region
neurodata_type_inc: DynamicTableRegion
doc: A reference to the Units table region that contains the units of the data.
quantity: '?'
146 changes: 136 additions & 10 deletions src/pynwb/ndx_binned_spikes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import os

import numpy as np
from typing import Tuple
from pynwb import load_namespaces, get_class
from pynwb import register_class
from pynwb.core import NWBDataInterface
from hdmf.utils import docval
from hdmf.common import DynamicTableRegion

try:
from importlib.resources import files
except ImportError:
# TODO: Remove when python 3.9 becomes the new minimum
from importlib_resources import files
from importlib.resources import files


# Get path to the namespace.yaml file with the expected location when installed not in editable mode
__location_of_this_file = files(__name__)
Expand All @@ -35,7 +33,7 @@ class BinnedAlignedSpikes(NWBDataInterface):
"milliseconds_from_event_to_first_bin",
"data",
"event_timestamps",
{"name":"units_region", "child":True},
{"name": "units_region", "child": True},
)

DEFAULT_NAME = "BinnedAlignedSpikes"
Expand Down Expand Up @@ -65,7 +63,7 @@ class BinnedAlignedSpikes(NWBDataInterface):
"doc": (
"The time in milliseconds from the event to the beginning of the first bin. A negative value indicates"
"that the first bin is before the event whereas a positive value indicates that the first bin is "
"after the event."
"after the event."
),
"default": 0.0,
},
Expand Down Expand Up @@ -100,15 +98,143 @@ def __init__(self, **kwargs):
raise ValueError("The number of event timestamps must match the number of event repetitions in the data.")

super().__init__(name=kwargs["name"])

name = kwargs.pop("name")
super().__init__(name=name)

for key in kwargs:
setattr(self, key, kwargs[key])



@register_class(neurodata_type="AggregatedBinnedAlignedSpikes", namespace="ndx-binned-spikes") # noqa
class AggregatedBinnedAlignedSpikes(NWBDataInterface):
__nwbfields__ = (
"name",
"description",
"bin_width_in_milliseconds",
"milliseconds_from_event_to_first_bin",
"data",
"timestamps",
"event_indices",
{"name": "units_region", "child": True}, # TODO, I forgot why this is included
)

DEFAULT_NAME = "AggregatedBinnedAlignedSpikes"
DEFAULT_DESCRIPTION = "Spikes data binned and aligned to the timestamps of multiple events."

@docval(
{
"name": "name",
"type": str,
"doc": "The name of this container",
"default": DEFAULT_NAME,
},
{
"name": "description",
"type": str,
"doc": "A description of what the data represents",
"default": DEFAULT_DESCRIPTION,
},
{
"name": "bin_width_in_milliseconds",
"type": float,
"doc": "The length in milliseconds of the bins",
},
{
"name": "milliseconds_from_event_to_first_bin",
"type": float,
"doc": (
"The time in milliseconds from the event to the beginning of the first bin. A negative value indicates"
"that the first bin is before the event whereas a positive value indicates that the first bin is "
"after the event."
),
"default": 0.0,
},
{
"name": "data",
"type": "array_data",
"shape": [(None, None, None)],
"doc": (
"The binned data. It should be an array whose first dimension is the number of units, "
"the second dimension is the number of events, and the third dimension is the number of bins."
),
},
{
"name": "timestamps",
"type": "array_data",
"doc": (
"The timestamps at which the events occurred. It is assumed that they map positionally to "
"the second index of the data.",
),
"shape": (None,),
},
{
"name": "event_indices",
"type": "array_data",
"doc": "The timestamps at which the events occurred.",
"shape": (None,),
},
{
"name": "units_region",
"type": DynamicTableRegion,
"doc": "A reference to the Units table region that contains the units of the data.",
"default": None,
},
)
def __init__(self, **kwargs):

name = kwargs.pop("name")
super().__init__(name=name)

timestamps = kwargs["timestamps"]
event_indices = kwargs["event_indices"]
data = kwargs["data"]

assert data.shape[1] == timestamps.shape[0], "The number of timestamps must match the second axis of data."
assert event_indices.shape[0] == timestamps.shape[0], "The number of timestamps must match the event_indices."

# Assert timestamps are monotonically increasing
if not np.all(np.diff(kwargs["timestamps"]) >= 0):
error_msg = (
"The timestamps must be monotonically increasing and the data and event_indices "
"must be sorted by timestamps. Use the `sort_data_by_timestamps` method to do this "
"automatically before passing the data to the constructor."
)
raise ValueError(error_msg)

for key in kwargs:
setattr(self, key, kwargs[key])

# Should this return an instance of BinnedAlignedSpikes or just the data as it is?
# Going with the simple one for the moment
def get_data_for_stimuli(self, event_index):

mask = self.event_indices == event_index
binned_spikes_for_unit = self.data[:, mask, :]

return binned_spikes_for_unit

def get_timestamps_for_stimuli(self, event_index):

mask = self.event_indices == event_index
timestamps = self.timestamps[mask]

return timestamps

@staticmethod
def sort_data_by_timestamps(
data: np.ndarray,
timestamps: np.ndarray,
event_indices: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

sorted_indices = np.argsort(timestamps)
data = data[:, sorted_indices, :]
timestamps = timestamps[sorted_indices]
event_indices = event_indices[sorted_indices]

return data, timestamps, event_indices


# Remove these functions from the package
del load_namespaces, get_class
Loading

0 comments on commit 3247ad5

Please sign in to comment.