diff --git a/mne_gui_addons/_core.py b/mne_gui_addons/_core.py index 6dd5cfe..f9478ee 100644 --- a/mne_gui_addons/_core.py +++ b/mne_gui_addons/_core.py @@ -357,10 +357,11 @@ def _plot_images(self): plot_x_idx, plot_y_idx = self._xy_idx[axis] fig = self._figs[axis] ax = fig.axes[0] - img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T self._images["base"].append( ax.imshow( - img_data, + self._base_data[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T, cmap="gray", aspect="auto", zorder=1, @@ -623,8 +624,9 @@ def _draw(self, axis=None): def _update_base_images(self, axis=None, draw=False): """Update the base images.""" for axis in range(3) if axis is None else [axis]: - img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T - self._images["base"][axis].set_data(img_data) + self._images["base"][axis].set_data( + self._base_data[(slice(None),) * axis + (self._current_slice[axis],)].T + ) if draw: self._draw(axis) diff --git a/mne_gui_addons/_ieeg_locate.py b/mne_gui_addons/_ieeg_locate.py index 81ea458..c01998f 100644 --- a/mne_gui_addons/_ieeg_locate.py +++ b/mne_gui_addons/_ieeg_locate.py @@ -996,16 +996,21 @@ def _update_ch_images(self, axis=None, draw=False): def _update_ct_images(self, axis=None, draw=False): """Update the CT image(s).""" for axis in range(3) if axis is None else [axis]: - ct_data = np.take(self._ct_data, self._current_slice[axis], axis=axis).T + ct_data = ( + self._ct_data[(slice(None),) * axis + (self._current_slice[axis],)] + .copy() + .T + ) # Threshold the CT so only bright objects (electrodes) are visible ct_data[ct_data < self._ct_min_slider.value()] = np.nan ct_data[ct_data > self._ct_max_slider.value()] = np.nan self._images["ct"][axis].set_data(ct_data) if "local_max" in self._images: - ct_max_data = np.take( - self._ct_maxima, self._current_slice[axis], axis=axis - ).T - self._images["local_max"][axis].set_data(ct_max_data) + self._images["local_max"][axis].set_data( + self._ct_maxima[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T + ) if draw: self._draw(axis) @@ -1014,7 +1019,9 @@ def _update_mri_images(self, axis=None, draw=False): if "mri" in self._images: for axis in range(3) if axis is None else [axis]: self._images["mri"][axis].set_data( - np.take(self._mr_data, self._current_slice[axis], axis=axis).T + self._mr_data[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T ) if draw: self._draw(axis) @@ -1150,14 +1157,13 @@ def _toggle_show_max(self): self._update_ct_maxima() self._images["local_max"] = list() for axis in range(3): - ct_max_data = np.take( - self._ct_maxima, self._current_slice[axis], axis=axis - ).T self._images["local_max"].append( self._figs[axis] .axes[0] .imshow( - ct_max_data, + self._ct_maxima[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T, cmap="autumn", aspect="auto", vmin=0, @@ -1182,13 +1188,18 @@ def _toggle_show_brain(self): else: self._images["mri"] = list() for axis in range(3): - mri_data = np.take( - self._mr_data, self._current_slice[axis], axis=axis - ).T self._images["mri"].append( self._figs[axis] .axes[0] - .imshow(mri_data, cmap="hot", aspect="auto", alpha=0.25, zorder=2) + .imshow( + self._mr_data[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T, + cmap="hot", + aspect="auto", + alpha=0.25, + zorder=2, + ) ) self._toggle_brain_button.setText("Hide Brain") self._draw() diff --git a/mne_gui_addons/_segment.py b/mne_gui_addons/_segment.py index 0a79761..27923eb 100644 --- a/mne_gui_addons/_segment.py +++ b/mne_gui_addons/_segment.py @@ -320,7 +320,9 @@ def _update_img_scale(self): def _update_base_images(self, axis=None, draw=False): """Update the CT image(s).""" for axis in range(3) if axis is None else [axis]: - img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T + img_data = self._base_data[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T.copy() img_data[img_data < self._img_min_slider.value()] = np.nan img_data[img_data > self._img_max_slider.value()] = np.nan self._images["base"][axis].set_data(img_data) @@ -335,10 +337,11 @@ def _plot_vol_images(self): for axis in range(3): fig = self._figs[axis] ax = fig.axes[0] - vol_data = np.take(self._vol_img, self._current_slice[axis], axis=axis).T self._images["vol"].append( ax.imshow( - vol_data, + self._vol_img[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T, aspect="auto", zorder=3, cmap=_CMAP, @@ -438,8 +441,9 @@ def _mark_all(self): def _update_vol_images(self, axis=None, draw=False): """Update the volume image(s).""" for axis in range(3) if axis is None else [axis]: - vol_data = np.take(self._vol_img, self._current_slice[axis], axis=axis).T - self._images["vol"][axis].set_data(vol_data) + self._images["vol"][axis].set_data( + self._vol_img[(slice(None),) * axis + (self._current_slice[axis],)].T + ) if draw: self._draw(axis) diff --git a/mne_gui_addons/_vol_stc.py b/mne_gui_addons/_vol_stc.py index 5a704b3..b3301cc 100644 --- a/mne_gui_addons/_vol_stc.py +++ b/mne_gui_addons/_vol_stc.py @@ -306,7 +306,6 @@ def __init__( ] src_coord = self._get_src_coord() for axis in range(3): - stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T x_idx, y_idx = self._xy_idx[axis] extent = [ corners[0][x_idx], @@ -318,7 +317,7 @@ def __init__( self._figs[axis] .axes[0] .imshow( - stc_slice, + self._stc_img[(slice(None),) * axis + (src_coord[axis],)].T, aspect="auto", extent=extent, cmap=self._cmap, @@ -507,7 +506,7 @@ def _apply_vector_norm(self, stc_data, axis=1): # if self._data.dtype in (COMPLEX_DTYPE, BASE_INT_DTYPE): # stc_data = stc_data.round().astype(BASE_INT_DTYPE) else: - stc_data = np.take(stc_data, 0, axis=axis) + stc_data = stc_data[(slice(None),) * axis + (0,)] return stc_data def _apply_baseline_correction(self, stc_data): @@ -541,9 +540,9 @@ def _pick_stc_vertex(self, stc_data): def _pick_stc_tfr(self, stc_data): """Select the frequency and time based on GUI values.""" - stc_data = np.take(stc_data, self._t_idx, axis=-1) + stc_data = stc_data[..., self._t_idx] f_idx = 0 if self._f_idx is None else self._f_idx - stc_data = np.take(stc_data, f_idx, axis=-1) + stc_data = stc_data[..., f_idx] return stc_data def _configure_ui(self): @@ -1381,10 +1380,13 @@ def _plot_stc_images(self, axis=None, draw=True): for axis in range(3): # ensure in bounds if src_coord[axis] >= 0 and src_coord[axis] < self._stc_img.shape[axis]: - stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T + self._images["stc"][axis].set_data( + self._stc_img[(slice(None),) * axis + (src_coord[axis],)].T + ) else: - stc_slice = np.take(self._stc_img, 0, axis=axis).T * np.nan - self._images["stc"][axis].set_data(stc_slice) + self._images["stc"][axis].set_data( + self._stc_img[(slice(None),) * axis + (0,)].copy().T * np.nan + ) if draw and self._update: self._draw(axis) diff --git a/mne_gui_addons/tests/test_segment.py b/mne_gui_addons/tests/test_segment.py index 32c8836..e2f7041 100644 --- a/mne_gui_addons/tests/test_segment.py +++ b/mne_gui_addons/tests/test_segment.py @@ -37,9 +37,10 @@ def test_segment_display(renderer_interactive_pyvistaqt): # test no seghead, fsaverage doesn't have seghead with pytest.warns(RuntimeWarning, match="`seghead` not found"): - gui = VolumeSegmenter( - subject="fsaverage", subjects_dir=subjects_dir, verbose=True - ) + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + gui = VolumeSegmenter( + subject="fsaverage", subjects_dir=subjects_dir, verbose=True + ) # test functions gui.set_RAS([25.37, 0.00, 34.18]) diff --git a/mne_gui_addons/tests/test_vol_stc.py b/mne_gui_addons/tests/test_vol_stc.py index 044a437..23c8147 100644 --- a/mne_gui_addons/tests/test_vol_stc.py +++ b/mne_gui_addons/tests/test_vol_stc.py @@ -57,7 +57,7 @@ def _fake_stc(src_type="vol"): ) + 1j * rng.integers( -1000, 1000, size=(n_epochs, len(info.ch_names), freqs.size, times.size) ) - epochs_tfr = mne.time_frequency.EpochsTFR(info, data, times=times, freqs=freqs) + epochs_tfr = mne.time_frequency.EpochsTFRArray(info, data, times=times, freqs=freqs) nuse = sum([this_src["nuse"] for this_src in src]) stc_data = rng.integers( -1000, 1000, size=(n_epochs, nuse, 3, freqs.size, times.size)