Skip to content

Commit

Permalink
add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Dec 2, 2024
1 parent 96313fa commit 6e1818b
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions viscy/data/gpu_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@


class GPUTransformDataModule(ABC, LightningDataModule):
"""Abstract data module with GPU transforms."""

train_dataset: Dataset
val_dataset: Dataset
batch_size: int
Expand Down Expand Up @@ -92,6 +94,26 @@ def val_gpu_transforms(self) -> Compose: ...


class CachedOmeZarrDataset(Dataset):
"""Dataset for cached OME-Zarr arrays.
Parameters
----------
positions : list[Position]
List of FOVs to load images from.
channel_names : list[str]
List of channel names to load.
cache_map : DictProxy
Shared dictionary for caching loaded volumes.
transform : Compose | None, optional
Composed transforms to be applied on the CPU, by default None
array_key : str, optional
The image array key name (multi-scale level), by default "0"
load_normalization_metadata : bool, optional
Load normalization metadata in the sample dictionary, by default True
skip_cache : bool, optional
Skip caching to save RAM, by default False
"""

def __init__(
self,
positions: list[Position],
Expand Down Expand Up @@ -148,6 +170,35 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:


class CachedOmeZarrDataModule(GPUTransformDataModule):
"""Data module for cached OME-Zarr arrays.
Parameters
----------
data_path : Path
Path to the HCS OME-Zarr dataset.
channels : str | list[str]
Channel names to load.
batch_size : int
Batch size for training and validation.
num_workers : int
Number of workers for data-loaders.
split_ratio : float
Fraction of the FOVs used for the training split.
The rest will be used for validation.
train_cpu_transforms : list[DictTransform]
Transforms to be applied on the CPU during training.
val_cpu_transforms : list[DictTransform]
Transforms to be applied on the CPU during validation.
train_gpu_transforms : list[DictTransform]
Transforms to be applied on the GPU during training.
val_gpu_transforms : list[DictTransform]
Transforms to be applied on the GPU during validation.
pin_memory : bool, optional
Use page-locked memory in data-loaders, by default True
skip_cache : bool, optional
Skip caching for this dataset, by default False
"""

def __init__(
self,
data_path: Path,
Expand Down

0 comments on commit 6e1818b

Please sign in to comment.