Skip to content

Commit

Permalink
Merge pull request #25 from TheJacksonLaboratory/dev
Browse files Browse the repository at this point in the history
Improve labels handling
  • Loading branch information
fercer authored Nov 27, 2024
2 parents bbb9764 + b34f1e3 commit d057b46
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 70 deletions.
95 changes: 78 additions & 17 deletions src/napari_activelearning/_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def add_multiscale_output_layer(
output_filename: Optional[Path] = None,
contrast_limits: Optional[Iterable[float]] = None,
colormap: Optional[str] = None,
use_as_input_labels: bool = False,
add_func: Optional[Callable] = napari.Viewer.add_image
):
if output_filename:
Expand Down Expand Up @@ -130,6 +131,7 @@ def add_multiscale_output_layer(
layers_group_name,
source_axes=axes,
use_as_input_image=False,
use_as_input_labels=use_as_input_labels,
use_as_sampling_mask=False
)

Expand Down Expand Up @@ -244,15 +246,24 @@ def fine_tune(self, dataset_metadata_list: Iterable[
transform = self._get_transform()

for dataset_metadata, top_lefts in dataset_metadata_list:
if top_lefts is not None:
patch_sampler = StaticPatchSampler(
patch_size=patch_sizes,
top_lefts=top_lefts,
spatial_axes=dataset_metadata["labels"]["axes"]
)
else:
patch_sampler = zds.PatchSampler(
patch_size=patch_sizes,
spatial_axes=dataset_metadata["labels"]["axes"],
min_area=0.05
)

dataset = zds.ZarrDataset(
list(dataset_metadata.values()),
return_positions=False,
draw_same_chunk=False,
patch_sampler=StaticPatchSampler(
patch_size=patch_sizes,
top_lefts=top_lefts,
spatial_axes=dataset_metadata["labels"]["axes"]
),
patch_sampler=patch_sampler,
shuffle=True,
)

Expand Down Expand Up @@ -355,7 +366,7 @@ def _prepare_datasets_metadata(
dataset_metadata[layer_type] = layers_group.metadata
dataset_metadata[layer_type]["roi"] = None

if layer_type in ["images", "labels"]:
if layer_type in ["images", "labels", "masks"]:
dataset_metadata[layer_type]["roi"] = [tuple(
slice(0, ax_s - ax_s % self._patch_sizes.get(ax, 1))
if (ax != "C"
Expand All @@ -365,7 +376,7 @@ def _prepare_datasets_metadata(
for ax, ax_s in zip(displayed_source_axes,
displayed_shape)
if (layer_type == "images"
or (layer_type == "labels" and ax != "C"))
or (layer_type in ["labels", "masks"] and ax != "C"))
)]

if isinstance(dataset_metadata[layer_type]["filenames"],
Expand Down Expand Up @@ -415,6 +426,7 @@ def _prepare_datasets_metadata(

def compute_acquisition(self, dataset_metadata, acquisition_fun,
segmentation_out,
sampled_mask=None,
sampling_positions=None,
segmentation_only=False):
model_spatial_axes = [
Expand Down Expand Up @@ -498,6 +510,9 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun,
segmentation_out[pos_u_lab] = seg_out[pred_sel]
segmentation_max = max(segmentation_max, seg_out.max())

if sampled_mask is not None:
sampled_mask[pos_u_lab] = True

img_sampling_positions.append(
LabelItem(acquisition_val, position=pos_u_lab)
)
Expand Down Expand Up @@ -613,11 +628,28 @@ def compute_acquisition_layers(
(sampling_mask_layers_group, "masks")]
)

if sampling_positions is None:
sampled_root = save_zarr(
output_filename,
data=None,
shape=output_shape,
chunk_size=True,
name="sampled_positions",
dtype=bool,
is_label=True,
is_multiscale=True
)

sampled_grp = sampled_root["labels/sampled_positions/0"]
else:
sampled_grp = None

# Compute acquisition function of the current image
img_sampling_positions = self.compute_acquisition(
dataset_metadata,
acquisition_fun=acquisition_fun_grp,
segmentation_out=segmentation_grp,
sampled_mask=sampled_grp,
sampling_positions=sampling_positions,
segmentation_only=segmentation_only
)
Expand Down Expand Up @@ -646,6 +678,21 @@ def compute_acquisition_layers(
add_func=viewer.add_image
)

if sampled_grp is not None:
add_multiscale_output_layer(
sampled_root,
axes=output_axes,
scale=output_scale,
data_group="labels/sampled_positions/0",
group_name=group_name + " sampled positions",
layers_group_name="sampled positions",
image_group=image_group,
reference_source_axes=displayed_source_axes,
reference_scale=displayed_scale,
output_filename=output_filename,
add_func=viewer.add_labels
)

segmentation_channel = add_multiscale_output_layer(
segmentation_root,
axes=output_axes,
Expand All @@ -657,10 +704,13 @@ def compute_acquisition_layers(
reference_source_axes=displayed_source_axes,
reference_scale=displayed_scale,
output_filename=output_filename,
use_as_input_labels=True,
add_func=viewer.add_labels
)

if not segmentation_only and image_group.labels_group is None:
if (not segmentation_only
and image_group is not None
and image_group.labels_group is None):
new_label_group = self.labels_manager.add_labels(
segmentation_channel,
img_sampling_positions
Expand All @@ -679,8 +729,7 @@ def fine_tune(self):
range(self.image_groups_manager.groups_root.childCount()))
))

if (not image_groups
or not self.labels_manager.labels_group_root.childCount()):
if not image_groups:
return False

dataset_metadata_list = []
Expand All @@ -689,16 +738,29 @@ def fine_tune(self):
image_group.setSelected(True)

input_layers_group_idx = image_group.input_layers_group

segmentation_layers_group = image_group.getLayersGroup(
layers_group_name="segmentation"
)
label_layers_group_idx = image_group.labels_layers_group

if (input_layers_group_idx is None
or segmentation_layers_group is None):
or label_layers_group_idx is None):
continue

sampling_mask_layers_group = None
if image_group.sampling_mask_layers_group is not None:
sampling_mask_layers_group = image_group.child(
image_group.sampling_mask_layers_group
)

input_layers_group = image_group.child(input_layers_group_idx)
label_layers_group = image_group.child(label_layers_group_idx)

layer_types = [
(input_layers_group, "images"),
(label_layers_group, "labels")
]

if (sampling_mask_layers_group is not None
and image_group.labels_group is None):
layer_types.append((sampling_mask_layers_group, "masks"))

displayed_source_axes = input_layers_group.source_axes
displayed_shape = input_layers_group.shape
Expand All @@ -715,8 +777,7 @@ def fine_tune(self):
output_axes,
displayed_source_axes,
displayed_shape,
[(input_layers_group, "images"),
(segmentation_layers_group, "labels")]
layer_types,
)

dataset_metadata_list.append((dataset_metadata,
Expand Down
21 changes: 20 additions & 1 deletion src/napari_activelearning/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ def __init__(self):
self.update_use_as_input
)

self.use_as_labels_chk = QCheckBox("Use as labels")
self.use_as_labels_chk.setEnabled(False)
self.use_as_labels_chk.toggled.connect(
self.update_use_as_labels
)

self.use_as_sampling_chk = QCheckBox("Use as sampling mask")
self.use_as_sampling_chk.setEnabled(False)
self.use_as_sampling_chk.toggled.connect(
Expand Down Expand Up @@ -251,7 +257,8 @@ def __init__(self):
editor_grid_lyt.addWidget(self.output_dir_le, 3, 1, 1, 3)
editor_grid_lyt.addWidget(self.output_dir_btn, 3, 3)
editor_grid_lyt.addWidget(self.use_as_input_chk, 4, 0)
editor_grid_lyt.addWidget(self.use_as_sampling_chk, 4, 1)
editor_grid_lyt.addWidget(self.use_as_labels_chk, 4, 1)
editor_grid_lyt.addWidget(self.use_as_sampling_chk, 4, 2)
editor_grid_lyt.addWidget(QLabel("Layer scale"), 5, 0)
editor_grid_lyt.addWidget(self.edit_scale_mdspn, 5, 1)
editor_grid_lyt.addWidget(QLabel("Layer translate"), 5, 2)
Expand Down Expand Up @@ -285,7 +292,12 @@ def _clear_layers_group(self):
self.edit_axes_le.setEnabled(False)
self.layers_group_name_cmb.setEnabled(False)
self.use_as_input_chk.setEnabled(False)
self.use_as_input_chk.setChecked(False)
self.use_as_labels_chk.setEnabled(False)
self.use_as_labels_chk.setChecked(False)
self.use_as_sampling_chk.setEnabled(False)
self.use_as_sampling_chk.setChecked(False)

self.edit_scale_mdspn.axes = ""
self.edit_translate_mdspn.axes = ""

Expand Down Expand Up @@ -331,13 +343,17 @@ def _fill_layers_group(self):
self.use_as_input_chk.setChecked(
self._active_layers_group.use_as_input_image
)
self.use_as_labels_chk.setChecked(
self._active_layers_group.use_as_input_labels
)
self.use_as_sampling_chk.setChecked(
self._active_layers_group.use_as_sampling_mask
)

self.layers_group_name_cmb.setEnabled(True)
self.edit_axes_le.setEnabled(True)
self.use_as_input_chk.setEnabled(True)
self.use_as_labels_chk.setEnabled(True)
self.use_as_sampling_chk.setEnabled(True)

self._fill_layer()
Expand Down Expand Up @@ -398,6 +414,9 @@ def update_layers_group_name(self):
def update_use_as_input(self):
super().update_use_as_input(self.use_as_input_chk.isChecked())

def update_use_as_labels(self):
super().update_use_as_labels(self.use_as_labels_chk.isChecked())

def update_use_as_sampling(self):
super().update_use_as_sampling(self.use_as_sampling_chk.isChecked())

Expand Down
37 changes: 14 additions & 23 deletions src/napari_activelearning/_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ def __init__(self):

def add_managed_label_group(self, label_group: LabelGroup):
layer = label_group.layer_channel.layer
layers_group = label_group.layer_channel.parent()
if layers_group is not None:
image_group = layers_group.parent()
else:
image_group = None

if layer not in self.managed_layers:
self.managed_layers[layer] = []

self.managed_layers[layer].append(label_group)
self.managed_layers[layer] = (label_group, image_group)

viewer = napari.current_viewer()
viewer.layers.events.removed.connect(
Expand All @@ -108,22 +110,13 @@ def add_managed_label_group(self, label_group: LabelGroup):
def remove_managed_label_group(self, label_group: LabelGroup):
layer = label_group.layer_channel.layer

if (layer in self.managed_layers
and label_group in self.managed_layers[layer]):
self.managed_layers[layer].remove(label_group)
if layer in self.managed_layers:
(label_group,
image_group) = self.managed_layers.pop(layer)

layers_group = label_group.layer_channel.parent()
image_group = None
if layers_group is not None:
image_group = layers_group.parent()

if (image_group is not None
and label_group == image_group.labels_group):
if image_group is not None:
image_group.labels_group = None

if not self.managed_layers[layer]:
self.managed_layers.pop(layer)

self.setSelected(True)

def remove_managed_layer(self, event):
Expand Down Expand Up @@ -442,8 +435,6 @@ def edit_labels(self):
return True

def commit(self):
segmentation_channel_layer = None

edit_data = None

if self._requires_commit:
Expand All @@ -457,10 +448,10 @@ def commit(self):
and self._active_edit_layer in viewer.layers):
viewer.layers.remove(self._active_edit_layer)

if segmentation_channel_layer:
segmentation_channel_layer.refresh()
segmentation_channel_layer.visible = True
viewer.layers.selection.add(segmentation_channel_layer)
if self._active_layer_channel:
self._active_layer_channel.layer.refresh()
self._active_layer_channel.visible = True
self._active_layer_channel.selected = True

self._transaction = None
self._active_edit_layer = None
Expand Down
Loading

0 comments on commit d057b46

Please sign in to comment.