Skip to content

Commit

Permalink
matplotlib visuals docs and type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Dec 21, 2024
1 parent 9cdc21d commit 5a7f0e1
Showing 1 changed file with 92 additions and 31 deletions.
123 changes: 92 additions & 31 deletions waveorder/visuals/matplotlib_visuals.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,70 @@
import matplotlib.pyplot as plt

from waveorder.visuals.utils import complex_tensor_to_rgb
from waveorder.sampling import nd_fourier_central_cuboid

import numpy as np
import os
import torch


def plot_5d_ortho(
rcCzyx_data,
filename,
voxel_size,
zyx_slice,
color_funcs,
row_labels=None,
column_labels=None,
rose_path=None,
inches_per_column=1.5,
label_size=1,
ortho_line_width=0.5,
row_column_line_width=0.5,
xyz_labels=True,
background_color="white",
**kwargs,
):
rcCzyx_data: np.ndarray,
filename: str,
voxel_size: tuple[float, float, float],
zyx_slice: tuple[int, int, int],
color_funcs: list[list[callable]],
row_labels: list[str] = None,
column_labels: list[str] = None,
rose_path: str = None,
inches_per_column: float = 1.5,
label_size: int = 1,
ortho_line_width: float = 0.5,
row_column_line_width: float = 0.5,
xyz_labels: bool = True,
background_color: str = "white",
**kwargs: dict,
) -> None:
"""
Plot 5D multi-channel data in a grid or ortho-slice views.
Input data is a 6D array with (row, column, channels, Z, Y, X) dimensions.
`color_funcs` permits different RGB color maps for each row and column.
Parameters
----------
rcCzyx_data : numpy.ndarray
5D array with shape (R, C, Ch, Z, Y, X) containing the data to plot.
[r]ows and [c]olumns form a grid
[C]hannels contain multiple color channels
[ZYX] contain 3D volumes.
filename : str
Path to save the output plot.
voxel_size : tuple[float, float, float]
Size of each voxel in (Z, Y, X) dimensions.
zyx_slice : tuple[int, int, int]
Indices of the ortho-slices to plot in (Z, Y, X) indices.
color_funcs : list[list[callable]]
A list of lists of callables, one for each element of the plot grid,
with len(color_funcs) == R and len(colors_funcs[0] == C).
Each callable accepts [C]hannel arguments and returns RGB color values,
enabling different RGB color maps for each member of the grid.
row_labels : list[str], optional
Labels for the rows, by default None.
column_labels : list[str], optional
Labels for the columns, by default None.
rose_path : str, optional
Path to an image to display in the top-left corner, by default None.
inches_per_column : float, optional
Width of each column in inches, by default 1.5.
label_size : int, optional
Size of the labels, by default 1.
ortho_line_width : float, optional
Width of the orthogonal lines, by default 0.5.
row_column_line_width : float, optional
Width of the lines between rows and columns, by default 0.5.
xyz_labels : bool, optional
Whether to display XYZ labels, by default True.
background_color : str, optional
Background color of the plot, by default "white".
**kwargs : dict
Additional keyword arguments passed to color_funcs.
"""
R, C, Ch, Z, Y, X = rcCzyx_data.shape

# Extent
Expand Down Expand Up @@ -199,16 +239,37 @@ def plot_5d_ortho(


def add_ortho_lines_to_axis(
axis,
yx_slice,
axis_labels,
outer_color,
vertical_color,
horizontal_color,
line_width=0,
text_color="white",
):
axis: plt.Axes,
yx_slice: tuple[int, int],
axis_labels: tuple[str, str],
outer_color: str,
vertical_color: str,
horizontal_color: str,
line_width: float = 0,
text_color: str = "white",
) -> None:
"""
Add orthogonal lines and labels to a given axis.
Parameters
----------
axis : matplotlib.axes.Axes
The axis to which the orthogonal lines and labels will be added.
yx_slice : tuple[int, int]
The (Y, X) slice indices for the orthogonal lines.
axis_labels : tuple[str, str]
The labels for the Y and X axes.
outer_color : str
The color of the outer rectangle.
vertical_color : str
The color of the vertical line.
horizontal_color : str
The color of the horizontal line.
line_width : float, optional
The width of the lines, by default 0.
text_color : str, optional
The color of the text labels, by default "white".
"""
xmin, xmax = axis.get_xlim()
ymin, ymax = axis.get_ylim()

Expand Down

0 comments on commit 5a7f0e1

Please sign in to comment.