Skip to content

Commit

Permalink
Added capability to define patch sizes per axis for masks
Browse files Browse the repository at this point in the history
  • Loading branch information
fercer committed Jul 18, 2024
1 parent c07abcb commit 11fe3a6
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 220 deletions.
2 changes: 2 additions & 0 deletions src/napari_activelearning/_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun,
else:
img_sp = img_sp[..., 0]


# TODO: Add the current selected slice from _update_roi_from_position
pos_u_lab = (slice(pos[0, 0], pos[0, 1]),
slice(pos[1, 0], pos[1, 1]))

Expand Down
202 changes: 86 additions & 116 deletions src/napari_activelearning/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from functools import partial
import napari
import math

from ._acquisition import AcquisitionFunction, TunableMethod
from ._layers import (ImageGroupEditor, ImageGroupsManager, LayerScaleEditor,
Expand Down Expand Up @@ -325,7 +326,7 @@ def __init__(self):
self.generate_mask_btn.clicked.connect(self.generate_mask_layer)

self.edit_scale_lyt = QGridLayout()
self.edit_scale_lyt.addWidget(self.generate_mask_btn, 0, 0, 0, 2)
self.edit_scale_lyt.addWidget(self.generate_mask_btn, 0, 0, 1, 3)
self.setLayout(self.edit_scale_lyt)

self._curr_power_spn_list = []
Expand All @@ -345,12 +346,18 @@ def _clear_layer_channel(self):
item = self._curr_power_spn_list.pop()
self.edit_scale_lyt.removeWidget(item)

def _fill_layer(self):
def _update_reference_info(self):
super()._update_reference_info()
self._clear_layer_channel()

if self._active_layer_channel:
patch_sizes = self._patch_sizes
source_axes = self._mask_axes
patch_sizes = self._patch_sizes
if patch_sizes is None:
patch_sizes = [
min(ax_s, 128)
for ax_s in self._img_shape
]

for ax_idx, (ax_ps, ax) in enumerate(zip(patch_sizes,
source_axes)):
Expand All @@ -362,20 +369,7 @@ def _fill_layer(self):

scale_le = QLineEdit()
scale_le.setValidator(QIntValidator(0, 2**16))
scale_le.returnPressed.connect(
partial(self._set_patch_size, ax_idx=ax_idx)
)
scale_le.setEnabled(False)

power_spn.valueChanged.connect(
partial(self._modify_patch_size, ax_idx=ax_idx)
)
power_spn.setValue(7)

self._curr_labels_list.append(QLabel(ax))
self.edit_scale_lyt.addWidget(self._curr_labels_list[-1],
ax_idx + 1,
0)
self._curr_scale_le_list.append(scale_le)
self.edit_scale_lyt.addWidget(self._curr_scale_le_list[-1],
ax_idx + 1,
Expand All @@ -384,38 +378,43 @@ def _fill_layer(self):
self.edit_scale_lyt.addWidget(self._curr_power_spn_list[-1],
ax_idx + 1,
2)
self._curr_labels_list.append(QLabel(ax))
self.edit_scale_lyt.addWidget(self._curr_labels_list[-1],
ax_idx + 1,
0)

scale_le.returnPressed.connect(
partial(self._set_patch_size, ax_idx=ax_idx)
)
power_spn.valueChanged.connect(
partial(self._modify_patch_size, ax_idx=ax_idx)
)
power_spn.setValue(int(math.log2(ax_ps)))

self.generate_mask_btn.setEnabled(True)

def _modify_patch_size(self, scale: int, ax_idx: int = 0):
self._curr_scale_le_list[ax_idx].setText(str(2 ** scale))
self._set_patch_size(ax_idx)
self._set_patch_size()

def _set_patch_size(self, ax_idx: int = 0):
super().set_patch_size(int(self._curr_scale_le_list[ax_idx].text()))

@property
def active_image_group(self):
return super().active_image_group
def _set_patch_size(self):
patch_sizes = [
int(scale_le.text())
for scale_le in self._curr_scale_le_list
]

@active_image_group.setter
def active_image_group(self, active_image_group: ImageGroup):
super(MaskGeneratorWidget, type(self)).active_image_group\
.fset(self, active_image_group)
super().set_patch_size(patch_sizes)

self.generate_mask_btn.setEnabled(
self._active_image_group is not None
)
self.patch_size_le.setEnabled(
self._active_image_group is not None
)
self.patch_size_spn.setEnabled(
self._active_image_group is not None
)
def generate_mask_layer(self):
self._set_patch_size()
super().generate_mask_layer()


class ImageGroupsManagerWidget(ImageGroupsManager, QWidget):
def __init__(self, default_axis_labels: str = "TZYX"):
super().__init__(default_axis_labels)

# Re-instanciate the following objects with their widget versions.
self.image_groups_editor = ImageGroupEditorWidget()
self.layer_scale_editor = LayerScaleEditorWidget()
self.mask_generator = MaskGeneratorWidget()
Expand Down Expand Up @@ -443,17 +442,17 @@ def __init__(self, default_axis_labels: str = "TZYX"):
self.image_groups_tw.addTopLevelItem(self.groups_root)
self.groups_root.setExpanded(True)

self.new_group_btn = QPushButton("New")
self.new_group_btn = QPushButton("New image group")
self.new_group_btn.setToolTip("Create a new group. If layers are "
"selected, add these to the new group.")
self.new_group_btn.clicked.connect(self.create_group)

self.add_group_btn = QPushButton("Add")
self.add_group_btn = QPushButton("Add to group")
self.add_group_btn.setEnabled(False)
self.add_group_btn.setToolTip("Add selected layers to current group")
self.add_group_btn.clicked.connect(self.update_group)

self.remove_group_btn = QPushButton("Remove")
self.remove_group_btn = QPushButton("Remove group")
self.remove_group_btn.setEnabled(False)
self.remove_group_btn.setToolTip("Remove selected group. This will not"
" remove the layers from napari "
Expand Down Expand Up @@ -485,36 +484,30 @@ def __init__(self, default_axis_labels: str = "TZYX"):
self.remove_layer_btn.clicked.connect(self.remove_layer)
self.save_metadata_btn.clicked.connect(self.dump_dataset_specs)

self.group_buttons_lyt = QHBoxLayout()
self.group_buttons_lyt.addWidget(self.new_group_btn)
self.group_buttons_lyt.addWidget(self.add_group_btn)
self.group_buttons_lyt.addWidget(self.remove_group_btn)
manager_lyt = QGridLayout()
manager_lyt.addWidget(self.new_group_btn, 0, 0)
manager_lyt.addWidget(self.add_group_btn, 0, 1)
manager_lyt.addWidget(self.remove_group_btn, 0, 2)

self.layers_group_buttons_lyt = QHBoxLayout()
self.layers_group_buttons_lyt.addWidget(self.new_layers_group_btn)
self.layers_group_buttons_lyt.addWidget(self.remove_layers_group_btn)
self.layers_group_buttons_lyt.addWidget(self.save_layers_group_btn)
manager_lyt.addWidget(self.new_layers_group_btn, 1, 0)
manager_lyt.addWidget(self.remove_layers_group_btn, 1, 1)
manager_lyt.addWidget(self.save_layers_group_btn, 1, 2)

self.layer_buttons_lyt = QHBoxLayout()
self.layer_buttons_lyt.addWidget(self.add_layer_btn)
self.layer_buttons_lyt.addWidget(self.remove_layer_btn)
self.layer_buttons_lyt.addWidget(self.save_metadata_btn)
manager_lyt.addWidget(self.add_layer_btn, 2, 0)
manager_lyt.addWidget(self.remove_layer_btn, 2, 1)
manager_lyt.addWidget(self.save_metadata_btn, 2, 2)

self.show_editor_chk = QCheckBox("Edit group properties")
self.show_editor_chk.setChecked(False)
self.show_editor_chk.toggled.connect(self._show_editor)

self.group_lyt = QVBoxLayout()
self.group_lyt.addLayout(self.group_buttons_lyt)
self.group_lyt.addLayout(self.layers_group_buttons_lyt)
self.group_lyt.addLayout(self.layer_buttons_lyt)
self.group_lyt.addWidget(self.show_editor_chk)
self.group_lyt.addWidget(self.image_groups_editor)
self.group_lyt.addWidget(self.layer_scale_editor)
self.group_lyt.addWidget(self.mask_generator)
self.group_lyt.addWidget(self.image_groups_tw)
manager_lyt.addWidget(self.show_editor_chk, 3, 0, 1, 1)
manager_lyt.addWidget(self.image_groups_editor, 4, 0, 1, 3)
manager_lyt.addWidget(self.layer_scale_editor, 5, 0, 1, 3)
manager_lyt.addWidget(self.mask_generator, 6, 0, 1, 3)
manager_lyt.addWidget(self.image_groups_tw, 7, 0, 1, 3)

self.setLayout(self.group_lyt)
self.setLayout(manager_lyt)

self.image_groups_editor.setVisible(False)
self.layer_scale_editor.setVisible(False)
Expand Down Expand Up @@ -677,22 +670,16 @@ def __init__(self):
self.commit_btn.setEnabled(False)
self.commit_btn.clicked.connect(self.commit)

self.navigation_layout = QHBoxLayout()
self.navigation_layout.addWidget(self.prev_img_btn)
self.navigation_layout.addWidget(self.prev_patch_btn)
self.navigation_layout.addWidget(self.next_patch_btn)
self.navigation_layout.addWidget(self.next_img_btn)

self.edit_layout = QHBoxLayout()
self.edit_layout.addWidget(self.edit_labels_btn)
self.edit_layout.addWidget(self.commit_btn)
manager_lyt = QGridLayout()
manager_lyt.addWidget(self.labels_table_tw, 0, 0, 1, 4)
manager_lyt.addWidget(self.prev_img_btn, 1, 0)
manager_lyt.addWidget(self.prev_patch_btn, 1, 1)
manager_lyt.addWidget(self.next_patch_btn, 1, 2)
manager_lyt.addWidget(self.next_img_btn, 1, 3)
manager_lyt.addWidget(self.edit_labels_btn, 2, 0, 1, 2)
manager_lyt.addWidget(self.commit_btn, 2, 2, 1, 2)

self.manager_layout = QVBoxLayout()
self.manager_layout.addWidget(self.labels_table_tw)
self.manager_layout.addLayout(self.navigation_layout)
self.manager_layout.addLayout(self.edit_layout)

self.setLayout(self.manager_layout)
self.setLayout(manager_lyt)

def focus_region(self, label: Optional[QTreeWidgetItem] = None,
edit_focused_label: bool = False):
Expand Down Expand Up @@ -728,33 +715,18 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget,
super().__init__(image_groups_manager, labels_manager,
tunable_segmentation_method)

self.patch_size_lbl = QLabel("Patch size:")
self.patch_size_spn = QSpinBox(minimum=128, maximum=1024,
singleStep=128)
self.patch_size_spn.valueChanged.connect(self._set_patch_size)

self.patch_size_lyt = QHBoxLayout()
self.patch_size_lyt.addWidget(self.patch_size_lbl)
self.patch_size_lyt.addWidget(self.patch_size_spn)

self.max_samples_lbl = QLabel("Maximum samples:")
self.max_samples_spn = QSpinBox(minimum=1, maximum=10000, value=100,
singleStep=10)
self.max_samples_spn.valueChanged.connect(self._set_max_samples)

self.max_samples_lyt = QHBoxLayout()
self.max_samples_lyt.addWidget(self.max_samples_lbl)
self.max_samples_lyt.addWidget(self.max_samples_spn)

self.MC_repetitions_lbl = QLabel("Monte Carlo repetitions")
self.MC_repetitions_spn = QSpinBox(minimum=1, maximum=100, value=30,
self.MC_repetitions_spn = QSpinBox(minimum=2, maximum=100, value=30,
singleStep=10)
self.MC_repetitions_spn.valueChanged.connect(self._set_MC_repetitions)

self.MC_repetitions_lyt = QHBoxLayout()
self.MC_repetitions_lyt.addWidget(self.MC_repetitions_lbl)
self.MC_repetitions_lyt.addWidget(self.MC_repetitions_spn)

self.execute_selected_btn = QPushButton("Run on selected image groups")
self.execute_selected_btn.clicked.connect(
partial(self.compute_acquisition_layers, run_all=False)
Expand All @@ -765,36 +737,34 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget,
partial(self.compute_acquisition_layers, run_all=True)
)

self.image_lbl = QLabel("Image queue:")
self.image_pb = QProgressBar()
self.image_lyt = QHBoxLayout()
self.image_lyt.addWidget(self.image_lbl)
self.image_lyt.addWidget(self.image_pb)

self.patch_lbl = QLabel("Patch queue:")
self.patch_pb = QProgressBar()
self.patch_lyt = QHBoxLayout()
self.patch_lyt.addWidget(self.patch_lbl)
self.patch_lyt.addWidget(self.patch_pb)

self.finetuning_btn = QPushButton("Fine tune model")
self.finetuning_btn.clicked.connect(self.fine_tune)

self.execute_lyt = QHBoxLayout()
self.execute_lyt.addWidget(self.execute_selected_btn)
self.execute_lyt.addWidget(self.execute_all_btn)

self.acquisition_lyt = QVBoxLayout()
self.acquisition_lyt.addLayout(self.patch_size_lyt)
self.acquisition_lyt.addLayout(self.max_samples_lyt)
self.acquisition_lyt.addLayout(self.MC_repetitions_lyt)
self.acquisition_lyt.addWidget(self.tunable_segmentation_method)
self.acquisition_lyt.addLayout(self.execute_lyt)
self.acquisition_lyt.addLayout(self.image_lyt)
self.acquisition_lyt.addLayout(self.patch_lyt)
self.acquisition_lyt.addWidget(self.finetuning_btn)

self.setLayout(self.acquisition_lyt)
acquisition_lyt = QGridLayout()
acquisition_lyt.addWidget(QLabel("Patch size:"), 0, 0)
acquisition_lyt.addWidget(self.patch_size_spn, 0, 1)
acquisition_lyt.addWidget(QLabel("Maximum samples:"), 1, 0)
acquisition_lyt.addWidget(self.max_samples_spn, 1, 1)
acquisition_lyt.addWidget(QLabel("Monte Carlo repetitions"), 2, 0)
acquisition_lyt.addWidget(self.MC_repetitions_spn, 2, 1)
acquisition_lyt.addWidget(self.tunable_segmentation_method, 3, 0, 1, 2)
acquisition_lyt.addWidget(self.execute_selected_btn, 4, 0)
acquisition_lyt.addWidget(self.execute_all_btn, 4, 1)
acquisition_lyt.addWidget(self.finetuning_btn, 5, 0, 1, 2)
acquisition_lyt.addWidget(QLabel("Image queue:"), 6, 0, 1, 1)
acquisition_lyt.addWidget(self.image_pb, 6, 1)
acquisition_lyt.addWidget(QLabel("Patch queue:"), 7, 0, 1, 1)
acquisition_lyt.addWidget(self.patch_pb, 7, 1)

self.setLayout(acquisition_lyt)

self.labels_manager.setVisible(False)

def _show_labels_manager(self, show_it: bool):
self.labels_manager.setVisible(show_it)

def _reset_image_progressbar(self, num_images: int):
self.image_pb.setRange(0, num_images)
Expand Down
Loading

0 comments on commit 11fe3a6

Please sign in to comment.