diff --git a/src/napari_activelearning/_tests/test_acquisition.py b/src/napari_activelearning/_tests/test_acquisition.py index 3f72d71..611334f 100644 --- a/src/napari_activelearning/_tests/test_acquisition.py +++ b/src/napari_activelearning/_tests/test_acquisition.py @@ -1,5 +1,4 @@ import pytest -import shutil from unittest.mock import MagicMock, patch from pathlib import Path import numpy as np @@ -7,7 +6,8 @@ from napari_activelearning._acquisition import (AcquisitionFunction, add_multiscale_output_layer) -from napari_activelearning._layers import LayerChannel +from napari_activelearning._layers import LayerChannel, ImageGroup, LayersGroup +from napari_activelearning._models import SimpleTunable try: import torch @@ -183,3 +183,127 @@ def test_add_multiscale_output_layer(output_array, image_group, ) assert isinstance(output_channel, LayerChannel) + + +def test_prepare_datasets_metadata(acquisition_function): + # Define the input parameters for the method + image_group = ImageGroup() + output_axes = "TCZYX" + displayed_source_axes = "TCZYX" + displayed_shape = [1, 3, 10, 10, 10] + layer_types = [(LayersGroup(), "images")] + + # Call the method + dataset_metadata, sampling_positions = acquisition_function._prepare_datasets_metadata( + image_group, + output_axes, + displayed_source_axes, + displayed_shape, + layer_types + ) + + # Assert the output values + assert dataset_metadata == { + "images": { + "filenames": None, + "data_group": None, + "source_axes": "TCZYX", + "axes": "TCZYX", + "roi": [[slice(None), slice(None), slice(None), slice(None), slice(None)]], + "modality": "images" + } + } + assert sampling_positions is None + + +def test_compute_acquisition_layers(acquisition_function): + # Mock the necessary dependencies and setup the test data + image_group = MagicMock() + acquisition_function.image_groups_manager.groups_root.childCount.return_value = 1 + acquisition_function.image_groups_manager.groups_root.child.return_value = image_group + acquisition_function.image_groups_manager.groups_root.child.return_value.getSelected.return_value = True + acquisition_function.image_groups_manager.groups_root.child.return_value.group_name = "test_group" + acquisition_function.image_groups_manager.groups_root.child.return_value.group_dir = "/path/to/group" + acquisition_function.image_groups_manager.groups_root.child.return_value.input_layers_group = 0 + acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.source_axes = "TCZYX" + acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.shape = (1, 3, 10, 10, 10) + acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.scale = (1.0, 1.0, 1.0, 1.0, 1.0) + acquisition_function.image_groups_manager.groups_root.child.return_value.sampling_mask_layers_group = None + acquisition_function._prepare_datasets_metadata.return_value = ({}, None) + + # Call the method under test + acquisition_function.compute_acquisition_layers(run_all=True, segmentation_group_name="segmentation", segmentation_only=False) + + # Assert that the necessary methods were called with the expected arguments + acquisition_function.image_groups_manager.groups_root.childCount.assert_called_once() + acquisition_function.image_groups_manager.groups_root.child.assert_called_once_with(0) + acquisition_function.image_groups_manager.groups_root.child.return_value.getSelected.assert_called_once() + acquisition_function.image_groups_manager.groups_root.child.return_value.group_name.assert_called_once() + acquisition_function.image_groups_manager.groups_root.child.return_value.group_dir.assert_called_once() + acquisition_function.image_groups_manager.groups_root.child.return_value.input_layers_group.assert_called_once() + acquisition_function.image_groups_manager.groups_root.child.return_value.child.assert_called_once_with(0) + acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.source_axes.assert_called_once() + acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.shape.assert_called_once() + acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.scale.assert_called_once() + acquisition_function.image_groups_manager.groups_root.child.return_value.sampling_mask_layers_group.assert_called_once() + acquisition_function._prepare_datasets_metadata.assert_called_once_with( + image_group, + "TCZYX", + "TCZYX", + (1, 3, 10, 10, 10), + [(acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value, "images"), + (acquisition_function.image_groups_manager.groups_root.child.return_value.sampling_mask_layers_group, "masks")] + ) + + +@pytest.fixture +def tunable_method(): + return SimpleTunable() + + +def test_fine_tune(tunable_method): + dataset_metadata_list = [ + ( + { + "images": { + "filenames": ["image1.tif", "image2.tif"], + "data_group": "data", + "source_axes": "YXC", + "axes": "YXC", + "roi": None, + "modality": "images" + } + }, + [ + [1, 2, 3], + [4, 5, 6] + ] + ), + ( + { + "images": { + "filenames": ["image3.tif", "image4.tif"], + "data_group": "data", + "source_axes": "YXC", + "axes": "YXC", + "roi": None, + "modality": "images" + } + }, + [ + [7, 8, 9], + [10, 11, 12] + ] + ) + ] + + train_data_proportion = 0.8 + patch_sizes = 256 + model_axes = "YXC" + + tunable_method.fine_tune(dataset_metadata_list, train_data_proportion, patch_sizes, model_axes) + + # Add assertions here to verify the behavior of the fine_tune method + assert tunable_method.train_data_proportion == train_data_proportion + assert tunable_method.patch_sizes == patch_sizes + assert tunable_method.model_axes == model_axes diff --git a/src/napari_activelearning/_tests/test_layers.py b/src/napari_activelearning/_tests/test_layers.py index e8b4128..3448af0 100644 --- a/src/napari_activelearning/_tests/test_layers.py +++ b/src/napari_activelearning/_tests/test_layers.py @@ -24,7 +24,7 @@ def layer_channel(sample_layer): def test_initialization(layer_channel, sample_layer): - assert layer_channel.layer == sample_layer, print(id(sample_layer), id(layer_channel.layer)) + assert layer_channel.layer == sample_layer assert layer_channel.channel == 1 assert layer_channel.source_axes == "TZYX" assert layer_channel.name == "sample_layer" diff --git a/src/napari_activelearning/_tests/test_utils.py b/src/napari_activelearning/_tests/test_utils.py new file mode 100644 index 0000000..33251f9 --- /dev/null +++ b/src/napari_activelearning/_tests/test_utils.py @@ -0,0 +1,394 @@ +import pytest + +import shutil +from pathlib import Path +import operator + +import numpy as np +import zarr +import zarrdataset as zds + +from napari.layers import Image +from napari.layers._source import Source +from napari.layers._multiscale_data import MultiScaleData +from napari_activelearning._utils import (get_source_data, downsample_image, + save_zarr, + validate_name, + get_basename, + get_dataloader, + StaticPatchSampler, + SuperPixelGenerator) + + +@pytest.fixture +def dataset_metadata(): + return { + "images": { + "filenames": ["image1.tif", "image2.tif"], + "data_group": "data", + "source_axes": "YXC", + "axes": "YXC", + "roi": None, + "modality": "images" + } + } + + +@pytest.fixture(scope="module", params=[True, False]) +def output_dir(request, tmpdir_factory): + if request.param: + tmp_dir = tmpdir_factory.mktemp("temp") + tmp_dir_path = Path(tmp_dir) + else: + tmp_dir_path = None + + yield tmp_dir_path + + +@pytest.fixture(scope="module", params=[Path, None, zarr.Group]) +def output_group(request, tmpdir_factory): + group_type = request.param + if group_type is Path: + tmp_dir = tmpdir_factory.mktemp("temp") + zarr_group = Path(tmp_dir) / "output.zarr" + elif group_type is zarr.Group: + zarr_group = zarr.open() + else: + zarr_group = None + + yield zarr_group + + +@pytest.fixture(scope="module", params=[None, "0"]) +def data_group(request): + return request.param + + +def single_scale_array(*args): + shape = (1, 3, 10, 10, 10) + data = np.random.random(shape) + return data, None, None, shape + + +def multiscale_array(*args): + shape = (1, 3, 10, 10, 10) + data = np.random.random(shape) + data = [data, data[..., ::2, ::2, ::2], data[..., ::4, ::4, ::4]] + shape = [arr.shape for arr in data] + + return data, None, None, shape + + +def single_scale_zarr(output_dir, data_group): + input_filename = None + + sample_data, _, _, shape = single_scale_array() + + if output_dir: + input_filename = output_dir / "input.zarr" + z_root = zarr.open(input_filename) + else: + z_root = zarr.open() + + if data_group: + z_group = z_root.create_group(data_group) + else: + z_group = z_root + + z_group.create_dataset(name="0", data=sample_data, overwrite=True) + if data_group: + data_group = str(Path(data_group) / "0") + else: + data_group = "0" + + return z_root, input_filename, data_group, shape + + +def multiscale_zarr(output_dir, data_group): + input_filename = None + + sample_data, _, _, shape = multiscale_array() + + if output_dir: + input_filename = output_dir / "input.zarr" + z_root = zarr.open(input_filename) + + else: + z_root = zarr.open() + + if data_group: + z_group = z_root.create_group(data_group) + else: + z_group = z_root + + source_data = [] + for lvl, data in enumerate(sample_data): + z_group.create_dataset(name="%i" % lvl, data=data, overwrite=True) + source_data.append(z_group["%i" % lvl]) + + return source_data, input_filename, data_group, shape + + +@pytest.fixture +def image_collection(): + source_data, input_filename, data_group, shape = single_scale_zarr(None, + None) + collection = zds.ImageCollection( + dict( + images=dict( + filename=source_data, + data_group=data_group, + source_axes="TCZYX", + axes="TCZYX" + ) + ), + spatial_axes="ZYX" + ) + return collection + + +@pytest.fixture(scope="module", params=[single_scale_array, + single_scale_zarr, + multiscale_array, + multiscale_zarr]) +def sample_layer(request, output_dir, data_group): + (source_data, + input_filename, + data_group, + _) = request.param(output_dir, data_group) + + if isinstance(source_data, zarr.Group): + source_data = source_data[data_group] + + layer = Image( + data=source_data, + name="sample_layer", + scale=[1.0, 1.0, 1.0, 1.0], + translate=[0.0, 0.0, 0.0, 0.0], + visible=True + ) + + if input_filename: + if data_group: + layer._source = Source(path=str(input_filename / data_group)) + else: + layer._source = Source(path=str(input_filename)) + + if isinstance(layer.data, (MultiScaleData, list)): + if data_group: + data_group = str(Path(data_group) / "0") + else: + data_group = "0" + + return layer, source_data, input_filename, data_group + + +@pytest.fixture(scope="module", params=[single_scale_array, + single_scale_zarr]) +def single_scale_type_variant_array(request, output_dir, data_group): + return request.param(output_dir, data_group) + + +def test_get_source_data(sample_layer): + layer, org_source_data, org_input_filename, org_data_group = sample_layer + input_filename, data_group = get_source_data(layer) + + assert (not isinstance(input_filename, (Path, str)) + or input_filename == str(org_input_filename)) + assert (isinstance(input_filename, (Path, str)) + or (isinstance(input_filename, (MultiScaleData, list)) + and all(map(np.array_equal, input_filename, org_source_data))) + or np.array_equal(input_filename, org_source_data)) + assert (not isinstance(input_filename, (Path, str)) + or data_group == org_data_group) + + +def test_downsample_image(single_scale_type_variant_array): + (source_data, + input_filename, + data_group, + array_shape) = single_scale_type_variant_array + + scale = 2 + num_scales = 10 + if data_group and "/" in data_group: + data_group_root = data_group.split("/")[0] + else: + data_group_root = "" + + if input_filename is not None: + source_data = input_filename + + downsampled_zarr = downsample_image( + source_data, + "TCZYX", + data_group, + scale=scale, + num_scales=num_scales, + reference_source_axes="TCZYX", + reference_scale=(1, 1, 1, 1, 1), + reference_units=None + ) + + if isinstance(array_shape, list): + array_shape = array_shape[0] + + min_spatial_shape = min(array_shape["TCZYX".index(ax)] for ax in "ZYX") + + expected_scales = min(num_scales, + int(np.log(min_spatial_shape) / np.log(scale))) + + expected_shapes = [ + [int(np.ceil(ax_s / (scale ** s))) if ax in "ZYX" else ax_s + for ax, ax_s in zip("TCZYX", array_shape) + ] + for s in range(expected_scales) + ] + + assert len(downsampled_zarr) == expected_scales + assert all(map(lambda src_shape, dwn_arr: + all(map(operator.eq, src_shape, dwn_arr.shape)), + expected_shapes, + downsampled_zarr)) + + if isinstance(input_filename, (Path, str)): + z_root = zarr.open(input_filename, mode="r") + assert all(map(lambda scl: str(scl) in z_root[data_group_root], + range(expected_scales))) + assert "multiscales" in z_root[data_group_root].attrs + + for scl in range(1, expected_scales): + shutil.rmtree(input_filename / data_group_root / str(scl)) + + +def test_save_zarr(sample_layer, output_group): + layer, source_data, input_filename, data_group = sample_layer + name = "test_data" + group_name = "labels/" + name + + is_multiscale = isinstance(layer.data, (MultiScaleData, list)) + + out_grp = save_zarr(output_group, layer.data, layer.data.shape, + True, name, + layer.data.dtype, + is_multiscale=is_multiscale, + metadata=None, + is_label=True) + + assert group_name in out_grp + assert (not is_multiscale + or len(out_grp[group_name]) == len(layer.data)) + assert (isinstance(out_grp.store, zarr.MemoryStore) + or "image-label" in out_grp[group_name].attrs) + + +def test_validate_name(): + group_names = {"Group1", "Group2", "Group3"} + + # Test case 1: New child name is not in group names + previous_child_name = None + new_child_name = "Group4" + expected_result = "Group4" + assert validate_name(group_names, + previous_child_name, + new_child_name) == expected_result + + # Test case 2: New child name is already in group names + previous_child_name = "Group1" + new_child_name = "Group2" + expected_result = "Group2 (1)" + assert validate_name(group_names, + previous_child_name, + new_child_name) == expected_result + + # Test case 3: New child name is empty + previous_child_name = "Group2 (1)" + new_child_name = "" + expected_result = "" + assert validate_name(group_names, + previous_child_name, + new_child_name) == expected_result + + # Test case 4: Previous child name is not in group names + previous_child_name = "Group1" + new_child_name = "Group5" + expected_result = "Group5" + assert validate_name(group_names, + previous_child_name, + new_child_name) == expected_result + + +def test_get_basename(): + layer_name = "sample_layer" + expected_result = "sample_layer" + assert get_basename(layer_name) == expected_result + + layer_name = "sample_layer 1" + expected_result = "sample_layer" + assert get_basename(layer_name) == expected_result + + +def test_get_dataloader(dataset_metadata): + patch_size = {"Y": 64, "X": 64} + sampling_positions = [[0, 0], [0, 64], [64, 0], [64, 64]] + shuffle = True + num_workers = 4 + batch_size = 8 + spatial_axes = "YX" + model_input_axes = "YXC" + + dataloader = get_dataloader( + dataset_metadata, + patch_size=patch_size, + sampling_positions=sampling_positions, + shuffle=shuffle, + num_workers=num_workers, + batch_size=batch_size, + spatial_axes=spatial_axes, + model_input_axes=model_input_axes + ) + + assert isinstance(dataloader._patch_sampler, StaticPatchSampler) + + +def test_compute_chunks(image_collection): + patch_size = {"Z": 1, "Y": 5, "X": 5} + top_lefts = [[3, 0, 0], [3, 0, 5], [3, 5, 0], [3, 5, 5]] + + patch_sampler = StaticPatchSampler(patch_size=patch_size, + top_lefts=top_lefts) + + expected_output = [dict(X=slice(0, 10), Y=slice(0, 10), Z=slice(0, 10))] + + chunks_slices = patch_sampler.compute_chunks(image_collection) + + assert chunks_slices == expected_output + + +def test_compute_patches(image_collection): + patch_size = {"Z": 1, "Y": 5, "X": 5} + top_lefts = [[3, 0, 0], [3, 0, 5], [3, 5, 0], [3, 5, 5]] + chunk_tl = dict(X=slice(None), Y=slice(None), Z=slice(None)) + + patch_sampler = StaticPatchSampler(patch_size=patch_size, + top_lefts=top_lefts) + + chunks_slices = patch_sampler.compute_patches(image_collection, chunk_tl) + + # Assert that the number of chunks is equal to the number of top_lefts + assert len(chunks_slices) == len(top_lefts) + + # Assert that each chunk slice has the correct shape + for chunk_slices in chunks_slices: + assert chunk_slices["Z"].stop - chunk_slices["Z"].start == patch_size["Z"] + assert chunk_slices["Y"].stop - chunk_slices["Y"].start == patch_size["Y"] + assert chunk_slices["X"].stop - chunk_slices["X"].start == patch_size["X"] + + +def test_compute_transform(): + generator = SuperPixelGenerator(num_superpixels=25, axes="YXC", + model_axes="YXC") + image = np.random.random((10, 10, 3)) + labels = generator._compute_transform(image) + assert labels.shape == (10, 10, 1) + assert np.unique(labels).size == 25 diff --git a/src/napari_activelearning/_tests/tst_utils.py b/src/napari_activelearning/_tests/tst_utils.py deleted file mode 100644 index f067585..0000000 --- a/src/napari_activelearning/_tests/tst_utils.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest - -import os - -import numpy as np -import zarr -import dask.array as da - -from napari.layers import Layer -from napari_activelearning._utils import (get_source_data, downsample_image, - save_zarr, - validate_name) - - -def test_get_source_data(): - layer = Layer(data=[1, 2, 3]) - input_filename, data_group = get_source_data(layer) - assert input_filename is None - assert data_group is None - - layer = Layer(data=[1, 2, 3], metadata={'path': '/path/to/data.zarr', 'data_group': 'group'}) - input_filename, data_group = get_source_data(layer) - assert input_filename == '/path/to/data.zarr' - assert data_group == 'group' - - -def test_downsample_image(): - # Create a test input array - input_array = np.random.rand(100, 100) - input_dask_array = da.from_array(input_array) - - # Create a test Zarr store - zarr_store = zarr.MemoryStore() - zarr_group = zarr.group(store=zarr_store) - - # Save the input array to the Zarr store - zarr_group.create_dataset('data', data=input_dask_array) - - # Call the downsample_image function - downsampled_zarr = downsample_image(zarr_group, ['Y', 'X'], 'data', scale=2, num_scales=3) - - # Check the shape of the downsampled array - assert downsampled_zarr[0].shape == (100, 100) - assert downsampled_zarr[1].shape == (50, 50) - assert downsampled_zarr[2].shape == (25, 25) - - -def test_save_zarr(): - # Create a test output filename - output_filename = "test_output.zarr" - - # Create a test data array - data = np.random.rand(100, 100) - - # Create test parameters - shape = data.shape - chunk_size = 10 - name = "test_data" - dtype = data.dtype - - # Call the save_zarr function - save_zarr(output_filename, data, shape, chunk_size, name, dtype) - - # Check if the output file exists - assert os.path.exists(output_filename) - - # Check if the saved data matches the original data - saved_data = zarr.open(output_filename, mode="r")[name][:] - assert np.array_equal(saved_data, data) - - # Clean up the test output file - os.remove(output_filename) - - -def test_validate_name(): - group_names = {"Group1", "Group2", "Group3"} - - # Test case 1: New child name is not in group names - previous_child_name = "Group1" - new_child_name = "Group4" - expected_result = "Group4" - assert validate_name(group_names, previous_child_name, new_child_name) == expected_result - - # Test case 2: New child name is already in group names - previous_child_name = "Group1" - new_child_name = "Group2" - expected_result = "Group2 (1)" - assert validate_name(group_names, previous_child_name, new_child_name) == expected_result - - # Test case 3: New child name is empty - previous_child_name = "Group1" - new_child_name = "" - expected_result = "Group1" - assert validate_name(group_names, previous_child_name, new_child_name) == expected_result - - # Test case 4: Previous child name is not in group names - previous_child_name = "Group4" - new_child_name = "Group5" - expected_result = "Group5" - assert validate_name(group_names, previous_child_name, new_child_name) == expected_result - - # Test case 5: Previous child name is empty - previous_child_name = "" - new_child_name = "Group6" - expected_result = "Group6" - assert validate_name(group_names, previous_child_name, new_child_name) == expected_result \ No newline at end of file diff --git a/src/napari_activelearning/_utils.py b/src/napari_activelearning/_utils.py index 0eca2bb..16dd274 100644 --- a/src/napari_activelearning/_utils.py +++ b/src/napari_activelearning/_utils.py @@ -396,7 +396,11 @@ def downsample_image(z_root, source_axes, data_group, scale=4, num_scales=5, source_arr = da.from_zarr(z_root[data_group]) z_ms = [source_arr] - data_group = "/".join(data_group.split("/")[:-1]) + if data_group is None: + data_group = "" + else: + data_group = "/".join(data_group.split("/")[:-1]) + groups_root = data_group + "/%i" source_arr_shape = {ax: source_arr.shape[source_axes.index(ax)] @@ -478,15 +482,22 @@ def get_source_data(layer: Layer): data_group = "" if input_filename: + input_filename = Path(input_filename) + input_filename_parts = input_filename.parts + extension_idx = list(filter(lambda idx: + ".zarr" in input_filename_parts[idx], + range(len(input_filename_parts)))) + if extension_idx: + extension_idx = extension_idx[0] + data_group = str(Path(*input_filename_parts[extension_idx + 1:])) + input_filename = Path(*input_filename_parts[:extension_idx + 1]) + input_filename = str(input_filename) - data_group = "/".join(input_filename.split(".")[-1].split("/")[1:]) + else: return layer.data, None - if data_group: - input_filename = input_filename[:-len(data_group) - 1] - - if input_filename and isinstance(layer.data, MultiScaleData): + if input_filename and isinstance(layer.data, (MultiScaleData, list)): data_group = str(Path(data_group) / "0") if not input_filename: