diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 73a3308685d..f7c6dc51a4c 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -471,6 +471,7 @@ def tfr_array_multitaper( output="complex", n_jobs=None, *, + return_weights=False, verbose=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. @@ -502,8 +503,14 @@ def tfr_array_multitaper( * ``'itc'`` : inter-trial coherence. * ``'avg_power_itc'`` : average of single trial power and inter-trial coherence across trials. + %(n_jobs)s The parallelization is implemented across channels. + return_weights : bool, default False + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. + + .. versionadded:: 1.9.0 %(verbose)s Returns @@ -520,6 +527,9 @@ def tfr_array_multitaper( If ``output`` is ``'avg_power_itc'``, the real values in ``out`` contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and + ``return_weights=True``. See Also -------- @@ -550,6 +560,7 @@ def tfr_array_multitaper( use_fft=use_fft, decim=decim, output=output, + return_weights=return_weights, n_jobs=n_jobs, verbose=verbose, ) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index cd3a97ab90a..1799692d6ce 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -432,17 +432,21 @@ def test_tfr_morlet(): def test_dpsswavelet(): """Test DPSS tapers.""" freqs = np.arange(5, 25, 3) - Ws = _make_dpss( - 1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True + Ws, weights = _make_dpss( + 1000, + freqs=freqs, + n_cycles=freqs / 2.0, + time_bandwidth=4.0, + zero_mean=True, + return_weights=True, ) - assert len(Ws) == 3 # 3 tapers expected + assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected + assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs) # Check that zero mean is true assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5 - assert len(Ws[0]) == len(freqs) # As many wavelets as asked for - @pytest.mark.slowtest def test_tfr_multitaper(): diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index eaf173092bb..3a04c9910ce 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -264,8 +264,11 @@ def _make_dpss( ------- Ws : list of array The wavelets time series. + Cs : list of array + The concentration weights. Only returned if return_weights=True. """ Ws = list() + Cs = list() freqs = np.array(freqs) if np.any(freqs <= 0): @@ -281,6 +284,7 @@ def _make_dpss( for m in range(n_taps): Wm = list() + Cm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -302,12 +306,15 @@ def _make_dpss( real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) + Ck = np.sqrt(conc[m]) Wm.append(Wk) + Cm.append(Ck) Ws.append(Wm) + Cs.append(Cm) if return_weights: - return Ws, conc + return Ws, Cs return Ws @@ -428,6 +435,7 @@ def _compute_tfr( use_fft=True, decim=1, output="complex", + return_weights=False, n_jobs=None, *, verbose=None, @@ -479,6 +487,9 @@ def _compute_tfr( * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. + return_weights : bool, default False + Whether to return the taper weights. Only applies if method='multitaper' and + output='complex' or 'phase'. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. @@ -495,6 +506,9 @@ def _compute_tfr( n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if method='multitaper', output='complex' or + 'phase', and return_weights=True. """ # Check data epoch_data = np.asarray(epoch_data) @@ -516,6 +530,9 @@ def _compute_tfr( decim, output, ) + return_weights = ( + return_weights and method == "multitaper" and output in ["complex", "phase"] + ) decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): @@ -531,13 +548,18 @@ def _compute_tfr( Ws = [W] # to have same dimensionality as the 'multitaper' case elif method == "multitaper": - Ws = _make_dpss( + out = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, + return_weights=return_weights, ) + if return_weights: + Ws, weights = out + else: + Ws = out # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -561,6 +583,8 @@ def _compute_tfr( out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) + if return_weights: + weights = np.array(weights) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -585,6 +609,9 @@ def _compute_tfr( out = out.transpose(2, 0, 1, 3, 4) else: out = out.transpose(1, 0, 2, 3) + + if return_weights: + return out, weights return out @@ -1187,9 +1214,6 @@ def __init__( f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) - # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) - if method == "morlet": - method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): @@ -1203,6 +1227,9 @@ def __init__( method_kw.setdefault("output", "power") self._freqs = np.asarray(freqs, dtype=np.float64) del freqs + # always store weights for per-taper outputs + if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: + method_kw["return_weights"] = True # check validity of kwargs manually to save compute time if any are invalid tfr_funcs = dict( morlet=tfr_array_morlet, @@ -1224,6 +1251,7 @@ def __init__( self._method = method self._inst_type = type(inst) self._baseline = None + self._weights = None self.preload = True # needed for __getitem__, never False for TFRs # self._dims may also get updated by child classes self._dims = ["channel", "freq", "time"] @@ -1382,6 +1410,7 @@ def __getstate__(self): info=self.info, baseline=self._baseline, decim=self._decim, + weights=self._weights, ) def __setstate__(self, state): @@ -1410,6 +1439,7 @@ def __setstate__(self, state): self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) + self._weights = state.get("weights") # objs saved before #12910 won't have # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked @@ -1516,6 +1546,10 @@ def _compute_tfr(self, data, n_jobs, verbose): if self.method == "stockwell": self._data, self._itc, freqs = result assert np.array_equal(self._freqs, freqs) + elif self.method == "multitaper" and self._tfr_func.keywords.get( + "output", "" + ) in ["complex", "phase"]: + self._data, self._weights = result elif self._tfr_func.keywords.get("output", "").endswith("_itc"): self._data, self._itc = result.real, result.imag else: @@ -1694,6 +1728,11 @@ def times(self): """The time points present in the data (in seconds).""" return self._times_readonly + @property + def weights(self): + """The weights used for each taper in the time-frequency estimates.""" + return self._weights + @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): """Crop data to a given time interval in place.