Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PHATE #210

Merged
merged 45 commits into from
Dec 21, 2024
Merged

PHATE #210

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
db80819
translation: fix validation loss aggregation (#202)
ziw-liu Nov 8, 2024
820c805
exposing prefetch and persistent worker (#203)
edyoshikun Nov 13, 2024
a73b9a0
metrics for dynamic, smoothness and docstrings
alishbaimran Nov 19, 2024
79db1e4
updated metrics and plots for distance
alishbaimran Nov 20, 2024
42cf756
fixed CI test cases
alishbaimran Nov 20, 2024
0a23fc4
nexnt loss prototype
edyoshikun Nov 13, 2024
f2d75cd
fix bug with z_scale_range in hcs datamodule. If the value is an int …
edyoshikun Nov 13, 2024
aa74eff
exclude the negative pair from dataloader and forward pass
edyoshikun Nov 13, 2024
771d33d
adding option using pytorch-metric-learning implementation
edyoshikun Nov 14, 2024
46974bc
removing our implementation of NTXentLoss and using pytorch metric
edyoshikun Nov 14, 2024
1fde4ab
ruff
edyoshikun Nov 14, 2024
3f8363d
prototype for phate and umap plot
edyoshikun Nov 20, 2024
26ebe74
- proofreading the calculations
edyoshikun Nov 20, 2024
7a3b036
methods to rank nearest neighbors in embeddings
ziw-liu Nov 21, 2024
3992d9d
example script to plot state change of a single track
ziw-liu Nov 21, 2024
cae8290
test using scaled features
ziw-liu Nov 21, 2024
30f4246
phate embeddings
edyoshikun Nov 21, 2024
356c15d
removing dataframe from the compute_phate
edyoshikun Nov 21, 2024
57abb19
adding phate to the prediction writer and moving it as dependency.
edyoshikun Nov 21, 2024
19fdcf2
changing the phate defaults in the prediction writer.
edyoshikun Nov 21, 2024
1ee8af1
ruff
edyoshikun Nov 21, 2024
14b7368
fixing bug in phate in predict writer
edyoshikun Nov 21, 2024
db653a5
adding code for measuring the smoothness
edyoshikun Nov 22, 2024
cf6dbe7
cleanup to run on triplet and ntxent
edyoshikun Nov 22, 2024
a98e882
fix plots for smoothnes
edyoshikun Nov 22, 2024
b75ddd1
nexnt loss prototype
edyoshikun Nov 13, 2024
98b36de
exclude the negative pair from dataloader and forward pass
edyoshikun Nov 13, 2024
5bc2ff6
adding option using pytorch-metric-learning implementation
edyoshikun Nov 14, 2024
472be4e
removing our implementation of NTXentLoss and using pytorch metric
edyoshikun Nov 14, 2024
66b3899
ruff
edyoshikun Nov 14, 2024
9cc35cf
remove blank line diff
ziw-liu Dec 2, 2024
3c96f13
remove blank line diff
ziw-liu Dec 2, 2024
d3d4715
simplying the engine
edyoshikun Dec 10, 2024
70dfb5d
explicit target shape argument in the HCS data module
ziw-liu Nov 21, 2024
a845847
Revert "explicit target shape argument in the HCS data module"
ziw-liu Nov 21, 2024
a6e2318
Explicit target shape argument in the HCS data module (#212)
ziw-liu Nov 27, 2024
fa74b0d
Gradio example (#158)
edyoshikun Dec 2, 2024
076bb5c
Merge branch 'ntxent_loss' into phate
edyoshikun Dec 19, 2024
ea79aad
adding configurable phate arguments via config
edyoshikun Dec 19, 2024
47d84b7
script to recompute phate and overwrite the previous phate data
edyoshikun Dec 19, 2024
2546cc9
ruff
edyoshikun Dec 19, 2024
f40f316
solving redundancies
edyoshikun Dec 20, 2024
745911a
modularizing the smoothness
edyoshikun Dec 21, 2024
965abbf
removing redundant _fit_phate()
edyoshikun Dec 21, 2024
3f01a96
ruff
edyoshikun Dec 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 312 additions & 0 deletions applications/contrastive_phenotyping/evaluation/ALFI_displacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
# %%
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from viscy.representation.embedding_writer import read_embedding_dataset
from viscy.representation.evaluation.distance import (
calculate_normalized_euclidean_distance_cell,
compute_displacement,
compute_dynamic_range,
compute_rms_per_track,
)
from collections import defaultdict
from tabulate import tabulate

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from collections import OrderedDict

# %% function

# Removed redundant compute_displacement_mean_std_full function
# Removed redundant compute_dynamic_range and compute_rms_per_track functions


def plot_rms_histogram(rms_values, label, bins=30):
"""
Plot histogram of RMS values across tracks.

Parameters:
rms_values : list
List of RMS values, one for each track.
label : str
Label for the dataset (used in the title).
bins : int, optional
Number of bins for the histogram. Default is 30.

Returns:
None: Displays the histogram.
"""
plt.figure(figsize=(10, 6))
plt.hist(rms_values, bins=bins, alpha=0.7, color="blue", edgecolor="black")
plt.title(f"Histogram of RMS Values Across Tracks ({label})", fontsize=16)
plt.xlabel("RMS of Time Derivative", fontsize=14)
plt.ylabel("Frequency", fontsize=14)
plt.grid(True)
plt.show()


def plot_displacement(
mean_displacement, std_displacement, label, metrics_no_track=None
):
"""
Plot embedding displacement over time with mean and standard deviation.

Parameters:
mean_displacement : dict
Mean displacement for each tau.
std_displacement : dict
Standard deviation of displacement for each tau.
label : str
Label for the dataset.
metrics_no_track : dict, optional
Metrics for the "Classical Contrastive (No Tracking)" dataset to compare against.

Returns:
None: Displays the plot.
"""
plt.figure(figsize=(10, 6))
taus = list(mean_displacement.keys())
mean_values = list(mean_displacement.values())
std_values = list(std_displacement.values())

plt.plot(taus, mean_values, marker="o", label=f"{label}", color="green")
plt.fill_between(
taus,
np.array(mean_values) - np.array(std_values),
np.array(mean_values) + np.array(std_values),
color="green",
alpha=0.3,
label=f"Std Dev ({label})",
)

if metrics_no_track:
mean_values_no_track = list(metrics_no_track["mean_displacement"].values())
std_values_no_track = list(metrics_no_track["std_displacement"].values())

plt.plot(
taus,
mean_values_no_track,
marker="o",
label="Classical Contrastive (No Tracking)",
color="blue",
)
plt.fill_between(
taus,
np.array(mean_values_no_track) - np.array(std_values_no_track),
np.array(mean_values_no_track) + np.array(std_values_no_track),
color="blue",
alpha=0.3,
label="Std Dev (No Tracking)",
)

plt.xlabel("Time Shift (τ)", fontsize=14)
plt.ylabel("Euclidean Distance", fontsize=14)
plt.title(f"Embedding Displacement Over Time ({label})", fontsize=16)
plt.grid(True)
plt.legend(fontsize=12)
plt.show()


def plot_overlay_displacement(overlay_displacement_data):
"""
Plot embedding displacement over time for all datasets in one plot.

Parameters:
overlay_displacement_data : dict
A dictionary containing mean displacement per tau for all datasets.

Returns:
None: Displays the plot.
"""
plt.figure(figsize=(12, 8))
for label, mean_displacement in overlay_displacement_data.items():
taus = list(mean_displacement.keys())
mean_values = list(mean_displacement.values())
plt.plot(taus, mean_values, marker="o", label=label)

plt.xlabel("Time Shift (τ)", fontsize=14)
plt.ylabel("Euclidean Distance", fontsize=14)
plt.title("Overlayed Embedding Displacement Over Time", fontsize=16)
plt.grid(True)
plt.legend(fontsize=12)
plt.show()


# %% hist stats
def plot_boxplot_rms_across_models(datasets_rms):
"""
Plot a boxplot for the distribution of RMS values across models.

Parameters:
datasets_rms : dict
A dictionary where keys are dataset names and values are lists of RMS values.

Returns:
None: Displays the boxplot.
"""
plt.figure(figsize=(12, 6))
labels = list(datasets_rms.keys())
data = list(datasets_rms.values())
print(labels)
print(data)
# Plot the boxplot
plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True)

plt.title(
"Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16
)
plt.ylabel("RMS of Time Derivative", fontsize=14)
plt.xticks(rotation=45, fontsize=12)
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()


def plot_histogram_absolute_differences(datasets_abs_diff):
"""
Plot histograms of absolute differences across embeddings for all models.

Parameters:
datasets_abs_diff : dict
A dictionary where keys are dataset names and values are lists of absolute differences.

Returns:
None: Displays the histograms.
"""
plt.figure(figsize=(12, 6))
for label, abs_diff in datasets_abs_diff.items():
plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True)

plt.title("Histograms of Absolute Differences Across Models", fontsize=16)
plt.xlabel("Absolute Difference", fontsize=14)
plt.ylabel("Density", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.7)
plt.tight_layout()
plt.show()


# %% Paths to datasets
feature_paths = {
"7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr",
"21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr",
"28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr",
"56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr",
"Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr",
}

no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr"

# %% Process Datasets
max_tau = 69
metrics = {}

overlay_displacement_data = {}
datasets_rms = {}
datasets_abs_diff = {}

# Process "No Tracking" dataset
features_path_no_track = Path(no_track_path)
embedding_dataset_no_track = read_embedding_dataset(features_path_no_track)

mean_displacement_no_track, std_displacement_no_track = compute_displacement(
embedding_dataset_no_track, max_tau=max_tau, return_mean_std=True
)
dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track)
metrics["No Tracking"] = {
"dynamic_range": dynamic_range_no_track,
"mean_displacement": mean_displacement_no_track,
"std_displacement": std_displacement_no_track,
}

overlay_displacement_data["No Tracking"] = mean_displacement_no_track

print("\nProcessing No Tracking dataset...")
print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}")

plot_displacement(mean_displacement_no_track, std_displacement_no_track, "No Tracking")

rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track)
datasets_rms["No Tracking"] = rms_values_no_track

print(f"Plotting histogram of RMS values for No Tracking dataset...")
plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30)

# Compute absolute differences for "No Tracking"
abs_diff_no_track = np.concatenate(
[
np.linalg.norm(
np.diff(embedding_dataset_no_track["features"].values[indices], axis=0),
axis=-1,
)
for indices in np.split(
np.arange(len(embedding_dataset_no_track["track_id"])),
np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1,
)
]
)
datasets_abs_diff["No Tracking"] = abs_diff_no_track

# Process other datasets
for label, path in feature_paths.items():
print(f"\nProcessing {label} dataset...")

features_path = Path(path)
embedding_dataset = read_embedding_dataset(features_path)

mean_displacement, std_displacement = compute_displacement(
embedding_dataset, max_tau=max_tau, return_mean_std=True
)
dynamic_range = compute_dynamic_range(mean_displacement)
metrics[label] = {
"dynamic_range": dynamic_range,
"mean_displacement": mean_displacement,
"std_displacement": std_displacement,
}

overlay_displacement_data[label] = mean_displacement

print(f"Dynamic Range for {label}: {dynamic_range}")

plot_displacement(
mean_displacement,
std_displacement,
label,
metrics_no_track=metrics.get("No Tracking", None),
)

rms_values = compute_rms_per_track(embedding_dataset)
datasets_rms[label] = rms_values

print(f"Plotting histogram of RMS values for {label}...")
plot_rms_histogram(rms_values, label, bins=30)

abs_diff = np.concatenate(
[
np.linalg.norm(
np.diff(embedding_dataset["features"].values[indices], axis=0), axis=-1
)
for indices in np.split(
np.arange(len(embedding_dataset["track_id"])),
np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1,
)
]
)
datasets_abs_diff[label] = abs_diff

print("\nPlotting overlayed displacement for all datasets...")
plot_overlay_displacement(overlay_displacement_data)

print("\nSummary of Dynamic Ranges:")
for label, metric in metrics.items():
print(f"{label}: Dynamic Range = {metric['dynamic_range']}")

print("\nPlotting RMS boxplot across models...")
plot_boxplot_rms_across_models(datasets_rms)

print("\nPlotting histograms of absolute differences across models...")
plot_histogram_absolute_differences(datasets_abs_diff)


# %%
Loading
Loading