Skip to content

Commit

Permalink
Whole brain affs (#24)
Browse files Browse the repository at this point in the history
* Merge pull request #1 from flatironinstitute/main (#19)

* Add files via upload

* boundary augmentation

* commented reneu

* commentted reneu

* added

* stuff

* made change

* tensorboard

* here

* changes

* changes

* change

* yes

* changes

* update the samples

* loaded image chunks, but the image dimension is not correctly handled

* fixed a bug of patch cutout

* add cleanup script to clean up the files produced during training

* affinity map training is still not working. the random_patch function is not calling the inherited transform function

* fix some array size

* changes to ba file

* print more info

* save the status, not working yet. The patch location generators are done.

* initial

* rename patch_location_generator to patch_bounding_box_generator

* documentation

* inherite from trainerbase to fix a bug of target to be all zero.

* whole brain volume sampler

---------

Co-authored-by: Manuel Paez <[email protected]>
Co-authored-by: Manuel Paez <[email protected]>
Co-authored-by: mannypaeza <[email protected]>
  • Loading branch information
4 people authored Apr 27, 2023
1 parent 2b0561d commit 442dc65
Show file tree
Hide file tree
Showing 13 changed files with 694 additions and 244 deletions.
Binary file added examples/.boundary_augmentation.yaml.swp
Binary file not shown.
56 changes: 21 additions & 35 deletions examples/affs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,27 @@ system:
seed: 1

dataset:
training:
s3vol01700:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_01700/img.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_01700/label_v3.h5"
# s3vol02299:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02299/img.h5",]
# label: "~/dropbox/40_gt/13_wasp_sample3/vol_02299/label_v3.h5"
s3vol02400:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02400/img_zyx_2400-2656_5700-5956_2770-3026.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_02400/label_v1.h5"
#s3vol02794:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02794/img_zyx_2794-3050_5811-6067_8757-9013.h5",]
# label: "~/dropbox/40_gt/13_wasp_sample3/vol_02794/seg_v1_cropped.h5"
s3vol03290:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03290/img_zyx_3290-3546_2375-2631_8450-8706.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03290/label_v1.h5"
s3vol03700:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03700/img_zyx_3700-3956_5000-5256_4250-4506.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03700/label_v3.h5"
s3vol03998:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03998/img_zyx_3998-4254_4280-4536_4035-4291.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_03998/label_v2.h5"
s3vol04900:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04900/img.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_04900/label_v1.h5"
s3vol05250:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05250/img_zyx_5250-5506_4600-4856_5500-5756.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_05250/label_v3_remove_contact.h5"
s3vol05450:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05450/img_zyx_5450-5706_5350-5606_7000-7256.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_05450/label_v4_chiyip.h5"
validation:
s3vol04000:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04000/img_zyx_4000-4256_3400-3656_8150-8406.h5",]
label: "~/dropbox/40_gt/13_wasp_sample3/vol_04000/label_v3.h5"
sample3:
images: [
"precomputed://file:///mnt/ceph/users/neuro/wasp_em/jwu/sample3/05_yuri_v3",
"precomputed://file:///mnt/ceph/users/neuro/wasp_em/jwu/sample3/07_yuri_v5",
"precomputed://file:///mnt/ceph/users/neuro/wasp_em/jwu/sample3/04_clahe"
]
mask: "precomputed://file:///mnt/ceph/users/neuro/wasp_em/jwu/sample3/20_segmentation_mask"
labels:
training: [
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_01700/label_v3.h5",
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_02400/label_v1.h5",
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_03290/label_v1.h5",
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_03700/label_v3.h5",
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_03998/label_v2.h5",
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_04900/label_v1.h5",
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_05250/label_v3_remove_contact.h5",
]
validation: [
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_04000/label_v3.h5",
"/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_05450/label_v4_chiyip.h5",
]
model:
in_channels: 1
out_channels: 3
Expand Down
43 changes: 24 additions & 19 deletions examples/boundary_augmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,32 @@ system:
dataset:
training:
s3vol01700:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_01700/affs_160k.h5",]
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_01700/affs_160k.h5",]
s3vol02299:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02299/affs_03_160k.h5",]
#s3vol02794:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02794/affs_03_160k.h5",]
#s3vol03290:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03290/affs_03_160k.h5",]
#s3vol03700:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_03700/affs_03_160k.h5",]
#s3vol04900:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04900/affs_160k.h5",]
#s3vol05250:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05250/affs_03_160k.h5",]
#s3vol05450:
# images: ["~/dropbox/40_gt/13_wasp_sample3/vol_05450/affs_160k.h5",]
validation:
s3vol02400:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_02400/affs_03_160k.h5",]
s3vol04000:
images: ["~/dropbox/40_gt/13_wasp_sample3/vol_04000/affs_03_160k.h5",]
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_02299/affs_03_160k.h5",]
s3vol02794:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_02794/affs_03_160k.h5",]
s3vol03290:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_03290/affs_03_160k.h5",]
s3vol03700:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_03700/affs_03_160k.h5",]
s3vol04900:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_04900/affs_160k.h5",]
s3vol05250:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_05250/affs_03_160k.h5",]
s3vol05450:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_05450/affs_160k.h5",]

validation:
s3test1:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/58_broken_membrane/31_test_3072-3584_5120-5632_8196-8708/aff_zyx_3072-3584_5120-5632_8196-8708.h5",]
s3test2:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/58_broken_membrane/32_test_5120-5632_5632-6144_10240-10752/aff_zyx_5120-5632_5632-6144_10240-10752.h5",]
s3test3:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/58_broken_membrane/33_test_2560-3072_5632-6144_8704-9216/aff_zyx_2560-3072_5632-6144_8704-9216.h5",]
s3test4:
images: ["/mnt/ceph/users/neuro/wasp_em/jwu/58_broken_membrane/41_test_2560-3584_5120-6144_8192-9216/aff_zyx_2560-3584_5120-6144_8192-9216.h5",]

model:
in_channels: 3
out_channels: 3
Expand Down
27 changes: 23 additions & 4 deletions neutorch/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from chunkflow.lib.cartesian_coordinate import Cartesian
from yacs.config import CfgNode

from neutorch.data.sample import SemanticSample, SelfSupervisedSample, AffinityMapSample
from neutorch.data.sample import *
from neutorch.data.transform import *

DEFAULT_PATCH_SIZE = Cartesian(128, 128, 128)
Expand Down Expand Up @@ -101,9 +101,10 @@ def random_patch(self):
return patch.image, patch.label

def __next__(self):
image, label = self.random_patch
image = to_tensor(image)
label = to_tensor(label)
image_chunk, label_chunk = self.random_patch
image = to_tensor(image_chunk.array)
label = to_tensor(label_chunk.array)

return image, label

def __iter__(self):
Expand Down Expand Up @@ -224,6 +225,24 @@ def __next__(self):

return image, target

class AffinityMapVolumeWithMask(DatasetBase):
def __init__(self, samples: list):
super().__init__(samples)

@classmethod
def from_config(cls, cfg: CfgNode, **kwargs):
output_patch_size = Cartesian.from_collection(
cfg.train.patch_size)

samples = []
for sample_name in cfg.samples:
sample_cfg = cfg.samples[sample_name]
sample_class = eval(sample_cfg.type)
sample = sample_class.from_config(
sample_cfg, output_patch_size)
samples.append(sample)
return cls(samples)

class AffinityMapDataset(DatasetBase):
def __init__(self, samples: list):
super().__init__(samples)
Expand Down
58 changes: 28 additions & 30 deletions neutorch/data/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,36 @@

# torch.multiprocessing.set_start_method('spawn')

# from chunkflow.lib.cartesian_coordinate import Cartesian
from chunkflow.lib.cartesian_coordinate import Cartesian
from chunkflow.chunk import Chunk


class Patch(object):
def __init__(self, image: np.ndarray, label: np.ndarray):
def __init__(self, image: Chunk, label: Chunk,
mask: Chunk = None):
"""A patch of volume containing both image and label
Args:
image (np.ndarray): image
label (np.ndarray): label
image (Chunk): image
label (Chunk): label
"""
assert image.shape == label.shape

image = self._expand_to_5d(image)
label = self._expand_to_5d(label)

assert image.voxel_offset == label.voxel_offset
if mask is not None:
mask.shape == label.shape
assert mask.ndim == 3
assert mask.voxel_offset == image.voxel_offset

image.array = self._expand_to_5d(image.array)
label.array = self._expand_to_5d(label.array)

self.image = image
self.label = label
self.mask = mask

@cached_property
def has_mask(self):
return self.mask is not None

def _expand_to_5d(self, arr: np.ndarray):
if arr.ndim == 4:
Expand All @@ -37,29 +49,18 @@ def _expand_to_5d(self, arr: np.ndarray):
return arr

def shrink(self, size: tuple):
assert len(size) == 6
_, _, z, y, x = self.shape
self.image = self.image[
...,
size[0]:z-size[3],
size[1]:y-size[4],
size[2]:x-size[5],
]
self.label = self.label[
...,
size[0]:z-size[3],
size[1]:y-size[4],
size[2]:x-size[5],
]


self.image.shrink(size)
self.label.shrink(size)
if self.has_mask:
self.mask.shrink(size)

@property
def shape(self):
return self.image.shape

@cached_property
def center(self):
return tuple(ps // 2 for ps in self.shape[-3:])
return Cartesian.from_collection(self.shape[-3:]) // 2

def normalize(self):
def _normalize(arr):
Expand All @@ -70,14 +71,11 @@ def _normalize(arr):
arr = arr.type(torch.float32)
arr /= 255.
return arr
self.image = _normalize(self.image)
self.label = _normalize(self.label)
self.image.array = _normalize(self.image.array)
self.label.array = _normalize(self.label.array)

def collate_batch(batch):

patch_list = []

for patch in batch:
patch_list.append(patch)

return patch
Loading

0 comments on commit 442dc65

Please sign in to comment.