From 8ca4afc3dc1f675a3b19b702555805c17dce0190 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 29 Mar 2024 14:46:38 -0400 Subject: [PATCH 01/10] Changed PatchSampler to take as base the patche size instead of the input image's chunk sizes --- .gitignore | 1 + zarrdataset/_samplers.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 5bc0292..51e576d 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +example.py # Translations *.mo diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index bff3d6a..b22bdea 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -130,7 +130,7 @@ def _compute_grid(self, chunk_mask: np.ndarray, # with respect to the input image, use the mask coordinates as # reference to overlap the coordinates of the sampling patches. # Otherwise, use the patches coordinates instead. - if all(map(operator.ge, patch_shape, mask_relative_shape)): + if all(map(operator.gt, patch_shape, mask_relative_shape)): active_coordinates = np.nonzero(chunk_mask) ref_axes = mask_axes @@ -138,9 +138,9 @@ def _compute_grid(self, chunk_mask: np.ndarray, shape = mask_relative_shape mask_is_greater = False - + patch_ratio = [ - image_size[ax] // ps + round(image_size[ax] / ps) for ax, ps in zip(mask_axes, patch_shape.astype(np.int64)) if ax in self.spatial_axes ] @@ -153,7 +153,7 @@ def _compute_grid(self, chunk_mask: np.ndarray, else: active_coordinates = np.meshgrid( - *[np.arange(image_size[ax] // ps) + *[np.arange(round(image_size[ax] / ps)) for ax, ps in zip(mask_axes, patch_shape) if ax in self.spatial_axes] ) @@ -285,7 +285,7 @@ def compute_chunks(self, mask = image_collection.collection[image_collection.mask_mode] spatial_chunk_sizes = dict( - (ax, chk) + (ax, self._patch_size[ax] * round(chk / self._patch_size[ax])) for ax, chk in zip(image.axes, image.chunk_size) if ax in self.spatial_axes ) From 4720c7642cbbdc7562377fadcf85d3cee356b148 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 29 Mar 2024 16:46:02 -0400 Subject: [PATCH 02/10] Reverted change in the computation when masks elements are relative smaller than patch sizes --- zarrdataset/_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index b22bdea..a53aa6f 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -153,7 +153,7 @@ def _compute_grid(self, chunk_mask: np.ndarray, else: active_coordinates = np.meshgrid( - *[np.arange(round(image_size[ax] / ps)) + *[np.arange(image_size[ax] // ps) for ax, ps in zip(mask_axes, patch_shape) if ax in self.spatial_axes] ) From 9e1e985f53a037ef09e1dc795cb50f522523529f Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 5 Apr 2024 11:55:35 -0400 Subject: [PATCH 03/10] Fixed spatial chunk size computation when patch sizes are grater than the chunk size --- zarrdataset/_samplers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index a53aa6f..2d0db69 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -285,7 +285,8 @@ def compute_chunks(self, mask = image_collection.collection[image_collection.mask_mode] spatial_chunk_sizes = dict( - (ax, self._patch_size[ax] * round(chk / self._patch_size[ax])) + (ax, + self._patch_size[ax] * max(1, round(chk / self._patch_size[ax]))) for ax, chk in zip(image.axes, image.chunk_size) if ax in self.spatial_axes ) From d7fb5f5ddb5cd7cb90fecfeb42c78675442abd6e Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Tue, 30 Apr 2024 10:51:24 -0400 Subject: [PATCH 04/10] Fixed missing patches from chunks smaller than the input image chunk size --- tests/test_samplers.py | 5 +- zarrdataset/_samplers.py | 105 ++++++++++++++++++++---------------- zarrdataset/_zarrdataset.py | 6 +-- 3 files changed, 65 insertions(+), 51 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 24c97c7..d0e5a9b 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -276,7 +276,10 @@ def test_BlueNoisePatchSampler_mask_not2scale(image_collection_mask_not2scale): chunk_tlbr=chunks_toplefts[0] ) - assert len(patches_toplefts) == 0, \ + # Samples can be retrieved from chunks that are not multiple of the patch + # size. The ZarrDataset class should handle these cases, either by droping + # these patches, or by adding padding when allowed by the user. + assert len(patches_toplefts) == 1, \ (f"Expected 0 patches, got {len(patches_toplefts)} instead.") diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index 2d0db69..56ffaca 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -38,10 +38,10 @@ class PatchSampler(object): """ def __init__(self, patch_size: Union[int, Iterable[int], dict], min_area: Union[int, float] = 1, - spatial_axes: str ="ZYX"): + spatial_axes: str = "ZYX"): # The maximum chunk sizes are used to generate a reference sampling # position array used fo every sampled chunk. - self._max_chunk_size = dict((ax, 0) for ax in spatial_axes) + self._max_chunk_size = {ax: 0 for ax in spatial_axes} if isinstance(patch_size, (list, tuple)): if len(patch_size) != len(spatial_axes): @@ -49,11 +49,10 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], f"number of axes in `spatial_axes`, got " f"{len(patch_size)} for {spatial_axes}") - patch_size = dict((ax, ps) - for ax, ps in zip(spatial_axes, patch_size)) + patch_size = {ax: ps for ax, ps in zip(spatial_axes, patch_size)} elif isinstance(patch_size, int): - patch_size = dict((ax, patch_size) for ax in spatial_axes) + patch_size = {ax: patch_size for ax in spatial_axes} elif not isinstance(patch_size, dict): raise ValueError(f"Patch size must be a dictionary specifying the" @@ -64,14 +63,13 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], self.spatial_axes = spatial_axes - self._patch_size = dict( - (ax, patch_size.get(ax, 1)) - for ax in spatial_axes - ) + self._patch_size = {ax: patch_size.get(ax, 1) for ax in spatial_axes} self._min_area = min_area - def _compute_corners(self, non_zero_pos: tuple, axes: str) -> np.ndarray: + def _compute_corners(self, non_zero_pos: tuple, axes: str, + limits_per_dim: Union[np.ndarray, None] = None, + ) -> np.ndarray: toplefts = np.stack(non_zero_pos).T toplefts = toplefts.astype(np.float32) @@ -80,14 +78,18 @@ def _compute_corners(self, non_zero_pos: tuple, axes: str) -> np.ndarray: dim = len(axes) factors = 2 ** np.arange(dim + 1) for d in range(2 ** dim): - corner_value = np.array((d % factors[1:]) // factors[:-1], - dtype=np.float32) + corner_value = np.array((d % factors[1:]) // factors[:-1], + dtype=np.float32) toplefts_corners.append( toplefts + (1 - 1e-4) * corner_value ) corners = np.stack(toplefts_corners) + if limits_per_dim is not None: + corners = np.minimum(corners, + limits_per_dim[None, None, ...] - 1e-4) + return corners def _compute_overlap(self, corners: np.ndarray, shape: np.ndarray, @@ -99,9 +101,12 @@ def _compute_overlap(self, corners: np.ndarray, shape: np.ndarray, corners_cut = np.maximum(tls_scaled[0], tls_idx[-1]) - dist2cut = np.fabs(scaled_corners - corners_cut[None]) + dist2cut = np.fabs(corners - corners_cut[None]) coverage = np.prod(dist2cut, axis=-1) + # Scale the coverage to the size of the input shape + coverage *= np.prod(shape) + return coverage, tls_idx.astype(np.int64) def _compute_grid(self, chunk_mask: np.ndarray, @@ -114,7 +119,7 @@ def _compute_grid(self, chunk_mask: np.ndarray, [1 / m_scl for ax, m_scl in mask_scale.items() if ax in self.spatial_axes - ], + ], dtype=np.float32 ) @@ -122,7 +127,7 @@ def _compute_grid(self, chunk_mask: np.ndarray, [patch_size[ax] for ax in mask_axes if ax in self.spatial_axes - ], + ], dtype=np.float32 ) @@ -132,6 +137,8 @@ def _compute_grid(self, chunk_mask: np.ndarray, # Otherwise, use the patches coordinates instead. if all(map(operator.gt, patch_shape, mask_relative_shape)): active_coordinates = np.nonzero(chunk_mask) + limits_per_dim = np.array(chunk_mask.shape) + 1 + ref_axes = mask_axes ref_shape = patch_shape @@ -153,11 +160,17 @@ def _compute_grid(self, chunk_mask: np.ndarray, else: active_coordinates = np.meshgrid( - *[np.arange(image_size[ax] // ps) + *[np.arange(math.ceil(image_size[ax] / ps)) for ax, ps in zip(mask_axes, patch_shape) if ax in self.spatial_axes] ) + limits_per_dim = np.array([ + image_size[ax] / ps + for ax, ps in zip(mask_axes, patch_shape) + if ax in self.spatial_axes + ]) + active_coordinates = tuple( coord_ax.flatten() for coord_ax in active_coordinates @@ -172,10 +185,11 @@ def _compute_grid(self, chunk_mask: np.ndarray, mask_is_greater = True - corners = self._compute_corners(active_coordinates, axes=ref_axes) + corners = self._compute_corners(active_coordinates, axes=ref_axes, + limits_per_dim=limits_per_dim) (coverage, - corners_idx)= self._compute_overlap(corners, shape, ref_shape) + corners_idx) = self._compute_overlap(corners, shape, ref_shape) if mask_is_greater: # The mask ratio is greater than the patches size @@ -247,15 +261,15 @@ def _compute_toplefts_slices(self, mask: ImageBase, image_shape: dict, br = ((chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None else 0) + tls[mask.axes.index(ax)] + patch_size[ax]) - if br <= image_shape[ax]: - curr_tl.append((ax, slice(tl, br))) - else: - break + + curr_tl.append((ax, slice(tl, + br if br <= image_shape[ax] + else image_shape[ax]))) + else: curr_tl.append((ax, slice(0, 1))) - else: - toplefts.append(dict(curr_tl)) + toplefts.append(dict(curr_tl)) return toplefts @@ -284,27 +298,26 @@ def compute_chunks(self, image = image_collection.collection[image_collection.reference_mode] mask = image_collection.collection[image_collection.mask_mode] - spatial_chunk_sizes = dict( - (ax, - self._patch_size[ax] * max(1, round(chk / self._patch_size[ax]))) + # This computes a chunk size in terms of the patch size instead of the + # original array chunk size. + spatial_chunk_sizes = { + ax: (self._patch_size[ax] + * max(1, math.ceil(chk / self._patch_size[ax]))) for ax, chk in zip(image.axes, image.chunk_size) if ax in self.spatial_axes - ) + } - image_shape = dict(map(tuple, zip(image.axes, image.shape))) + image_shape = {ax: s for ax, s in zip(image.axes, image.shape)} - self._max_chunk_size = dict( - (ax, (min(max(self._max_chunk_size[ax], - spatial_chunk_sizes[ax], - self._patch_size[ax]), - image_shape[ax])) - if ax in image.axes else 1) + self._max_chunk_size = { + ax: (min(max(self._max_chunk_size[ax], + spatial_chunk_sizes[ax]), + image_shape[ax])) + if ax in image.axes else 1 for ax in self.spatial_axes - ) + } - chunk_tlbr = dict( - map(tuple, zip(self.spatial_axes, repeat(slice(None)))) - ) + chunk_tlbr = {ax: slice(None) for ax in self.spatial_axes} chunk_mask = mask[chunk_tlbr] @@ -330,14 +343,12 @@ def compute_patches(self, image_collection: ImageCollection, chunk_tlbr: dict) -> Iterable[dict]: image = image_collection.collection[image_collection.reference_mode] mask = image_collection.collection[image_collection.mask_mode] - image_shape = dict(map(tuple, zip(image.axes, image.shape))) - chunk_size = dict( - (ax, - (ctb.stop if ctb.stop is not None else image_shape[ax]) - - (ctb.start if ctb.start is not None else 0) - ) + image_shape = {ax: s for ax, s in zip(image.axes, image.shape)} + chunk_size = { + ax: ((ctb.stop if ctb.stop is not None else image_shape[ax]) + - (ctb.start if ctb.start is not None else 0)) for ax, ctb in chunk_tlbr.items() - ) + } chunk_mask = mask[chunk_tlbr] @@ -388,7 +399,7 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], self._base_chunk_tls = None self._resample_positions = resample_positions self._allow_overlap = allow_overlap - + def compute_sampling_positions(self, force=False) -> None: """Compute the sampling positions using blue-noise sampling. diff --git a/zarrdataset/_zarrdataset.py b/zarrdataset/_zarrdataset.py index 34a1082..2b6cd4c 100644 --- a/zarrdataset/_zarrdataset.py +++ b/zarrdataset/_zarrdataset.py @@ -509,7 +509,7 @@ def _initialize(self, force=False): modes = self._collections.keys() for collection in zip(*self._collections.values()): - collection = dict([(m, c) for m, c in zip(modes, collection)]) + collection = {m: c for m, c in zip(modes, collection)} for mode in collection.keys(): collection[mode]["zarr_store"] = self._zarr_store[mode] collection[mode]["image_func"] = self._image_loader_func[mode] @@ -527,8 +527,8 @@ def _initialize(self, force=False): toplefts.append(self._patch_sampler.compute_chunks(curr_img)) else: toplefts.append([ - dict((ax, slice(None)) - for ax in curr_img.collection[self._ref_mod].axes) + {ax: slice(None) + for ax in curr_img.collection[self._ref_mod].axes} ] ) From 18263786150e84a8d85752ccfb9c21ef3e510aad Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Wed, 1 May 2024 16:26:04 -0400 Subject: [PATCH 05/10] Padding and stride added to PatchSampler and ImageBase classes to allow overlapping patches extraction --- tests/test_imageloaders.py | 23 +++++++ tests/test_samplers.py | 115 ++++++++++++++++++++++++++++++----- zarrdataset/_imageloaders.py | 65 +++++++++++--------- zarrdataset/_samplers.py | 86 ++++++++++++++++++++++---- 4 files changed, 233 insertions(+), 56 deletions(-) diff --git a/tests/test_imageloaders.py b/tests/test_imageloaders.py index 366861b..b996140 100644 --- a/tests/test_imageloaders.py +++ b/tests/test_imageloaders.py @@ -163,6 +163,29 @@ def test_ImageBase_slicing(): f"{expected_selection_shape}, got {img_sel_2.shape} instead") +def test_ImageBase_padding(): + shape = (16, 16, 3) + axes = "YXC" + img = zds.ImageBase(shape, chunk_size=None, source_axes=axes, mode="image") + + random.seed(44512) + selection_1 = dict( + (ax, slice(random.randint(-10, 0), + random.randint(1, r_s + 10))) + for ax, r_s in zip(axes, shape) + ) + + expected_selection_shape = tuple( + selection_1[ax].stop - selection_1[ax].start for ax in axes + ) + + img_sel_1 = img[selection_1] + + assert img_sel_1.shape == expected_selection_shape, \ + (f"Expected selection {selection_1} to have shape " + f"{expected_selection_shape}, got {img_sel_1.shape} instead") + + @pytest.mark.parametrize("axes, roi, expected_size", [ (None, None, (16, 16, 3)), (None, slice(None), (16, 16, 3)), diff --git a/tests/test_samplers.py b/tests/test_samplers.py index d0e5a9b..23bbbfa 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -121,7 +121,7 @@ def test_PatchSampler_correct_patch_size(patch_size, spatial_axes, expected_patch_size): patch_sampler = zds.PatchSampler(patch_size=patch_size, spatial_axes=spatial_axes) - + assert patch_sampler._patch_size == expected_patch_size, \ (f"Expected `patch_size` to be a dictionary as {expected_patch_size}, " f"got {patch_sampler._patch_size} instead.") @@ -146,18 +146,18 @@ def test_PatchSampler_chunk_generation(patch_size, image_collection): chunks_toplefts = patch_sampler.compute_chunks(image_collection) - chunk_size = dict( - (ax, cs) + chunk_size = { + ax: cs for ax, cs in zip(image_collection.collection["images"].axes, image_collection.collection["images"].chunk_size) - ) + } - scaled_chunk_size = dict( - (ax, int(cs * image_collection.collection["masks"].scale[ax])) + scaled_chunk_size = { + ax: int(cs * image_collection.collection["masks"].scale[ax]) for ax, cs in zip(image_collection.collection["images"].axes, image_collection.collection["images"].chunk_size) if ax in image_collection.collection["masks"].axes - ) + } scaled_mask = transform.downscale_local_mean( image_collection.collection["masks"][:], @@ -194,11 +194,53 @@ def test_PatchSampler(patch_size, image_collection): chunk_tlbr=chunks_toplefts[0] ) - scaled_patch_size = dict( - (ax, int(patch_size * scl)) + scaled_patch_size = { + ax: int(patch_size * scl) for ax, scl in image_collection.collection["masks"].scale.items() + } + + scaled_mask = transform.downscale_local_mean( + image_collection.collection["masks"][chunks_toplefts[0]], + factors=(scaled_patch_size["Y"], scaled_patch_size["X"]) + ) + expected_patches_toplefts = np.nonzero(scaled_mask) + + expected_patches_toplefts = [ + dict( + [("Z", slice(0, 1, None))] + + [ + (ax, slice(tl * patch_size, (tl + 1) * patch_size)) + for ax, tl in zip("YX", tls) + ] + ) + for tls in zip(*expected_patches_toplefts) + ] + + assert all(map(operator.eq, patches_toplefts, expected_patches_toplefts)),\ + (f"Expected patches to be {expected_patches_toplefts[:3]}, got " + f"{patches_toplefts[:3]} instead.") + + +@pytest.mark.parametrize("patch_size, stride, image_collection", [ + (32, 32, IMAGE_SPECS[10]), + (32, 16, IMAGE_SPECS[10]), + (32, 64, IMAGE_SPECS[10]), +], indirect=["image_collection"]) +def test_PatchSampler_stride(patch_size, stride, image_collection): + patch_sampler = zds.PatchSampler(patch_size, stride=stride) + + chunks_toplefts = patch_sampler.compute_chunks(image_collection) + + patches_toplefts = patch_sampler.compute_patches( + image_collection, + chunk_tlbr=chunks_toplefts[0] ) + scaled_patch_size = { + ax: int(stride * scl) + for ax, scl in image_collection.collection["masks"].scale.items() + } + scaled_mask = transform.downscale_local_mean( image_collection.collection["masks"][chunks_toplefts[0]], factors=(scaled_patch_size["Y"], scaled_patch_size["X"]) @@ -209,8 +251,48 @@ def test_PatchSampler(patch_size, image_collection): dict( [("Z", slice(0, 1, None))] + [ - (ax, slice(int(tl * patch_size), - int(math.ceil((tl + 1) * patch_size)))) + (ax, slice(tl * stride, tl * stride + patch_size)) + for ax, tl in zip("YX", tls) + ] + ) + for tls in zip(*expected_patches_toplefts) + ] + assert all(map(operator.eq, patches_toplefts, expected_patches_toplefts)),\ + (f"Expected patches to be {expected_patches_toplefts[:3]}, got " + f"{patches_toplefts[:3]} instead.") + + +@pytest.mark.parametrize("patch_size, pad, image_collection", [ + (32, 0, IMAGE_SPECS[10]), + (32, 2, IMAGE_SPECS[10]), +], indirect=["image_collection"]) +def test_PatchSampler_pad(patch_size, pad, image_collection): + patch_sampler = zds.PatchSampler(patch_size, pad=pad) + + chunks_toplefts = patch_sampler.compute_chunks(image_collection) + + patches_toplefts = patch_sampler.compute_patches( + image_collection, + chunk_tlbr=chunks_toplefts[0] + ) + + scaled_patch_size = { + ax: int(patch_size * scl) + for ax, scl in image_collection.collection["masks"].scale.items() + } + + scaled_mask = transform.downscale_local_mean( + image_collection.collection["masks"][chunks_toplefts[0]], + factors=(scaled_patch_size["Y"], scaled_patch_size["X"]) + ) + expected_patches_toplefts = np.nonzero(scaled_mask) + + # TODO: Change expected patches toplefts for strided ones + expected_patches_toplefts = [ + dict( + [("Z", slice(0, 1, None))] + + [ + (ax, slice(tl * patch_size - pad, (tl + 1) * patch_size + pad)) for ax, tl in zip("YX", tls) ] ) @@ -283,10 +365,12 @@ def test_BlueNoisePatchSampler_mask_not2scale(image_collection_mask_not2scale): (f"Expected 0 patches, got {len(patches_toplefts)} instead.") -@pytest.mark.parametrize("patch_size, image_collection, specs", [ - (512, MASKABLE_IMAGE_SPECS[0], MASKABLE_IMAGE_SPECS[0]) +@pytest.mark.parametrize("patch_size, stride, image_collection, specs", [ + (512, 512, MASKABLE_IMAGE_SPECS[0], MASKABLE_IMAGE_SPECS[0]), + (512, 256, MASKABLE_IMAGE_SPECS[0], MASKABLE_IMAGE_SPECS[0]) ], indirect=["image_collection"]) -def test_unique_sampling_PatchSampler(patch_size, image_collection, specs): +def test_unique_sampling_PatchSampler(patch_size, stride, image_collection, + specs): from skimage import color, filters, morphology import zarr @@ -307,7 +391,8 @@ def test_unique_sampling_PatchSampler(patch_size, image_collection, specs): mode="masks") image_collection.reset_scales() - patch_sampler = zds.PatchSampler(patch_size, min_area=1/16 ** 2) + patch_sampler = zds.PatchSampler(patch_size, stride=stride, + min_area=1/16 ** 2) chunks_toplefts = patch_sampler.compute_chunks(image_collection) diff --git a/zarrdataset/_imageloaders.py b/zarrdataset/_imageloaders.py index 5d92dc1..06f6b56 100644 --- a/zarrdataset/_imageloaders.py +++ b/zarrdataset/_imageloaders.py @@ -89,7 +89,7 @@ def image2array(arr_src: Union[str, zarr.Group, zarr.Array, np.ndarray], arr = zarr.array(data=arr_src, shape=arr_src.shape, chunks=arr_src.shape) return arr, None - + # Try to create a connection with the file, to determine if it is a remote # resource or local file. s3_obj = connect_s3(arr_src) @@ -98,7 +98,7 @@ def image2array(arr_src: Union[str, zarr.Group, zarr.Array, np.ndarray], # Try to open the input file with tifffile (if installed). try: if (data_group is None - or (isinstance(data_group, str) and not len(data_group))): + or (isinstance(data_group, str) and not len(data_group))): tiff_args = dict( key=None, level=None, @@ -192,9 +192,9 @@ class ImageBase(object): _image_func = None def __init__(self, shape: Iterable[int], - chunk_size: Union[Iterable[int], None]=None, - source_axes: str="", - mode: str=""): + chunk_size: Union[Iterable[int], None] = None, + source_axes: str = "", + mode: str = ""): if chunk_size is None: chunk_size = shape @@ -204,6 +204,7 @@ def __init__(self, shape: Iterable[int], self.arr = zarr.ones(shape=shape, dtype=bool, chunks=chunk_size) self.roi = tuple([slice(None)] * len(source_axes)) self.mode = mode + self._chunk_size = chunk_size def _iscached(self, coords): @@ -224,17 +225,30 @@ def _cache_chunk(self, index): if not self._iscached(index): self._cached_coords = tuple( map(lambda i, chk, s: - slice(chk * int(i.start / chk) + slice(max(0, chk * int(i.start / chk)) if i.start is not None else 0, min(s, chk * int(math.ceil(i.stop / chk))) - if i.stop is not None else None, - None), + if i.stop is not None else s), index, self.arr.chunks, self.arr.shape) ) + + padding = tuple( + (cc.start - i.start if i.start is not None and i.start < 0 else 0, + i.stop - cc.stop if i.stop is not None and i.stop > s else 0) + for cc, i, s in zip(self._cached_coords, index, self.arr.shape) + ) + self._cache = self.arr[self._cached_coords] + if any([any(pad) for pad in padding]): + self._cache = np.pad(self._cache, padding) + self._cached_coords = tuple( + slice(cc.start - p_low, cc.stop + p_high) + for (p_low, p_high), cc in zip(padding, self._cached_coords) + ) + cached_index = tuple( map(lambda cache, i: slice((i.start - cache.start) if i.start is not None else 0, @@ -257,20 +271,14 @@ def __getitem__(self, index : Union[slice, tuple, dict]) -> np.ndarray: if not isinstance(index, dict): # Arrange the indices requested using the reference image axes # ordering. - index = dict( - ((ax, sel) - for ax, sel in zip(spatial_reference_axes, index)) - ) + index = {ax: sel for ax, sel in zip(spatial_reference_axes, index)} mode_index, _ = select_axes(self.axes, index) mode_scales = tuple(self.scale[ax] for ax in self.axes) mode_index = scale_coords(mode_index, mode_scales) - mode_index = dict( - ((ax, sel) - for ax, sel in zip(self.axes, mode_index)) - ) + mode_index = {ax: sel for ax, sel in zip(self.axes, mode_index)} # Locate the mode_index within the ROI: roi_mode_index = translate2roi(mode_index, self.roi, self.source_axes, @@ -430,7 +438,7 @@ def __init__(self, filename: str, source_axes: str, parsed_roi = roi elif isinstance(roi, slice): if (len(source_axes) > 1 - and not (roi.start is None and roi.stop is None)): + and not (roi.start is None and roi.stop is None)): raise ValueError(f"ROIs must specify a slice per axes. " f"Expected {len(source_axes)} slices, got " f"only {roi}") @@ -440,11 +448,10 @@ def __init__(self, filename: str, source_axes: str, raise ValueError(f"Incorrect ROI format, expected a list of " f"slices, or a parsable string, got {roi}") - roi_slices = list( - map(lambda r: - slice(r.start if r.start is not None else 0, r.stop, None), - parsed_roi) - ) + roi_slices = [ + slice(r.start if r.start is not None else 0, r.stop, None) + for r in parsed_roi + ] (self.arr, self._store) = image2array(filename, data_group=data_group, @@ -512,11 +519,11 @@ def __init__(self, collection_args : dict, self.spatial_axes = spatial_axes - self.collection = dict(( - (mode, ImageLoader(spatial_axes=spatial_axes, mode=mode, - **mode_args)) + self.collection = { + mode: ImageLoader(spatial_axes=spatial_axes, mode=mode, + **mode_args) for mode, mode_args in collection_args.items() - )) + } self._generate_mask() self.reset_scales() @@ -574,9 +581,7 @@ def reset_scales(self) -> None: img.rescale(spatial_reference_shape, spatial_reference_axes) def __getitem__(self, index): - collection_set = dict( - (mode, img[index]) - for mode, img in self.collection.items() - ) + collection_set = {mode: img[index] + for mode, img in self.collection.items()} return collection_set diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index 56ffaca..d1a4ea0 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -28,6 +28,14 @@ class PatchSampler(object): of the patch of the axes listed in `spatial_axes`. Use the same convention as how Zarr structure array chunks in order to handle patch shapes and channels correctly. + stride : Union[int, Iterable[int], dict, None] + Distance in pixels of the movement of the sampling sliding window. + If `stride` is less than `patch_size` for an axis, patches will have an + overlap between them. This is usuful in inference mode for avoiding + edge artifacts. If None is passed, the `patch_size` will be used as + `stride`. + pad : Union[int, Iterable[int], dict, None] + Padding in pixels added to the extracted patch at each specificed axis. min_area : Union[int, float] Minimum patch area covered by the mask to consider it samplable. A number in range [0, 1) will be used as percentage of the patch size. A @@ -37,6 +45,8 @@ class PatchSampler(object): The spatial axes from where patches can be extracted. """ def __init__(self, patch_size: Union[int, Iterable[int], dict], + stride: Union[int, Iterable[int], dict, None] = None, + pad: Union[int, Iterable[int], dict, None] = None, min_area: Union[int, float] = 1, spatial_axes: str = "ZYX"): # The maximum chunk sizes are used to generate a reference sampling @@ -61,9 +71,53 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], f" or an integer for a cubic patch. Received " f"{patch_size} of type {type(patch_size)}") + if isinstance(stride, (list, tuple)): + if len(stride) != len(spatial_axes): + raise ValueError(f"The size of `stride` must match the " + f"number of axes in `spatial_axes`, got " + f"{len(stride)} for {spatial_axes}") + + stride = {ax: st for ax, st in zip(spatial_axes, stride)} + + elif isinstance(stride, int): + stride = {ax: stride for ax in spatial_axes} + + elif stride is None: + stride = patch_size + + elif not isinstance(stride, dict): + raise ValueError(f"Stride size must be a dictionary specifying the" + f" stride step size of each axes, an iterable (" + f"list, tuple) with the same order as the spatial" + f" axes, or an integer for a cubic patch. " + f"Received {stride} of type {type(stride)}") + + if pad is None: + pad = 0 + + if isinstance(pad, (list, tuple)): + if len(pad) != len(spatial_axes): + raise ValueError(f"The size of `pad` must match the " + f"number of axes in `spatial_axes`, got " + f"{len(pad)} for {spatial_axes}") + + pad = {ax: st for ax, st in zip(spatial_axes, pad)} + + elif isinstance(pad, int): + pad = {ax: pad for ax in spatial_axes} + + elif not isinstance(pad, dict): + raise ValueError(f"Pad size must be a dictionary specifying the" + f" numer of pixels added to each axes, an " + f"iterable (list, tuple) with the same order as " + f"the spatial axes, or an integer for a cubic " + f"patch. Received {pad} of type {type(pad)}") + self.spatial_axes = spatial_axes self._patch_size = {ax: patch_size.get(ax, 1) for ax in spatial_axes} + self._stride = {ax: stride.get(ax, 1) for ax in spatial_axes} + self._pad = {ax: pad.get(ax, 0) for ax in spatial_axes} self._min_area = min_area @@ -114,7 +168,6 @@ def _compute_grid(self, chunk_mask: np.ndarray, mask_scale: dict, patch_size: dict, image_size: dict): - mask_relative_shape = np.array( [1 / m_scl for ax, m_scl in mask_scale.items() @@ -248,7 +301,11 @@ def _compute_valid_toplefts(self, chunk_mask: np.ndarray, mask_axes: str, def _compute_toplefts_slices(self, mask: ImageBase, image_shape: dict, patch_size: dict, valid_mask_toplefts: np.ndarray, - chunk_tlbr: dict): + chunk_tlbr: dict, + pad: Union[dict, None] = None): + if pad is None: + pad = {ax: 0 for ax in self.spatial_axes} + toplefts = [] for tls in valid_mask_toplefts: curr_tl = [] @@ -262,9 +319,10 @@ def _compute_toplefts_slices(self, mask: ImageBase, image_shape: dict, if chunk_tlbr[ax].start is not None else 0) + tls[mask.axes.index(ax)] + patch_size[ax]) - curr_tl.append((ax, slice(tl, - br if br <= image_shape[ax] - else image_shape[ax]))) + curr_tl.append((ax, slice(tl - pad[ax], + (br if br <= image_shape[ax] + else image_shape[ax]) + + pad[ax]))) else: curr_tl.append((ax, slice(0, 1))) @@ -301,11 +359,17 @@ def compute_chunks(self, # This computes a chunk size in terms of the patch size instead of the # original array chunk size. spatial_chunk_sizes = { - ax: (self._patch_size[ax] - * max(1, math.ceil(chk / self._patch_size[ax]))) + ax: (self._stride[ax] + * max(1, math.ceil(chk / self._stride[ax]))) for ax, chk in zip(image.axes, image.chunk_size) if ax in self.spatial_axes } + # spatial_chunk_sizes = { + # ax: (self._patch_size[ax] + # * max(1, math.ceil(chk / self._patch_size[ax]))) + # for ax, chk in zip(image.axes, image.chunk_size) + # if ax in self.spatial_axes + # } image_shape = {ax: s for ax, s in zip(image.axes, image.shape)} @@ -356,16 +420,16 @@ def compute_patches(self, image_collection: ImageCollection, chunk_mask, mask.axes, mask.scale, - self._patch_size, - chunk_size - ) + self._stride, + chunk_size) patches_slices = self._compute_toplefts_slices( mask, image_shape=image_shape, patch_size=self._patch_size, valid_mask_toplefts=valid_mask_toplefts, - chunk_tlbr=chunk_tlbr + chunk_tlbr=chunk_tlbr, + pad=self._pad ) return patches_slices From b894985a3ca3459d0c8e7a4d906744f0aa73c610 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Thu, 2 May 2024 10:38:54 -0400 Subject: [PATCH 06/10] Added tests for stride and pad parameters of PatchSampler class --- tests/test_samplers.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 23bbbfa..d38ca9d 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -127,6 +127,32 @@ def test_PatchSampler_correct_patch_size(patch_size, spatial_axes, f"got {patch_sampler._patch_size} instead.") +@pytest.mark.parametrize("stride, spatial_axes, expected_stride", [ + (512, "X", dict(X=512)), + ((128, 64), "XY", dict(X=128, Y=64)), +]) +def test_PatchSampler_correct_stride(stride, spatial_axes, expected_stride): + patch_sampler = zds.PatchSampler(patch_size=512, stride=stride, + spatial_axes=spatial_axes) + + assert patch_sampler._stride == expected_stride, \ + (f"Expected `stride` to be a dictionary as {expected_stride}, " + f"got {patch_sampler._stride} instead.") + + +@pytest.mark.parametrize("pad, spatial_axes, expected_pad", [ + (512, "X", dict(X=512)), + ((128, 64), "XY", dict(X=128, Y=64)), +]) +def test_PatchSampler_correct_pad(pad, spatial_axes, expected_pad): + patch_sampler = zds.PatchSampler(patch_size=512, pad=pad, + spatial_axes=spatial_axes) + + assert patch_sampler._pad == expected_pad, \ + (f"Expected `pad` to be a dictionary as {expected_pad}, " + f"got {patch_sampler._pad} instead.") + + @pytest.mark.parametrize("patch_size, spatial_axes", [ ((512, 128), "X"), ((128, ), "XY"), @@ -138,6 +164,30 @@ def test_PatchSampler_incorrect_patch_size(patch_size, spatial_axes): spatial_axes=spatial_axes) +@pytest.mark.parametrize("stride, spatial_axes", [ + ((512, 128), "X"), + ((128, ), "XY"), + ("stride", "ZYX"), +]) +def test_PatchSampler_incorrect_stride(stride, spatial_axes): + with pytest.raises(ValueError): + patch_sampler = zds.PatchSampler(patch_size=512, + stride=stride, + spatial_axes=spatial_axes) + + +@pytest.mark.parametrize("pad, spatial_axes", [ + ((512, 128), "X"), + ((128, ), "XY"), + ("pad", "ZYX"), +]) +def test_PatchSampler_incorrect_pad(pad, spatial_axes): + with pytest.raises(ValueError): + patch_sampler = zds.PatchSampler(patch_size=512, + pad=pad, + spatial_axes=spatial_axes) + + @pytest.mark.parametrize("patch_size, image_collection", [ (32, IMAGE_SPECS[10]) ], indirect=["image_collection"]) From aaa51f67c8e8a3bd546f4bf15c2c043ab716684e Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Thu, 2 May 2024 15:59:57 -0400 Subject: [PATCH 07/10] Fixed patch slices generation in PatchSampler to always retrieve patches of the defined shape --- zarrdataset/_samplers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index d1a4ea0..f876a74 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -320,9 +320,7 @@ def _compute_toplefts_slices(self, mask: ImageBase, image_shape: dict, + tls[mask.axes.index(ax)] + patch_size[ax]) curr_tl.append((ax, slice(tl - pad[ax], - (br if br <= image_shape[ax] - else image_shape[ax]) - + pad[ax]))) + br + pad[ax]))) else: curr_tl.append((ax, slice(0, 1))) From 78857494b3338fdd369473b4d7b4be19ae4cda90 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Tue, 7 May 2024 16:32:16 -0400 Subject: [PATCH 08/10] Standardized patch sampling method to handle smaller and bigger mask scales than image scale --- .../advanced_example_pytorch_inference.md | 177 +++++++ tests/test_samplers.py | 36 +- tests/test_zarrdataset.py | 63 ++- zarrdataset/_samplers.py | 474 ++++++++++-------- 4 files changed, 498 insertions(+), 252 deletions(-) create mode 100644 docs/source/examples/advanced_example_pytorch_inference.md diff --git a/docs/source/examples/advanced_example_pytorch_inference.md b/docs/source/examples/advanced_example_pytorch_inference.md new file mode 100644 index 0000000..fdeb3f9 --- /dev/null +++ b/docs/source/examples/advanced_example_pytorch_inference.md @@ -0,0 +1,177 @@ +```python +import zarrdataset as zds + +import torch +from torch.utils.data import DataLoader +``` + + +```python +# These are images from the Image Data Resource (IDR) +# https://idr.openmicroscopy.org/ that are publicly available and were +# converted to the OME-NGFF (Zarr) format by the OME group. More examples +# can be found at Public OME-Zarr data (Nov. 2020) +# https://www.openmicroscopy.org/2020/11/04/zarr-data.html + +filenames = [ + "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr" +] +``` + + +```python +import random +import numpy as np + +# For reproducibility +np.random.seed(478963) +torch.manual_seed(478964) +random.seed(478965) +``` + +## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI) + +Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result. +Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size. + + +```python +patch_size = dict(Y=128, X=128) +pad = dict(Y=16, X=16) +patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True) +``` + +Create a dataset from the list of filenames. All those files should be stored within their respective group "0". + +Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly + + +```python +image_specs = zds.ImagesDatasetSpecs( + filenames=filenames, + data_group="4", + source_axes="TCZYX", + axes="YXC", + roi="0,0,0,0,0:1,-1,1,-1,-1" +) + +my_dataset = zds.ZarrDataset(image_specs, + patch_sampler=patch_sampler, + return_positions=True) +``` + + +```python +my_dataset +``` + + + + + ZarrDataset (PyTorch support:True, tqdm support :True) + Modalities: images + Transforms order: [] + Using images modality as reference. + Using for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. + + + +Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32. + + +```python +import torchvision + +img_preprocessing = torchvision.transforms.Compose([ + zds.ToDtype(dtype=np.float32), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(127, 255) +]) + +my_dataset.add_transform("images", img_preprocessing) +``` + + +```python +my_dataset +``` + + + + + ZarrDataset (PyTorch support:True, tqdm support :True) + Modalities: images + Transforms order: [('images',)] + Using images modality as reference. + Using for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. + + + +## Create a DataLoader from the dataset object + +ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module. + + +```python +my_dataloader = DataLoader(my_dataset, num_workers=0) +``` + + +```python +import dask.array as da +import numpy as np +import zarr + +z_arr = zarr.open("https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr/4", mode="r") + +H = z_arr.shape[-2] +W = z_arr.shape[-1] + +pad_H = (128 - H) % 128 +pad_W = (128 - W) % 128 +z_prediction = zarr.zeros((H + pad_H, W + pad_W), dtype=np.float32, chunks=(128, 128)) +z_prediction +``` + + + + + + + + +Set up a simple model for illustration purpose + + +```python +model = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1), + torch.nn.ReLU() +) +``` + + +```python +for i, (pos, sample) in enumerate(my_dataloader): + pred_pos = ( + slice(pos[0, 0, 0].item() + 16, + pos[0, 0, 1].item() - 16), + slice(pos[0, 1, 0].item() + 16, + pos[0, 1, 1].item() - 16) + ) + pred = model(sample) + z_prediction[pred_pos] = pred.detach().cpu().numpy()[0, 0, 16:-16, 16:-16] +``` + +## Visualize the result + + +```python +import matplotlib.pyplot as plt + +plt.subplot(2, 1, 1) +plt.imshow(np.moveaxis(z_arr[0, :, 0, ...], 0, -1)) +plt.subplot(2, 1, 2) +plt.imshow(z_prediction) +plt.show() +``` diff --git a/tests/test_samplers.py b/tests/test_samplers.py index d38ca9d..2b14ed0 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -354,6 +354,33 @@ def test_PatchSampler_pad(patch_size, pad, image_collection): f"{patches_toplefts[:3]} instead.") +@pytest.mark.parametrize("patch_size, allow_incomplete_patches," + "image_collection", [ + (1024, True, IMAGE_SPECS[10]), + (1024, False, IMAGE_SPECS[10]), +], indirect=["image_collection"]) +def test_PatchSampler_incomplete_patches(patch_size, allow_incomplete_patches, + image_collection): + patch_sampler = zds.PatchSampler( + patch_size, + allow_incomplete_patches=allow_incomplete_patches + ) + + chunks_toplefts = patch_sampler.compute_chunks(image_collection) + + patches_toplefts = patch_sampler.compute_patches( + image_collection, + chunk_tlbr=chunks_toplefts[0] + ) + + expected_num_patches = 1 if allow_incomplete_patches else 0 + + assert len(patches_toplefts) == expected_num_patches,\ + (f"Expected to have {expected_num_patches}, when " + f"`allow_incomplete_patches` is {allow_incomplete_patches} " + f"got {len(patches_toplefts)} instead.") + + @pytest.mark.parametrize("patch_size, axes, resample, allow_overlap," "image_collection", [ (dict(X=32, Y=32, Z=1), "XYZ", True, True, IMAGE_SPECS[10]), @@ -379,15 +406,6 @@ def test_BlueNoisePatchSampler(patch_size, axes, resample, allow_overlap, (f"Expected {len(patch_sampler._base_chunk_tls)} patches, got " f"{len(patches_toplefts)} instead.") - patches_toplefts = patch_sampler.compute_patches( - image_collection, - chunk_tlbr=chunks_toplefts[-1] - ) - - assert len(patches_toplefts) == len(patch_sampler._base_chunk_tls), \ - (f"Expected {len(patch_sampler._base_chunk_tls)} patches, got " - f"{len(patches_toplefts)} instead.") - @pytest.mark.parametrize("image_collection_mask_not2scale", [ IMAGE_SPECS[10] diff --git a/tests/test_zarrdataset.py b/tests/test_zarrdataset.py index 7112f63..17b0b74 100644 --- a/tests/test_zarrdataset.py +++ b/tests/test_zarrdataset.py @@ -116,8 +116,12 @@ def image_dataset_specs(request): @pytest.fixture(scope="function") def patch_sampler_specs(request): - patch_sampler = zds.PatchSampler(patch_size=request.param) - return patch_sampler, request.param + patch_size, allow_incomplete_patches = request.param + patch_sampler = zds.PatchSampler( + patch_size=patch_size, + allow_incomplete_patches=allow_incomplete_patches + ) + return patch_sampler, patch_size, allow_incomplete_patches @pytest.mark.parametrize("image_dataset_specs", [ @@ -142,7 +146,7 @@ def test_compatibility_no_tqdm(image_dataset_specs): try: next(iter(dataset)) - + except Exception as e: raise AssertionError(f"No exceptions where expected, got {e} " f"instead.") @@ -304,9 +308,9 @@ def test_ZarrDataset(image_dataset_specs, shuffle, return_positions, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk", [ - (IMAGE_SPECS[10], 32, True, False), - (IMAGE_SPECS[10], 32, True, True), - (IMAGE_SPECS[10], 32, False, True), + (IMAGE_SPECS[10], (32, False), True, False), + (IMAGE_SPECS[10], (32, False), True, True), + (IMAGE_SPECS[10], (32, False), False, True), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) @@ -314,7 +318,7 @@ def test_patched_ZarrDataset(image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, @@ -409,33 +413,40 @@ def test_patched_ZarrDataset(image_dataset_specs, patch_sampler_specs, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs", [ - (IMAGE_SPECS[10], 1024), + (IMAGE_SPECS[10], (1024, True)), + (IMAGE_SPECS[10], (1024, False)), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) def test_greater_patch_ZarrDataset(image_dataset_specs, patch_sampler_specs): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, - patch_sampler=patch_sampler, + patch_sampler=patch_sampler ) n_samples = 0 for _ in ds: n_samples += 1 - assert n_samples == 0, ("Expected zero samples since requested patch size" - f" is greater than the image size.") + if allow_incomplete_patches: + assert n_samples > 0, ("Expected at elast one sample when patch" + " size is greater than the image size, and" + " `allow_incomplete_patches` is True.") + else: + assert n_samples == 0, ("Expected zero samples since requested patch" + " size is greater than the image size, and" + " `allow_incomplete_patches` is False.") @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk," "batch_size, num_workers", [ - (IMAGE_SPECS[10], 32, True, False, 2, 2), - ([IMAGE_SPECS[10]] * 4, 32, True, True, 2, 3), - ([IMAGE_SPECS[10]] * 2, 32, True, True, 2, 3), + (IMAGE_SPECS[10], (32, False), True, False, 2, 2), + ([IMAGE_SPECS[10]] * 4, (32, False), True, True, 2, 3), + ([IMAGE_SPECS[10]] * 2, (32, False), True, True, 2, 3), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) @@ -446,7 +457,7 @@ def test_multithread_ZarrDataset(image_dataset_specs, patch_sampler_specs, num_workers): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, @@ -514,21 +525,21 @@ def test_multithread_ZarrDataset(image_dataset_specs, patch_sampler_specs, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk," "batch_size, num_workers, repeat_dataset", [ - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 1), - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 2), - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 3), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 1), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 2), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 3), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) def test_multithread_chained_ZarrDataset(image_dataset_specs, - patch_sampler_specs, - shuffle, - draw_same_chunk, - batch_size, - num_workers, - repeat_dataset): + patch_sampler_specs, + shuffle, + draw_same_chunk, + batch_size, + num_workers, + repeat_dataset): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = [zds.ZarrDataset(dataset_specs=dataset_specs, shuffle=shuffle, diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index f876a74..52df198 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -1,8 +1,6 @@ -from typing import Iterable, Union, Tuple +from typing import Iterable, Union, Tuple, List import math import numpy as np -from itertools import repeat -from functools import reduce import operator import poisson_disc @@ -23,16 +21,16 @@ class PatchSampler(object): patches (hyper-cuboids) are supported by now. If a single int is passed, that size is used for all dimensions. If an iterable (list, tuple) is passed, each value will be assigned to the corresponding axes - in `spatial_axes`, the size of `patch_size` must match the lenght of + in `spatial_axes`, the size of `patch_size` must match the lenght of `spatial_axes'. If a dict is passed, this should have at least the size - of the patch of the axes listed in `spatial_axes`. Use the same - convention as how Zarr structure array chunks in order to handle patch + of the patch of the axes listed in `spatial_axes`. Use the same + convention as how Zarr structure array chunks in order to handle patch shapes and channels correctly. stride : Union[int, Iterable[int], dict, None] Distance in pixels of the movement of the sampling sliding window. If `stride` is less than `patch_size` for an axis, patches will have an overlap between them. This is usuful in inference mode for avoiding - edge artifacts. If None is passed, the `patch_size` will be used as + edge artifacts. If None is passed, the `patch_size` will be used as `stride`. pad : Union[int, Iterable[int], dict, None] Padding in pixels added to the extracted patch at each specificed axis. @@ -43,12 +41,17 @@ class PatchSampler(object): covered by the mask. spatial_axes : str The spatial axes from where patches can be extracted. + allow_incomplete_patches : bool + Allow to retrieve patches that are smaller than the patch size. This is + the case of samples at the edge of the image that are usually smaller + than the specified patch size. """ def __init__(self, patch_size: Union[int, Iterable[int], dict], stride: Union[int, Iterable[int], dict, None] = None, pad: Union[int, Iterable[int], dict, None] = None, min_area: Union[int, float] = 1, - spatial_axes: str = "ZYX"): + spatial_axes: str = "ZYX", + allow_incomplete_patches: bool = False): # The maximum chunk sizes are used to generate a reference sampling # position array used fo every sampled chunk. self._max_chunk_size = {ax: 0 for ax in spatial_axes} @@ -120,212 +123,225 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], self._pad = {ax: pad.get(ax, 0) for ax in spatial_axes} self._min_area = min_area + self._allow_incomplete_patches = allow_incomplete_patches - def _compute_corners(self, non_zero_pos: tuple, axes: str, - limits_per_dim: Union[np.ndarray, None] = None, + def _compute_corners(self, coordinates: np.ndarray, scale: np.ndarray ) -> np.ndarray: - toplefts = np.stack(non_zero_pos).T - toplefts = toplefts.astype(np.float32) - toplefts_corners = [] - dim = len(axes) + dim = coordinates.shape[-1] factors = 2 ** np.arange(dim + 1) for d in range(2 ** dim): corner_value = np.array((d % factors[1:]) // factors[:-1], dtype=np.float32) toplefts_corners.append( - toplefts + (1 - 1e-4) * corner_value + coordinates + scale * (1 - 1e-4) * corner_value ) corners = np.stack(toplefts_corners) - if limits_per_dim is not None: - corners = np.minimum(corners, - limits_per_dim[None, None, ...] - 1e-4) - return corners - def _compute_overlap(self, corners: np.ndarray, shape: np.ndarray, - ref_shape: np.ndarray) -> Tuple[np.ndarray, - np.ndarray]: - scaled_corners = corners * shape[None, None] - tls_scaled = scaled_corners / ref_shape[None, None] - tls_idx = np.floor(tls_scaled) + def _compute_reference_indices(self, reference_coordinates: np.ndarray + ) -> Tuple[List[np.ndarray], + List[Tuple[int]]]: + reference_per_axis = list(map( + lambda coords: np.append(np.full((1, ), fill_value=-float("inf")), + np.unique(coords)), + reference_coordinates.T + )) + + reference_idx = map( + lambda coord_axis, ref_axis: + np.argmax(ref_axis[None, ...] + * (coord_axis[..., None] >= ref_axis[None, ...]), + axis=-1), + reference_coordinates.T, + reference_per_axis + ) + reference_idx = np.stack(tuple(reference_idx), axis=-1) + reference_idx = [ + tuple(tls_coord - 1) + for tls_coord in reference_idx.reshape(-1, len(reference_per_axis)) + ] + + return reference_per_axis, reference_idx + + def _compute_overlap(self, corners_coordinates: np.ndarray, + reference_per_axis: np.ndarray) -> Tuple[np.ndarray, + np.ndarray]: + tls_idx = map( + lambda coord_axis, ref_axis: + np.argmax(ref_axis[None, None, ...] + * (coord_axis[..., None] >= ref_axis[None, None, ...]), + axis=-1), + np.moveaxis(corners_coordinates, -1, 0), + reference_per_axis + ) + tls_idx = np.stack(tuple(tls_idx), axis=-1) - corners_cut = np.maximum(tls_scaled[0], tls_idx[-1]) + tls_coordinates = map( + lambda tls_coord, ref_axis: ref_axis[tls_coord], + np.moveaxis(tls_idx, -1, 0), + reference_per_axis + ) + tls_coordinates = np.stack(tuple(tls_coordinates), axis=-1) - dist2cut = np.fabs(corners - corners_cut[None]) - coverage = np.prod(dist2cut, axis=-1) + corners_cut = np.maximum(corners_coordinates[0], tls_coordinates[-1]) - # Scale the coverage to the size of the input shape - coverage *= np.prod(shape) + dist2cut = np.fabs(corners_coordinates - corners_cut[None]) + coverage = np.prod(dist2cut, axis=-1) - return coverage, tls_idx.astype(np.int64) + return coverage, tls_idx - 1 - def _compute_grid(self, chunk_mask: np.ndarray, - mask_axes: str, - mask_scale: dict, + def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_size: dict): - mask_relative_shape = np.array( - [1 / m_scl - for ax, m_scl in mask_scale.items() - if ax in self.spatial_axes - ], - dtype=np.float32 - ) - - patch_shape = np.array( - [patch_size[ax] - for ax in mask_axes - if ax in self.spatial_axes - ], - dtype=np.float32 - ) - - # If the patch sizes are greater than the relative shape of the mask - # with respect to the input image, use the mask coordinates as - # reference to overlap the coordinates of the sampling patches. - # Otherwise, use the patches coordinates instead. - if all(map(operator.gt, patch_shape, mask_relative_shape)): - active_coordinates = np.nonzero(chunk_mask) - limits_per_dim = np.array(chunk_mask.shape) + 1 + image_size: dict, + allow_incomplete_patches: bool = False): + mask_scale = np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + image_scale = np.array([image_size.get(ax, 1) / patch_size.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + round_fn = math.ceil if allow_incomplete_patches else math.floor + + image_blocks = [ + round_fn( + ( + min(image_size.get(ax, 1), + chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float("inf")) + - (chunk_tlbr[ax].start + if chunk_tlbr[ax].start is not None + else 0) + ) / patch_size.get(ax, 1)) + for ax in self.spatial_axes + ] - ref_axes = mask_axes + if min(image_blocks) == 0: + return [] - ref_shape = patch_shape - shape = mask_relative_shape + image_scale = np.array([patch_size.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + image_coordinates = np.array(list(np.ndindex(*image_blocks)), + dtype=np.float32) - mask_is_greater = False + image_coordinates *= image_scale - patch_ratio = [ - round(image_size[ax] / ps) - for ax, ps in zip(mask_axes, patch_shape.astype(np.int64)) - if ax in self.spatial_axes - ] + mask_scale = 1 / np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) - if not all(patch_ratio): - return np.empty( - [0] * len(set(mask_axes).intersection(self.spatial_axes)), - dtype=np.int64 + mask_coordinates = list(np.nonzero(mask[:])) + for ax_i, ax in enumerate(self.spatial_axes): + if ax not in mask.axes: + mask_coordinates.insert( + ax_i, + np.zeros(mask_coordinates[0].size) ) - else: - active_coordinates = np.meshgrid( - *[np.arange(math.ceil(image_size[ax] / ps)) - for ax, ps in zip(mask_axes, patch_shape) - if ax in self.spatial_axes] - ) + mask_coordinates = np.stack(mask_coordinates, dtype=np.float32).T + mask_coordinates *= mask_scale[None, ...] - limits_per_dim = np.array([ - image_size[ax] / ps - for ax, ps in zip(mask_axes, patch_shape) - if ax in self.spatial_axes - ]) + # Filter out mask coordinates outside the current selected chunk + chunk_tl_coordinates = np.array( + [chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None else 0 + for ax in self.spatial_axes], + dtype=np.float32 + ) + chunk_br_coordinates = np.array( + [chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float('inf') + for ax in self.spatial_axes], + dtype=np.float32 + ) - active_coordinates = tuple( - coord_ax.flatten() - for coord_ax in active_coordinates - ) + in_chunk = np.all( + np.bitwise_and( + mask_coordinates > (chunk_tl_coordinates - mask_scale - 1e-4), + mask_coordinates < chunk_br_coordinates + 1e-4 + ), + axis=1 + ) + mask_coordinates = mask_coordinates[in_chunk] - ref_axes = "".join([ - ax for ax in self.spatial_axes if ax in mask_axes - ]) + if all(map(operator.ge, image_scale, mask_scale)): + mask_corners = self._compute_corners(mask_coordinates, mask_scale) - ref_shape = mask_relative_shape - shape = patch_shape + (reference_per_axis, + reference_idx) =\ + self._compute_reference_indices(image_coordinates) - mask_is_greater = True + (coverage, + corners_idx) = self._compute_overlap(mask_corners, + reference_per_axis) - corners = self._compute_corners(active_coordinates, axes=ref_axes, - limits_per_dim=limits_per_dim) + covered_indices = [ + reference_idx.index(tuple(idx)) + for idx in corners_idx.reshape(-1, len(self.spatial_axes)) + ] - (coverage, - corners_idx) = self._compute_overlap(corners, shape, ref_shape) + patches_coverage = np.bincount(covered_indices, + weights=coverage.flatten(), + minlength=np.prod(image_blocks)) - if mask_is_greater: - # The mask ratio is greater than the patches size - mask_values = chunk_mask[tuple(corners_idx.T)].T - patches_coverage = coverage * mask_values + else: + image_corners = self._compute_corners(image_coordinates, + image_scale) - covered_tls = corners[0, ...].astype(np.int64) + (reference_per_axis, + reference_idx) = self._compute_reference_indices(mask_coordinates) - else: - # The mask ratio is less than the patches size - patch_coordinates = np.ravel_multi_index(tuple(corners_idx.T), - chunk_mask.shape) - patches_coverage = np.bincount(patch_coordinates.flatten(), - weights=coverage.flatten()) - patches_coverage = np.take(patches_coverage, patch_coordinates).T + (coverage, + corners_idx) = self._compute_overlap(image_corners, + reference_per_axis) - covered_tls = corners_idx[0, ...] + covered_indices = np.array([ + tuple(idx) in reference_idx + for idx in corners_idx.reshape(-1, len(self.spatial_axes)) + ]).reshape(coverage.shape) - patches_coverage = np.sum(patches_coverage, axis=0) + patches_coverage = np.sum(covered_indices * coverage, axis=0) - # Compute minimum area covered by masked regions to sample a patch. min_area = self._min_area if min_area < 1: - min_area *= patch_shape.prod() + min_area *= np.prod(list(patch_size.values())) - minumum_covered_tls = covered_tls[patches_coverage > min_area] + minimum_covered_tls = image_coordinates[patches_coverage > min_area] + minimum_covered_tls = minimum_covered_tls.astype(np.int64) - if not mask_is_greater: - # Collapse to unique coordinates since there will be multiple - # instances of the same patch. - minumum_covered_tls = np.ravel_multi_index( - tuple(minumum_covered_tls.T), - patch_ratio, - mode="clip" - ) - - minumum_covered_tls = np.unique(minumum_covered_tls) - - minumum_covered_tls = np.unravel_index( - minumum_covered_tls, - patch_ratio - ) + return minimum_covered_tls - minumum_covered_tls = np.stack(minumum_covered_tls).T - - return minumum_covered_tls * patch_shape[None].astype(np.int64) - - def _compute_valid_toplefts(self, chunk_mask: np.ndarray, mask_axes: str, - mask_scale: dict, + def _compute_valid_toplefts(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_size: dict): - return self._compute_grid(chunk_mask, mask_axes, mask_scale, - patch_size, - image_size) + **kwargs): + return self._compute_grid(chunk_tlbr, mask, patch_size, **kwargs) - def _compute_toplefts_slices(self, mask: ImageBase, image_shape: dict, - patch_size: dict, + def _compute_toplefts_slices(self, chunk_tlbr: dict, valid_mask_toplefts: np.ndarray, - chunk_tlbr: dict, + patch_size: dict, pad: Union[dict, None] = None): if pad is None: pad = {ax: 0 for ax in self.spatial_axes} - toplefts = [] - for tls in valid_mask_toplefts: - curr_tl = [] - - for ax in self.spatial_axes: - if ax in mask.axes: - tl = ((chunk_tlbr[ax].start - if chunk_tlbr[ax].start is not None else 0) - + tls[mask.axes.index(ax)]) - br = ((chunk_tlbr[ax].start - if chunk_tlbr[ax].start is not None else 0) - + tls[mask.axes.index(ax)] + patch_size[ax]) - - curr_tl.append((ax, slice(tl - pad[ax], - br + pad[ax]))) - - else: - curr_tl.append((ax, slice(0, 1))) - - toplefts.append(dict(curr_tl)) + toplefts = [ + {ax: slice( + (chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None + else 0) + tls[self.spatial_axes.index(ax)] + - pad[ax], + (chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None + else 0) + tls[self.spatial_axes.index(ax)] + patch_size[ax] + + pad[ax]) + for ax in self.spatial_axes + } + for tls in valid_mask_toplefts + ] return toplefts @@ -362,41 +378,31 @@ def compute_chunks(self, for ax, chk in zip(image.axes, image.chunk_size) if ax in self.spatial_axes } - # spatial_chunk_sizes = { - # ax: (self._patch_size[ax] - # * max(1, math.ceil(chk / self._patch_size[ax]))) - # for ax, chk in zip(image.axes, image.chunk_size) - # if ax in self.spatial_axes - # } - image_shape = {ax: s for ax, s in zip(image.axes, image.shape)} + image_size = {ax: s for ax, s in zip(image.axes, image.shape)} self._max_chunk_size = { ax: (min(max(self._max_chunk_size[ax], spatial_chunk_sizes[ax]), - image_shape[ax])) + image_size[ax])) if ax in image.axes else 1 for ax in self.spatial_axes } chunk_tlbr = {ax: slice(None) for ax in self.spatial_axes} - chunk_mask = mask[chunk_tlbr] - valid_mask_toplefts = self._compute_grid( - chunk_mask, - mask.axes, - mask.scale, + chunk_tlbr, + mask, self._max_chunk_size, - image_shape + image_size, + allow_incomplete_patches=True ) chunks_slices = self._compute_toplefts_slices( - mask, - image_shape=image_shape, - patch_size=self._max_chunk_size, + chunk_tlbr, valid_mask_toplefts=valid_mask_toplefts, - chunk_tlbr=chunk_tlbr + patch_size=self._max_chunk_size ) return chunks_slices @@ -405,29 +411,36 @@ def compute_patches(self, image_collection: ImageCollection, chunk_tlbr: dict) -> Iterable[dict]: image = image_collection.collection[image_collection.reference_mode] mask = image_collection.collection[image_collection.mask_mode] - image_shape = {ax: s for ax, s in zip(image.axes, image.shape)} - chunk_size = { - ax: ((ctb.stop if ctb.stop is not None else image_shape[ax]) - - (ctb.start if ctb.start is not None else 0)) - for ax, ctb in chunk_tlbr.items() + image_size = {ax: s for ax, s in zip(image.axes, image.shape)} + + stride = { + ax: self._stride.get(ax, 1) if image_size.get(ax, 1) > 1 else 1 + for ax in self.spatial_axes } - chunk_mask = mask[chunk_tlbr] + patch_size = { + ax: self._patch_size.get(ax, 1) if image_size.get(ax, 1) > 1 else 1 + for ax in self.spatial_axes + } + + pad = { + ax: self._pad.get(ax, 0) if image_size.get(ax, 1) > 1 else 0 + for ax in self.spatial_axes + } valid_mask_toplefts = self._compute_valid_toplefts( - chunk_mask, - mask.axes, - mask.scale, - self._stride, - chunk_size) + chunk_tlbr, + mask, + stride, + image_size=image_size, + allow_incomplete_patches=self._allow_incomplete_patches + ) patches_slices = self._compute_toplefts_slices( - mask, - image_shape=image_shape, - patch_size=self._patch_size, + chunk_tlbr, valid_mask_toplefts=valid_mask_toplefts, - chunk_tlbr=chunk_tlbr, - pad=self._pad + patch_size=patch_size, + pad=pad ) return patches_slices @@ -457,7 +470,7 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], resample_positions=False, allow_overlap=False, **kwargs): - super(BlueNoisePatchSampler, self).__init__(patch_size) + super(BlueNoisePatchSampler, self).__init__(patch_size, **kwargs) self._base_chunk_tls = None self._resample_positions = resample_positions self._allow_overlap = allow_overlap @@ -499,35 +512,62 @@ def compute_sampling_positions(self, force=False) -> None: self._base_chunk_tls = np.zeros((1, len(self.spatial_axes)), dtype=np.int64) - def _compute_valid_toplefts(self, - chunk_mask: np.ndarray, - mask_axes: str, - mask_scale: dict, + def _compute_valid_toplefts(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_shape: dict): + **kwargs): self.compute_sampling_positions(force=self._resample_positions) # Filter sampling positions that does not contain any mask portion. - sampling_pos = np.hstack( - tuple( - self._base_chunk_tls[:, self.spatial_axes.index(ax), None] - if ax in self.spatial_axes else - np.zeros((len(self._base_chunk_tls), 1), dtype=np.float32) - for ax in mask_axes - ) - ) - spatial_patch_sizes = np.array([ - patch_size[ax] - for ax in mask_axes - if ax in self.spatial_axes + patch_size.get(ax, 1) for ax in self.spatial_axes ]) - mask_corners = self._compute_corners(np.nonzero(chunk_mask), - mask_axes) + mask_scale = np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + mask_scale = 1 / np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + mask_coordinates = list(np.nonzero(mask[:])) + for ax_i, ax in enumerate(self.spatial_axes): + if ax not in mask.axes: + mask_coordinates.insert( + ax_i, + np.zeros(mask_coordinates[0].size) + ) + + mask_coordinates = np.stack(mask_coordinates, dtype=np.float32).T + mask_coordinates *= mask_scale[None, ...] + + # Filter out mask coordinates outside the current selected chunk + chunk_tl_coordinates = np.array( + [chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None else 0 + for ax in self.spatial_axes], + dtype=np.float32 + ) + chunk_br_coordinates = np.array( + [chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float('inf') + for ax in self.spatial_axes], + dtype=np.float32 + ) + + in_chunk = np.all( + np.bitwise_and( + mask_coordinates > (chunk_tl_coordinates - mask_scale - 1e-4), + mask_coordinates < chunk_br_coordinates + 1e-4 + ), + axis=1 + ) + mask_coordinates = mask_coordinates[in_chunk] + + mask_corners = self._compute_corners(mask_coordinates, mask_scale) dist = (mask_corners[None, :, :, :] - - sampling_pos[:, None, None, :].astype(np.float32) + - self._base_chunk_tls[:, None, None, :].astype(np.float32) - spatial_patch_sizes[None, None, None, :] / 2) mask_samplable_pos, = np.nonzero( @@ -537,6 +577,6 @@ def _compute_valid_toplefts(self, ) ) - toplefts = sampling_pos[mask_samplable_pos] + toplefts = self._base_chunk_tls[mask_samplable_pos] return toplefts From a5a409749fc5e1d264a1ce0abc665fada602b08d Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Tue, 7 May 2024 17:01:57 -0400 Subject: [PATCH 09/10] Added example notebook to documentation --- .../advanced_example_pytorch_inference.md | 19 ++++++++++++++++++- docs/source/index.rst | 2 ++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/source/examples/advanced_example_pytorch_inference.md b/docs/source/examples/advanced_example_pytorch_inference.md index fdeb3f9..d3a6cd9 100644 --- a/docs/source/examples/advanced_example_pytorch_inference.md +++ b/docs/source/examples/advanced_example_pytorch_inference.md @@ -1,3 +1,20 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +execution: + timeout: 120 +--- + +# Integration of ZarrDataset with PyTorch's DataLoader for inference (Advanced) + ```python import zarrdataset as zds @@ -32,7 +49,7 @@ random.seed(478965) ## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI) Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result. -Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size. +Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size by setting `allow_incomplete_patches=True`. ```python diff --git a/docs/source/index.rst b/docs/source/index.rst index 9359764..9394850 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,8 @@ Welcome to ZarrDataset's documentation! examples/advanced_example_pytorch + examples/advanced_example_pytorch_inference + REFERENCE ========= From 6ca445d6aa4dc384c5b97fb97fd4fecba6e077b5 Mon Sep 17 00:00:00 2001 From: Fernando Cervantes Sanchez Date: Fri, 10 May 2024 12:58:04 -0400 Subject: [PATCH 10/10] Fixed incorrect sampling of patches on masked regions --- .../advanced_example_pytorch_inference.md | 2 +- zarrdataset/_samplers.py | 58 ++++++++++++------- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/docs/source/examples/advanced_example_pytorch_inference.md b/docs/source/examples/advanced_example_pytorch_inference.md index d3a6cd9..29217cc 100644 --- a/docs/source/examples/advanced_example_pytorch_inference.md +++ b/docs/source/examples/advanced_example_pytorch_inference.md @@ -58,7 +58,7 @@ pad = dict(Y=16, X=16) patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True) ``` -Create a dataset from the list of filenames. All those files should be stored within their respective group "0". +Create a dataset from the list of filenames. All those files should be stored within their respective group "4", which in this case it correspond to a downsampled version of the full resolution image by a factor of 16. Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index 52df198..67bba1d 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -142,26 +142,32 @@ def _compute_corners(self, coordinates: np.ndarray, scale: np.ndarray return corners - def _compute_reference_indices(self, reference_coordinates: np.ndarray + def _compute_reference_indices(self, reference_coordinates: np.ndarray, + reference_axes_sizes: np.ndarray ) -> Tuple[List[np.ndarray], List[Tuple[int]]]: reference_per_axis = list(map( - lambda coords: np.append(np.full((1, ), fill_value=-float("inf")), - np.unique(coords)), - reference_coordinates.T + lambda coords, axis_size: np.concatenate(( + np.full((1, ), fill_value=-float("inf")), + np.unique(coords), + np.full((1, ), fill_value=np.max(coords) + axis_size))), + reference_coordinates.T, + reference_axes_sizes )) reference_idx = map( lambda coord_axis, ref_axis: - np.argmax(ref_axis[None, ...] - * (coord_axis[..., None] >= ref_axis[None, ...]), - axis=-1), + np.max(np.arange(ref_axis.size) + * (coord_axis.reshape(-1, 1) >= ref_axis[None, ...]), + axis=1), reference_coordinates.T, reference_per_axis ) reference_idx = np.stack(tuple(reference_idx), axis=-1) + reference_idx = reference_idx.reshape(reference_coordinates.T.shape) + reference_idx = [ - tuple(tls_coord - 1) + tuple(tls_coord) for tls_coord in reference_idx.reshape(-1, len(reference_per_axis)) ] @@ -172,13 +178,14 @@ def _compute_overlap(self, corners_coordinates: np.ndarray, np.ndarray]: tls_idx = map( lambda coord_axis, ref_axis: - np.argmax(ref_axis[None, None, ...] - * (coord_axis[..., None] >= ref_axis[None, None, ...]), - axis=-1), + np.max(np.arange(ref_axis.size) + * (coord_axis.reshape(-1, 1) >= ref_axis[None, ...]), + axis=1), np.moveaxis(corners_coordinates, -1, 0), reference_per_axis ) tls_idx = np.stack(tuple(tls_idx), axis=-1) + tls_idx = tls_idx.reshape(corners_coordinates.shape) tls_coordinates = map( lambda tls_coord, ref_axis: ref_axis[tls_coord], @@ -192,11 +199,12 @@ def _compute_overlap(self, corners_coordinates: np.ndarray, dist2cut = np.fabs(corners_coordinates - corners_cut[None]) coverage = np.prod(dist2cut, axis=-1) - return coverage, tls_idx - 1 + return coverage, tls_idx def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, image_size: dict, + min_area: float, allow_incomplete_patches: bool = False): mask_scale = np.array([mask.scale.get(ax, 1) for ax in self.spatial_axes], @@ -223,7 +231,7 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, ] if min(image_blocks) == 0: - return [] + return [] image_scale = np.array([patch_size.get(ax, 1) for ax in self.spatial_axes], @@ -254,6 +262,7 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, for ax in self.spatial_axes], dtype=np.float32 ) + chunk_br_coordinates = np.array( [chunk_tlbr[ax].stop if chunk_tlbr[ax].stop is not None @@ -271,12 +280,16 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, ) mask_coordinates = mask_coordinates[in_chunk] + # Translate the mask coordinates to the origin for comparison with + # image coordinates. + mask_coordinates -= chunk_tl_coordinates + if all(map(operator.ge, image_scale, mask_scale)): mask_corners = self._compute_corners(mask_coordinates, mask_scale) (reference_per_axis, reference_idx) =\ - self._compute_reference_indices(image_coordinates) + self._compute_reference_indices(image_coordinates, image_scale) (coverage, corners_idx) = self._compute_overlap(mask_corners, @@ -284,19 +297,22 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, covered_indices = [ reference_idx.index(tuple(idx)) + if tuple(idx) in reference_idx else len(reference_idx) for idx in corners_idx.reshape(-1, len(self.spatial_axes)) ] patches_coverage = np.bincount(covered_indices, weights=coverage.flatten(), - minlength=np.prod(image_blocks)) + minlength=len(reference_idx) + 1) + patches_coverage = patches_coverage[:-1] else: image_corners = self._compute_corners(image_coordinates, image_scale) (reference_per_axis, - reference_idx) = self._compute_reference_indices(mask_coordinates) + reference_idx) = self._compute_reference_indices(mask_coordinates, + mask_scale) (coverage, corners_idx) = self._compute_overlap(image_corners, @@ -309,10 +325,6 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, patches_coverage = np.sum(covered_indices * coverage, axis=0) - min_area = self._min_area - if min_area < 1: - min_area *= np.prod(list(patch_size.values())) - minimum_covered_tls = image_coordinates[patches_coverage > min_area] minimum_covered_tls = minimum_covered_tls.astype(np.int64) @@ -396,6 +408,7 @@ def compute_chunks(self, mask, self._max_chunk_size, image_size, + min_area=1, allow_incomplete_patches=True ) @@ -428,11 +441,16 @@ def compute_patches(self, image_collection: ImageCollection, for ax in self.spatial_axes } + min_area = self._min_area + if min_area < 1: + min_area *= np.prod(list(patch_size.values())) + valid_mask_toplefts = self._compute_valid_toplefts( chunk_tlbr, mask, stride, image_size=image_size, + min_area=min_area, allow_incomplete_patches=self._allow_incomplete_patches )