Skip to content

Commit

Permalink
Merge pull request #200 from catalystneuro/segmentation_trace_orienta…
Browse files Browse the repository at this point in the history
…tion_fix

Fix segmentation orientation
  • Loading branch information
h-mayorquin authored Aug 30, 2022
2 parents ae62792 + 6f241df commit b71f17d
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 46 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Upcoming

### Back-compatability break
* The orientation of traces in all `SegmentationExtractor`s has been standardized to have time (frames) as the first axis, and ROIs as the final axis. [PR #200](https://github.com/catalystneuro/roiextractors/pull/200)

### Features
* Add support for newer versions of EXTRACT output files.
* Add support for newer versions of EXTRACT output files. [PR #170](https://github.com/catalystneuro/roiextractors/pull/170)
The `ExtractSegmentationExtractor` class is now abstract and redirects to the newer or older
extractor depending on the version of the file. [PR #170](https://github.com/catalystneuro/roiextractors/pull/170)
* The `ExtractSegmentationExtractor.write_segmentation` method has now been deprecated. [PR #170](https://github.com/catalystneuro/roiextractors/pull/170)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np

from ...extraction_tools import PathType
from ...extraction_tools import PathType, get_package
from ...multisegmentationextractor import MultiSegmentationExtractor
from ...segmentationextractor import SegmentationExtractor

Expand Down Expand Up @@ -70,8 +70,10 @@ def _image_mask_sparse_read(self):
return image_masks

def _trace_extractor_read(self, field):
if self._dataset_file["estimates"].get(field):
return self._dataset_file["estimates"][field] # lazy read dataset)
lazy_ops = get_package(package_name="lazy_ops")

if field in self._dataset_file["estimates"]:
return lazy_ops.DatasetView(self._dataset_file["estimates"][field]).lazy_transpose()

def _summary_image_read(self):
if self._dataset_file["estimates"].get("Cn"):
Expand Down
16 changes: 8 additions & 8 deletions src/roiextractors/extractors/numpyextractors/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,31 +233,31 @@ def __init__(
self.is_dumpable = False
self._image_masks = image_masks
self._roi_response_raw = raw
assert self._image_masks.shape[2] == len(self._roi_response_raw), (
assert self._image_masks.shape[-1] == self._roi_response_raw.shape[-1], (
"Inconsistency between image masks and raw traces. "
"Image masks must be (px, py, num_rois), "
"traces must be (num_rois, num_frames)"
"traces must be (num_frames, num_rois)"
)
self._roi_response_dff = dff
if self._roi_response_dff is not None:
assert self._image_masks.shape[2] == len(self._roi_response_dff), (
assert self._image_masks.shape[-1] == self._roi_response_dff.shape[-1], (
"Inconsistency between image masks and raw traces. "
"Image masks must be (px, py, num_rois), "
"traces must be (num_rois, num_frames)"
"traces must be (num_frames, num_rois)"
)
self._roi_response_neuropil = neuropil
if self._roi_response_neuropil is not None:
assert self._image_masks.shape[2] == len(self._roi_response_neuropil), (
assert self._image_masks.shape[-1] == self._roi_response_neuropil.shape[-1], (
"Inconsistency between image masks and raw traces. "
"Image masks must be (px, py, num_rois), "
"traces must be (num_rois, num_frames)"
"traces must be (num_frames, num_rois)"
)
self._roi_response_deconvolved = deconvolved
if self._roi_response_deconvolved is not None:
assert self._image_masks.shape[2] == len(self._roi_response_deconvolved), (
assert self._image_masks.shape[-1] == self._roi_response_deconvolved.shape[-1], (
"Inconsistency between image masks and raw traces. "
"Image masks must be (px, py, num_rois), "
"traces must be (num_rois, num_frames)"
"traces must be (num_frames, num_rois)"
)
self._kwargs = {
"image_masks": image_masks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def _image_mask_extractor_read(self):
return DatasetView(self._dataset_file[self._group0[0]]["extractedImages"]).lazy_transpose([1, 2, 0])

def _trace_extractor_read(self):
extracted_signals = DatasetView(self._dataset_file[self._group0[0]]["extractedSignals"])
return extracted_signals.T
return self._dataset_file[self._group0[0]]["extractedSignals"]

def _tot_exptime_extractor_read(self):
return self._dataset_file[self._group0[0]]["time"]["totalTime"][0][0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _image_mask_extractor_read(self) -> DatasetView:

def _trace_extractor_read(self) -> DatasetView:
"""Returns the traces with a shape of number of ROIs and number of frames."""
return DatasetView(self._output_struct["temporal_weights"])
return DatasetView(self._output_struct["temporal_weights"]).lazy_transpose()

def get_accepted_list(self) -> list:
"""
Expand Down Expand Up @@ -276,7 +276,7 @@ def __init__(self, file_path: PathType):
self._image_masks = self._image_mask_extractor_read()
self._roi_response_raw = self._trace_extractor_read()
self._raw_movie_file_location = self._raw_datafile_read()
self._sampling_frequency = self._roi_response_raw.shape[1] / self._tot_exptime_extractor_read()
self._sampling_frequency = self._roi_response_raw.shape[0] / self._tot_exptime_extractor_read()
self._image_correlation = self._summary_image_read()

def __del__(self):
Expand All @@ -292,8 +292,7 @@ def _image_mask_extractor_read(self):
return self._dataset_file[self._group0[0]]["filters"][:].transpose([1, 2, 0])

def _trace_extractor_read(self):
extracted_signals = DatasetView(self._dataset_file[self._group0[0]]["traces"])
return extracted_signals.T
return self._dataset_file[self._group0[0]]["traces"]

def _tot_exptime_extractor_read(self):
return self._dataset_file[self._group0[0]]["time"]["totalTime"][0][0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def __init__(
self.folder_path = Path(folder_path)

self.stat = self._load_npy("stat.npy")
self._roi_response_raw = self._load_npy("F.npy", mmap_mode="r")
self._roi_response_neuropil = self._load_npy("Fneu.npy", mmap_mode="r")
self._roi_response_deconvolved = self._load_npy("spks.npy", mmap_mode="r")
self._roi_response_raw = self._load_npy("F.npy", mmap_mode="r").T
self._roi_response_neuropil = self._load_npy("Fneu.npy", mmap_mode="r").T
self._roi_response_deconvolved = self._load_npy("spks.npy", mmap_mode="r").T
self.iscell = self._load_npy("iscell.npy", mmap_mode="r")
self.ops = self._load_npy("ops.npy").item()

Expand Down
12 changes: 6 additions & 6 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def get_num_frames(self) -> int:
"""
for trace in self.get_traces_dict().values():
if trace is not None and len(trace.shape) > 0:
return trace.shape[1]
return trace.shape[0]

def get_roi_locations(self, roi_ids=None) -> np.array:
def get_roi_locations(self, roi_ids=None) -> np.ndarray:
"""
Returns the locations of the Regions of Interest
Expand Down Expand Up @@ -103,7 +103,7 @@ def get_roi_ids(self) -> list:
"""
return list(range(self.get_num_rois()))

def get_roi_image_masks(self, roi_ids=None) -> np.array:
def get_roi_image_masks(self, roi_ids=None) -> np.ndarray:
"""Returns the image masks extracted from segmentation algorithm.
Parameters
Expand Down Expand Up @@ -173,10 +173,10 @@ def get_traces(self, roi_ids=None, start_frame=None, end_frame=None, name="raw")
roi_idx_ = range(self.get_num_rois())
else:
all_ids = self.get_roi_ids()
roi_idx_ = [all_ids.index(i) for i in roi_ids]
roi_idx_ = [int(all_ids.index(i)) for i in roi_ids]
traces = self.get_traces_dict().get(name)
if traces is not None and len(traces.shape) != 0:
return np.array([traces[int(i), start_frame:end_frame] for i in roi_idx_])
return np.array([traces[start_frame:end_frame, idx] for idx in roi_idx_])

def get_traces_dict(self):
"""
Expand Down Expand Up @@ -243,7 +243,7 @@ def get_num_rois(self):
"""
for trace in self.get_traces_dict().values():
if trace is not None and len(trace.shape) > 0:
return trace.shape[0]
return trace.shape[1]

def get_channel_names(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def generate_dummy_segmentation_extractor(
movie_dims = (num_rows, num_columns)

# Create signals
raw = np.random.rand(num_rois, num_frames) if has_raw_signal else None
dff = np.random.rand(num_rois, num_frames) if has_dff_signal else None
deconvolved = np.random.rand(num_rois, num_frames) if has_deconvolved_signal else None
neuropil = np.random.rand(num_rois, num_frames) if has_neuropil_signal else None
raw = np.random.rand(num_frames, num_rois) if has_raw_signal else None
dff = np.random.rand(num_frames, num_rois) if has_dff_signal else None
deconvolved = np.random.rand(num_frames, num_rois) if has_deconvolved_signal else None
neuropil = np.random.rand(num_frames, num_rois) if has_neuropil_signal else None

# Summary images
mean_image = np.random.rand(num_rows, num_columns) if has_summary_images else None
Expand Down
12 changes: 4 additions & 8 deletions tests/test_extractsegmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def test_extractor_data_validity(self):
).lazy_transpose()
self.assertEqual(self.extractor._image_masks.shape, spatial_weights.shape)

temporal_weights = DatasetView(segmentation_file[self.output_struct_name]["temporal_weights"])
self.assertEqual(self.extractor._roi_response_dff.shape, temporal_weights.shape)
self.assertEqual(self.extractor._roi_response_dff.shape, (2000, 20))

self.assertEqual(self.extractor._roi_response_raw, None)

Expand All @@ -131,14 +130,11 @@ def test_extractor_data_validity(self):

assert_array_equal(self.extractor.get_image_size(), [50, 50])

num_rois = temporal_weights.shape[0]
self.assertEqual(self.extractor.get_num_rois(), num_rois)

num_frames = temporal_weights.shape[1]
self.assertEqual(self.extractor.get_num_frames(), num_frames)
self.assertEqual(self.extractor.get_num_rois(), 20)
self.assertEqual(self.extractor.get_num_frames(), 2000)

self.assertEqual(self.extractor.get_rejected_list(), [])
self.assertEqual(self.extractor.get_accepted_list(), list(range(num_rois)))
self.assertEqual(self.extractor.get_accepted_list(), list(range(20)))

def test_extractor_config(self):
"""Test that the extractor class returns the expected config."""
Expand Down
8 changes: 1 addition & 7 deletions tests/test_internals/test_testing_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
from hdmf.testing import TestCase
from numpy.testing import assert_array_equal

import numpy as np
from numpy.testing import assert_array_equal

from roiextractors.testing import (
generate_dummy_segmentation_extractor,
_assert_iterable_complete,
)
from roiextractors.testing import generate_dummy_segmentation_extractor, _assert_iterable_complete


class TestDummySegmentationExtractor(TestCase):
Expand Down

0 comments on commit b71f17d

Please sign in to comment.