Skip to content

Commit

Permalink
Fixed assertions and improved layer transformation management
Browse files Browse the repository at this point in the history
  • Loading branch information
m-albert committed Feb 21, 2024
1 parent db425e0 commit fe6c677
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 47 deletions.
4 changes: 2 additions & 2 deletions src/napari_stitcher/_tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def test_read_mosaic_image_into_list_of_spatial_xarrays():

view_sims = _reader.read_mosaic_image_into_list_of_spatial_xarrays(test_path)

assert(2, len(view_sims))
assert(min([ax in view_sims[0].dims for ax in ['x', 'y']]))
assert 2 == len(view_sims)
assert min([ax in view_sims[0].dims for ax in ['x', 'y']])

return

Expand Down
9 changes: 6 additions & 3 deletions src/napari_stitcher/_tests/test_viewer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
for ndim in [2, 3]
for N_c in [1, 2]
for N_t in [1, 2]
# (3, 1, 2)
]
)
def test_create_image_layer_tuples_from_msims(ndim, N_c, N_t, make_napari_viewer):
Expand Down Expand Up @@ -52,13 +53,15 @@ def test_create_image_layer_tuples_from_msims(ndim, N_c, N_t, make_napari_viewer
lds = viewer_utils.create_image_layer_tuples_from_msims(
msims, transform_key=registered_transform_key)
assert len(lds) == N_c * tiles_x * tiles_y * tiles_z
viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

viewer_utils.add_image_layer_tuples_to_viewer(
viewer, lds, manage_viewer_transformations=True)

lds = viewer_utils.create_image_layer_tuples_from_msims(
[mfused], transform_key=registered_transform_key)
assert len(lds) == N_c
viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

viewer_utils.add_image_layer_tuples_to_viewer(
viewer, lds, manage_viewer_transformations=True)

# wiggle time
if N_t > 1:
Expand Down
12 changes: 6 additions & 6 deletions src/napari_stitcher/_tests/test_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,22 @@ def test_stitcher_q_widget_integrated(make_napari_viewer, capsys):
stitcher_widget.run_registration()

# Check that parameters were obtained
assert(stitcher_widget.params is not None)
assert stitcher_widget.params is not None

# Check that parameters are visualised

# First, view 0 is not shifted
assert(np.allclose(
assert np.allclose(
np.eye(ndim + 1),
viewer.layers[0].affine.affine_matrix[-(ndim+1):, -(ndim+1):]))
viewer.layers[0].affine.affine_matrix[-(ndim+1):, -(ndim+1):])

# Toggle showing the registrations
stitcher_widget.visualization_type_rbuttons.value=_widget.CHOICE_REGISTERED

# Make sure view 0 is shifted now
assert(~np.allclose(
assert ~np.allclose(
np.eye(ndim + 1),
viewer.layers[1].affine.affine_matrix[-(ndim+1):, -(ndim+1):]))
viewer.layers[1].affine.affine_matrix[-(ndim+1):, -(ndim+1):])

# Run fusion
# stitcher_widget.button_fuse.clicked()
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_fusion_without_registration(make_napari_viewer):

# Run stitching
stitcher_widget.run_fusion()
assert(len(viewer.layers) == 3)
assert len(viewer.layers) == 3


def test_vanilla_layers_2D_no_time(make_napari_viewer):
Expand Down
28 changes: 14 additions & 14 deletions src/napari_stitcher/_tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def test_writer_napari(field_ndim, N_t, N_c, make_napari_viewer):
read_im = tifffile.imread(filepath)

# make sure dimensionality is right
assert(read_im.ndim == field_ndim + int(len(times) > 1) + int(len(channels) > 1))
assert read_im.ndim == field_ndim + int(len(times) > 1) + int(len(channels) > 1)

# test metadata

# https://pypi.org/project/tifffile/#examples
tif = tifffile.TiffFile(filepath)
assert(tif.series[0].axes,
['', 't'][len(times) > 1] +\
['', 'z'][field_ndim > 2] +\
['', 'c'][len(channels) > 1] +\
'YX')
assert tif.series[0].axes == \
['', 'T'][len(times) > 1] +\
['', 'Z'][field_ndim > 2] +\
['', 'C'][len(channels) > 1] +\
'YX'

resolution_unit_checked = False
resolution_value_checked = False
Expand All @@ -77,17 +77,17 @@ def test_writer_napari(field_ndim, N_t, N_c, make_napari_viewer):
p = tif.pages[0]
for tag in p.tags:
print(tag.name, '/', tag.value)
if tag.name == 'ResolutionUnit':
assert(tag.value, 'um')
resolution_unit_checked = True
# if tag.name == 'ResolutionUnit':
# assert tag.value == 'um'
# resolution_unit_checked = True
if tag.name == 'XResolution':
assert(np.isclose(spacing_xy, tag.value[1] / tag.value[0]))
assert np.isclose(spacing_xy, tag.value[1] / tag.value[0])
resolution_value_checked = True
if tag.name == 'BitsPerSample':
assert(tag.value, np.iinfo(im_dtype).bits)
assert tag.value == np.iinfo(im_dtype).bits
bitspersample_checked = True

assert(resolution_unit_checked)
assert(resolution_value_checked)
assert(bitspersample_checked)
# assert resolution_unit_checked
assert resolution_value_checked
assert bitspersample_checked

52 changes: 30 additions & 22 deletions src/napari_stitcher/viewer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,15 @@ def add_image_layer_tuples_to_viewer(
# add callback to manage viewer transformations
# (napari doesn't yet support different affine transforms for a single layer)
if manage_viewer_transformations:
viewer.dims.events.connect(
partial(manage_viewer_transformations_callback, viewer=viewer))

for l in layers:
l.metadata['napari_stitcher_manage_transformations'] = True

if manage_viewer_transformations_callback not in viewer.dims.events.callbacks:
viewer.dims.events.connect(
partial(manage_viewer_transformations_callback,
viewer=viewer)
)

return layers

Expand Down Expand Up @@ -315,17 +322,17 @@ def manage_viewer_transformations_callback(event, viewer):
- for each (compatible) layer loaded in viewer
"""

# compatible_layers = [l for l in self.viewer.layers
# if si.is_spatial_image(l.data[0])]

# consider all layers for now
compatible_layers = viewer.layers
# layers_to_manage = [l for l in viewer.layers if l.name in layer_names_to_manage]

layers_to_manage = [l for l in viewer.layers
if 'napari_stitcher_manage_transformations' in l.metadata.keys()
and l.metadata['napari_stitcher_manage_transformations']]

if not len(compatible_layers): return
if not len(layers_to_manage): return

# determine spatial dimensions from layers
all_spatial_dims = [spatial_image_utils.get_spatial_dims_from_sim(
l.data[0]) for l in compatible_layers]
l.data[0]) for l in layers_to_manage]

highest_sdim = max([len(sdim) for sdim in all_spatial_dims])

Expand All @@ -337,27 +344,27 @@ def manage_viewer_transformations_callback(event, viewer):
else:
curr_tp = 0

for _, l in enumerate(compatible_layers):
for _, l in enumerate(layers_to_manage):

if not 'full_affine_transform' in l.metadata.keys(): continue

layer_sim = l.data[0]

params = l.metadata['full_affine_transform']

try:
if 't' in params.dims:
if 't' in params.dims:
try:
p = np.array(params.sel(t=layer_sim.coords['t'][curr_tp])).squeeze()
else:
p = np.array(params).squeeze()
except:
notifications.notification_manager.receive_info(
'Timepoint %s: no parameters available for tp' % curr_tp)
continue
# if curr_tp not available, use nearest available parameter
# notifications.notification_manager.receive_info(
# 'Timepoint %s: no parameters available, taking nearest available one.' % curr_tp)
# p = np.array(params.sel(t=layer_sim.coords['t'][curr_tp], method='nearest')).squeeze()
except:
notifications.notification_manager.receive_info(
'Timepoint %s: no parameters available for tp' % curr_tp)
# if curr_tp not available, use nearest available parameter
# notifications.notification_manager.receive_info(
# 'Timepoint %s: no parameters available, taking nearest available one.' % curr_tp)
p = np.array(params.sel(t=layer_sim.coords['t'][curr_tp], method='nearest')).squeeze()
continue
else:
p = np.array(params).squeeze()

ndim_layer_data = len(layer_sim.shape)

Expand All @@ -368,6 +375,7 @@ def manage_viewer_transformations_callback(event, viewer):
full_vis_p = np.eye(ndim_layer_data + 1)
full_vis_p[-len(vis_p):, -len(vis_p):] = vis_p

# import pdb; pdb.set_trace()
l.affine.affine_matrix = full_vis_p

# refreshing layers fails sometimes
Expand Down

0 comments on commit fe6c677

Please sign in to comment.