diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index 945396d..1f5819f 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -246,7 +246,7 @@ def fine_tune(self, dataset_metadata_list: Iterable[ dataset = zds.ZarrDataset( list(dataset_metadata.values()), return_positions=False, - draw_same_chunk=True, + draw_same_chunk=False, patch_sampler=StaticPatchSampler( patch_size=patch_sizes, top_lefts=top_lefts, diff --git a/src/napari_activelearning/_interface.py b/src/napari_activelearning/_interface.py index fee278d..82a60e0 100644 --- a/src/napari_activelearning/_interface.py +++ b/src/napari_activelearning/_interface.py @@ -794,7 +794,17 @@ def edit_labels(self): def commit(self): super(LabelsManagerWidget, self).commit() self.commit_btn.setEnabled(False) - self.edit_labels_btn.setEnabled(True) + self.edit_labels_btn.setEnabled(self._active_label is not None) + + self.prev_img_btn.setEnabled(self._active_label is not None) + self.prev_patch_btn.setEnabled(self._active_label is not None) + self.next_patch_btn.setEnabled(self._active_label is not None) + self.next_img_btn.setEnabled(self._active_label is not None) + + self.remove_labels_btn.setEnabled(self._active_label is not None) + self.remove_labels_group_btn.setEnabled( + self._active_label_group is not None + ) class TunableMethodWidget(TunableMethod, QWidget): diff --git a/src/napari_activelearning/_labels.py b/src/napari_activelearning/_labels.py index 7929360..77c68f2 100644 --- a/src/napari_activelearning/_labels.py +++ b/src/napari_activelearning/_labels.py @@ -123,6 +123,8 @@ def remove_managed_label_group(self, label_group: LabelGroup): if not self.managed_layers[layer]: self.managed_layers.pop(layer) + self.setSelected(True) + def remove_managed_layer(self, event): removed_layer = event.value @@ -130,6 +132,8 @@ def remove_managed_layer(self, event): for label_group in label_group_list: self.removeChild(label_group) + self.setSelected(True) + def addChild(self, child: QTreeWidgetItem): if isinstance(child, LabelGroup): if child.layer_channel: @@ -151,11 +155,15 @@ def removeChild(self, child: QTreeWidgetItem): super(LabelGroupRoot, self).removeChild(child) + self.setSelected(True) + def takeChild(self, index: int): child = super(LabelGroupRoot, self).takeChild(index) if isinstance(child, LabelGroup) and child.layer_channel: self.remove_managed_label_group(child) + self.setSelected(True) + return child def takeChildren(self): @@ -164,6 +172,8 @@ def takeChildren(self): if isinstance(child, LabelGroup) and child.layer_channel: self.remove_managed_label_group(child) + self.setSelected(True) + return children @@ -189,6 +199,50 @@ def __init__(self): self.commit ) + def _load_label_data(self, input_filename, data_group=None): + if isinstance(input_filename, (Path, str)): + spec = { + 'driver': 'zarr', + 'kvstore': { + 'driver': 'file', + 'path': str(Path(input_filename) / data_group), + }, + } + + ts_array = ts.open(spec).result() + + self._transaction = ts.Transaction() + + label_data = ts_array.with_transaction(self._transaction) + selection = ts.d[:][self._active_label.position].translate_to[0] + label_data = label_data[selection] + + elif isinstance(input_filename, MultiScaleData): + label_data = np.array( + input_filename[0][self._active_label.position] + ) + else: + label_data = np.array(input_filename[self._active_label.position]) + + return label_data + + def _write_label_data(self, label_data: np.ndarray): + if self._active_layers_group: + segmentation_channel = self._active_layers_group.child(0) + segmentation_channel_layer = segmentation_channel.layer + if isinstance(segmentation_channel.layer.data, MultiScaleData): + segmentation_channel_data =\ + segmentation_channel_layer.data[0] + else: + segmentation_channel_data = segmentation_channel_layer.data + + if isinstance(self._transaction, ts.Transaction): + self._transaction.commit_async() + elif (self._active_label.position is not None + and segmentation_channel_data is not None): + segmentation_channel_data[self._active_label.position] =\ + label_data + def add_labels(self, layer_channel: LayerChannel, labels: Iterable[LabelItem]): new_label_group = LabelGroup(layer_channel) @@ -212,7 +266,14 @@ def remove_labels(self): self._active_label_group.removeChild(self._active_label) if not self._active_label_group.childCount(): self.remove_labels_group() + else: + self._active_label_group.setSelected(True) + + for child in map(lambda idx: self.labels_group_root.child(idx), + range(self.labels_group_root.childCount())): + child.setSelected(False) + self._active_label = None self._requires_commit = False self.commit() @@ -221,6 +282,14 @@ def remove_labels_group(self): return self.labels_group_root.removeChild(self._active_label_group) + + for child in map(lambda idx: self.labels_group_root.child(idx), + range(self.labels_group_root.childCount())): + child.setSelected(False) + + self.labels_group_root.setSelected(True) + self._active_label_group = None + self._requires_commit = False self.commit() @@ -277,7 +346,7 @@ def focus_region(self, label: Optional[QTreeWidgetItem] = None, if isinstance(label, list) and len(label): label = label[0] - elif not isinstance(label, LabelItem): + elif not isinstance(label, (LabelItem, LabelGroup)): label = None self._active_label_group = None @@ -361,33 +430,11 @@ def edit_labels(self): input_filename = self._active_layers_group.source_data data_group = self._active_layers_group.data_group - if isinstance(input_filename, (Path, str)): - spec = { - 'driver': 'zarr', - 'kvstore': { - 'driver': 'file', - 'path': str(Path(input_filename) / data_group), - }, - } - - ts_array = ts.open(spec).result() - - self._transaction = ts.Transaction() - - lazy_data = ts_array.with_transaction(self._transaction) - lazy_data =\ - lazy_data[ts.d[:][self._active_label.position].translate_to[0]] - - elif isinstance(input_filename, MultiScaleData): - lazy_data = np.array( - input_filename[0][self._active_label.position] - ) - else: - lazy_data = np.array(input_filename[self._active_label.position]) + label_data = self._load_label_data(input_filename, data_group) viewer = napari.current_viewer() self._active_edit_layer = viewer.add_labels( - lazy_data, + label_data, name="Labels edit", blending="translucent_no_depth", opacity=0.7, @@ -415,21 +462,22 @@ def commit(self): if self._active_edit_layer: edit_data = self._active_edit_layer.data - if self._active_layers_group: - segmentation_channel = self._active_layers_group.child(0) - segmentation_channel_layer = segmentation_channel.layer - if isinstance(segmentation_channel.layer.data, MultiScaleData): - segmentation_channel_data =\ - segmentation_channel_layer.data[0] - else: - segmentation_channel_data = segmentation_channel_layer.data - - if isinstance(self._transaction, ts.Transaction): - self._transaction.commit_async() - elif (self._active_label.position is not None - and segmentation_channel_data is not None): - segmentation_channel_data[self._active_label.position] =\ - edit_data + # if self._active_layers_group: + # segmentation_channel = self._active_layers_group.child(0) + # segmentation_channel_layer = segmentation_channel.layer + # if isinstance(segmentation_channel.layer.data, MultiScaleData): + # segmentation_channel_data =\ + # segmentation_channel_layer.data[0] + # else: + # segmentation_channel_data = segmentation_channel_layer.data + + # if isinstance(self._transaction, ts.Transaction): + # self._transaction.commit_async() + # elif (self._active_label.position is not None + # and segmentation_channel_data is not None): + # segmentation_channel_data[self._active_label.position] =\ + # edit_data + self._write_label_data(edit_data) viewer = napari.current_viewer() if (self._active_edit_layer diff --git a/src/napari_activelearning/_layers.py b/src/napari_activelearning/_layers.py index 8af85e2..0ed1907 100644 --- a/src/napari_activelearning/_layers.py +++ b/src/napari_activelearning/_layers.py @@ -869,6 +869,14 @@ def remove_managed_layer(self, event): if not image_group.childCount(): self.removeChild(image_group) + if image_group.labels_group: + if image_group.labels_group.parent(): + image_group.labels_group.parent().removeChild( + image_group.labels_group + ) + + image_group.labels_group = None + if not self.managed_layers[removed_layer]: self.managed_layers.pop(removed_layer) diff --git a/src/napari_activelearning/_utils.py b/src/napari_activelearning/_utils.py index 8654cbf..74fe9c3 100644 --- a/src/napari_activelearning/_utils.py +++ b/src/napari_activelearning/_utils.py @@ -249,7 +249,7 @@ def get_dataloader( train_dataset = zds.ZarrDataset( list(dataset_metadata.values()), return_positions=True, - draw_same_chunk=True, + draw_same_chunk=False, patch_sampler=patch_sampler, shuffle=shuffle )