diff --git a/README.md b/README.md index 8dd68e1..ba465db 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ from datetime import datetime from zoneinfo import ZoneInfo from pynwb import NWBHDF5IO, NWBFile -session_description = "A session of data where PSTH was produced" +session_description = "A session of data where a PSTH structure was produced" session_start_time = datetime.now(ZoneInfo("Asia/Ulaanbaatar")) identifier = "a_session_identifier" nwbfile = NWBFile( @@ -95,7 +95,7 @@ Note that in the diagram above, the `milliseconds_from_event_to_first_bin` is ne The `data` argument passed to the `BinnedAlignedSpikes` stores counts across all the event timestamps for each of the units. The data is a 3D array where the first dimension indexes the units, the second dimension indexes the event timestamps, and the third dimension indexes the bins where the counts are stored. The shape of the data is `(number_of_units`, `number_of_events`, `number_of_bins`). -The `event_timestamps` is used to store the timestamps of the events and should have the same length as the second dimension of `data`. +The `event_timestamps` argument is used to store the timestamps of the events and should have the same length as the second dimension of `data`. Note that the event_timestamps should not decrease or in other words the events are expected to be in ascending order in time. The first dimension of `data` works almost like a dictionary. That is, you select a specific unit by indexing the first dimension. For example, `data[0]` would return the data of the first unit. For each of the units, the data is organized with the time on the first axis as this is the convention in the NWB format. As a consequence of this choice the data of each unit is contiguous in memory. @@ -106,7 +106,7 @@ The following diagram illustrates the structure of the data for a concrete examp ### Linking to units table -One way to make the information stored in the `BinnedAlignedSpikes` object more useful is to indicate exactly which units or neurons the first dimension of the `data` attribute corresponds to. This is **optional but recommended** as it makes the data more interpretable and useful for future users. In NWB the units are usually stored in a `Units` [table](https://pynwb.readthedocs.io/en/stable/pynwb.misc.html#pynwb.misc.Units). To illustrate how to to create this link let's first create a toy `Units` table: +One way to make the information stored in the `BinnedAlignedSpikes` object more useful for future users is to indicate exactly which units or neurons the first dimension of the `data` attribute corresponds to. This is **optional but recommended** as it makes the data more meaningful and easier to interpret. In NWB the units are usually stored in a `Units` [table](https://pynwb.readthedocs.io/en/stable/pynwb.misc.html#pynwb.misc.Units). To illustrate how to to create this link let's first create a toy `Units` table: ```python import numpy as np @@ -177,70 +177,64 @@ binned_aligned_spikes = BinnedAlignedSpikes( ``` -As with the previous example this can be then added to a processing module in an NWB file and written to disk using exactly the same code as before. +As with the previous example this can be then added to a processing module in an NWB file and then written to disk using exactly the same code as before. -### Storing data from multiple events together -In experiments where multiple stimuli are presented to a subject within a single session, it is often useful to store the aggregated spike counts from all events in a single object. For such cases, the `AggregatedBinnedAlignedSpikes` object is ideal. This object functions similarly to the `BinnedAlignedSpikes` object but is designed to store data from multiple events (e.g., different stimuli) together. - -Since events may not occur the same number of times, an homogeneous data structure is not possible. Therefore the `AggregatedBinnedAlignedSpikes` object includes an additional variable, event_indices, to indicate which event each set of counts corresponds to. You can create this object as follows: +### Storing data from multiple conditions (i.e. multiple stimuli) +`BinnedAlignedSpikes` can also be used to store data that is aggregated across multiple conditions while at the same time keeping track of which condition each set of counts corresponds to. This is useful when you want to store the spike counts around multiple conditions (e.g., different stimuli, behavioral events, etc.) in a single structure. Since each condition may not occur the same number of times (e.g. different stimuli do not appear in the same frequency), an homogeneous data structure is not possible. Therefore an extra variable, `condition_indices`, is used to indicate which condition each set of counts corresponds to. ```python -from ndx_binned_spikes import AggregatedBinnedAlignedSpikes +from ndx_binned_spikes import BinnedAlignedSpikes -aggregated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes( +binned_aligned_spikes = BinnedAlignedSpikes( bin_width_in_milliseconds=bin_width_in_milliseconds, milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin, - data=data, # Shape (number_of_units, aggregated_events_counts, number_of_bins) - timestamps=timestamps, # As many timestamps as the second dimension of data - event_indices=event_indices, # An index indicating which event each of the counts corresponds to + data=data, # Shape (number_of_units, number_of_events, number_of_bins) + timestamps=timestamps, # Shape (number_of_events,) + condition_indices=condition_indices, # Shape (number_of_events,) + condition_labels=condition_labels, # Shape (number_of_conditions,) or np.unique(condition_indices).size ) ``` -The `aggregated_events_counts` represents the total number of repetitions for all the events being aggregated. For example, if data is being aggregated from two stimuli where the first stimulus appeared twice and the second appeared three times, the aggregated_events_counts would be 5. +Note that `number_of_events` here represents the total number of repetitions for all the conditions being aggregated. For example, if data is being aggregated from two stimuli where the first stimulus appeared twice and the second appeared three times, the `number_of_events` would be 5. + +The `condition_indices` is an indicator vector that should be constructed so that `data[:, condition_indices == condition_index, :]` corresponds to the binned spike counts for the condition with the specified condition_index. You can retrieve the same data using the convenience method `binned_aligned_spikes.get_data_for_condition(condition_index)`. -The `event_indices` is an indicator vector that should be constructed so that `data[:, event_indices == event_index, :]` corresponds to the binned spike counts around the event with the specified event_index. You can retrieve the same data using the convenience method `aggregated_binned_aligned_spikes.get_data_for_event(event_index)`. +The `condition_labels` argument is optional and can be used to store the labels of the conditions. This is meant to help to understand the nature of the conditions -It's important to note that the timestamps must be in ascending order and must correspond positionally to the event indices and the second dimension of the data. If they are not, a ValueError will be raised. To help organize the data correctly, you can use the convenience method `AggregatedBinnedAlignedSpikes.sort_data_by_timestamps(data=data, timestamps=timestamps, event_indices=event_indices)`, which ensures the data is properly sorted. Here’s how it can be used: +It's important to note that the timestamps must be in ascending order and must correspond positionally to the condition indices and the second dimension of the data. If they are not, a ValueError will be raised. To help organize the data correctly, you can use the convenience method `BinnedAlignedSpikes.sort_data_by_event_timestamps(data=data, event_timestamps=event_timestamps, condition_indices=condition_indices)`, which ensures the data is properly sorted. Here’s how it can be used: ```python -sorted_data, sorted_timestamps, sorted_event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_timestamps(data=data, timestamps=timestamps, event_indices=event_indices) +sorted_data, sorted_event_timestamps, sorted_condition_indices = BinnedAlignedSpikes.sort_data_by_event_timestamps(data=data, event_timestamps=event_timestamps, condition_indices=condition_indices) -aggregated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes( +binned_aligned_spikes = BinnedAlignedSpikes( bin_width_in_milliseconds=bin_width_in_milliseconds, milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin, data=sorted_data, - timestamps=sorted_timestamps, - event_indices=sorted_event_indices, + event_timestamps=sorted_event_timestamps, + condition_indices=sorted_condition_indices, + condition_labels=condition_labels ) ``` The same can be achieved by using the following script: ```python -sorted_indices = np.argsort(timestamps) +sorted_indices = np.argsort(event_timestamps) sorted_data = data[:, sorted_indices, :] -sorted_timestamps = timestamps[sorted_indices] -sorted_event_indices = event_indices[sorted_indices] - -aggregated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes( - bin_width_in_milliseconds=bin_width_in_milliseconds, - milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin, - data=sorted_data, - timestamps=sorted_timestamps, - event_indices=sorted_event_indices, -) +sorted_event_timestamps = event_timestamps[sorted_indices] +sorted_condition_indices = condition_indices[sorted_indices] ``` -#### Example of building an `AggregatedBinnedAlignedSpikes` object from scratch +#### Example of building an `BinnedAlignedSpikes` for two conditions -To better understand how this object works, let's consider a specific example. Suppose we have data for two distinct events (such as two different stimuli) and their associated timestamps, similar to the `BinnedAlignedSpikes` examples mentioned earlier: +To better understand how this object works, let's consider a specific example. Suppose we have data for two different stimuli and their associated timestamps: ```python import numpy as np # Two units and 4 bins -data_for_first_event = np.array( +data_for_first_stimuli = np.array( [ # Unit 1 [ @@ -256,7 +250,7 @@ data_for_first_event = np.array( ) # Also two units and 4 bins but this event appeared three times -data_for_second_event = np.array( +data_for_second_stimuli = np.array( [ # Unit 1 [ @@ -273,41 +267,40 @@ data_for_second_event = np.array( ] ) -timestamps_first_event = [5.0, 15.0] -timestamps_second_event = [1.0, 10.0, 20.0] +timestamps_first_stimuli = [5.0, 15.0] +timestamps_second_stimuli = [1.0, 10.0, 20.0] ``` -The way that we would build the data for the `AggregatedBinnedAlignedSpikes` object is as follows: +The way that we would build the data for the `BinnedAlignedSpikes` object is as follows: ```python -from ndx_binned_spikes import AggregatedBinnedAlignedSpikes +from ndx_binned_spikes import BinnedAlignedSpikes bin_width_in_milliseconds = 100.0 milliseconds_from_event_to_first_bin = -50.0 -data = np.concatenate([data_for_first_event, data_for_second_event], axis=1) -timestamps = np.concatenate([timestamps_first_event, timestamps_second_event]) -event_indices = np.concatenate([np.zeros(2), np.ones(3)]) +data = np.concatenate([data_for_first_stimuli, data_for_second_stimuli], axis=1) +event_timestamps = np.concatenate([timestamps_first_stimuli, timestamps_second_stimuli]) +condition_indices = np.concatenate([np.zeros(2), np.ones(3)]) +condition_labels = ["a", "b"] -sorted_data, sorted_timestamps, sorted_event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_timestamps(data=data, timestamps=timestamps, event_indices=event_indices) +sorted_data, sorted_event_timestamps, sorted_condition_indices = BinnedAlignedSpikes.sort_data_by_event_timestamps(data=data, event_timestamps=event_timestamps, condition_indices=condition_indices) -aggregated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes( +binned_aligned_spikes = BinnedAlignedSpikes( bin_width_in_milliseconds=bin_width_in_milliseconds, milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin, data=sorted_data, - timestamps=sorted_timestamps, - event_indices=sorted_event_indices, + event_timestamps=sorted_event_timestamps, + condition_indices=sorted_condition_indices, ) ``` -Then we can recover the original data by calling the `get_data_for_event` method: +Then we can recover the original data by calling the `get_data_for_condition` method: ```python -retrieved_data_for_first_event = aggregated_binned_aligned_spikes.get_data_for_stimuli(event_index=0) -np.testing.assert_array_equal(retrieved_data_for_first_event, data_for_first_event) +retrieved_data_for_first_stimuli = binned_aligned_spikes.get_data_for_condition(condition_index=0) +np.testing.assert_array_equal(retrieved_data_for_first_stimuli, data_for_first_stimuli) ``` -The `AggregatedBinnedAlignedSpikes` object can be added to a processing module in an NWB file and written to disk using the same code as before. Plus, a region of the `Units` table can be linked to the `AggregatedBinnedAlignedSpikes` object in the same way as it was done for the `BinnedAlignedSpikes` object. - --- This extension was created using [ndx-template](https://github.com/nwb-extensions/ndx-template). diff --git a/spec/ndx-binned-spikes.extensions.yaml b/spec/ndx-binned-spikes.extensions.yaml index cd17cd3..3486762 100644 --- a/spec/ndx-binned-spikes.extensions.yaml +++ b/spec/ndx-binned-spikes.extensions.yaml @@ -2,7 +2,7 @@ groups: - neurodata_type_def: BinnedAlignedSpikes neurodata_type_inc: NWBDataInterface default_name: BinnedAlignedSpikes - doc: A data interface for binned spike data aligned to an event (e.g. a stimuli + doc: A data interface for binned spike data aligned to an event (e.g. a stimulus or the beginning of a trial). attributes: - name: name @@ -11,7 +11,8 @@ groups: doc: The name of this container - name: description dtype: text - value: Spikes data binned and aligned to event timestamps. + value: Spikes data binned and aligned to the event timestamps of one or multiple + conditions. doc: A description of what the data represents - name: bin_width_in_milliseconds dtype: float64 @@ -25,7 +26,7 @@ groups: required: false datasets: - name: data - dtype: numeric + dtype: uint64 dims: - num_units - number_of_events @@ -44,64 +45,27 @@ groups: shape: - null doc: The timestamps at which the events occurred. - - name: units_region - neurodata_type_inc: DynamicTableRegion - doc: A reference to the Units table region that contains the units of the data. - quantity: '?' -- neurodata_type_def: AggregatedBinnedAlignedSpikes - neurodata_type_inc: NWBDataInterface - default_name: AggregatedBinnedAlignedSpikes - doc: A data interface for aggregated binned spike data aligned to multiple events. - The data for all the events is concatenated along the second dimension and a second - array, event_indices, is used to keep track of which event each row of the data - corresponds to. - attributes: - - name: name - dtype: text - value: BinnedAlignedSpikes - doc: The name of this container - - name: description - dtype: text - value: Spikes data binned and aligned to the timestamps of multiple events. - doc: A description of what the data represents - - name: bin_width_in_milliseconds - dtype: float64 - doc: The length in milliseconds of the bins - - name: milliseconds_from_event_to_first_bin - dtype: float64 - default_value: 0.0 - doc: The time in milliseconds from the event to the beginning of the first bin. - A negative value indicatesthat the first bin is before the event whereas a positive - value indicates that the first bin is after the event. - required: false - datasets: - - name: data - dtype: numeric - dims: - - num_units - - number_of_events - - number_of_bins - shape: - - null - - null - - null - doc: The binned data. It should be an array whose first dimension is the number - of units, the second dimension is the total number of events of all stimuli, - and the third dimension is the number of bins. - - name: event_indices - dtype: int64 + - name: condition_indices + dtype: uint64 dims: - number_of_events shape: - null - doc: The index of the event that each row of the data corresponds to. - - name: timestamps - dtype: float64 + doc: The index of the condition that each timestamps corresponds to (e.g. a stimulus + type, trial number, category, etc.).This is only used when the data is aligned + to multiple conditions + quantity: '?' + - name: condition_labels + dtype: text dims: - - number_of_events + - number_of_conditions shape: - null - doc: The timestamps at which the events occurred. + doc: The labels of the conditions that the data is aligned to. The size of this + array should match the number of conditions. This is only used when the data + is aligned to multiple conditions. First condition is index 0, second is index + 1, etc. + quantity: '?' - name: units_region neurodata_type_inc: DynamicTableRegion doc: A reference to the Units table region that contains the units of the data. diff --git a/src/pynwb/ndx_binned_spikes/__init__.py b/src/pynwb/ndx_binned_spikes/__init__.py index 12848eb..53dd29c 100644 --- a/src/pynwb/ndx_binned_spikes/__init__.py +++ b/src/pynwb/ndx_binned_spikes/__init__.py @@ -32,12 +32,13 @@ class BinnedAlignedSpikes(NWBDataInterface): "bin_width_in_milliseconds", "milliseconds_from_event_to_first_bin", "data", - "event_timestamps", - {"name": "units_region", "child": True}, + "timestamps", + "condition_indices", + {"name": "units_region", "child": True}, # TODO, I forgot why this is included ) DEFAULT_NAME = "BinnedAlignedSpikes" - DEFAULT_DESCRIPTION = "Spikes data binned and aligned to event timestamps." + DEFAULT_DESCRIPTION = "Spikes data binned and aligned to the event timestamps of one or multiple conditions." @docval( { @@ -79,100 +80,33 @@ class BinnedAlignedSpikes(NWBDataInterface): { "name": "event_timestamps", "type": "array_data", - "doc": "The timestamps at which the events occurred.", - "shape": (None,), - }, - { - "name": "units_region", - "type": DynamicTableRegion, - "doc": "A reference to the Units table region that contains the units of the data.", - "default": None, - }, - ) - def __init__(self, **kwargs): - - data = kwargs["data"] - event_timestamps = kwargs["event_timestamps"] - - if data.shape[1] != event_timestamps.shape[0]: - raise ValueError("The number of event timestamps must match the number of event repetitions in the data.") - - super().__init__(name=kwargs["name"]) - - name = kwargs.pop("name") - super().__init__(name=name) - - for key in kwargs: - setattr(self, key, kwargs[key]) - - -@register_class(neurodata_type="AggregatedBinnedAlignedSpikes", namespace="ndx-binned-spikes") # noqa -class AggregatedBinnedAlignedSpikes(NWBDataInterface): - __nwbfields__ = ( - "name", - "description", - "bin_width_in_milliseconds", - "milliseconds_from_event_to_first_bin", - "data", - "timestamps", - "event_indices", - {"name": "units_region", "child": True}, # TODO, I forgot why this is included - ) - - DEFAULT_NAME = "AggregatedBinnedAlignedSpikes" - DEFAULT_DESCRIPTION = "Spikes data binned and aligned to the timestamps of multiple events." - - @docval( - { - "name": "name", - "type": str, - "doc": "The name of this container", - "default": DEFAULT_NAME, - }, - { - "name": "description", - "type": str, - "doc": "A description of what the data represents", - "default": DEFAULT_DESCRIPTION, - }, - { - "name": "bin_width_in_milliseconds", - "type": float, - "doc": "The length in milliseconds of the bins", - }, - { - "name": "milliseconds_from_event_to_first_bin", - "type": float, "doc": ( - "The time in milliseconds from the event to the beginning of the first bin. A negative value indicates" - "that the first bin is before the event whereas a positive value indicates that the first bin is " - "after the event." + "The timestamps at which the events occurred. It is assumed that they map positionally to " + "the second index of the data.", ), - "default": 0.0, + "shape": (None,), }, { - "name": "data", + "name": "condition_indices", "type": "array_data", - "shape": [(None, None, None)], "doc": ( - "The binned data. It should be an array whose first dimension is the number of units, " - "the second dimension is the number of events, and the third dimension is the number of bins." + "The index of the condition that each entry of `event_timestamps` corresponds to " + "(e.g. a stimuli type, trial number, category, etc.)." + "This is only used when the data is aligned to multiple conditions" ), + "shape": (None,), + "default": None, }, { - "name": "timestamps", + "name":"condition_labels", "type": "array_data", "doc": ( - "The timestamps at which the events occurred. It is assumed that they map positionally to " - "the second index of the data.", + "The labels of the conditions that the data is aligned to. The size of this array should match " + "the number of conditions. This is only used when the data is aligned to multiple conditions. " + "First condition is index 0, second is index 1, etc." ), "shape": (None,), - }, - { - "name": "event_indices", - "type": "array_data", - "doc": "The timestamps at which the events occurred.", - "shape": (None,), + "default": None, }, { "name": "units_region", @@ -186,54 +120,70 @@ def __init__(self, **kwargs): name = kwargs.pop("name") super().__init__(name=name) - timestamps = kwargs["timestamps"] - event_indices = kwargs["event_indices"] + event_timestamps = kwargs["event_timestamps"] data = kwargs["data"] - assert data.shape[1] == timestamps.shape[0], "The number of timestamps must match the second axis of data." - assert event_indices.shape[0] == timestamps.shape[0], "The number of timestamps must match the event_indices." + if data.shape[1] != event_timestamps.shape[0]: + msg = ( + f"The number of event_timestamps must match the second axis of data: \n" + f"event_timestamps.size: {event_timestamps.size} \n" + f"data.shape[1]: {data.shape[1]}" + ) + raise ValueError(msg) # Assert timestamps are monotonically increasing - if not np.all(np.diff(kwargs["timestamps"]) >= 0): + if not np.all(np.diff(kwargs["event_timestamps"]) >= 0): error_msg = ( - "The timestamps must be monotonically increasing and the data and event_indices " - "must be sorted by timestamps. Use the `sort_data_by_timestamps` method to do this " - "automatically before passing the data to the constructor." + "The event_timestamps must be monotonically increasing and the data and condition_indices " + "must be sorted by event_timestamps. Use the `BinnedAlignedSpikes.sort_data_by_timestamps` " + "method to do this automatically before initializing `BinnedAlignedSpikes`." ) raise ValueError(error_msg) + # Condition indices check + condition_indices = kwargs.get("condition_indices", None) + self.has_multiple_conditions = condition_indices is not None + if self.has_multiple_conditions: + assert ( + condition_indices.shape[0] == event_timestamps.shape[0] + ), "The number of event_timestamps must match the condition_indices." + for key in kwargs: setattr(self, key, kwargs[key]) - # Should this return an instance of BinnedAlignedSpikes or just the data as it is? - # Going with the simple one for the moment - def get_data_for_stimuli(self, event_index): + def get_data_for_condition(self, condition_index): - mask = self.event_indices == event_index + if not self.has_multiple_conditions: + return self.data + + mask = self.condition_indices == condition_index binned_spikes_for_unit = self.data[:, mask, :] return binned_spikes_for_unit - def get_timestamps_for_stimuli(self, event_index): + def get_event_timestamps_for_condition(self, condition_index): + + if not self.has_multiple_conditions: + return self.event_timestamps - mask = self.event_indices == event_index - timestamps = self.timestamps[mask] + mask = self.condition_indices == condition_index + event_timestamps = self.event_timestamps[mask] - return timestamps + return event_timestamps @staticmethod - def sort_data_by_timestamps( + def sort_data_by_event_timestamps( data: np.ndarray, - timestamps: np.ndarray, - event_indices: np.ndarray, + event_timestamps: np.ndarray, + condition_indices: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - sorted_indices = np.argsort(timestamps) + sorted_indices = np.argsort(event_timestamps) data = data[:, sorted_indices, :] - timestamps = timestamps[sorted_indices] - event_indices = event_indices[sorted_indices] + event_timestamps = event_timestamps[sorted_indices] + condition_indices = condition_indices[sorted_indices] - return data, timestamps, event_indices + return data, event_timestamps, condition_indices # Remove these functions from the package diff --git a/src/pynwb/ndx_binned_spikes/testing/mock.py b/src/pynwb/ndx_binned_spikes/testing/mock.py index 9f89a5c..939b928 100644 --- a/src/pynwb/ndx_binned_spikes/testing/mock.py +++ b/src/pynwb/ndx_binned_spikes/testing/mock.py @@ -1,99 +1,12 @@ from typing import Optional -from ndx_binned_spikes import BinnedAlignedSpikes, AggregatedBinnedAlignedSpikes +from ndx_binned_spikes import BinnedAlignedSpikes import numpy as np from pynwb import NWBFile from pynwb.misc import Units from hdmf.common import DynamicTableRegion -def mock_BinnedAlignedSpikes( - number_of_units: int = 2, - number_of_events: int = 4, - number_of_bins: int = 3, - bin_width_in_milliseconds: float = 20.0, - milliseconds_from_event_to_first_bin: float = 1.0, - seed: int = 0, - event_timestamps: Optional[np.ndarray] = None, - data: Optional[np.ndarray] = None, - units_region: Optional[DynamicTableRegion] = None, -) -> "BinnedAlignedSpikes": - """ - Generate a mock BinnedAlignedSpikes object with specified parameters or from given data. - - Parameters - ---------- - number_of_units : int, optional - The number of different units (channels, neurons, etc.) to simulate. - number_of_events : int, optional - The number of timestamps of the event that the data is aligned to. - number_of_bins : int, optional - The number of bins. - bin_width_in_milliseconds : float, optional - The width of each bin in milliseconds. - milliseconds_from_event_to_first_bin : float, optional - The time in milliseconds from the event start to the first bin. - seed : int, optional - Seed for the random number generator to ensure reproducibility. - event_timestamps : np.ndarray, optional - An array of timestamps for each event. If not provided, it will be automatically generated. - It should have size `number_of_events`. - data : np.ndarray, optional - A 3D array of shape (number_of_units, number_of_events, number_of_bins) representing - the binned spike data. If provided, it overrides the generation of mock data based on other parameters. - Its shape should match the expected number of units, event repetitions, and bins. - units_region: DynamicTableRegion, optional - A reference to the Units table region that contains the units of the data. - - Returns - ------- - BinnedAlignedSpikes - A mock BinnedAlignedSpikes object populated with the provided or generated data and parameters. - - Raises - ------ - AssertionError - If `event_timestamps` is provided and its shape does not match the expected number of event repetitions. - - Notes - ----- - This function simulates a BinnedAlignedSpikes object, which is typically used for neural data analysis, - representing binned spike counts aligned to specific events. - - Examples - -------- - >>> mock_bas = mock_BinnedAlignedSpikes() - >>> print(mock_bas.data.shape) - (2, 4, 3) - """ - - if data is not None: - number_of_units, number_of_events, number_of_bins = data.shape - else: - rng = np.random.default_rng(seed=seed) - data = rng.integers(low=0, high=100, size=(number_of_units, number_of_events, number_of_bins)) - - if event_timestamps is None: - event_timestamps = np.arange(number_of_events, dtype="float64") - else: - assert ( - event_timestamps.shape[0] == number_of_events - ), "The shape of `event_timestamps` does not match `number_of_events`." - event_timestamps = np.array(event_timestamps, dtype="float64") - - if event_timestamps.shape[0] != data.shape[1]: - raise ValueError("The shape of `event_timestamps` does not match `number_of_events`.") - - binned_aligned_spikes = BinnedAlignedSpikes( - bin_width_in_milliseconds=bin_width_in_milliseconds, - milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin, - data=data, - event_timestamps=event_timestamps, - units_region=units_region, - ) - return binned_aligned_spikes - - # TODO: Remove once pynwb 2.7.0 is released and use the mock class there def mock_Units( num_units: int = 10, @@ -124,34 +37,23 @@ def mock_Units( return units_table -""" -Ok, so for the first structure what we can align to: -- A specific stimulus (the event here is every time the stimulus occurs) -- A column of the trials table (e.g. ). The event here is every trial -- A time column of any dynamic table. -- Add event from ndx-events - -When do we need to aggregate? -Stimulus, because not all stimulus happen the same number of times. -What else is inhomogeneous in this way? A column of the trials table -""" - - -def mock_AggregatedBinnedAlignedSpikes( +def mock_BinnedAlignedSpikes( number_of_units: int = 2, + number_of_events: int = 10, number_of_bins: int = 3, - aggregated_events_counts: int = 5, - number_of_events: int = 2, + number_of_conditions: int = 5, bin_width_in_milliseconds: float = 20.0, milliseconds_from_event_to_first_bin: float = 1.0, seed: int = 0, - event_indices: Optional[np.ndarray] = None, - timestamps: Optional[np.ndarray] = None, + event_timestamps: Optional[np.ndarray] = None, data: Optional[np.ndarray] = None, + condition_indices: Optional[np.ndarray] = None, + condition_labels: Optional[np.ndarray] = None, units_region: Optional[DynamicTableRegion] = None, -) -> "AggregatedBinnedAlignedSpikes": + sort_data: bool = True, +) -> BinnedAlignedSpikes: """ - Generate a mock AggregatedBinnedAlignedSpikes object with specified parameters or from given data. + Generate a mock BinnedAlignedSpikes object with specified parameters or from given data. Parameters ---------- @@ -161,8 +63,8 @@ def mock_AggregatedBinnedAlignedSpikes( The number of timestamps of the event that the data is aligned to. number_of_bins : int, optional The number of bins. - number_of_different_events : int, optional - The number of different events that the data is aligned to. + number_of_conditions : int, optional + The number of different conditions that the data is aligned to. It should be less than `number_of_events`. bin_width_in_milliseconds : float, optional The width of each bin in milliseconds. milliseconds_from_event_to_first_bin : float, optional @@ -173,68 +75,88 @@ def mock_AggregatedBinnedAlignedSpikes( A 3D array of shape (number_of_units, number_of_events, number_of_bins) representing the binned spike data. If provided, it overrides the generation of mock data based on other parameters. Its shape should match the expected number of units, event repetitions, and bins. - timestamps : np.ndarray, optional - An array of timestamps for each event. If not provided, it will be automatically generated. + event_timestamps : np.ndarray, optional + An array of event_timestamps for each event. If not provided, it will be automatically generated. It should have size `number_of_events`. + condition_indices : np.ndarray, optional + An array of indices characterizing each condition. If not provided, it will be automatically generated + from the number of conditions and number of events. It should have size `number_of_events`. + If provided, the `number_of_conditions` parameter will be ignored and the number of conditions will be + inferred from the unique values in `condition_indices`. + condition_labels: np.ndarray, optional + An array of labels for each condition. It should have size `number_of_conditions`. units_region: DynamicTableRegion, optional A reference to the Units table region that contains the units of the data. - event_indices : np.ndarray, optional - An array of indices for each event. If not provided, it will be automatically generated. + sort_data: bool, optional + If True, the data will be sorted by timestamps. Returns ------- - AggregatedBinnedAlignedSpikes - A mock AggregatedBinnedAlignedSpikes object populated with the provided or generated data and parameters. + BinnedAlignedSpikes + A mock BinnedAlignedSpikes object populated with the provided or generated data and parameters. """ if data is not None: - number_of_units, aggregated_events_counts, number_of_bins = data.shape + number_of_units, number_of_events, number_of_bins = data.shape else: rng = np.random.default_rng(seed=seed) - data = rng.integers(low=0, high=100, size=(number_of_units, aggregated_events_counts, number_of_bins)) - - if timestamps is None: - timestamps = np.arange(aggregated_events_counts, dtype="float64") - - if event_indices is None: - event_indices = np.zeros(aggregated_events_counts, dtype=int) - all_indices = np.arange(number_of_events, dtype=int) - - # Ensure all indices appear at least once - event_indices[:number_of_events] = rng.choice(all_indices, size=number_of_events, replace=False) - # Then fill the rest randomly - event_indices[number_of_events:] = rng.choice( - event_indices[:number_of_events], - size=aggregated_events_counts - number_of_events, - replace=True, - ) + data = rng.integers(low=0, high=100, size=(number_of_units, number_of_events, number_of_bins)) # Assert data shapes assertion_msg = ( - "The shape of `data` should be (number_of_units, aggregated_events_counts, number_of_bins), " + "The shape of `data` should be `(number_of_units, number_of_events, number_of_bins)`, " f"The actual shape is {data.shape} \n " - "but {number_of_bins=}, {aggregated_events_counts=}, {number_of_units=} was passed" + f"but {number_of_bins=}, {number_of_events=}, {number_of_units=} was passed" ) - assert data.shape == (number_of_units, aggregated_events_counts, number_of_bins), assertion_msg + assert data.shape == (number_of_units, number_of_events, number_of_bins), assertion_msg + + if event_timestamps is None: + event_timestamps = np.arange(number_of_events, dtype="float64") - if timestamps.shape[0] != aggregated_events_counts: - raise ValueError("The shape of `timestamps` does not match `aggregated_events_counts`.") + if event_timestamps.shape[0] != number_of_events: + raise ValueError("The shape of `event_timestamps` does not match `number_of_events`.") + + if condition_indices is None and number_of_conditions > 0: + + assert ( + number_of_conditions < number_of_events + ), "The number of conditions should be less than the number of events." + + condition_indices = np.zeros(number_of_events, dtype=int) + all_indices = np.arange(number_of_conditions, dtype=int) + + # Ensure all conditions indices appear at least once + condition_indices[:number_of_conditions] = rng.choice(all_indices, size=number_of_conditions, replace=False) + # Then fill the rest with random samples + condition_indices[number_of_conditions:] = rng.choice( + condition_indices[:number_of_events], + size=number_of_events - number_of_conditions, + replace=True, + ) - assert ( - event_indices.shape[0] == aggregated_events_counts - ), "The shape of `event_indices` does not match `aggregated_events_counts`." - event_indices = np.array(event_indices, dtype=int) + + if condition_indices is not None: + number_of_conditions = np.unique(condition_indices).size + + if condition_labels is not None: + condition_labels = np.asarray(condition_labels, dtype="U") + + if condition_labels.size != number_of_conditions: + raise ValueError("The number of condition labels should match the number of conditions.") # Sort the data by timestamps - sorted_indices = np.argsort(timestamps) - data = data[:, sorted_indices, :] - event_indices = event_indices[sorted_indices] + if sort_data: + sorted_indices = np.argsort(event_timestamps) + data = data[:, sorted_indices, :] + if condition_indices is not None: + condition_indices = condition_indices[sorted_indices] - aggreegated_binned_aligned_spikes = AggregatedBinnedAlignedSpikes( + binned_aligned_spikes = BinnedAlignedSpikes( bin_width_in_milliseconds=bin_width_in_milliseconds, milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin, data=data, - timestamps=timestamps, - event_indices=event_indices, + event_timestamps=event_timestamps, + condition_indices=condition_indices, + condition_labels=condition_labels, units_region=units_region, ) - return aggreegated_binned_aligned_spikes + return binned_aligned_spikes diff --git a/src/pynwb/tests/test_aggregated_binned_aligned_spikes.py b/src/pynwb/tests/test_aggregated_binned_aligned_spikes.py deleted file mode 100644 index 0bc6401..0000000 --- a/src/pynwb/tests/test_aggregated_binned_aligned_spikes.py +++ /dev/null @@ -1,203 +0,0 @@ -import numpy as np - -from pynwb import NWBHDF5IO -from pynwb.testing.mock.file import mock_NWBFile -from pynwb.testing import TestCase, remove_test_file - -from ndx_binned_spikes import AggregatedBinnedAlignedSpikes -from ndx_binned_spikes.testing.mock import mock_AggregatedBinnedAlignedSpikes, mock_Units -from hdmf.common import DynamicTableRegion - - -class TestAggregatedBinnedAlignedSpikesConstructor(TestCase): - """Simple unit test for creating a AggregatedBinnedAlignedSpikes.""" - - def setUp(self): - """Set up an NWB file..""" - - self.number_of_units = 2 - self.number_of_bins = 4 - self.number_of_events = 5 - - self.bin_width_in_milliseconds = 20.0 - self.milliseconds_from_event_to_first_bin = -100.0 - - # Two units in total and 4 bins, and event with two timestamps - self.data_for_first_event = np.array( - [ - # Unit 1 data - [ - [0, 1, 2, 3], # Bin counts around the first timestamp - [4, 5, 6, 7], # Bin counts around the second timestamp - ], - # Unit 2 data - [ - [8, 9, 10, 11], # Bin counts around the first timestamp - [12, 13, 14, 15], # Bin counts around the second timestamp - ], - ], - ) - - # Also two units and 4 bins but this event appeared three times - self.data_for_second_event = np.array( - [ - # Unit 1 data - [ - [0, 1, 2, 3], # Bin counts around the first timestamp - [4, 5, 6, 7], # Bin counts around the second timestamp - [8, 9, 10, 11], # Bin counts around the third timestamp - ], - # Unit 2 data - [ - [12, 13, 14, 15], # Bin counts around the first timestamp - [16, 17, 18, 19], # Bin counts around the second timestamp - [20, 21, 22, 23], # Bin counts around the third timestamp - ], - ] - ) - - self.timestamps_first_event = [5.0, 15.0] - self.timestamps_second_event = [0.0, 10.0, 20.0] - - self.event_indices = np.concatenate( - [ - np.full(event_data.shape[1], event_index) - for event_index, event_data in enumerate([self.data_for_first_event, self.data_for_second_event]) - ] - ) - - self.data = np.concatenate([self.data_for_first_event, self.data_for_second_event], axis=1) - self.timestamps = np.concatenate([self.timestamps_first_event, self.timestamps_second_event]) - - self.sorted_indices = np.argsort(self.timestamps) - - def test_constructor(self): - """Test that the constructor for AggregatedBinnedAlignedSpikes sets values as expected.""" - - with self.assertRaises(ValueError): - AggregatedBinnedAlignedSpikes( - bin_width_in_milliseconds=self.bin_width_in_milliseconds, - milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, - data=self.data, - timestamps=self.timestamps, - event_indices=self.event_indices, - ) - - - data, timestamps, event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_timestamps( - self.data, - self.timestamps, - self.event_indices, - ) - - aggregated_binnned_align_spikes = AggregatedBinnedAlignedSpikes( - bin_width_in_milliseconds=self.bin_width_in_milliseconds, - milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, - data=data, - timestamps=timestamps, - event_indices=event_indices, - ) - - np.testing.assert_array_equal(aggregated_binnned_align_spikes.data, self.data[:, self.sorted_indices, :]) - np.testing.assert_array_equal( - aggregated_binnned_align_spikes.event_indices, self.event_indices[self.sorted_indices] - ) - np.testing.assert_array_equal(aggregated_binnned_align_spikes.timestamps, self.timestamps[self.sorted_indices]) - self.assertEqual(aggregated_binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds) - self.assertEqual( - aggregated_binnned_align_spikes.milliseconds_from_event_to_first_bin, - self.milliseconds_from_event_to_first_bin, - ) - - self.assertEqual(aggregated_binnned_align_spikes.data.shape[0], self.number_of_units) - self.assertEqual(aggregated_binnned_align_spikes.data.shape[1], self.number_of_events) - self.assertEqual(aggregated_binnned_align_spikes.data.shape[2], self.number_of_bins) - - def test_get_single_event_data_methods(self): - - - data, timestamps, event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_timestamps( - self.data, - self.timestamps, - self.event_indices, - ) - - aggregated_binnned_align_spikes = AggregatedBinnedAlignedSpikes( - bin_width_in_milliseconds=self.bin_width_in_milliseconds, - milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, - data=data, - timestamps=timestamps, - event_indices=event_indices, - ) - - data_for_stimuli_1 = aggregated_binnned_align_spikes.get_data_for_stimuli(event_index=0) - np.testing.assert_allclose(data_for_stimuli_1, self.data_for_first_event) - - data_for_stimuli_2 = aggregated_binnned_align_spikes.get_data_for_stimuli(event_index=1) - np.testing.assert_allclose(data_for_stimuli_2, self.data_for_second_event) - - timestamps_stimuli_1 = aggregated_binnned_align_spikes.get_timestamps_for_stimuli(event_index=0) - np.testing.assert_allclose(timestamps_stimuli_1, self.timestamps_first_event) - - timestamps_stimuli_2 = aggregated_binnned_align_spikes.get_timestamps_for_stimuli(event_index=1) - np.testing.assert_allclose(timestamps_stimuli_2, self.timestamps_second_event) - - -class TestAggregatedBinnedAlignedSpikesSimpleRoundtrip(TestCase): - """Simple roundtrip test for AggregatedBinnedAlignedSpikes.""" - - def setUp(self): - self.nwbfile = mock_NWBFile() - - self.path = "test.nwb" - - def tearDown(self): - remove_test_file(self.path) - - def test_roundtrip_acquisition(self): - - self.aggregated_binned_aligned_spikes = mock_AggregatedBinnedAlignedSpikes() - - self.nwbfile.add_acquisition(self.aggregated_binned_aligned_spikes) - - with NWBHDF5IO(self.path, mode="w") as io: - io.write(self.nwbfile) - - with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: - read_nwbfile = io.read() - read_container = read_nwbfile.acquisition["AggregatedBinnedAlignedSpikes"] - self.assertContainerEqual(self.aggregated_binned_aligned_spikes, read_container) - - def test_roundtrip_processing_module(self): - self.aggregated_binned_aligned_spikes = mock_AggregatedBinnedAlignedSpikes() - - ecephys_processinng_module = self.nwbfile.create_processing_module(name="ecephys", description="a description") - ecephys_processinng_module.add(self.aggregated_binned_aligned_spikes) - - with NWBHDF5IO(self.path, mode="w") as io: - io.write(self.nwbfile) - - with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: - read_nwbfile = io.read() - read_container = read_nwbfile.processing["ecephys"]["AggregatedBinnedAlignedSpikes"] - self.assertContainerEqual(self.aggregated_binned_aligned_spikes, read_container) - - def test_roundtrip_with_units_table(self): - - units = mock_Units(num_units=3) - self.nwbfile.units = units - region_indices = [0, 3] - units_region = DynamicTableRegion( - data=region_indices, table=units, description="region of units table", name="units_region" - ) - - aggregated_binned_aligned_spikes_with_region = mock_AggregatedBinnedAlignedSpikes(units_region=units_region) - self.nwbfile.add_acquisition(aggregated_binned_aligned_spikes_with_region) - - with NWBHDF5IO(self.path, mode="w") as io: - io.write(self.nwbfile) - - with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: - read_nwbfile = io.read() - read_container = read_nwbfile.acquisition["AggregatedBinnedAlignedSpikes"] - self.assertContainerEqual(aggregated_binned_aligned_spikes_with_region, read_container) diff --git a/src/pynwb/tests/test_binned_aligned_spikes.py b/src/pynwb/tests/test_binned_aligned_spikes.py index 4830476..2582987 100644 --- a/src/pynwb/tests/test_binned_aligned_spikes.py +++ b/src/pynwb/tests/test_binned_aligned_spikes.py @@ -1,5 +1,4 @@ -"""Unit and integration tests for the example BinnedAlignedSpikes extension neurodata type. -""" +"""Unit and integration tests for the example BinnedAlignedSpikes extension neurodata type.""" import numpy as np @@ -62,7 +61,6 @@ def test_constructor(self): def test_constructor_units_region(self): - units_table = Units() units_table.add_column(name="unit_name", description="a readable identifier for the units") @@ -99,7 +97,7 @@ def test_constructor_units_region(self): def test_constructor_inconsistent_timestamps_and_data_error(self): shorter_timestamps = self.event_timestamps[:-1] - + with self.assertRaises(ValueError): BinnedAlignedSpikes( bin_width_in_milliseconds=self.bin_width_in_milliseconds, @@ -107,7 +105,149 @@ def test_constructor_inconsistent_timestamps_and_data_error(self): data=self.data, event_timestamps=shorter_timestamps, ) - + + +class TestBinnedAlignedSpikesMultipleConditions(TestCase): + """Simple unit test for creating a BinnedAlignedSpikes with multiple conditions.""" + + def setUp(self): + """Set up an NWB file..""" + + self.number_of_units = 2 + self.number_of_bins = 4 + self.number_of_events = 5 + self.number_of_conditions = 2 + + self.bin_width_in_milliseconds = 20.0 + self.milliseconds_from_event_to_first_bin = -100.0 + + # Two units in total and 4 bins, and event with two timestamps + self.data_for_first_condition = np.array( + [ + # Unit 1 data + [ + [0, 1, 2, 3], # Bin counts around the first timestamp + [4, 5, 6, 7], # Bin counts around the second timestamp + ], + # Unit 2 data + [ + [8, 9, 10, 11], # Bin counts around the first timestamp + [12, 13, 14, 15], # Bin counts around the second timestamp + ], + ], + ) + + # Also two units and 4 bins but this event appeared three times + self.data_for_second_condition = np.array( + [ + # Unit 1 data + [ + [0, 1, 2, 3], # Bin counts around the first timestamp + [4, 5, 6, 7], # Bin counts around the second timestamp + [8, 9, 10, 11], # Bin counts around the third timestamp + ], + # Unit 2 data + [ + [12, 13, 14, 15], # Bin counts around the first timestamp + [16, 17, 18, 19], # Bin counts around the second timestamp + [20, 21, 22, 23], # Bin counts around the third timestamp + ], + ] + ) + + self.timestamps_first_condition = [5.0, 15.0] + self.timestamps_second_condition = [0.0, 10.0, 20.0] + + data_list = [self.data_for_first_condition, self.data_for_second_condition] + self.data = np.concatenate(data_list, axis=1) + + indices_list = [np.full(data.shape[1], condition_index) for condition_index, data in enumerate(data_list)] + self.condition_indices = np.concatenate(indices_list) + + self.event_timestamps = np.concatenate([self.timestamps_first_condition, self.timestamps_second_condition]) + + self.sorted_indices = np.argsort(self.event_timestamps) + + self.condition_labels = ["first", "second"] + + def test_constructor(self): + """Test that the constructor for BinnedAlignedSpikes sets values as expected.""" + + # Test error if the timestamps are not sorted and/or aligned to conditions + with self.assertRaises(ValueError): + BinnedAlignedSpikes( + bin_width_in_milliseconds=self.bin_width_in_milliseconds, + milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, + data=self.data, + event_timestamps=self.event_timestamps, + condition_indices=self.condition_indices, + ) + + data, event_timestamps, condition_indices = BinnedAlignedSpikes.sort_data_by_event_timestamps( + self.data, + self.event_timestamps, + self.condition_indices, + ) + + aggregated_binnned_align_spikes = BinnedAlignedSpikes( + bin_width_in_milliseconds=self.bin_width_in_milliseconds, + milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, + data=data, + event_timestamps=event_timestamps, + condition_indices=condition_indices, + condition_labels=self.condition_labels, + ) + + np.testing.assert_array_equal(aggregated_binnned_align_spikes.data, self.data[:, self.sorted_indices, :]) + np.testing.assert_array_equal( + aggregated_binnned_align_spikes.condition_indices, self.condition_indices[self.sorted_indices] + ) + np.testing.assert_array_equal( + aggregated_binnned_align_spikes.event_timestamps, self.event_timestamps[self.sorted_indices] + ) + + np.testing.assert_array_equal( + aggregated_binnned_align_spikes.condition_labels, self.condition_labels + ) + + self.assertEqual(aggregated_binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds) + self.assertEqual( + aggregated_binnned_align_spikes.milliseconds_from_event_to_first_bin, + self.milliseconds_from_event_to_first_bin, + ) + + self.assertEqual(aggregated_binnned_align_spikes.data.shape[0], self.number_of_units) + self.assertEqual(aggregated_binnned_align_spikes.data.shape[1], self.number_of_events) + self.assertEqual(aggregated_binnned_align_spikes.data.shape[2], self.number_of_bins) + + def test_get_single_condition_data_methods(self): + + data, event_timestamps, condition_indices = BinnedAlignedSpikes.sort_data_by_event_timestamps( + self.data, + self.event_timestamps, + self.condition_indices, + ) + + aggregated_binnned_align_spikes = BinnedAlignedSpikes( + bin_width_in_milliseconds=self.bin_width_in_milliseconds, + milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, + data=data, + event_timestamps=event_timestamps, + condition_indices=condition_indices, + ) + + data_condition1 = aggregated_binnned_align_spikes.get_data_for_condition(condition_index=0) + np.testing.assert_allclose(data_condition1, self.data_for_first_condition) + + data_condition2 = aggregated_binnned_align_spikes.get_data_for_condition(condition_index=1) + np.testing.assert_allclose(data_condition2, self.data_for_second_condition) + + timestamps_condition1 = aggregated_binnned_align_spikes.get_event_timestamps_for_condition(condition_index=0) + np.testing.assert_allclose(timestamps_condition1, self.timestamps_first_condition) + + timestamps_condition2 = aggregated_binnned_align_spikes.get_event_timestamps_for_condition(condition_index=1) + np.testing.assert_allclose(timestamps_condition2, self.timestamps_second_condition) + class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase): """Simple roundtrip test for BinnedAlignedSpikes.""" @@ -115,7 +255,6 @@ class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase): def setUp(self): self.nwbfile = mock_NWBFile() - self.path = "test.nwb" def tearDown(self): @@ -126,7 +265,9 @@ def test_roundtrip_acquisition(self): Add a BinnedAlignedSpikes to an NWBFile, write it to file, read the file and test that the BinnedAlignedSpikes from the file matches the original BinnedAlignedSpikes. """ - self.binned_aligned_spikes = mock_BinnedAlignedSpikes() + + # Testing here + self.binned_aligned_spikes = mock_BinnedAlignedSpikes(number_of_conditions=3, condition_labels=["a", "b", "c"]) self.nwbfile.add_acquisition(self.binned_aligned_spikes) @@ -164,7 +305,6 @@ def test_roundtrip_with_units_table(self): binned_aligned_spikes_with_region = mock_BinnedAlignedSpikes(units_region=units_region) self.nwbfile.add_acquisition(binned_aligned_spikes_with_region) - with NWBHDF5IO(self.path, mode="w") as io: io.write(self.nwbfile) @@ -172,4 +312,3 @@ def test_roundtrip_with_units_table(self): read_nwbfile = io.read() read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"] self.assertContainerEqual(binned_aligned_spikes_with_region, read_container) - diff --git a/src/spec/create_extension_spec.py b/src/spec/create_extension_spec.py index 3bf1e5d..cfd38c2 100644 --- a/src/spec/create_extension_spec.py +++ b/src/spec/create_extension_spec.py @@ -21,9 +21,7 @@ def main(): ) ns_builder.include_namespace("core") - # TODO: if your extension builds on another extension, include the namespace - # of the other extension below - # ns_builder.include_namespace("ndx-other-extension") + binned_aligned_spikes_data = NWBDatasetSpec( name="data", @@ -31,18 +29,10 @@ def main(): "The binned data. It should be an array whose first dimension is the number of units, the second dimension " "is the number of events, and the third dimension is the number of bins." ), - dtype="numeric", # TODO should this be a uint64? + dtype="uint64", shape=[None, None, None], dims=["num_units", "number_of_events", "number_of_bins"], ) - - event_timestamps = NWBDatasetSpec( - name="event_timestamps", - doc="The timestamps at which the events occurred.", - dtype="float64", - shape=[None], - dims=["number_of_events"], - ) units_region = NWBDatasetSpec( name="units_region", @@ -52,81 +42,46 @@ def main(): ) - binned_aligned_spikes = NWBGroupSpec( - neurodata_type_def="BinnedAlignedSpikes", - neurodata_type_inc="NWBDataInterface", - default_name="BinnedAlignedSpikes", - doc="A data interface for binned spike data aligned to an event (e.g. a stimuli or the beginning of a trial).", - datasets=[binned_aligned_spikes_data, event_timestamps, units_region], - attributes=[ - NWBAttributeSpec( - name="name", - doc="The name of this container", - dtype="text", - value="BinnedAlignedSpikes", - ), - NWBAttributeSpec( - name="description", - doc="A description of what the data represents", - dtype="text", - value="Spikes data binned and aligned to event timestamps.", - ), - NWBAttributeSpec( - name="bin_width_in_milliseconds", - doc="The length in milliseconds of the bins", - dtype="float64", - ), - NWBAttributeSpec( - name="milliseconds_from_event_to_first_bin", - doc=( - "The time in milliseconds from the event to the beginning of the first bin. A negative value indicates" - "that the first bin is before the event whereas a positive value indicates that the first bin is " - "after the event." - ), - dtype="float64", - default_value=0.0, - ) - ], - ) - - aggregated_binned_aligned_spikes_data = NWBDatasetSpec( - name="data", - doc=( - "The binned data. It should be an array whose first dimension is the number of units, the second dimension " - "is the total number of events of all stimuli, and the third dimension is the number of bins." - ), - dtype="numeric", # TODO should this be a uint64? - shape=[None, None, None], - dims=["num_units", "number_of_events", "number_of_bins"], - ) - - timestamps = NWBDatasetSpec( - name="timestamps", + event_timestamps = NWBDatasetSpec( + name="event_timestamps", doc="The timestamps at which the events occurred.", dtype="float64", shape=[None], dims=["number_of_events"], ) - event_indices = NWBDatasetSpec( - name="event_indices", - doc="The index of the event that each row of the data corresponds to.", - dtype="int64", + condition_indices = NWBDatasetSpec( + name="condition_indices", + doc= ( + "The index of the condition that each timestamps corresponds to " + "(e.g. a stimulus type, trial number, category, etc.)." + "This is only used when the data is aligned to multiple conditions" + ), + dtype="uint64", shape=[None], dims=["number_of_events"], + quantity="?", ) - - # TODO: This probably can inherit from the simple class and then add the stimuli index. - aggregated_binned_aligned_spikes = NWBGroupSpec( - neurodata_type_def="AggregatedBinnedAlignedSpikes", - neurodata_type_inc="NWBDataInterface", - default_name="AggregatedBinnedAlignedSpikes", + + condition_labels = NWBDatasetSpec( + name="condition_labels", doc=( - "A data interface for aggregated binned spike data aligned to multiple events. " - "The data for all the events is concatenated along the second dimension and a second array, " - "event_indices, is used to keep track of which event each row of the data corresponds to." - ), - datasets=[aggregated_binned_aligned_spikes_data, event_indices, timestamps, units_region], + "The labels of the conditions that the data is aligned to. The size of this array should match " + "the number of conditions. This is only used when the data is aligned to multiple conditions. " + "First condition is index 0, second is index 1, etc." + ), + dtype="text", + shape=[None], + dims=["number_of_conditions"], + quantity="?", + ) + + binned_aligned_spikes = NWBGroupSpec( + neurodata_type_def="BinnedAlignedSpikes", + neurodata_type_inc="NWBDataInterface", + default_name="BinnedAlignedSpikes", + doc="A data interface for binned spike data aligned to an event (e.g. a stimulus or the beginning of a trial).", + datasets=[binned_aligned_spikes_data, event_timestamps, condition_indices, condition_labels, units_region], attributes=[ NWBAttributeSpec( name="name", @@ -138,7 +93,7 @@ def main(): name="description", doc="A description of what the data represents", dtype="text", - value="Spikes data binned and aligned to the timestamps of multiple events.", + value="Spikes data binned and aligned to the event timestamps of one or multiple conditions.", ), NWBAttributeSpec( name="bin_width_in_milliseconds", @@ -159,8 +114,7 @@ def main(): ) - # TODO: add all of your new data types to this list - new_data_types = [binned_aligned_spikes, aggregated_binned_aligned_spikes] + new_data_types = [binned_aligned_spikes] # export the spec to yaml files in the spec folder output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "spec"))