diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index aec5c92..87d9a1d 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -9,9 +9,9 @@ try: import torch from torch.utils.data import DataLoader - USING_TORCH = True + USING_PYTORCH = True except ModuleNotFoundError: - USING_TORCH = False + USING_PYTORCH = False import napari from napari.layers._multiscale_data import MultiScaleData @@ -123,7 +123,7 @@ def add_multiscale_output_layer( return output_channel -if USING_TORCH: +if USING_PYTORCH: class DropoutEvalOverrider(torch.nn.Module): def __init__(self, dropout_module): super(DropoutEvalOverrider, self).__init__() @@ -239,7 +239,7 @@ def fine_tune(self, dataset_metadata_list: Iterable[ shuffle=True, ) - if USING_TORCH: + if USING_PYTORCH: dataloader = DataLoader( dataset, num_workers=self._num_workers, @@ -256,7 +256,7 @@ def fine_tune(self, dataset_metadata_list: Iterable[ ) for img, lab in dataloader: - if USING_TORCH: + if USING_PYTORCH: img = img[0].numpy() lab = lab[0].numpy() @@ -453,7 +453,7 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun, self._reset_patch_progressbar() for pos, img, img_sp in dl: - if USING_TORCH: + if USING_PYTORCH: pos = pos[0].numpy() img = img[0].numpy() img_sp = img_sp[0].numpy() diff --git a/src/napari_activelearning/_interface.py b/src/napari_activelearning/_interface.py index 762e181..f77e1d7 100644 --- a/src/napari_activelearning/_interface.py +++ b/src/napari_activelearning/_interface.py @@ -1,8 +1,9 @@ from typing import Optional, Union, Iterable -from qtpy.QtGui import QIntValidator +from qtpy.QtGui import QIntValidator, QDoubleValidator from qtpy.QtCore import Qt, Signal -from qtpy.QtWidgets import (QWidget, QPushButton, QGridLayout, QLineEdit, +from qtpy.QtWidgets import (QWidget, QPushButton, QGridLayout, QVBoxLayout, + QLineEdit, QComboBox, QLabel, QFileDialog, @@ -36,6 +37,8 @@ class MultiSpinBox(QWidget): def __init__(self): super().__init__() + self._dtype = int + self.edit_scale_lyt = QGridLayout() self.setLayout(self.edit_scale_lyt) @@ -62,8 +65,8 @@ def sizes(self): return self._sizes @sizes.setter - def sizes(self, new_sizes: Union[Iterable[int], dict]): - if isinstance(dict): + def sizes(self, new_sizes: Union[Iterable, dict]): + if isinstance(new_sizes, dict): self._sizes = new_sizes self._axes = new_sizes.keys() @@ -78,6 +81,21 @@ def sizes(self, new_sizes: Union[Iterable[int], dict]): self.update_spin_boxes() + @staticmethod + def _create_spinbox(ax_s: int): + power_spn = QSpinBox( + minimum=0, maximum=16, + buttonSymbols=QAbstractSpinBox.UpDownArrows + ) + power_spn.lineEdit().hide() + power_spn.setValue(int(math.log2(ax_s))) + + scale_le = QLineEdit() + scale_le.setValidator(QIntValidator(1, 2**16)) + scale_le.setText(str(ax_s)) + + return scale_le, power_spn + def clear_layer_channel(self): while self._curr_scale_le_list: item = self._curr_scale_le_list.pop() @@ -95,16 +113,7 @@ def update_spin_boxes(self): self.clear_layer_channel() for ax_idx, (ax, ax_s) in enumerate(self._sizes.items()): - power_spn = QSpinBox( - minimum=0, maximum=16, - buttonSymbols=QAbstractSpinBox.UpDownArrows - ) - power_spn.lineEdit().hide() - power_spn.setValue(int(math.log2(ax_s))) - - scale_le = QLineEdit() - scale_le.setValidator(QIntValidator(1, 2**16)) - scale_le.setText(str(ax_s)) + scale_le, power_spn = self._create_spinbox(ax_s) self._curr_scale_le_list.append(scale_le) self.edit_scale_lyt.addWidget(self._curr_scale_le_list[-1], @@ -134,23 +143,41 @@ def _modify_size(self, scale: int, ax_idx: int = 0): def _set_patch_size(self): axes = self._sizes.keys() self._sizes = { - ax: int(scale_le.text()) + ax: self._dtype(scale_le.text()) for ax, scale_le in zip(axes, self._curr_scale_le_list) } self.sizesChanged.emit(self._sizes) +class MultiDoubleSpinBox(MultiSpinBox): + def __init__(self): + super().__init__() + self._dtype = float + + @staticmethod + def _create_spinbox(ax_s: float): + power_spn = QDoubleSpinBox( + minimum=-16, maximum=16, + buttonSymbols=QAbstractSpinBox.UpDownArrows + ) + power_spn.lineEdit().hide() + power_spn.setValue(math.log2(max(1e-12, ax_s))) + + scale_le = QLineEdit() + scale_le.setValidator(QDoubleValidator(1e-12, 1e12, 12)) + scale_le.setText(str(ax_s)) + + return scale_le, power_spn + + class ImageGroupEditorWidget(ImageGroupEditor, QWidget): def __init__(self): super().__init__() - layout = QGridLayout() self.group_name_le = QLineEdit("None selected") self.group_name_le.setEnabled(False) self.group_name_le.returnPressed.connect(self.update_group_name) - layout.addWidget(QLabel("Group name:"), 0, 0) - layout.addWidget(self.group_name_le, 0, 1) self.layers_group_name_cmb = QComboBox() self.layers_group_name_cmb.setEditable(True) @@ -158,25 +185,17 @@ def __init__(self): self.update_layers_group_name ) self.layers_group_name_cmb.setEnabled(False) - layout.addWidget(QLabel("Channels group name:"), 0, 2) - layout.addWidget(self.layers_group_name_cmb, 0, 3) self.display_name_lbl = QLabel("None selected") - layout.addWidget(QLabel("Channel name:"), 1, 0) - layout.addWidget(self.display_name_lbl, 1, 1) self.edit_channel_spn = QSpinBox(minimum=0, maximum=0) self.edit_channel_spn.setEnabled(False) self.edit_channel_spn.editingFinished.connect(self.update_channels) self.edit_channel_spn.valueChanged.connect(self.update_channels) - layout.addWidget(QLabel("Channel:"), 1, 2) - layout.addWidget(self.edit_channel_spn, 1, 3) self.edit_axes_le = QLineEdit("None selected") self.edit_axes_le.setEnabled(False) self.edit_axes_le.returnPressed.connect(self.update_source_axes) - layout.addWidget(QLabel("Axes order:"), 2, 0) - layout.addWidget(self.edit_axes_le, 2, 1) self.output_dir_lbl = QLabel("Output directory:") self.output_dir_le = QLineEdit("Unset") @@ -189,25 +208,65 @@ def __init__(self): self._update_output_dir_edit ) self.output_dir_le.returnPressed.connect(self.update_output_dir) - layout.addWidget(QLabel("Output directory:"), 3, 0) - layout.addWidget(self.output_dir_le, 3, 1, 1, 3) - layout.addWidget(self.output_dir_btn, 3, 3) self.use_as_input_chk = QCheckBox("Use as input") self.use_as_input_chk.setEnabled(False) self.use_as_input_chk.toggled.connect( self.update_use_as_input ) - layout.addWidget(self.use_as_input_chk, 4, 0) self.use_as_sampling_chk = QCheckBox("Use as sampling mask") self.use_as_sampling_chk.setEnabled(False) self.use_as_sampling_chk.toggled.connect( self.update_use_as_sampling ) - layout.addWidget(self.use_as_sampling_chk, 4, 1) - self.setLayout(layout) + self.edit_scale_mdspn = MultiDoubleSpinBox() + self.edit_scale_mdspn.setEnabled(False) + self.edit_scale_mdspn.sizesChanged.connect(self.update_scale) + + self.edit_translate_mdspn = MultiDoubleSpinBox() + self.edit_translate_mdspn.setEnabled(False) + self.edit_translate_mdspn.sizesChanged.connect(self.update_translate) + + show_editor_chk = QCheckBox("Edit group properties") + show_editor_chk.setChecked(False) + show_editor_chk.toggled.connect(self._show_editor) + + editor_grid_lyt = QGridLayout() + editor_grid_lyt.addWidget(QLabel("Group name:"), 0, 0) + editor_grid_lyt.addWidget(self.group_name_le, 0, 1) + editor_grid_lyt.addWidget(QLabel("Channels group name:"), 0, 2) + editor_grid_lyt.addWidget(self.layers_group_name_cmb, 0, 3) + editor_grid_lyt.addWidget(QLabel("Channel name:"), 1, 0) + editor_grid_lyt.addWidget(self.display_name_lbl, 1, 1) + editor_grid_lyt.addWidget(QLabel("Channel:"), 1, 2) + editor_grid_lyt.addWidget(self.edit_channel_spn, 1, 3) + editor_grid_lyt.addWidget(QLabel("Axes order:"), 2, 0) + editor_grid_lyt.addWidget(self.edit_axes_le, 2, 1) + editor_grid_lyt.addWidget(QLabel("Output directory:"), 3, 0) + editor_grid_lyt.addWidget(self.output_dir_le, 3, 1, 1, 3) + editor_grid_lyt.addWidget(self.output_dir_btn, 3, 3) + editor_grid_lyt.addWidget(self.use_as_input_chk, 4, 0) + editor_grid_lyt.addWidget(self.use_as_sampling_chk, 4, 1) + editor_grid_lyt.addWidget(QLabel("Layer scale"), 5, 0) + editor_grid_lyt.addWidget(self.edit_scale_mdspn, 5, 1) + editor_grid_lyt.addWidget(QLabel("Layer translate"), 5, 2) + editor_grid_lyt.addWidget(self.edit_translate_mdspn, 5, 3) + + self.editor_widget = QWidget() + self.editor_widget.setLayout(editor_grid_lyt) + + editor_lyt = QVBoxLayout() + editor_lyt.addWidget(show_editor_chk) + editor_lyt.addWidget(self.editor_widget) + + self.setLayout(editor_lyt) + + self.editor_widget.setVisible(False) + + def _show_editor(self, show: bool): + self.editor_widget.setVisible(show) def _clear_image_group(self): self.layers_group_name_cmb.clear() @@ -292,6 +351,20 @@ def _fill_layer(self): self.edit_channel_spn.setValue(self._active_layer_channel.channel) self.edit_channel_spn.setEnabled(True) + self.edit_scale_mdspn.sizes = { + ax: ax_scl + for ax, ax_scl in zip(self._active_layer_channel.source_axes, + self._active_layer_channel.scale) + } + self.edit_scale_mdspn.setEnabled(True) + + self.edit_translate_mdspn.sizes = { + ax: ax_scl + for ax, ax_scl in zip(self._active_layer_channel.source_axes, + self._active_layer_channel.translate) + } + self.edit_translate_mdspn.setEnabled(True) + def _update_output_dir_edit(self, path): self.output_dir_le.setText(self.output_dir_dlg.selectedFiles()[0]) self.update_output_dir() @@ -319,6 +392,14 @@ def update_use_as_input(self): def update_use_as_sampling(self): super().update_use_as_sampling(self.use_as_sampling_chk.isChecked()) + def update_scale(self): + super().update_scale(list(self.edit_scale_mdspn.sizes.values())) + + def update_translate(self): + super().update_translate( + list(self.edit_translate_mdspn.sizes.values()) + ) + @property def active_image_group(self): return super().active_image_group @@ -355,69 +436,6 @@ def active_layer_channel(self, self._fill_layer() -class LayerScaleEditorWidget(LayerScaleEditor, QWidget): - def __init__(self): - super().__init__() - - self.edit_scale_lyt = QGridLayout() - self.edit_scale_lyt.addWidget(QLabel("Channel(s) scale(s):"), 0, 0) - self._curr_labels_list = [] - self._curr_scale_spn_list = [] - - self.setLayout(self.edit_scale_lyt) - - def _clear_layer_channel(self): - while self._curr_scale_spn_list: - item = self._curr_scale_spn_list.pop() - self.edit_scale_lyt.removeWidget(item) - - while self._curr_labels_list: - item = self._curr_labels_list.pop() - self.edit_scale_lyt.removeWidget(item) - - def _fill_layer(self): - self._clear_layer_channel() - - if self._active_layer_channel: - scales = self._active_layer_channel.scale - source_axes = self._active_layer_channel.source_axes - - for ax_idx, (ax_scl, ax) in enumerate(zip(scales, source_axes)): - edit_scale_spn = QDoubleSpinBox(minimum=1e-12, maximum=1e12, - singleStep=1e-7, - decimals=7) - edit_scale_spn.setValue(ax_scl) - edit_scale_spn.lineEdit().returnPressed.connect( - self.update_scale - ) - self._curr_labels_list.append(QLabel(ax)) - self.edit_scale_lyt.addWidget(self._curr_labels_list[-1], - ax_idx + 1, - 0) - self._curr_scale_spn_list.append(edit_scale_spn) - self.edit_scale_lyt.addWidget(self._curr_scale_spn_list[-1], - ax_idx + 1, - 1) - - def update_scale(self): - new_scale = [ - edit_scale_spn.value() - for edit_scale_spn in self._curr_scale_spn_list - ] - super().update_scale(new_scale) - - @property - def active_layer_channel(self): - return super().active_layer_channel - - @active_layer_channel.setter - def active_layer_channel(self, - active_layer_channel: Union[LayerChannel, None]): - super(LayerScaleEditor, type(self)).active_layer_channel\ - .fset(self, active_layer_channel) - self._fill_layer() - - class MaskGeneratorWidget(MaskGenerator, QWidget): def __init__(self): super().__init__() @@ -436,10 +454,26 @@ def __init__(self): patch_sizes_scr.setWidgetResizable(True) patch_sizes_scr.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.edit_scale_lyt = QGridLayout() - self.edit_scale_lyt.addWidget(self.generate_mask_btn, 0, 0, 1, 3) - self.edit_scale_lyt.addWidget(patch_sizes_scr, 1, 0, 1, 3) - self.setLayout(self.edit_scale_lyt) + edit_mask_lyt = QVBoxLayout() + edit_mask_lyt.addWidget(patch_sizes_scr) + edit_mask_lyt.addWidget(self.generate_mask_btn) + self.edit_mask_widget = QWidget() + self.edit_mask_widget.setLayout(edit_mask_lyt) + + show_editor_chk = QCheckBox("Edit mask properties") + show_editor_chk.setChecked(False) + show_editor_chk.toggled.connect(self._show_editor) + + mask_lyt = QVBoxLayout() + mask_lyt.addWidget(show_editor_chk) + mask_lyt.addWidget(self.edit_mask_widget) + + self.setLayout(mask_lyt) + + self.edit_mask_widget.setVisible(False) + + def _show_editor(self, show: bool): + self.edit_mask_widget.setVisible(show) def _set_patch_size(self, patch_sizes): super().set_patch_size(list(patch_sizes.values())) @@ -464,7 +498,6 @@ def __init__(self, default_axis_labels: str = "TZYX"): # Re-instanciate the following objects with their widget versions. self.image_groups_editor = ImageGroupEditorWidget() - self.layer_scale_editor = LayerScaleEditorWidget() self.mask_generator = MaskGeneratorWidget() self.image_groups_tw = QTreeWidget() @@ -545,31 +578,12 @@ def __init__(self, default_axis_labels: str = "TZYX"): manager_lyt.addWidget(self.remove_layer_btn, 2, 1) manager_lyt.addWidget(self.save_metadata_btn, 2, 2) - self.show_editor_chk = QCheckBox("Edit group properties") - self.show_editor_chk.setChecked(False) - self.show_editor_chk.toggled.connect(self._show_editor) - image_groups_scr = QScrollArea() - image_groups_scr.setWidget(self.image_groups_editor) - image_groups_scr.setWidgetResizable(True) - image_groups_scr.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - - manager_lyt.addWidget(self.show_editor_chk, 3, 0, 1, 1) - manager_lyt.addWidget(image_groups_scr, 4, 0, 1, 3) - manager_lyt.addWidget(self.layer_scale_editor, 5, 0, 1, 3) - manager_lyt.addWidget(self.mask_generator, 6, 0, 1, 3) - manager_lyt.addWidget(self.image_groups_tw, 7, 0, 2, 3) + manager_lyt.addWidget(self.image_groups_editor, 3, 0, 1, 3) + manager_lyt.addWidget(self.mask_generator, 4, 0, 1, 3) + manager_lyt.addWidget(self.image_groups_tw, 5, 0, 1, 3) self.setLayout(manager_lyt) - self.image_groups_editor.setVisible(False) - self.layer_scale_editor.setVisible(False) - self.mask_generator.setVisible(False) - - def _show_editor(self, show: bool = False): - self.image_groups_editor.setVisible(show) - self.layer_scale_editor.setVisible(show) - self.mask_generator.setVisible(show) - def _save_layers_group(self): self.save_layers_group() self.save_layers_group_btn.setEnabled(False) @@ -782,11 +796,23 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget, tunable_segmentation_method) self.patch_sizes_mspn = MultiSpinBox() + patch_sizes_scr = QScrollArea() patch_sizes_scr.setWidget(self.patch_sizes_mspn) patch_sizes_scr.setWidgetResizable(True) patch_sizes_scr.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + patch_sizes_lyt = QVBoxLayout() + patch_sizes_lyt.addWidget(QLabel("Patch size:")) + patch_sizes_lyt.addWidget(patch_sizes_scr) + + self.patch_sizes_widget = QWidget() + self.patch_sizes_widget.setLayout(patch_sizes_lyt) + + patch_sizes_chk = QCheckBox("Edit patch sizes") + patch_sizes_chk.setChecked(False) + patch_sizes_chk.toggled.connect(self._show_patch_sizes) + spatial_input_axes = self.input_axes if "C" in spatial_input_axes: spatial_input_axes = list(spatial_input_axes) @@ -831,31 +857,30 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget, self.finetuning_btn.clicked.connect(self.fine_tune) acquisition_lyt = QGridLayout() - acquisition_lyt.addWidget(QLabel("Patch size:"), 0, 0) - acquisition_lyt.addWidget(patch_sizes_scr, 0, 1, 1, 3) - acquisition_lyt.addWidget(QLabel("Maximum samples:"), 1, 0) - acquisition_lyt.addWidget(self.max_samples_spn, 1, 1) - acquisition_lyt.addWidget(QLabel("Monte Carlo repetitions"), 2, 0) - acquisition_lyt.addWidget(self.MC_repetitions_spn, 2, 1) - acquisition_lyt.addWidget(QLabel("Input axes"), 3, 0) - acquisition_lyt.addWidget(self.input_axes_le, 3, 1) - acquisition_lyt.addWidget(QLabel("Model axes"), 3, 2) - acquisition_lyt.addWidget(self.model_axes_le, 3, 3) - acquisition_lyt.addWidget(self.tunable_segmentation_method, 4, 0, 1, 4) - acquisition_lyt.addWidget(self.execute_selected_btn, 5, 0) - acquisition_lyt.addWidget(self.execute_all_btn, 5, 1) - acquisition_lyt.addWidget(self.finetuning_btn, 6, 1) - acquisition_lyt.addWidget(QLabel("Image queue:"), 7, 0, 1, 1) - acquisition_lyt.addWidget(self.image_pb, 7, 1, 1, 3) - acquisition_lyt.addWidget(QLabel("Patch queue:"), 8, 0, 1, 1) - acquisition_lyt.addWidget(self.patch_pb, 8, 1, 1, 3) + acquisition_lyt.addWidget(patch_sizes_chk, 0, 0) + acquisition_lyt.addWidget(self.patch_sizes_widget, 1, 0, 1, 3) + acquisition_lyt.addWidget(QLabel("Maximum samples:"), 2, 0) + acquisition_lyt.addWidget(self.max_samples_spn, 2, 1) + acquisition_lyt.addWidget(QLabel("Monte Carlo repetitions"), 3, 0) + acquisition_lyt.addWidget(self.MC_repetitions_spn, 3, 1) + acquisition_lyt.addWidget(QLabel("Input axes"), 4, 0) + acquisition_lyt.addWidget(self.input_axes_le, 4, 1) + acquisition_lyt.addWidget(QLabel("Model axes"), 4, 2) + acquisition_lyt.addWidget(self.model_axes_le, 4, 3) + acquisition_lyt.addWidget(self.tunable_segmentation_method, 5, 0, 1, 4) + acquisition_lyt.addWidget(self.execute_selected_btn, 6, 0) + acquisition_lyt.addWidget(self.execute_all_btn, 6, 1) + acquisition_lyt.addWidget(self.finetuning_btn, 7, 1) + acquisition_lyt.addWidget(QLabel("Image queue:"), 8, 0, 1, 1) + acquisition_lyt.addWidget(self.image_pb, 8, 1, 1, 3) + acquisition_lyt.addWidget(QLabel("Patch queue:"), 9, 0, 1, 1) + acquisition_lyt.addWidget(self.patch_pb, 9, 1, 1, 3) self.setLayout(acquisition_lyt) + self.patch_sizes_widget.setVisible(False) - self.labels_manager.setVisible(False) - - def _show_labels_manager(self, show_it: bool): - self.labels_manager.setVisible(show_it) + def _show_patch_sizes(self, show: bool): + self.patch_sizes_widget.setVisible(show) def _reset_image_progressbar(self, num_images: int): self.image_pb.setRange(0, num_images) diff --git a/src/napari_activelearning/_layers.py b/src/napari_activelearning/_layers.py index 2c7dc9f..10e0c00 100644 --- a/src/napari_activelearning/_layers.py +++ b/src/napari_activelearning/_layers.py @@ -917,6 +917,8 @@ def __init__(self): self._group_name = None self._layers_group_name = None self._edit_axes = None + self._edit_scale = None + self._edit_translate = None self._use_as_input = None self._use_as_sampling = None self._edit_channel = None @@ -948,10 +950,7 @@ def update_group_name(self, group_name: Optional[str] = None): self._active_image_group.group_name = self._group_name def update_channels(self, channel: Optional[int] = None): - if not self._active_layer_channel: - return - - if not self._active_layers_group: + if not self._active_layer_channel or not self._active_layers_group: return if channel: @@ -1056,6 +1055,26 @@ def update_use_as_sampling(self, use_it: Optional[bool] = None): == layers_group_idx): self._active_image_group.sampling_mask_layers_group = None + def update_scale(self, scale: Optional[Iterable[float]] = None): + if (not self._active_layers_group or not self._active_layers_group + or not self._active_layer_channel): + return + + if scale is not None: + self._edit_scale = scale + + self._active_layer_channel.scale = self._edit_scale + + def update_translate(self, translate: Optional[Iterable[float]] = None): + if (not self._active_layers_group or not self._active_layers_group + or not self._active_layer_channel): + return + + if translate is not None: + self._edit_translate = translate + + self._active_layer_channel.translate = self._edit_translate + class LayerScaleEditor(PropertiesEditor): def __init__(self): diff --git a/src/napari_activelearning/_tests/test_acquisition.py b/src/napari_activelearning/_tests/test_acquisition.py index e2735f7..bb57400 100644 --- a/src/napari_activelearning/_tests/test_acquisition.py +++ b/src/napari_activelearning/_tests/test_acquisition.py @@ -1,9 +1,14 @@ import pytest from unittest.mock import MagicMock, patch import numpy as np -import torch from napari_activelearning._acquisition import AcquisitionFunction +try: + import torch + USING_PYTORCH = True +except ModuleNotFoundError: + USING_PYTORCH = False + @pytest.fixture def image_groups_manager(): @@ -71,11 +76,19 @@ def test_compute_acquisition(acquisition_function): with (patch('napari_activelearning._acquisition.get_dataloader') as mock_dataloader): - mock_dataloader.return_value = [ - (torch.LongTensor([[[0, 1], [0, 1], [0, 10], [0, 10], [0, -1]]]), - torch.zeros((1, 1, 1, 10, 10, 3)), - torch.zeros((1, 1, 1, 10, 10, 1))) - ] + if USING_PYTORCH: + mock_dataloader.return_value = [ + (torch.LongTensor([[[0, 1], [0, 1], [0, 10], [0, 10], + [0, -1]]]), + torch.zeros((1, 1, 1, 10, 10, 3)), + torch.zeros((1, 1, 1, 10, 10, 1))) + ] + else: + mock_dataloader.return_value = [ + (np.array([[0, 1], [0, 1], [0, 10], [0, 10], [0, -1]]), + np.array((1, 1, 10, 10, 3)), + np.array((1, 1, 10, 10, 1))) + ] result = acquisition_function.compute_acquisition( dataset_metadata, acquisition_fun, segmentation_out, sampling_positions, diff --git a/src/napari_activelearning/_utils.py b/src/napari_activelearning/_utils.py index a453c7c..04eaf2b 100644 --- a/src/napari_activelearning/_utils.py +++ b/src/napari_activelearning/_utils.py @@ -14,9 +14,9 @@ try: import torch from torch.utils.data import DataLoader - USING_TORCH = True + USING_PYTORCH = True except ModuleNotFoundError: - USING_TORCH = False + USING_PYTORCH = False from napari.layers import Layer from napari.layers._multiscale_data import MultiScaleData @@ -254,7 +254,7 @@ def get_dataloader( shuffle=shuffle ) - if USING_TORCH: + if USING_PYTORCH: train_dataloader = DataLoader( train_dataset, num_workers=num_workers,