diff --git a/doc/changes/devel/13028.bugfix.rst b/doc/changes/devel/13028.bugfix.rst new file mode 100644 index 00000000000..13e34189eaf --- /dev/null +++ b/doc/changes/devel/13028.bugfix.rst @@ -0,0 +1 @@ +Fix epoch indexing in :class:`mne.time_frequency.EpochsTFRArray` when initialising the class with the default ``drop_log`` parameter, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index cd3a97ab90a..e68ea9e6e18 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -1218,8 +1218,8 @@ def test_averaging_freqsandtimes_epochsTFR(): avgpower = power.average(method=lambda x: np.mean(x, axis=2), **kwargs) -@pytest.mark.parametrize("n_drop", (0, 2)) -def test_epochstfr_getitem(epochs_full, n_drop): +@pytest.mark.parametrize("n_drop, as_tfr_array", ((0, False), (0, True), (2, False))) +def test_epochstfr_getitem(epochs_full, n_drop, as_tfr_array): """Test EpochsTFR.__getitem__().""" pd = pytest.importorskip("pandas") from pandas.testing import assert_frame_equal @@ -1227,16 +1227,20 @@ def test_epochstfr_getitem(epochs_full, n_drop): epochs_full.metadata = pd.DataFrame(dict(foo=list("aaaabbb"), bar=np.arange(7))) epochs_full.drop(np.arange(n_drop)) tfr = epochs_full.compute_tfr(method="morlet", freqs=freqs_linspace) - # check that various attributes are preserved - assert_frame_equal(tfr.metadata, epochs_full.metadata) - assert epochs_full.drop_log == tfr.drop_log - for attr in ("events", "selection", "times"): - assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr)) - # test pandas query - foo_a = tfr["foo == 'a'"] - bar_3 = tfr["bar <= 3"] - assert foo_a == bar_3 - assert foo_a.shape[0] == 4 - n_drop + if not as_tfr_array: # check that various attributes are preserved + assert_frame_equal(tfr.metadata, epochs_full.metadata) + assert epochs_full.drop_log == tfr.drop_log + for attr in ("events", "selection", "times"): + assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr)) + # test pandas query + foo_a = tfr["foo == 'a'"] + bar_3 = tfr["bar <= 3"] + assert foo_a == bar_3 + assert foo_a.shape[0] == 4 - n_drop + else: # repackage to check __getitem__ also works with unspecified events, etc... + tfr = EpochsTFRArray( + info=tfr.info, data=tfr.data, times=tfr.times, freqs=tfr.freqs + ) # test integer and slice subset_ints = tfr[[0, 1, 2]] subset_slice = tfr[:3] diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index eaf173092bb..1fb9f3f2e07 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -3105,8 +3105,14 @@ def __setstate__(self, state): ).squeeze(axis=0) self.events = state.get("events", _ensure_events(fake_events)) self.event_id = state.get("event_id", _check_event_id(None, self.events)) - self.drop_log = state.get("drop_log", tuple()) self.selection = state.get("selection", np.arange(n_epochs)) + self.drop_log = state.get( + "drop_log", + tuple( + () if k in self.selection else ("IGNORED",) + for k in range(max(len(self.events), max(self.selection) + 1)) + ), + ) self._bad_dropped = True # always true, need for `equalize_event_counts()` def __next__(self, return_event_id=False):