Skip to content

Commit

Permalink
Merge pull request #14 from TheJacksonLaboratory/dev
Browse files Browse the repository at this point in the history
Fixing labels manager interface
  • Loading branch information
fercer authored Oct 11, 2024
2 parents 1084f05 + 349d2b5 commit 31145db
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/napari_activelearning/_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def fine_tune(self, dataset_metadata_list: Iterable[
dataset = zds.ZarrDataset(
list(dataset_metadata.values()),
return_positions=False,
draw_same_chunk=True,
draw_same_chunk=False,
patch_sampler=StaticPatchSampler(
patch_size=patch_sizes,
top_lefts=top_lefts,
Expand Down
12 changes: 11 additions & 1 deletion src/napari_activelearning/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,17 @@ def edit_labels(self):
def commit(self):
super(LabelsManagerWidget, self).commit()
self.commit_btn.setEnabled(False)
self.edit_labels_btn.setEnabled(True)
self.edit_labels_btn.setEnabled(self._active_label is not None)

self.prev_img_btn.setEnabled(self._active_label is not None)
self.prev_patch_btn.setEnabled(self._active_label is not None)
self.next_patch_btn.setEnabled(self._active_label is not None)
self.next_img_btn.setEnabled(self._active_label is not None)

self.remove_labels_btn.setEnabled(self._active_label is not None)
self.remove_labels_group_btn.setEnabled(
self._active_label_group is not None
)


class TunableMethodWidget(TunableMethod, QWidget):
Expand Down
128 changes: 88 additions & 40 deletions src/napari_activelearning/_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,17 @@ def remove_managed_label_group(self, label_group: LabelGroup):
if not self.managed_layers[layer]:
self.managed_layers.pop(layer)

self.setSelected(True)

def remove_managed_layer(self, event):
removed_layer = event.value

label_group_list = self.managed_layers.get(removed_layer, [])
for label_group in label_group_list:
self.removeChild(label_group)

self.setSelected(True)

def addChild(self, child: QTreeWidgetItem):
if isinstance(child, LabelGroup):
if child.layer_channel:
Expand All @@ -151,11 +155,15 @@ def removeChild(self, child: QTreeWidgetItem):

super(LabelGroupRoot, self).removeChild(child)

self.setSelected(True)

def takeChild(self, index: int):
child = super(LabelGroupRoot, self).takeChild(index)
if isinstance(child, LabelGroup) and child.layer_channel:
self.remove_managed_label_group(child)

self.setSelected(True)

return child

def takeChildren(self):
Expand All @@ -164,6 +172,8 @@ def takeChildren(self):
if isinstance(child, LabelGroup) and child.layer_channel:
self.remove_managed_label_group(child)

self.setSelected(True)

return children


Expand All @@ -189,6 +199,50 @@ def __init__(self):
self.commit
)

def _load_label_data(self, input_filename, data_group=None):
if isinstance(input_filename, (Path, str)):
spec = {
'driver': 'zarr',
'kvstore': {
'driver': 'file',
'path': str(Path(input_filename) / data_group),
},
}

ts_array = ts.open(spec).result()

self._transaction = ts.Transaction()

label_data = ts_array.with_transaction(self._transaction)
selection = ts.d[:][self._active_label.position].translate_to[0]
label_data = label_data[selection]

elif isinstance(input_filename, MultiScaleData):
label_data = np.array(
input_filename[0][self._active_label.position]
)
else:
label_data = np.array(input_filename[self._active_label.position])

return label_data

def _write_label_data(self, label_data: np.ndarray):
if self._active_layers_group:
segmentation_channel = self._active_layers_group.child(0)
segmentation_channel_layer = segmentation_channel.layer
if isinstance(segmentation_channel.layer.data, MultiScaleData):
segmentation_channel_data =\
segmentation_channel_layer.data[0]
else:
segmentation_channel_data = segmentation_channel_layer.data

if isinstance(self._transaction, ts.Transaction):
self._transaction.commit_async()
elif (self._active_label.position is not None
and segmentation_channel_data is not None):
segmentation_channel_data[self._active_label.position] =\
label_data

def add_labels(self, layer_channel: LayerChannel,
labels: Iterable[LabelItem]):
new_label_group = LabelGroup(layer_channel)
Expand All @@ -212,7 +266,14 @@ def remove_labels(self):
self._active_label_group.removeChild(self._active_label)
if not self._active_label_group.childCount():
self.remove_labels_group()
else:
self._active_label_group.setSelected(True)

for child in map(lambda idx: self.labels_group_root.child(idx),
range(self.labels_group_root.childCount())):
child.setSelected(False)

self._active_label = None
self._requires_commit = False
self.commit()

Expand All @@ -221,6 +282,14 @@ def remove_labels_group(self):
return

self.labels_group_root.removeChild(self._active_label_group)

for child in map(lambda idx: self.labels_group_root.child(idx),
range(self.labels_group_root.childCount())):
child.setSelected(False)

self.labels_group_root.setSelected(True)
self._active_label_group = None

self._requires_commit = False
self.commit()

Expand Down Expand Up @@ -277,7 +346,7 @@ def focus_region(self, label: Optional[QTreeWidgetItem] = None,

if isinstance(label, list) and len(label):
label = label[0]
elif not isinstance(label, LabelItem):
elif not isinstance(label, (LabelItem, LabelGroup)):
label = None

self._active_label_group = None
Expand Down Expand Up @@ -361,33 +430,11 @@ def edit_labels(self):
input_filename = self._active_layers_group.source_data
data_group = self._active_layers_group.data_group

if isinstance(input_filename, (Path, str)):
spec = {
'driver': 'zarr',
'kvstore': {
'driver': 'file',
'path': str(Path(input_filename) / data_group),
},
}

ts_array = ts.open(spec).result()

self._transaction = ts.Transaction()

lazy_data = ts_array.with_transaction(self._transaction)
lazy_data =\
lazy_data[ts.d[:][self._active_label.position].translate_to[0]]

elif isinstance(input_filename, MultiScaleData):
lazy_data = np.array(
input_filename[0][self._active_label.position]
)
else:
lazy_data = np.array(input_filename[self._active_label.position])
label_data = self._load_label_data(input_filename, data_group)

viewer = napari.current_viewer()
self._active_edit_layer = viewer.add_labels(
lazy_data,
label_data,
name="Labels edit",
blending="translucent_no_depth",
opacity=0.7,
Expand Down Expand Up @@ -415,21 +462,22 @@ def commit(self):
if self._active_edit_layer:
edit_data = self._active_edit_layer.data

if self._active_layers_group:
segmentation_channel = self._active_layers_group.child(0)
segmentation_channel_layer = segmentation_channel.layer
if isinstance(segmentation_channel.layer.data, MultiScaleData):
segmentation_channel_data =\
segmentation_channel_layer.data[0]
else:
segmentation_channel_data = segmentation_channel_layer.data

if isinstance(self._transaction, ts.Transaction):
self._transaction.commit_async()
elif (self._active_label.position is not None
and segmentation_channel_data is not None):
segmentation_channel_data[self._active_label.position] =\
edit_data
# if self._active_layers_group:
# segmentation_channel = self._active_layers_group.child(0)
# segmentation_channel_layer = segmentation_channel.layer
# if isinstance(segmentation_channel.layer.data, MultiScaleData):
# segmentation_channel_data =\
# segmentation_channel_layer.data[0]
# else:
# segmentation_channel_data = segmentation_channel_layer.data

# if isinstance(self._transaction, ts.Transaction):
# self._transaction.commit_async()
# elif (self._active_label.position is not None
# and segmentation_channel_data is not None):
# segmentation_channel_data[self._active_label.position] =\
# edit_data
self._write_label_data(edit_data)

viewer = napari.current_viewer()
if (self._active_edit_layer
Expand Down
8 changes: 8 additions & 0 deletions src/napari_activelearning/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,14 @@ def remove_managed_layer(self, event):
if not image_group.childCount():
self.removeChild(image_group)

if image_group.labels_group:
if image_group.labels_group.parent():
image_group.labels_group.parent().removeChild(
image_group.labels_group
)

image_group.labels_group = None

if not self.managed_layers[removed_layer]:
self.managed_layers.pop(removed_layer)

Expand Down
2 changes: 1 addition & 1 deletion src/napari_activelearning/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def get_dataloader(
train_dataset = zds.ZarrDataset(
list(dataset_metadata.values()),
return_positions=True,
draw_same_chunk=True,
draw_same_chunk=False,
patch_sampler=patch_sampler,
shuffle=shuffle
)
Expand Down

0 comments on commit 31145db

Please sign in to comment.