Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Sep 6, 2024
1 parent 31b2d4c commit a42c6b4
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 36 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ data = np.array(
[2, 7, 4, 1], # Bin counts around the third timestamp
],
],
dtype="uint64",
)

event_timestamps = np.array([0.25, 5.0, 12.25]) # The timestamps to which we align the counts
Expand Down
19 changes: 19 additions & 0 deletions src/pynwb/ndx_binned_spikes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,25 @@ def sort_data_by_event_timestamps(

return data, event_timestamps, condition_indices

@property
def number_of_units(self):
return self.data.shape[0]

@property
def number_of_events(self):
return self.data.shape[1]

@property
def number_of_bins(self):
return self.data.shape[2]


@property
def number_of_conditions(self):
if self.has_multiple_conditions:
return np.unique(self.condition_indices).size
else:
return 1

# Remove these functions from the package
del load_namespaces, get_class
6 changes: 3 additions & 3 deletions src/pynwb/ndx_binned_spikes/testing/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def mock_BinnedAlignedSpikes(
number_of_units, number_of_events, number_of_bins = data.shape
else:
rng = np.random.default_rng(seed=seed)
data = rng.integers(low=0, high=100, size=(number_of_units, number_of_events, number_of_bins))
data = rng.integers(low=0, high=100, size=(number_of_units, number_of_events, number_of_bins), dtype="uint64")

# Assert data shapes
assertion_msg = (
Expand All @@ -121,8 +121,8 @@ def mock_BinnedAlignedSpikes(
number_of_conditions < number_of_events
), "The number of conditions should be less than the number of events."

condition_indices = np.zeros(number_of_events, dtype=int)
all_indices = np.arange(number_of_conditions, dtype=int)
condition_indices = np.zeros(number_of_events, dtype="uint64")
all_indices = np.arange(number_of_conditions, dtype='uint64')

# Ensure all conditions indices appear at least once
condition_indices[:number_of_conditions] = rng.choice(all_indices, size=number_of_conditions, replace=False)
Expand Down
72 changes: 39 additions & 33 deletions src/pynwb/tests/test_binned_aligned_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def test_constructor(self):

np.testing.assert_array_equal(binned_aligned_spikes.data, self.data)
np.testing.assert_array_equal(binned_aligned_spikes.event_timestamps, self.event_timestamps)




self.assertEqual(binned_aligned_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds)
self.assertEqual(
binned_aligned_spikes.milliseconds_from_event_to_first_bin, self.milliseconds_from_event_to_first_bin
Expand Down Expand Up @@ -124,7 +122,7 @@ def setUp(self):
self.bin_width_in_milliseconds = 20.0
self.milliseconds_from_event_to_first_bin = -100.0

# Two units in total and 4 bins, and event with two timestamps
# Two units in total and 4 bins, and condition with two timestamps
self.data_for_first_condition = np.array(
[
# Unit 1 data
Expand All @@ -138,9 +136,11 @@ def setUp(self):
[12, 13, 14, 15], # Bin counts around the second timestamp
],
],
dtype="uint64",

)

# Also two units and 4 bins but this event appeared three times
# Also two units and 4 bins but this condition appeared three times
self.data_for_second_condition = np.array(
[
# Unit 1 data
Expand All @@ -155,7 +155,8 @@ def setUp(self):
[16, 17, 18, 19], # Bin counts around the second timestamp
[20, 21, 22, 23], # Bin counts around the third timestamp
],
]
],
dtype="uint64",
)

self.timestamps_first_condition = [5.0, 15.0]
Expand All @@ -170,7 +171,7 @@ def setUp(self):
self.event_timestamps = np.concatenate([self.timestamps_first_condition, self.timestamps_second_condition])

self.sorted_indices = np.argsort(self.event_timestamps)

self.condition_labels = ["first", "second"]

def test_constructor(self):
Expand All @@ -192,7 +193,7 @@ def test_constructor(self):
self.condition_indices,
)

aggregated_binnned_align_spikes = BinnedAlignedSpikes(
binnned_align_spikes = BinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=data,
Expand All @@ -201,27 +202,23 @@ def test_constructor(self):
condition_labels=self.condition_labels,
)

np.testing.assert_array_equal(aggregated_binnned_align_spikes.data, self.data[:, self.sorted_indices, :])
np.testing.assert_array_equal(
aggregated_binnned_align_spikes.condition_indices, self.condition_indices[self.sorted_indices]
)
np.testing.assert_array_equal(
aggregated_binnned_align_spikes.event_timestamps, self.event_timestamps[self.sorted_indices]
)

np.testing.assert_array_equal(binnned_align_spikes.data, self.data[:, self.sorted_indices, :])
np.testing.assert_array_equal(
aggregated_binnned_align_spikes.condition_labels, self.condition_labels
binnned_align_spikes.condition_indices, self.condition_indices[self.sorted_indices]
)

self.assertEqual(aggregated_binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds)
np.testing.assert_array_equal(binnned_align_spikes.event_timestamps, self.event_timestamps[self.sorted_indices])

np.testing.assert_array_equal(binnned_align_spikes.condition_labels, self.condition_labels)

self.assertEqual(binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds)
self.assertEqual(
aggregated_binnned_align_spikes.milliseconds_from_event_to_first_bin,
binnned_align_spikes.milliseconds_from_event_to_first_bin,
self.milliseconds_from_event_to_first_bin,
)

self.assertEqual(aggregated_binnned_align_spikes.number_of_units, self.number_of_units)
self.assertEqual(aggregated_binnned_align_spikes.number_of_events, self.number_of_events)
self.assertEqual(aggregated_binnned_align_spikes.data.number_of_bins, self.number_of_bins)
self.assertEqual(binnned_align_spikes.number_of_units, self.number_of_units)
self.assertEqual(binnned_align_spikes.number_of_events, self.number_of_events)
self.assertEqual(binnned_align_spikes.number_of_bins, self.number_of_bins)

def test_get_single_condition_data_methods(self):

Expand All @@ -231,24 +228,24 @@ def test_get_single_condition_data_methods(self):
self.condition_indices,
)

aggregated_binnned_align_spikes = BinnedAlignedSpikes(
binnned_align_spikes = BinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=data,
event_timestamps=event_timestamps,
condition_indices=condition_indices,
)

data_condition1 = aggregated_binnned_align_spikes.get_data_for_condition(condition_index=0)
data_condition1 = binnned_align_spikes.get_data_for_condition(condition_index=0)
np.testing.assert_allclose(data_condition1, self.data_for_first_condition)

data_condition2 = aggregated_binnned_align_spikes.get_data_for_condition(condition_index=1)
data_condition2 = binnned_align_spikes.get_data_for_condition(condition_index=1)
np.testing.assert_allclose(data_condition2, self.data_for_second_condition)

timestamps_condition1 = aggregated_binnned_align_spikes.get_event_timestamps_for_condition(condition_index=0)
timestamps_condition1 = binnned_align_spikes.get_event_timestamps_for_condition(condition_index=0)
np.testing.assert_allclose(timestamps_condition1, self.timestamps_first_condition)

timestamps_condition2 = aggregated_binnned_align_spikes.get_event_timestamps_for_condition(condition_index=1)
timestamps_condition2 = binnned_align_spikes.get_event_timestamps_for_condition(condition_index=1)
np.testing.assert_allclose(timestamps_condition2, self.timestamps_second_condition)


Expand All @@ -268,12 +265,20 @@ def test_roundtrip_acquisition(self):
Add a BinnedAlignedSpikes to an NWBFile, write it to file, read the file
and test that the BinnedAlignedSpikes from the file matches the original BinnedAlignedSpikes.
"""

# Testing here
number_of_units = 5
number_of_bins = 10
number_of_events = 100
self.binned_aligned_spikes = mock_BinnedAlignedSpikes(number_of_conditions=3, condition_labels=["a", "b", "c"])
number_of_conditions = 3
condition_labels = ["a", "b", "c"]

self.binned_aligned_spikes = mock_BinnedAlignedSpikes(
number_of_units=number_of_units,
number_of_bins=number_of_bins,
number_of_events=number_of_events,
number_of_conditions=number_of_conditions,
condition_labels=condition_labels,
)

self.nwbfile.add_acquisition(self.binned_aligned_spikes)

Expand All @@ -284,12 +289,11 @@ def test_roundtrip_acquisition(self):
read_nwbfile = io.read()
read_binned_aligned_spikes = read_nwbfile.acquisition["BinnedAlignedSpikes"]
self.assertContainerEqual(self.binned_aligned_spikes, read_binned_aligned_spikes)

assert read_binned_aligned_spikes.number_of_units == number_of_units
assert read_binned_aligned_spikes.number_of_bins == number_of_bins
assert read_binned_aligned_spikes.number_of_events == number_of_events


assert read_binned_aligned_spikes.number_of_conditions == number_of_conditions

def test_roundtrip_processing_module(self):
self.binned_aligned_spikes = mock_BinnedAlignedSpikes()
Expand Down Expand Up @@ -324,3 +328,5 @@ def test_roundtrip_with_units_table(self):
read_nwbfile = io.read()
read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"]
self.assertContainerEqual(binned_aligned_spikes_with_region, read_container)


0 comments on commit a42c6b4

Please sign in to comment.