Skip to content

Commit

Permalink
add type hints throughout models
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Dec 21, 2024
1 parent 5a7f0e1 commit c53893c
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 133 deletions.
24 changes: 12 additions & 12 deletions waveorder/models/inplane_oriented_thick_pol3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from waveorder import correction, stokes, util


def generate_test_phantom(yx_shape):
def generate_test_phantom(yx_shape: Tuple[int, int]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
retardance = 0.25 * star
orientation = (theta % np.pi) * (star > 1e-3)
Expand All @@ -17,26 +17,26 @@ def generate_test_phantom(yx_shape):


def calculate_transfer_function(
swing,
scheme,
):
swing: float,
scheme: str,
) -> Tensor:
return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)


def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
def visualize_transfer_function(viewer, intensity_to_stokes_matrix: Tensor) -> None:
viewer.add_image(
intensity_to_stokes_matrix.cpu().numpy(),
name="Intensity to stokes matrix",
)


def apply_transfer_function(
retardance,
orientation,
transmittance,
depolarization,
intensity_to_stokes_matrix,
):
retardance: Tensor,
orientation: Tensor,
transmittance: Tensor,
depolarization: Tensor,
intensity_to_stokes_matrix: Tensor,
) -> Tensor:
stokes_params = stokes.stokes_after_adr(
retardance, orientation, transmittance, depolarization
)
Expand All @@ -59,7 +59,7 @@ def apply_inverse_transfer_function(
project_stokes_to_2d: bool = False,
flip_orientation: bool = False,
rotate_orientation: bool = False,
) -> Tuple[Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Reconstructs retardance, orientation, transmittance, and depolarization
from czyx_data and an intensity_to_stokes_matrix, providing options for
background correction, projection, and orientation transformations.
Expand Down
57 changes: 33 additions & 24 deletions waveorder/models/inplane_oriented_thick_pol3d_vector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import torch
import tqdm
import numpy as np

from torch import Tensor
from typing import Literal
from typing import Literal, TYPE_CHECKING
from torch.nn.functional import avg_pool3d, interpolate
from waveorder import optics, sampling, stokes, util
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer

if TYPE_CHECKING:
import napari

def generate_test_phantom(zyx_shape):

def generate_test_phantom(zyx_shape: tuple[int, int, int]) -> torch.Tensor:
# Simulate
yx_star, yx_theta, _ = util.generate_star_target(
yx_shape=zyx_shape[1:],
Expand All @@ -29,20 +30,22 @@ def generate_test_phantom(zyx_shape):


def calculate_transfer_function(
swing,
scheme,
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_illumination,
z_padding,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
fourier_oversample_factor=1,
transverse_downsample_factor=1,
):
swing: float,
scheme: str,
zyx_shape: tuple[int, int, int],
yx_pixel_size: float,
z_pixel_size: float,
wavelength_illumination: float,
z_padding: int,
index_of_refraction_media: float,
numerical_aperture_illumination: float,
numerical_aperture_detection: float,
invert_phase_contrast: bool = False,
fourier_oversample_factor: int = 1,
transverse_downsample_factor: int = 1,
) -> tuple[
torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]:
if z_padding != 0:
raise NotImplementedError("Padding not implemented for this model")

Expand Down Expand Up @@ -126,11 +129,13 @@ def calculate_transfer_function(
U, S, Vh = calculate_singular_system(cropped)

# Interpolate to final size in YX
def complex_interpolate(tensor, zyx_shape):
def complex_interpolate(
tensor: torch.Tensor, zyx_shape: tuple[int, int, int]
) -> torch.Tensor:
interpolated_real = interpolate(tensor.real, size=zyx_shape)
interpolated_imag = interpolate(tensor.imag, size=zyx_shape)
return interpolated_real + 1j * interpolated_imag

full_cropped = complex_interpolate(cropped, zyx_shape)
full_U = complex_interpolate(U, zyx_shape)
full_S = interpolate(S[None], size=zyx_shape)[0] # S is real
Expand Down Expand Up @@ -292,7 +297,11 @@ def calculate_singular_system(sfZYX_transfer_function):
return singular_system


def visualize_transfer_function(viewer, sfZYX_transfer_function, zyx_scale):
def visualize_transfer_function(
viewer: napari.Viewer,
sfZYX_transfer_function: torch.Tensor,
zyx_scale: tuple[float, float, float],
) -> None:
add_transfer_function_to_viewer(
viewer,
sfZYX_transfer_function,
Expand All @@ -304,10 +313,10 @@ def visualize_transfer_function(viewer, sfZYX_transfer_function, zyx_scale):


def apply_transfer_function(
fzyx_object,
sfZYX_transfer_function,
intensity_to_stokes_matrix, # TODO use this to simulate intensities
):
fzyx_object: torch.Tensor,
sfZYX_transfer_function: torch.Tensor,
intensity_to_stokes_matrix: torch.Tensor, # TODO use this to simulate intensities
) -> torch.Tensor:
fZYX_object = torch.fft.fftn(fzyx_object, dim=(1, 2, 3))
sZYX_data = torch.einsum(
"fzyx,sfzyx->szyx", fZYX_object, sfZYX_transfer_function
Expand Down
50 changes: 25 additions & 25 deletions waveorder/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@


def generate_test_phantom(
zyx_shape,
yx_pixel_size,
z_pixel_size,
sphere_radius,
):
zyx_shape: tuple[int, int, int],
yx_pixel_size: float,
z_pixel_size: float,
sphere_radius: float,
) -> Tensor:
sphere, _, _ = util.generate_sphere_target(
zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
)
Expand All @@ -22,14 +22,14 @@ def generate_test_phantom(


def calculate_transfer_function(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_emission,
z_padding,
index_of_refraction_media,
numerical_aperture_detection,
):
zyx_shape: tuple[int, int, int],
yx_pixel_size: float,
z_pixel_size: float,
wavelength_emission: float,
z_padding: int,
index_of_refraction_media: float,
numerical_aperture_detection: float,
) -> Tensor:

transverse_nyquist = sampling.transverse_nyquist(
wavelength_emission,
Expand Down Expand Up @@ -65,14 +65,14 @@ def calculate_transfer_function(


def _calculate_wrap_unsafe_transfer_function(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_emission,
z_padding,
index_of_refraction_media,
numerical_aperture_detection,
):
zyx_shape: tuple[int, int, int],
yx_pixel_size: float,
z_pixel_size: float,
wavelength_emission: float,
z_padding: int,
index_of_refraction_media: float,
numerical_aperture_detection: float,
) -> Tensor:
radial_frequencies = util.generate_radial_frequencies(
zyx_shape[1:], yx_pixel_size
)
Expand Down Expand Up @@ -108,7 +108,7 @@ def _calculate_wrap_unsafe_transfer_function(
return optical_transfer_function


def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
def visualize_transfer_function(viewer, optical_transfer_function: Tensor, zyx_scale: tuple[float, float, float]) -> None:
add_transfer_function_to_viewer(
viewer,
torch.real(optical_transfer_function),
Expand All @@ -118,8 +118,8 @@ def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):


def apply_transfer_function(
zyx_object, optical_transfer_function, z_padding, background=10
):
zyx_object: Tensor, optical_transfer_function: Tensor, z_padding: int, background: int = 10
) -> Tensor:
"""Simulate imaging by applying a transfer function
Parameters
Expand Down Expand Up @@ -164,7 +164,7 @@ def apply_inverse_transfer_function(
regularization_strength: float = 1e-3,
TV_rho_strength: float = 1e-3,
TV_iterations: int = 10,
):
) -> Tensor:
"""Reconstructs fluorescence density from zyx_data and
an optical_transfer_function, providing options for z padding and
reconstruction algorithms.
Expand Down
75 changes: 37 additions & 38 deletions waveorder/models/isotropic_thin_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

from waveorder import optics, sampling, util


def generate_test_phantom(
yx_shape,
yx_pixel_size,
wavelength_illumination,
index_of_refraction_media,
index_of_refraction_sample,
sphere_radius,
):
yx_shape: Tuple[int, int],
yx_pixel_size: float,
wavelength_illumination: float,
index_of_refraction_media: float,
index_of_refraction_sample: float,
sphere_radius: float,
) -> Tuple[Tensor, Tensor]:
sphere, _, _ = util.generate_sphere_target(
(3,) + yx_shape,
yx_pixel_size,
Expand All @@ -35,15 +34,15 @@ def generate_test_phantom(


def calculate_transfer_function(
yx_shape,
yx_pixel_size,
z_position_list,
wavelength_illumination,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
yx_shape: Tuple[int, int],
yx_pixel_size: float,
z_position_list: list,
wavelength_illumination: float,
index_of_refraction_media: float,
numerical_aperture_illumination: float,
numerical_aperture_detection: float,
invert_phase_contrast: bool = False,
) -> Tuple[Tensor, Tensor]:
transverse_nyquist = sampling.transverse_nyquist(
wavelength_illumination,
numerical_aperture_illumination,
Expand Down Expand Up @@ -93,15 +92,15 @@ def calculate_transfer_function(


def _calculate_wrap_unsafe_transfer_function(
yx_shape,
yx_pixel_size,
z_position_list,
wavelength_illumination,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
yx_shape: Tuple[int, int],
yx_pixel_size: float,
z_position_list: list,
wavelength_illumination: float,
index_of_refraction_media: float,
numerical_aperture_illumination: float,
numerical_aperture_detection: float,
invert_phase_contrast: bool = False,
) -> Tuple[Tensor, Tensor]:
if invert_phase_contrast:
z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,))

Expand Down Expand Up @@ -149,9 +148,9 @@ def _calculate_wrap_unsafe_transfer_function(

def visualize_transfer_function(
viewer,
absorption_2d_to_3d_transfer_function,
phase_2d_to_3d_transfer_function,
):
absorption_2d_to_3d_transfer_function: Tensor,
phase_2d_to_3d_transfer_function: Tensor,
) -> None:
"""Note: unlike other `visualize_transfer_function` calls, this transfer
function is a mixed 3D-to-2D transfer function, so it cannot reuse
util.add_transfer_function_to_viewer. If more 3D-to-2D transfer functions
Expand All @@ -178,9 +177,9 @@ def visualize_transfer_function(

def visualize_point_spread_function(
viewer,
absorption_2d_to_3d_transfer_function,
phase_2d_to_3d_transfer_function,
):
absorption_2d_to_3d_transfer_function: Tensor,
phase_2d_to_3d_transfer_function: Tensor,
) -> None:
arrays = [
(torch.fft.ifftn(absorption_2d_to_3d_transfer_function), "absorb PSF"),
(torch.fft.ifftn(phase_2d_to_3d_transfer_function), "phase PSF"),
Expand All @@ -199,11 +198,11 @@ def visualize_point_spread_function(


def apply_transfer_function(
yx_absorption,
yx_phase,
phase_2d_to_3d_transfer_function,
absorption_2d_to_3d_transfer_function,
):
yx_absorption: Tensor,
yx_phase: Tensor,
phase_2d_to_3d_transfer_function: Tensor,
absorption_2d_to_3d_transfer_function: Tensor,
) -> Tensor:
# Very simple simulation, consider adding noise and bkg knobs

# simulate absorbing object
Expand Down Expand Up @@ -240,7 +239,7 @@ def apply_inverse_transfer_function(
TV_rho_strength: float = 1e-3,
TV_iterations: int = 10,
bg_filter: bool = True,
) -> Tuple[Tensor]:
) -> Tuple[Tensor, Tensor]:
"""Reconstructs absorption and phase from zyx_data and a pair of
3D-to-2D transfer functions named absorption_2d_to_3d_transfer_function and
phase_2d_to_3d_transfer_function, providing options for reconstruction
Expand Down
Loading

0 comments on commit c53893c

Please sign in to comment.