Skip to content

Commit

Permalink
Working on testing _acquisition.py module
Browse files Browse the repository at this point in the history
  • Loading branch information
fercer committed Jul 29, 2024
1 parent 5f88b64 commit 66a7d1f
Show file tree
Hide file tree
Showing 5 changed files with 538 additions and 115 deletions.
128 changes: 126 additions & 2 deletions src/napari_activelearning/_tests/test_acquisition.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pytest
import shutil
from unittest.mock import MagicMock, patch
from pathlib import Path
import numpy as np
import zarr

from napari_activelearning._acquisition import (AcquisitionFunction,
add_multiscale_output_layer)
from napari_activelearning._layers import LayerChannel
from napari_activelearning._layers import LayerChannel, ImageGroup, LayersGroup
from napari_activelearning._models import SimpleTunable

try:
import torch
Expand Down Expand Up @@ -183,3 +183,127 @@ def test_add_multiscale_output_layer(output_array, image_group,
)

assert isinstance(output_channel, LayerChannel)


def test_prepare_datasets_metadata(acquisition_function):
# Define the input parameters for the method
image_group = ImageGroup()
output_axes = "TCZYX"
displayed_source_axes = "TCZYX"
displayed_shape = [1, 3, 10, 10, 10]
layer_types = [(LayersGroup(), "images")]

# Call the method
dataset_metadata, sampling_positions = acquisition_function._prepare_datasets_metadata(
image_group,
output_axes,
displayed_source_axes,
displayed_shape,
layer_types
)

# Assert the output values
assert dataset_metadata == {
"images": {
"filenames": None,
"data_group": None,
"source_axes": "TCZYX",
"axes": "TCZYX",
"roi": [[slice(None), slice(None), slice(None), slice(None), slice(None)]],
"modality": "images"
}
}
assert sampling_positions is None


def test_compute_acquisition_layers(acquisition_function):
# Mock the necessary dependencies and setup the test data
image_group = MagicMock()
acquisition_function.image_groups_manager.groups_root.childCount.return_value = 1
acquisition_function.image_groups_manager.groups_root.child.return_value = image_group
acquisition_function.image_groups_manager.groups_root.child.return_value.getSelected.return_value = True
acquisition_function.image_groups_manager.groups_root.child.return_value.group_name = "test_group"
acquisition_function.image_groups_manager.groups_root.child.return_value.group_dir = "/path/to/group"
acquisition_function.image_groups_manager.groups_root.child.return_value.input_layers_group = 0
acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.source_axes = "TCZYX"
acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.shape = (1, 3, 10, 10, 10)
acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.scale = (1.0, 1.0, 1.0, 1.0, 1.0)
acquisition_function.image_groups_manager.groups_root.child.return_value.sampling_mask_layers_group = None
acquisition_function._prepare_datasets_metadata.return_value = ({}, None)

# Call the method under test
acquisition_function.compute_acquisition_layers(run_all=True, segmentation_group_name="segmentation", segmentation_only=False)

# Assert that the necessary methods were called with the expected arguments
acquisition_function.image_groups_manager.groups_root.childCount.assert_called_once()
acquisition_function.image_groups_manager.groups_root.child.assert_called_once_with(0)
acquisition_function.image_groups_manager.groups_root.child.return_value.getSelected.assert_called_once()
acquisition_function.image_groups_manager.groups_root.child.return_value.group_name.assert_called_once()
acquisition_function.image_groups_manager.groups_root.child.return_value.group_dir.assert_called_once()
acquisition_function.image_groups_manager.groups_root.child.return_value.input_layers_group.assert_called_once()
acquisition_function.image_groups_manager.groups_root.child.return_value.child.assert_called_once_with(0)
acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.source_axes.assert_called_once()
acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.shape.assert_called_once()
acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value.scale.assert_called_once()
acquisition_function.image_groups_manager.groups_root.child.return_value.sampling_mask_layers_group.assert_called_once()
acquisition_function._prepare_datasets_metadata.assert_called_once_with(
image_group,
"TCZYX",
"TCZYX",
(1, 3, 10, 10, 10),
[(acquisition_function.image_groups_manager.groups_root.child.return_value.child.return_value, "images"),
(acquisition_function.image_groups_manager.groups_root.child.return_value.sampling_mask_layers_group, "masks")]
)


@pytest.fixture
def tunable_method():
return SimpleTunable()


def test_fine_tune(tunable_method):
dataset_metadata_list = [
(
{
"images": {
"filenames": ["image1.tif", "image2.tif"],
"data_group": "data",
"source_axes": "YXC",
"axes": "YXC",
"roi": None,
"modality": "images"
}
},
[
[1, 2, 3],
[4, 5, 6]
]
),
(
{
"images": {
"filenames": ["image3.tif", "image4.tif"],
"data_group": "data",
"source_axes": "YXC",
"axes": "YXC",
"roi": None,
"modality": "images"
}
},
[
[7, 8, 9],
[10, 11, 12]
]
)
]

train_data_proportion = 0.8
patch_sizes = 256
model_axes = "YXC"

tunable_method.fine_tune(dataset_metadata_list, train_data_proportion, patch_sizes, model_axes)

# Add assertions here to verify the behavior of the fine_tune method
assert tunable_method.train_data_proportion == train_data_proportion
assert tunable_method.patch_sizes == patch_sizes
assert tunable_method.model_axes == model_axes
2 changes: 1 addition & 1 deletion src/napari_activelearning/_tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def layer_channel(sample_layer):


def test_initialization(layer_channel, sample_layer):
assert layer_channel.layer == sample_layer, print(id(sample_layer), id(layer_channel.layer))
assert layer_channel.layer == sample_layer
assert layer_channel.channel == 1
assert layer_channel.source_axes == "TZYX"
assert layer_channel.name == "sample_layer"
Expand Down
Loading

0 comments on commit 66a7d1f

Please sign in to comment.