diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ae01ccd1..8d9331868 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,8 @@ * Propagate `output_struct_name` argument to `ExtractSegmentationInterface` to match its extractor arguments. [PR #128](https://github.com/catalystneuro/neuroconv/pull/128) * Added compression and iteration (with options control) to all Fluorescence traces in `write_segmentation`. [PR #120](https://github.com/catalystneuro/neuroconv/pull/120) * For irregular recordings, timestamps can now be saved along with all traces in `write_segmentation`. [PR #130](https://github.com/catalystneuro/neuroconv/pull/130) +* Added `mask_type` argument to `tools.roiextractors.add_plane_segmentation` function and all upstream calls. This allows users to request writing not just the image_masks (still the default) but also pixels, voxels or `None` of the above. [PR #119](https://github.com/catalystneuro/neuroconv/pull/119) +* `utils.json_schema.get_schema_from_method_signature` now allows `Optional[...]` annotation typing and subsequent `None` values during validation as long as it is still only applied to a simple non-conflicting type (no `Optional[Union[..., ...]]`). [PR #119](https://github.com/catalystneuro/neuroconv/pull/119) ### Documentation and tutorial enhancements: diff --git a/src/neuroconv/datainterfaces/ophys/basesegmentationextractorinterface.py b/src/neuroconv/datainterfaces/ophys/basesegmentationextractorinterface.py index fb7cc9947..71778ebb8 100644 --- a/src/neuroconv/datainterfaces/ophys/basesegmentationextractorinterface.py +++ b/src/neuroconv/datainterfaces/ophys/basesegmentationextractorinterface.py @@ -67,6 +67,7 @@ def run_conversion( stub_test: bool = False, stub_frames: int = 100, include_roi_centroids: bool = True, + mask_type: Optional[str] = "image", # Optional[Literal["image", "pixel"]] iterator_options: Optional[dict] = None, compression_options: Optional[dict] = None, ): @@ -86,6 +87,7 @@ def run_conversion( overwrite=overwrite, verbose=self.verbose, include_roi_centroids=include_roi_centroids, + mask_type=mask_type, iterator_options=iterator_options, compression_options=compression_options, ) diff --git a/src/neuroconv/tools/roiextractors/roiextractors.py b/src/neuroconv/tools/roiextractors/roiextractors.py index 5254681df..7631ef147 100644 --- a/src/neuroconv/tools/roiextractors/roiextractors.py +++ b/src/neuroconv/tools/roiextractors/roiextractors.py @@ -549,6 +549,7 @@ def add_plane_segmentation( metadata: Optional[dict], plane_segmentation_index: int = 0, include_roi_centroids: bool = True, + mask_type: Optional[str] = "image", # Optional[Literal["image", "pixel"]] iterator_options: Optional[dict] = None, compression_options: Optional[dict] = None, ) -> NWBFile: @@ -572,6 +573,17 @@ def add_plane_segmentation( If there are a very large number of ROIs (such as in whole-brain recordings), you may wish to disable this for faster write speeds. Defaults to True. + mask_type : str, optional + There are two types of ROI masks in NWB: ImageMasks and PixelMasks. + Image masks have the same shape as the reference images the segmentation was applied to, and weight each pixel + by its contribution to the ROI (typically boolean, with 0 meaning 'not in the ROI'). + Pixel masks are instead indexed by ROI, with the data at each index being the shape of the image by the number + of pixels in each ROI. + Voxel masks are instead indexed by ROI, with the data at each index being the shape of the volume by the number + of voxels in each ROI. + Specify your choice between these two as mask_type='image', 'pixel', 'voxel', or None. + If None, the mask information is not written to the NWB file. + Defaults to 'image'. iterator_options : dict, optional The options to use when iterating over the image masks of the segmentation extractor. compression_options : dict, optional @@ -582,14 +594,14 @@ def add_plane_segmentation( NWBFile The nwbfile passed as an input with the plane segmentation added. """ + assert mask_type in ["image", "pixel", "voxel", None], ( + "Keyword argument 'mask_type' must be one of either 'image', 'pixel', 'voxel', " + f"or None (to not write any masks)! Received '{mask_type}'." + ) + iterator_options = iterator_options or dict() compression_options = compression_options or dict(compression="gzip") - def image_mask_iterator(): - for roi_id in segmentation_extractor.get_roi_ids(): - image_masks = segmentation_extractor.get_roi_image_masks(roi_ids=[roi_id]).T.squeeze() - yield image_masks - # Set the defaults and required infrastructure metadata_copy = deepcopy(metadata) default_metadata = get_default_ophys_metadata() @@ -599,16 +611,8 @@ def image_mask_iterator(): plane_segmentation_metadata = image_segmentation_metadata["plane_segmentations"][plane_segmentation_index] plane_segmentation_name = plane_segmentation_metadata["name"] - add_imaging_plane( - nwbfile=nwbfile, - metadata=metadata_copy, - imaging_plane_index=plane_segmentation_index, - ) - - add_image_segmentation( - nwbfile=nwbfile, - metadata=metadata_copy, - ) + add_imaging_plane(nwbfile=nwbfile, metadata=metadata_copy, imaging_plane_index=plane_segmentation_index) + add_image_segmentation(nwbfile=nwbfile, metadata=metadata_copy) ophys = get_module(nwbfile, "ophys") image_segmentation_name = image_segmentation_metadata["name"] @@ -624,39 +628,63 @@ def image_mask_iterator(): imaging_plane_name = imaging_plane_metadata["name"] imaging_plane = nwbfile.imaging_planes[imaging_plane_name] - plane_segmentation_kwargs = dict( - **plane_segmentation_metadata, - imaging_plane=imaging_plane, - columns=[ - VectorData( - data=H5DataIO( - DataChunkIterator(image_mask_iterator(), **iterator_options), - **compression_options, - ), - name="image_mask", - description="image masks", - ), - VectorData( - data=accepted_ids, - name="Accepted", - description="1 if ROI was accepted or 0 if rejected as a cell during segmentation operation", - ), - VectorData( - data=rejected_ids, - name="Rejected", - description="1 if ROI was rejected or 0 if accepted as a cell during segmentation operation", - ), - ], - id=roi_ids, - ) - plane_segmentation = PlaneSegmentation(**plane_segmentation_kwargs) + plane_segmentation_kwargs = dict(**plane_segmentation_metadata, imaging_plane=imaging_plane) + if mask_type is None: + plane_segmentation = PlaneSegmentation(id=roi_ids, **plane_segmentation_kwargs) + elif mask_type == "image": + plane_segmentation = PlaneSegmentation(id=roi_ids, **plane_segmentation_kwargs) + plane_segmentation.add_column( + name="image_mask", + description="Image masks for each ROI.", + data=H5DataIO(segmentation_extractor.get_roi_image_masks().T, **compression_options), + ) + elif mask_type == "pixel" or mask_type == "voxel": + pixel_masks = segmentation_extractor.get_roi_pixel_masks() + num_pixel_dims = pixel_masks[0].shape[1] + + assert num_pixel_dims in [3, 4], ( + "The segmentation extractor returned a pixel mask that is not 3- or 4- dimensional! " + "Please open a ticket with https://github.com/catalystneuro/roiextractors/issues" + ) + if mask_type == "pixel" and num_pixel_dims == 4: + warn( + "Specified mask_type='pixel', but ROIExtractors returned 4-dimensional masks. " + "Using mask_type='voxel' instead." + ) + mask_type = "voxel" + if mask_type == "voxel" and num_pixel_dims == 3: + warn( + "Specified mask_type='voxel', but ROIExtractors returned 3-dimensional masks. " + "Using mask_type='pixel' instead." + ) + mask_type = "pixel" + + mask_type_kwarg = f"{mask_type}_mask" + plane_segmentation = PlaneSegmentation(**plane_segmentation_kwargs) + for roi_id, pixel_mask in zip(roi_ids, pixel_masks): + plane_segmentation.add_roi(**{"id": roi_id, mask_type_kwarg: [tuple(x) for x in pixel_mask]}) + if include_roi_centroids: # ROIExtractors uses height x width x (depth), but NWB uses width x height x depth tranpose_image_convention = (1, 0) if len(segmentation_extractor.get_image_size()) == 2 else (1, 0, 2) - roi_locations = segmentation_extractor.get_roi_locations()[tranpose_image_convention, :] + roi_locations = segmentation_extractor.get_roi_locations()[tranpose_image_convention, :].T plane_segmentation.add_column( - name="ROICentroids", description="The x, y, (z) centroids of each ROI.", data=roi_locations.T + name="ROICentroids", + description="The x, y, (z) centroids of each ROI.", + data=H5DataIO(roi_locations, **compression_options), ) + + plane_segmentation.add_column( + name="Accepted", + description="1 if ROI was accepted or 0 if rejected as a cell during segmentation operation.", + data=H5DataIO(accepted_ids, **compression_options), + ) + plane_segmentation.add_column( + name="Rejected", + description="1 if ROI was rejected or 0 if accepted as a cell during segmentation operation.", + data=H5DataIO(rejected_ids, **compression_options), + ) + image_segmentation.add_plane_segmentation(plane_segmentations=[plane_segmentation]) return nwbfile @@ -728,10 +756,7 @@ def add_fluorescence_traces( plane_index=plane_index, ) - roi_response_series_kwargs = dict( - rois=roi_table_region, - unit="n.a.", - ) + roi_response_series_kwargs = dict(rois=roi_table_region, unit="n.a.") # Add timestamps or rate timestamps = segmentation_extractor.frame_to_time(np.arange(segmentation_extractor.get_num_frames())) @@ -772,7 +797,6 @@ def add_fluorescence_traces( trace_metadata = next( trace_metadata for trace_metadata in response_series_metadata if trace_name == trace_metadata["name"] ) - # Build the roi response series roi_response_series_kwargs.update( data=H5DataIO(SliceableDataChunkIterator(trace, **iterator_options), **compression_options), @@ -883,6 +907,7 @@ def write_segmentation( buffer_size: int = 10, plane_num: int = 0, include_roi_centroids: bool = True, + mask_type: Optional[str] = "image", # Optional[Literal["image", "pixel"]] iterator_options: Optional[dict] = None, compression_options: Optional[dict] = None, ): @@ -920,6 +945,17 @@ def write_segmentation( If there are a very large number of ROIs (such as in whole-brain recordings), you may wish to disable this for faster write speeds. Defaults to True. + mask_type : str, optional + There are two types of ROI masks in NWB: ImageMasks and PixelMasks. + Image masks have the same shape as the reference images the segmentation was applied to, and weight each pixel + by its contribution to the ROI (typically boolean, with 0 meaning 'not in the ROI'). + Pixel masks are instead indexed by ROI, with the data at each index being the shape of the image by the number + of pixels in each ROI. + Voxel masks are instead indexed by ROI, with the data at each index being the shape of the volume by the number + of voxels in each ROI. + Specify your choice between these two as mask_type='image', 'pixel', 'voxel', or None. + If None, the mask information is not written to the NWB file. + Defaults to 'image'. """ assert ( nwbfile_path is None or nwbfile is None @@ -958,7 +994,7 @@ def write_segmentation( nwbfile_path=nwbfile_path, nwbfile=nwbfile, metadata=metadata_base_common, overwrite=overwrite, verbose=verbose ) as nwbfile_out: - ophys = get_module(nwbfile=nwbfile_out, name="ophys", description="contains optical physiology processed data") + _ = get_module(nwbfile=nwbfile_out, name="ophys", description="contains optical physiology processed data") for plane_no_loop, (segmentation_extractor, metadata) in enumerate( zip(segmentation_extractors, metadata_base_list) ): @@ -967,11 +1003,11 @@ def write_segmentation( add_devices(nwbfile=nwbfile_out, metadata=metadata) # ImageSegmentation: - image_segmentation_name = ( - "ImageSegmentation" if plane_no_loop == 0 else f"ImageSegmentation_Plane{plane_no_loop}" - ) - add_image_segmentation(nwbfile=nwbfile_out, metadata=metadata) - image_segmentation = ophys.data_interfaces.get(image_segmentation_name) + # image_segmentation_name = ( + # "ImageSegmentation" if plane_no_loop == 0 else f"ImageSegmentation_Plane{plane_no_loop}" + # ) + # add_image_segmentation(nwbfile=nwbfile_out, metadata=metadata) + # image_segmentation = ophys.data_interfaces.get(image_segmentation_name) # Add imaging plane add_imaging_plane(nwbfile=nwbfile_out, metadata=metadata) @@ -982,6 +1018,7 @@ def write_segmentation( nwbfile=nwbfile_out, metadata=metadata, include_roi_centroids=include_roi_centroids, + mask_type=mask_type, iterator_options=iterator_options, compression_options=compression_options, ) diff --git a/src/neuroconv/utils/json_schema.py b/src/neuroconv/utils/json_schema.py index 00173bbed..f3294e2b1 100644 --- a/src/neuroconv/utils/json_schema.py +++ b/src/neuroconv/utils/json_schema.py @@ -78,11 +78,17 @@ def get_schema_from_method_signature(class_method: classmethod, exclude: list = param_types = [annotation_json_type_map[x.__name__] for x in np.array(args)[valid_args]] else: raise ValueError("No valid arguments were found in the json type mapping!") - if len(set(param_types)) > 1: - raise ValueError( - "Conflicting json parameter types were detected from the annotation! " - f"{param.annotation.__args__} found." - ) + num_params = len(set(param_types)) + conflict_message = ( + "Conflicting json parameter types were detected from the annotation! " + f"{param.annotation.__args__} found." + ) + # Normally cannot support Union[...] of multiple annotation types + if num_params > 2: + raise ValueError(conflict_message) + # Special condition for Optional[...] + if num_params == 2 and not args[1] is type(None): # noqa: E721 + raise ValueError(conflict_message) param_type = param_types[0] else: arg = param.annotation diff --git a/tests/test_ophys/test_tools_roiextractors.py b/tests/test_ophys/test_tools_roiextractors.py index d385a67f7..2dcdcaff1 100644 --- a/tests/test_ophys/test_tools_roiextractors.py +++ b/tests/test_ophys/test_tools_roiextractors.py @@ -3,9 +3,12 @@ from tempfile import mkdtemp from pathlib import Path from datetime import datetime +from typing import Optional, List +from types import MethodType import psutil import numpy as np +from numpy.typing import ArrayLike from hdmf.data_utils import DataChunkIterator from hdmf.testing import TestCase from numpy.testing import assert_array_equal, assert_raises @@ -278,6 +281,22 @@ def test_add_image_segmentation(self): self.assertEqual(image_segmentation.name, self.image_segmentation_name) +def _generate_test_masks(num_rois: int, mask_type: str): # Literal["pixel", "voxel"] + masks = list() + size = 3 if mask_type == "pixel" else 4 + for idx in range(1, num_rois + 1): + masks.append(np.arange(idx, idx + size * idx, dtype=np.dtype("uint8")).reshape(-1, size)) + return masks + + +def _generate_casted_test_masks(num_rois: int, mask_type: str): # Literal["pixel", "voxel"] + original_mask = _generate_test_masks(num_rois=num_rois, mask_type=mask_type) + casted_masks = list() + for per_roi_mask in original_mask: + casted_masks.append([tuple(x) for x in per_roi_mask]) + return casted_masks + + class TestAddPlaneSegmentation(unittest.TestCase): @classmethod def setUpClass(cls): @@ -349,15 +368,9 @@ def test_add_plane_segmentation(self): assert_array_equal(plane_segmentation_roi_centroid_data, expected_roi_centroid_data) - image_mask_iterator = plane_segmentation["image_mask"].data - - data_chunks = np.zeros((self.num_rois, self.num_columns, self.num_rows)) - for data_chunk in image_mask_iterator: - data_chunks[data_chunk.selection] = data_chunk.data - # transpose to num_rois x image_width x image_height expected_image_masks = self.segmentation_extractor.get_roi_image_masks().T - assert_array_equal(data_chunks, expected_image_masks) + assert_array_equal(plane_segmentation["image_mask"], expected_image_masks) def test_do_not_include_roi_centroids(self): """Test that setting `include_roi_centroids=False` prevents the centroids from being calculated and added.""" @@ -434,6 +447,163 @@ def test_rejected_roi_ids(self, rejected_list, expected_rejected_roi_ids): plane_segmentation_accepted_roi_ids = plane_segmentation["Accepted"].data assert_array_equal(plane_segmentation_accepted_roi_ids, accepted_roi_ids) + def test_pixel_masks(self): + """Test the voxel mask option for writing a plane segementation table.""" + segmentation_extractor = generate_dummy_segmentation_extractor( + num_rois=self.num_rois, + num_frames=self.num_frames, + num_rows=self.num_rows, + num_columns=self.num_columns, + ) + + def get_roi_pixel_masks(self, roi_ids: Optional[ArrayLike] = None) -> List[np.ndarray]: + roi_ids = roi_ids or range(self.get_num_rois()) + pixel_masks = _generate_test_masks(num_rois=len(roi_ids), mask_type="pixel") + return pixel_masks + + segmentation_extractor.get_roi_pixel_masks = MethodType(get_roi_pixel_masks, segmentation_extractor) + + add_plane_segmentation( + segmentation_extractor=segmentation_extractor, + nwbfile=self.nwbfile, + metadata=self.metadata, + mask_type="pixel", + ) + + image_segmentation = self.nwbfile.processing["ophys"].get(self.image_segmentation_name) + plane_segmentations = image_segmentation.plane_segmentations + + plane_segmentation = plane_segmentations[self.plane_segmentation_name] + + true_pixel_masks = _generate_casted_test_masks(num_rois=self.num_rois, mask_type="pixel") + assert_array_equal(plane_segmentation["pixel_mask"], true_pixel_masks) + + def test_voxel_masks(self): + """Test the voxel mask option for writing a plane segementation table.""" + segmentation_extractor = generate_dummy_segmentation_extractor( + num_rois=self.num_rois, + num_frames=self.num_frames, + num_rows=self.num_rows, + num_columns=self.num_columns, + ) + + def get_roi_pixel_masks(self, roi_ids: Optional[ArrayLike] = None) -> List[np.ndarray]: + roi_ids = roi_ids or range(self.get_num_rois()) + voxel_masks = _generate_test_masks(num_rois=len(roi_ids), mask_type="voxel") + return voxel_masks + + segmentation_extractor.get_roi_pixel_masks = MethodType(get_roi_pixel_masks, segmentation_extractor) + + add_plane_segmentation( + segmentation_extractor=segmentation_extractor, + nwbfile=self.nwbfile, + metadata=self.metadata, + mask_type="voxel", + ) + + image_segmentation = self.nwbfile.processing["ophys"].get(self.image_segmentation_name) + plane_segmentations = image_segmentation.plane_segmentations + + plane_segmentation = plane_segmentations[self.plane_segmentation_name] + + true_voxel_masks = _generate_casted_test_masks(num_rois=self.num_rois, mask_type="voxel") + assert_array_equal(plane_segmentation["voxel_mask"], true_voxel_masks) + + def test_none_masks(self): + """Test the None mask_type option for writing a plane segementation table.""" + segmentation_extractor = generate_dummy_segmentation_extractor( + num_rois=self.num_rois, + num_frames=self.num_frames, + num_rows=self.num_rows, + num_columns=self.num_columns, + ) + + add_plane_segmentation( + segmentation_extractor=segmentation_extractor, nwbfile=self.nwbfile, metadata=self.metadata, mask_type=None + ) + + image_segmentation = self.nwbfile.processing["ophys"].get(self.image_segmentation_name) + plane_segmentations = image_segmentation.plane_segmentations + + plane_segmentation = plane_segmentations[self.plane_segmentation_name] + assert "image_mask" not in plane_segmentation + assert "pixel_mask" not in plane_segmentation + assert "voxel_mask" not in plane_segmentation + + def test_pixel_masks_auto_switch(self): + segmentation_extractor = generate_dummy_segmentation_extractor( + num_rois=self.num_rois, + num_frames=self.num_frames, + num_rows=self.num_rows, + num_columns=self.num_columns, + ) + + def get_roi_pixel_masks(self, roi_ids: Optional[ArrayLike] = None) -> List[np.ndarray]: + roi_ids = roi_ids or range(self.get_num_rois()) + pixel_masks = _generate_test_masks(num_rois=len(roi_ids), mask_type="pixel") + return pixel_masks + + segmentation_extractor.get_roi_pixel_masks = MethodType(get_roi_pixel_masks, segmentation_extractor) + + with self.assertWarnsRegex( + expected_warning=UserWarning, + expected_regex=( + "Specified mask_type='voxel', but ROIExtractors returned 3-dimensional masks. " + "Using mask_type='pixel' instead." + ), + ): + add_plane_segmentation( + segmentation_extractor=segmentation_extractor, + nwbfile=self.nwbfile, + metadata=self.metadata, + mask_type="voxel", + ) + + image_segmentation = self.nwbfile.processing["ophys"].get(self.image_segmentation_name) + plane_segmentations = image_segmentation.plane_segmentations + + plane_segmentation = plane_segmentations[self.plane_segmentation_name] + + true_voxel_masks = _generate_casted_test_masks(num_rois=self.num_rois, mask_type="pixel") + assert_array_equal(plane_segmentation["pixel_mask"], true_voxel_masks) + + def test_voxel_masks_auto_switch(self): + segmentation_extractor = generate_dummy_segmentation_extractor( + num_rois=self.num_rois, + num_frames=self.num_frames, + num_rows=self.num_rows, + num_columns=self.num_columns, + ) + + def get_roi_pixel_masks(self, roi_ids: Optional[ArrayLike] = None) -> List[np.ndarray]: + roi_ids = roi_ids or range(self.get_num_rois()) + voxel_masks = _generate_test_masks(num_rois=len(roi_ids), mask_type="voxel") + return voxel_masks + + segmentation_extractor.get_roi_pixel_masks = MethodType(get_roi_pixel_masks, segmentation_extractor) + + with self.assertWarnsRegex( + expected_warning=UserWarning, + expected_regex=( + "Specified mask_type='pixel', but ROIExtractors returned 4-dimensional masks. " + "Using mask_type='voxel' instead." + ), + ): + add_plane_segmentation( + segmentation_extractor=segmentation_extractor, + nwbfile=self.nwbfile, + metadata=self.metadata, + mask_type="pixel", + ) + + image_segmentation = self.nwbfile.processing["ophys"].get(self.image_segmentation_name) + plane_segmentations = image_segmentation.plane_segmentations + + plane_segmentation = plane_segmentations[self.plane_segmentation_name] + + true_voxel_masks = _generate_casted_test_masks(num_rois=self.num_rois, mask_type="voxel") + assert_array_equal(plane_segmentation["voxel_mask"], true_voxel_masks) + def test_not_overwriting_plane_segmentation_if_same_name(self): """Test that adding a plane segmentation with the same name will not overwrite the existing plane segmentation.""" @@ -994,7 +1164,7 @@ def test_non_iterative_write_assertion(self): mock_imaging.get_num_frames.return_value = num_frames_to_overflow reg_expression = ( - f"Memory error, full TwoPhotonSeries data is (.*?) GB are available! Please use iterator_type='v2'" + "Memory error, full TwoPhotonSeries data is (.*?) GB are available! Please use iterator_type='v2'" ) with self.assertRaisesRegex(MemoryError, reg_expression):