Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix EpochsTFRArray default drop log initialisation #13028

Merged
merged 6 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13028.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix epoch indexing in :class:`mne.time_frequency.EpochsTFRArray` when initialising the class with the default ``drop_log`` parameter, by `Thomas Binns`_.
28 changes: 16 additions & 12 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,25 +1218,29 @@ def test_averaging_freqsandtimes_epochsTFR():
avgpower = power.average(method=lambda x: np.mean(x, axis=2), **kwargs)


@pytest.mark.parametrize("n_drop", (0, 2))
def test_epochstfr_getitem(epochs_full, n_drop):
@pytest.mark.parametrize("n_drop, as_tfr_array", ((0, False), (0, True), (2, False)))
def test_epochstfr_getitem(epochs_full, n_drop, as_tfr_array):
"""Test EpochsTFR.__getitem__()."""
pd = pytest.importorskip("pandas")
from pandas.testing import assert_frame_equal

epochs_full.metadata = pd.DataFrame(dict(foo=list("aaaabbb"), bar=np.arange(7)))
epochs_full.drop(np.arange(n_drop))
tfr = epochs_full.compute_tfr(method="morlet", freqs=freqs_linspace)
# check that various attributes are preserved
assert_frame_equal(tfr.metadata, epochs_full.metadata)
assert epochs_full.drop_log == tfr.drop_log
for attr in ("events", "selection", "times"):
assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr))
# test pandas query
foo_a = tfr["foo == 'a'"]
bar_3 = tfr["bar <= 3"]
assert foo_a == bar_3
assert foo_a.shape[0] == 4 - n_drop
if not as_tfr_array: # check that various attributes are preserved
assert_frame_equal(tfr.metadata, epochs_full.metadata)
assert epochs_full.drop_log == tfr.drop_log
for attr in ("events", "selection", "times"):
assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr))
# test pandas query
foo_a = tfr["foo == 'a'"]
bar_3 = tfr["bar <= 3"]
assert foo_a == bar_3
assert foo_a.shape[0] == 4 - n_drop
else: # repackage to check __getitem__ also works with unspecified events, etc...
tfr = EpochsTFRArray(
info=tfr.info, data=tfr.data, times=tfr.times, freqs=tfr.freqs
)
# test integer and slice
subset_ints = tfr[[0, 1, 2]]
subset_slice = tfr[:3]
Expand Down
8 changes: 7 additions & 1 deletion mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3105,8 +3105,14 @@ def __setstate__(self, state):
).squeeze(axis=0)
self.events = state.get("events", _ensure_events(fake_events))
self.event_id = state.get("event_id", _check_event_id(None, self.events))
self.drop_log = state.get("drop_log", tuple())
self.selection = state.get("selection", np.arange(n_epochs))
self.drop_log = state.get(
"drop_log",
tuple(
() if k in self.selection else ("IGNORED",)
for k in range(max(len(self.events), max(self.selection) + 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scratching my head about this line. When would we not want just range(len(self.events))?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always have len(self.events) == len(epochs). But let's say you originally had 100 events and just selected the last 10 on way or another, that resulting Epochs object shouldhave for example

len(epochs) == 10
len(epochs.selection) == 10
len(epochs.events) == 10
list(epochs.selection) == list(range(90, 100))

and in general for anything not in epochs.selection you should be able to query that index in epochs.drop_log to see why it's not part of epochs. So epochs.drop_log[30] for example should exist and be non-empty meaning it was dropped for a reason, whereas epochs.drop_log[90:] should exist and all be empty (i.e., those are the epochs that have been kept).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, I forgot about events getting truncated/sliced/subselected.

Am I right then that we don't need to care about the case of having 30 events, and selecting only the middle 10? IIUC, if that happened via normal means, the drop_log would be present and the code in this diff won't be reached. It's only for EpochsTFRArray (where we don't know the dropping history) where we have to spawn a "fake" drop log, and in that case we can't know if (say) the last 10 epochs were dropped, but we also don't need to know that for things to work correctly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we just need some drop log that satisfies the necessary expectations about length of the drop log, selection values, and number of events.

Copy link
Contributor Author

@tsbinns tsbinns Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So something like

tuple(() if k in self.selection else ("IGNORED",) for k in len(self.events))

to account for the fact that a non-default selection param could be passed?

Or just the super simple

tuple(() for k in self.events)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the len(self.events) won't be enough, I think what you have here with the max(...) is more correct

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... or you have to make sure self.selection = np.arange(len(self.events))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry. Yeah at the moment whatever selection can be passed so self.selection = np.arange(len(self.events)) can't be assumed.

),
)
self._bad_dropped = True # always true, need for `equalize_event_counts()`

def __next__(self, return_event_id=False):
Expand Down
Loading