Skip to content

Commit

Permalink
BUG: Fix largest culprit in speed issue (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrockhill authored Apr 30, 2024
1 parent 5e2aaa4 commit d0b83ff
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 35 deletions.
10 changes: 6 additions & 4 deletions mne_gui_addons/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,11 @@ def _plot_images(self):
plot_x_idx, plot_y_idx = self._xy_idx[axis]
fig = self._figs[axis]
ax = fig.axes[0]
img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T
self._images["base"].append(
ax.imshow(
img_data,
self._base_data[
(slice(None),) * axis + (self._current_slice[axis],)
].T,
cmap="gray",
aspect="auto",
zorder=1,
Expand Down Expand Up @@ -623,8 +624,9 @@ def _draw(self, axis=None):
def _update_base_images(self, axis=None, draw=False):
"""Update the base images."""
for axis in range(3) if axis is None else [axis]:
img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T
self._images["base"][axis].set_data(img_data)
self._images["base"][axis].set_data(
self._base_data[(slice(None),) * axis + (self._current_slice[axis],)].T
)
if draw:
self._draw(axis)

Expand Down
39 changes: 25 additions & 14 deletions mne_gui_addons/_ieeg_locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,16 +996,21 @@ def _update_ch_images(self, axis=None, draw=False):
def _update_ct_images(self, axis=None, draw=False):
"""Update the CT image(s)."""
for axis in range(3) if axis is None else [axis]:
ct_data = np.take(self._ct_data, self._current_slice[axis], axis=axis).T
ct_data = (
self._ct_data[(slice(None),) * axis + (self._current_slice[axis],)]
.copy()
.T
)
# Threshold the CT so only bright objects (electrodes) are visible
ct_data[ct_data < self._ct_min_slider.value()] = np.nan
ct_data[ct_data > self._ct_max_slider.value()] = np.nan
self._images["ct"][axis].set_data(ct_data)
if "local_max" in self._images:
ct_max_data = np.take(
self._ct_maxima, self._current_slice[axis], axis=axis
).T
self._images["local_max"][axis].set_data(ct_max_data)
self._images["local_max"][axis].set_data(
self._ct_maxima[
(slice(None),) * axis + (self._current_slice[axis],)
].T
)
if draw:
self._draw(axis)

Expand All @@ -1014,7 +1019,9 @@ def _update_mri_images(self, axis=None, draw=False):
if "mri" in self._images:
for axis in range(3) if axis is None else [axis]:
self._images["mri"][axis].set_data(
np.take(self._mr_data, self._current_slice[axis], axis=axis).T
self._mr_data[
(slice(None),) * axis + (self._current_slice[axis],)
].T
)
if draw:
self._draw(axis)
Expand Down Expand Up @@ -1150,14 +1157,13 @@ def _toggle_show_max(self):
self._update_ct_maxima()
self._images["local_max"] = list()
for axis in range(3):
ct_max_data = np.take(
self._ct_maxima, self._current_slice[axis], axis=axis
).T
self._images["local_max"].append(
self._figs[axis]
.axes[0]
.imshow(
ct_max_data,
self._ct_maxima[
(slice(None),) * axis + (self._current_slice[axis],)
].T,
cmap="autumn",
aspect="auto",
vmin=0,
Expand All @@ -1182,13 +1188,18 @@ def _toggle_show_brain(self):
else:
self._images["mri"] = list()
for axis in range(3):
mri_data = np.take(
self._mr_data, self._current_slice[axis], axis=axis
).T
self._images["mri"].append(
self._figs[axis]
.axes[0]
.imshow(mri_data, cmap="hot", aspect="auto", alpha=0.25, zorder=2)
.imshow(
self._mr_data[
(slice(None),) * axis + (self._current_slice[axis],)
].T,
cmap="hot",
aspect="auto",
alpha=0.25,
zorder=2,
)
)
self._toggle_brain_button.setText("Hide Brain")
self._draw()
Expand Down
14 changes: 9 additions & 5 deletions mne_gui_addons/_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ def _update_img_scale(self):
def _update_base_images(self, axis=None, draw=False):
"""Update the CT image(s)."""
for axis in range(3) if axis is None else [axis]:
img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T
img_data = self._base_data[
(slice(None),) * axis + (self._current_slice[axis],)
].T.copy()
img_data[img_data < self._img_min_slider.value()] = np.nan
img_data[img_data > self._img_max_slider.value()] = np.nan
self._images["base"][axis].set_data(img_data)
Expand All @@ -335,10 +337,11 @@ def _plot_vol_images(self):
for axis in range(3):
fig = self._figs[axis]
ax = fig.axes[0]
vol_data = np.take(self._vol_img, self._current_slice[axis], axis=axis).T
self._images["vol"].append(
ax.imshow(
vol_data,
self._vol_img[
(slice(None),) * axis + (self._current_slice[axis],)
].T,
aspect="auto",
zorder=3,
cmap=_CMAP,
Expand Down Expand Up @@ -438,8 +441,9 @@ def _mark_all(self):
def _update_vol_images(self, axis=None, draw=False):
"""Update the volume image(s)."""
for axis in range(3) if axis is None else [axis]:
vol_data = np.take(self._vol_img, self._current_slice[axis], axis=axis).T
self._images["vol"][axis].set_data(vol_data)
self._images["vol"][axis].set_data(
self._vol_img[(slice(None),) * axis + (self._current_slice[axis],)].T
)
if draw:
self._draw(axis)

Expand Down
18 changes: 10 additions & 8 deletions mne_gui_addons/_vol_stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def __init__(
]
src_coord = self._get_src_coord()
for axis in range(3):
stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T
x_idx, y_idx = self._xy_idx[axis]
extent = [
corners[0][x_idx],
Expand All @@ -318,7 +317,7 @@ def __init__(
self._figs[axis]
.axes[0]
.imshow(
stc_slice,
self._stc_img[(slice(None),) * axis + (src_coord[axis],)].T,
aspect="auto",
extent=extent,
cmap=self._cmap,
Expand Down Expand Up @@ -507,7 +506,7 @@ def _apply_vector_norm(self, stc_data, axis=1):
# if self._data.dtype in (COMPLEX_DTYPE, BASE_INT_DTYPE):
# stc_data = stc_data.round().astype(BASE_INT_DTYPE)
else:
stc_data = np.take(stc_data, 0, axis=axis)
stc_data = stc_data[(slice(None),) * axis + (0,)]
return stc_data

def _apply_baseline_correction(self, stc_data):
Expand Down Expand Up @@ -541,9 +540,9 @@ def _pick_stc_vertex(self, stc_data):

def _pick_stc_tfr(self, stc_data):
"""Select the frequency and time based on GUI values."""
stc_data = np.take(stc_data, self._t_idx, axis=-1)
stc_data = stc_data[..., self._t_idx]
f_idx = 0 if self._f_idx is None else self._f_idx
stc_data = np.take(stc_data, f_idx, axis=-1)
stc_data = stc_data[..., f_idx]
return stc_data

def _configure_ui(self):
Expand Down Expand Up @@ -1381,10 +1380,13 @@ def _plot_stc_images(self, axis=None, draw=True):
for axis in range(3):
# ensure in bounds
if src_coord[axis] >= 0 and src_coord[axis] < self._stc_img.shape[axis]:
stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T
self._images["stc"][axis].set_data(
self._stc_img[(slice(None),) * axis + (src_coord[axis],)].T
)
else:
stc_slice = np.take(self._stc_img, 0, axis=axis).T * np.nan
self._images["stc"][axis].set_data(stc_slice)
self._images["stc"][axis].set_data(
self._stc_img[(slice(None),) * axis + (0,)].copy().T * np.nan
)
if draw and self._update:
self._draw(axis)

Expand Down
7 changes: 4 additions & 3 deletions mne_gui_addons/tests/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def test_segment_display(renderer_interactive_pyvistaqt):

# test no seghead, fsaverage doesn't have seghead
with pytest.warns(RuntimeWarning, match="`seghead` not found"):
gui = VolumeSegmenter(
subject="fsaverage", subjects_dir=subjects_dir, verbose=True
)
with pytest.warns(RuntimeWarning, match="`pial` surface not found"):
gui = VolumeSegmenter(
subject="fsaverage", subjects_dir=subjects_dir, verbose=True
)

# test functions
gui.set_RAS([25.37, 0.00, 34.18])
Expand Down
2 changes: 1 addition & 1 deletion mne_gui_addons/tests/test_vol_stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _fake_stc(src_type="vol"):
) + 1j * rng.integers(
-1000, 1000, size=(n_epochs, len(info.ch_names), freqs.size, times.size)
)
epochs_tfr = mne.time_frequency.EpochsTFR(info, data, times=times, freqs=freqs)
epochs_tfr = mne.time_frequency.EpochsTFRArray(info, data, times=times, freqs=freqs)
nuse = sum([this_src["nuse"] for this_src in src])
stc_data = rng.integers(
-1000, 1000, size=(n_epochs, nuse, 3, freqs.size, times.size)
Expand Down

0 comments on commit d0b83ff

Please sign in to comment.