Skip to content

Commit

Permalink
fix plots for smoothnes
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Nov 22, 2024
1 parent cf6dbe7 commit a98e882
Showing 1 changed file with 62 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tqdm import tqdm
import pandas as pd


plt.style.use("../evaluation/figure.mplstyle")

# plotting
Expand All @@ -36,10 +37,9 @@ def compute_piece_wise_dissimilarity(
piece_wise_rank_difference_per_track = []
for name, subdata in features_df.groupby(["fov_name", "track_id"]):
if len(subdata) > 1:
single_track_dissimilarity = select_block(cross_dist, subdata.index.values)
single_track_rank_fraction = select_block(
rank_fractions, subdata.index.values
)
indices = subdata.index.values
single_track_dissimilarity = select_block(cross_dist, indices)
single_track_rank_fraction = select_block(rank_fractions, indices)
piece_wise_dissimilarity = compare_time_offset(
single_track_dissimilarity, time_offset=1
)
Expand All @@ -64,20 +64,25 @@ def plot_histogram(


# %%
PATH_TO_GDRIVE_FIGUE = "/home/eduardo.hirata/mydata/gdrive/publications/learning_impacts_of_infection/fig_manuscript/rev2_ICLR_fig/"

prediction_path_1 = Path(
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_1.zarr"
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr"
)
prediction_path_2 = Path(
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_1.zarr"
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr"
)
for prediction_path in tqdm([prediction_path_1, prediction_path_2]):

for prediction_path, loss_name in tqdm(
[(prediction_path_1, "ntxent"), (prediction_path_2, "triplet")]
):

# Read the dataset
embeddings = read_embedding_dataset(prediction_path)
features = embeddings["features"]

scaled_features = StandardScaler().fit_transform(features.values)
# COmpute the cosine dissimilarity
# Compute the cosine dissimilarity
cross_dist = cross_dissimilarity(scaled_features, metric="cosine")
rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True)

Expand All @@ -91,43 +96,58 @@ def plot_histogram(
compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions)
)

# Get the median/mode of the off diagonal elements
median_piece_wise_dissimilarity = [
np.median(track) for track in piece_wise_dissimilarity_per_track
]
p99_piece_wise_dissimilarity = [
np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track
]
p1_percentile_piece_wise_dissimilarity = [
np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track
]
all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track)

# # Get the median/mode of the off diagonal elements
# median_piece_wise_dissimilarity = np.array(
# [np.median(track) for track in piece_wise_dissimilarity_per_track]
# )
p99_piece_wise_dissimilarity = np.array(
[np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track]
)
p1_percentile_piece_wise_dissimilarity = np.array(
[np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track]
)

# Random sampling values in the dissimilarity matrix
n_samples = 2000
sampled_values = [
cross_dist[
np.random.randint(0, len(cross_dist)), np.random.randint(0, len(cross_dist))
]
for _ in range(n_samples)
]
n_samples = 3000
random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2))
sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]]

print(f"Dissimilarity Statistics for {prediction_path.stem}")
print(f"Mean: {np.mean(all_dissimilarity)}")
print(f"Std: {np.std(all_dissimilarity)}")
print(f"Median: {np.median(all_dissimilarity)}")

print(f"Distance Statistics for random sampling")
print(f"Mean: {np.mean(sampled_values)}")
print(f"Std: {np.std(sampled_values)}")
print(f"Median: {np.median(sampled_values)}")

if VERBOSE:
# Plot histograms
# plot_histogram(
# median_piece_wise_dissimilarity,
# "Adjacent Frame Median Dissimilarity per Track",
# "Cosine Dissimilarity",
# "Frequency",
# )
# plot_histogram(
# p1_percentile_piece_wise_dissimilarity,
# "Adjacent Frame 1 Percentile Dissimilarity per Track",
# "Cosine Dissimilarity",
# "Frequency",
# )
# plot_histogram(
# p99_piece_wise_dissimilarity,
# "Adjacent Frame 99 Percentile Dissimilarity per Track",
# "Cosine Dissimilarity",
# "Frequency",
# )

plot_histogram(
median_piece_wise_dissimilarity,
"Adjacent Frame Median Dissimilarity per Track",
"Cosine Dissimilarity",
"Frequency",
)
plot_histogram(
p1_percentile_piece_wise_dissimilarity,
"Adjacent Frame 1 Percentile Dissimilarity per Track",
"Cosine Dissimilarity",
"Frequency",
)
plot_histogram(
p99_piece_wise_dissimilarity,
"Adjacent Frame 99 Percentile Dissimilarity per Track",
piece_wise_dissimilarity_per_track,
"Adjacent Frame Dissimilarity per Track",
"Cosine Dissimilarity",
"Frequency",
)
Expand All @@ -145,7 +165,7 @@ def plot_histogram(
# Plot the median and the random sampling in one plot each with different colors
fig = plt.figure()
sns.histplot(
median_piece_wise_dissimilarity,
all_dissimilarity,
bins=30,
kde=True,
color="cyan",
Expand All @@ -161,8 +181,8 @@ def plot_histogram(
plt.legend(["Adjacent Frame", "Random Sample"])
plt.show()
fig.savefig(
f"./cosine_dissimilarity_smoothness_{prediction_path.stem}.pdf",
dpi=300,
f"{PATH_TO_GDRIVE_FIGUE}/cosine_dissimilarity_smoothness_{prediction_path.stem}_{loss_name}.pdf",
dpi=600,
)

# %%

0 comments on commit a98e882

Please sign in to comment.