Skip to content

Commit

Permalink
Working on generalization of the axes for 2D and 3D models
Browse files Browse the repository at this point in the history
  • Loading branch information
fercer committed Jul 19, 2024
1 parent 11fe3a6 commit b9f6d1b
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 114 deletions.
126 changes: 74 additions & 52 deletions src/napari_activelearning/_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ def compute_acquisition_superpixel(probs, super_pixel_labels):

return u_sp_lab


if USING_TORCH:
class DropoutEvalOverrider(torch.nn.Module):
def __init__(self, dropout_module):
super(DropoutEvalOverrider, self).__init__()

self._dropout = type(dropout_module)(dropout_module.p,
inplace=dropout_module.inplace)
self._dropout = type(dropout_module)(
dropout_module.p, inplace=dropout_module.inplace)

def forward(self, input):
training_temp = self._dropout.training
Expand All @@ -68,16 +69,15 @@ def forward(self, input):

return out


def add_dropout(net, p=0.05):
# First step checks if there is any Dropout layer existing in the model
has_dropout = False
for module in net.modules():
if isinstance(module, torch.nn.Sequential):
for l_idx, layer in enumerate(module):
if isinstance(layer, (torch.nn.Dropout, torch.nn.Dropout1d,
torch.nn.Dropout2d,
torch.nn.Dropout3d)):
torch.nn.Dropout2d,
torch.nn.Dropout3d)):
has_dropout = True
break
else:
Expand Down Expand Up @@ -216,7 +216,10 @@ def __init__(self, image_groups_manager: ImageGroupsManager,
tunable_segmentation_method: TunableMethod):
self._patch_size = 128
self._max_samples = 1
self._MC_repetitions = 30
self._MC_repetitions = 3

viewer = napari.current_viewer()
self._input_axes = "".join(viewer.dims.axis_labels).upper()

self.image_groups_manager = image_groups_manager
self.labels_manager = labels_manager
Expand All @@ -226,21 +229,16 @@ def __init__(self, image_groups_manager: ImageGroupsManager,

def _update_roi_from_position(self):
viewer = napari.current_viewer()
displayed_axes = "".join(viewer.dims.axis_labels).upper()
viewer_axes = "".join(viewer.dims.axis_labels).upper()
position = viewer.cursor.position
axes_order = viewer.dims.order

roi_start = [0] * len(axes_order)
roi_length = [-1] * len(axes_order)
for ord in axes_order[:-viewer.dims.ndisplay]:
roi_start[ord] = int(position[ord])
roi_length[ord] = 1

self._roi = {
ax: slice(ax_start if ax_length > 0 else None,
(ax_start + ax_length) if ax_length > 0 else None)
for ax, ax_start, ax_length in zip(displayed_axes, roi_start,
roi_length)
viewer_axes[ord]:
slice(None)
if viewer_axes[ord] in self._input_axes
else slice(int(position[ord]), int(position[ord]) + 1)
for ord in axes_order
}

def _compute_acquisition_fun(self, img, img_sp, MC_repetitions):
Expand Down Expand Up @@ -275,14 +273,31 @@ def _update_patch_progressbar(self, curr_patch_index: int):
def compute_acquisition(self, dataset_metadata, acquisition_fun,
segmentation_out,
sampling_positions=None,
segmentation_only=False):
segmentation_only=False,
spatial_axes="ZYX"):
dl = get_dataloader(dataset_metadata, patch_size=self._patch_size,
sampling_positions=sampling_positions,
spatial_axes=spatial_axes,
shuffle=True)
segmentation_max = 0
n_samples = 0
img_sampling_positions = []

pred_spatial_axes = list(dataset_metadata["images"]["axes"])
if "C" in pred_spatial_axes:
pred_spatial_axes.remove("C")
pred_spatial_axes = "".join(pred_spatial_axes)

if "masks" in dataset_metadata:
mask_axes = dataset_metadata["masks"]["source_axes"]
else:
mask_axes = pred_spatial_axes

pred_sel = tuple(
slice(None) if ax in pred_spatial_axes else None
for ax in mask_axes
)

self._reset_patch_progressbar()
for pos, img, img_sp in dl:
if USING_TORCH:
Expand All @@ -292,18 +307,24 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun,
else:
img_sp = img_sp[..., 0]

pos = {
ax: slice(pos_ax[0], pos_ax[1])
for ax, pos_ax in zip(dataset_metadata["images"]["axes"],
pos)
}

# TODO: Add the current selected slice from _update_roi_from_position
pos_u_lab = (slice(pos[0, 0], pos[0, 1]),
slice(pos[1, 0], pos[1, 1]))
pos_u_lab = tuple(
pos.get(ax, self._roi[ax])
for ax in mask_axes
)

if not segmentation_only:
u_sp_lab = self._compute_acquisition_fun(
img,
img_sp,
self._MC_repetitions,
)
acquisition_fun[pos_u_lab] = u_sp_lab
acquisition_fun[pos_u_lab] = u_sp_lab[pred_sel]
acquisition_val = u_sp_lab.max()
else:
acquisition_val = 0
Expand All @@ -312,7 +333,7 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun,
img,
segmentation_max
)
segmentation_out[pos_u_lab] = seg_out
segmentation_out[pos_u_lab] = seg_out[pred_sel]
segmentation_max = max(segmentation_max, seg_out.max())

img_sampling_positions.append(
Expand Down Expand Up @@ -352,7 +373,13 @@ def compute_acquisition_layers(
self._update_roi_from_position()

self._reset_image_progressbar(len(image_groups))

viewer = napari.current_viewer()
spatial_axes = self._input_axes
if "C" in spatial_axes:
spatial_axes = list(spatial_axes)
spatial_axes.remove("C")
spatial_axes = "".join(spatial_axes)

for n, image_group in enumerate(image_groups):
image_group.setSelected(True)
Expand Down Expand Up @@ -383,7 +410,7 @@ def compute_acquisition_layers(
for ax, ax_s, ax_scl in zip(displayed_source_axes,
displayed_shape,
displayed_scale)
if ax in "ZYX" and ax_s > 1
if ax in spatial_axes and ax_s > 1
]))

if not segmentation_only:
Expand Down Expand Up @@ -433,16 +460,8 @@ def compute_acquisition_layers(
for ax in layers_group.source_axes
]

spatial_axes = "".join([
ax for ax in layers_group.source_axes
if ax in displayed_source_axes[-viewer.dims.ndisplay:]
])

if "images" in layer_type:
if "C" in displayed_source_axes:
spatial_axes += "C"

for ax in ["Z", "Y", "X"]:
for ax in spatial_axes:
if ax not in displayed_source_axes:
continue

Expand All @@ -466,9 +485,13 @@ def compute_acquisition_layers(
dataset_metadata[layer_type]["filenames"] =\
dataset_metadata[layer_type]["filenames"][0]

dataset_metadata[layer_type]["axes"] = spatial_axes
dataset_metadata[layer_type]["modality"] = layer_type

if "images" in layer_type:
dataset_metadata[layer_type]["axes"] = self._input_axes
else:
dataset_metadata[layer_type]["axes"] = spatial_axes

if image_group.labels_group:
sampling_positions = list(
map(lambda child:
Expand All @@ -485,7 +508,8 @@ def compute_acquisition_layers(
acquisition_fun=acquisition_fun_grp,
segmentation_out=segmentation_grp,
sampling_positions=sampling_positions,
segmentation_only=segmentation_only
segmentation_only=segmentation_only,
spatial_axes=spatial_axes
)

self._update_image_progressbar(n + 1)
Expand All @@ -500,7 +524,7 @@ def compute_acquisition_layers(
# Downsample the acquisition function
acquisition_fun_ms = downsample_image(
acquisition_root,
source_axes="YX",
source_axes=spatial_axes,
data_group="labels/acquisition_fun/0",
scale=4,
num_scales=5,
Expand Down Expand Up @@ -532,7 +556,7 @@ def compute_acquisition_layers(
if acquisition_layers_group is None:
acquisition_layers_group = image_group.add_layers_group(
"acquisition",
source_axes="YX",
source_axes=spatial_axes,
use_as_input_image=False,
use_as_sampling_mask=False
)
Expand All @@ -551,7 +575,7 @@ def compute_acquisition_layers(
# Downsample the segmentation output
segmentation_ms = downsample_image(
segmentation_root,
source_axes="YX",
source_axes=spatial_axes,
data_group=f"labels/{segmentation_group_name}/0",
scale=4,
num_scales=5,
Expand All @@ -578,7 +602,7 @@ def compute_acquisition_layers(
if segmentation_layers_group is None:
segmentation_layers_group = image_group.add_layers_group(
segmentation_group_name,
source_axes="YX",
source_axes=spatial_axes,
use_as_input_image=False,
use_as_sampling_mask=False
)
Expand Down Expand Up @@ -617,7 +641,12 @@ def fine_tune(self):

dataset_metadata_list = []

viewer = napari.current_viewer()
spatial_axes = self._input_axes
if "C" in spatial_axes:
spatial_axes = list(spatial_axes)
spatial_axes.remove("C")
spatial_axes = "".join(spatial_axes)

for image_group in image_groups:
image_group.setSelected(True)

Expand All @@ -633,8 +662,6 @@ def fine_tune(self):

input_layers_group = image_group.child(input_layers_group_idx)

displayed_source_axes = input_layers_group.source_axes

dataset_metadata = {}

for layers_group, layer_type in [
Expand All @@ -649,15 +676,6 @@ def fine_tune(self):
)
]

spatial_axes = "".join([
ax for ax in layers_group.source_axes
if ax in displayed_source_axes[-viewer.dims.ndisplay:]
])

if "images" in layer_type:
if "C" in displayed_source_axes:
spatial_axes += "C"

if isinstance(dataset_metadata[layer_type]["filenames"],
MultiScaleData):
dataset_metadata[layer_type]["filenames"] =\
Expand All @@ -668,9 +686,13 @@ def fine_tune(self):
dataset_metadata[layer_type]["filenames"] =\
dataset_metadata[layer_type]["filenames"].compute()

dataset_metadata[layer_type]["axes"] = spatial_axes
dataset_metadata[layer_type]["modality"] = layer_type

if "images" in layer_type:
dataset_metadata[layer_type]["axes"] = self._input_axes
else:
dataset_metadata[layer_type]["axes"] = spatial_axes

sampling_positions = list(
map(lambda child: [ax_pos.start for ax_pos in child.position],
map(lambda idx: image_group.labels_group.child(idx),
Expand Down
42 changes: 23 additions & 19 deletions src/napari_activelearning/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,6 @@ def update_channels(self):
def update_source_axes(self):
super().update_source_axes(self.edit_axes_le.text())

display_source_axes = list(self._edit_axes.lower())
if "c" in display_source_axes:
display_source_axes.remove("c")
display_source_axes = tuple(display_source_axes)

viewer = napari.current_viewer()
if display_source_axes != viewer.dims.axis_labels:
viewer.dims.axis_labels = display_source_axes

def update_layers_group_name(self):
if super().update_layers_group_name(
self.layers_group_name_cmb.lineEdit().text()):
Expand Down Expand Up @@ -346,6 +337,8 @@ def _clear_layer_channel(self):
item = self._curr_power_spn_list.pop()
self.edit_scale_lyt.removeWidget(item)

self.generate_mask_btn.setEnabled(False)

def _update_reference_info(self):
super()._update_reference_info()
self._clear_layer_channel()
Expand Down Expand Up @@ -716,17 +709,23 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget,
tunable_segmentation_method)

self.patch_size_spn = QSpinBox(minimum=128, maximum=1024,
value=self._patch_size,
singleStep=128)
self.patch_size_spn.valueChanged.connect(self._set_patch_size)

self.max_samples_spn = QSpinBox(minimum=1, maximum=10000, value=100,
self.max_samples_spn = QSpinBox(minimum=1, maximum=10000,
value=self._max_samples,
singleStep=10)
self.max_samples_spn.valueChanged.connect(self._set_max_samples)

self.MC_repetitions_spn = QSpinBox(minimum=2, maximum=100, value=30,
self.MC_repetitions_spn = QSpinBox(minimum=2, maximum=100,
value=self._MC_repetitions,
singleStep=10)
self.MC_repetitions_spn.valueChanged.connect(self._set_MC_repetitions)

self.input_axes_le = QLineEdit(self._input_axes)
self.input_axes_le.editingFinished.connect(self._set_input_axes)

self.execute_selected_btn = QPushButton("Run on selected image groups")
self.execute_selected_btn.clicked.connect(
partial(self.compute_acquisition_layers, run_all=False)
Expand All @@ -750,14 +749,16 @@ def __init__(self, image_groups_manager: ImageGroupsManagerWidget,
acquisition_lyt.addWidget(self.max_samples_spn, 1, 1)
acquisition_lyt.addWidget(QLabel("Monte Carlo repetitions"), 2, 0)
acquisition_lyt.addWidget(self.MC_repetitions_spn, 2, 1)
acquisition_lyt.addWidget(self.tunable_segmentation_method, 3, 0, 1, 2)
acquisition_lyt.addWidget(self.execute_selected_btn, 4, 0)
acquisition_lyt.addWidget(self.execute_all_btn, 4, 1)
acquisition_lyt.addWidget(self.finetuning_btn, 5, 0, 1, 2)
acquisition_lyt.addWidget(QLabel("Image queue:"), 6, 0, 1, 1)
acquisition_lyt.addWidget(self.image_pb, 6, 1)
acquisition_lyt.addWidget(QLabel("Patch queue:"), 7, 0, 1, 1)
acquisition_lyt.addWidget(self.patch_pb, 7, 1)
acquisition_lyt.addWidget(QLabel("Input axes order"), 3, 0)
acquisition_lyt.addWidget(self.input_axes_le, 3, 1)
acquisition_lyt.addWidget(self.tunable_segmentation_method, 4, 0, 1, 2)
acquisition_lyt.addWidget(self.execute_selected_btn, 5, 0)
acquisition_lyt.addWidget(self.execute_all_btn, 5, 1)
acquisition_lyt.addWidget(self.finetuning_btn, 6, 0, 1, 2)
acquisition_lyt.addWidget(QLabel("Image queue:"), 7, 0, 1, 1)
acquisition_lyt.addWidget(self.image_pb, 7, 1)
acquisition_lyt.addWidget(QLabel("Patch queue:"), 8, 0, 1, 1)
acquisition_lyt.addWidget(self.patch_pb, 8, 1)

self.setLayout(acquisition_lyt)

Expand Down Expand Up @@ -788,3 +789,6 @@ def _set_MC_repetitions(self):

def _set_max_samples(self):
self._max_samples = self.max_samples_spn.value()

def _set_input_axes(self):
self._input_axes = self.input_axes_le.text()
Loading

0 comments on commit b9f6d1b

Please sign in to comment.