-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #244 from matchms/update_plotting_functions
Update plotting functions
- Loading branch information
Showing
45 changed files
with
963 additions
and
1,795 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
ms2deepscore/benchmarking/CalculateScoresBetweenAllIonmodes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
|
||
from ms2deepscore import MS2DeepScore | ||
from ms2deepscore.validation_loss_calculation.PredictionsAndTanimotoScores import PredictionsAndTanimotoScores | ||
from ms2deepscore.validation_loss_calculation.calculate_scores_for_validation import create_embedding_matrix_symmetric, \ | ||
create_embedding_matrix_not_symmetric, calculate_tanimoto_scores_unique_inchikey | ||
from ms2deepscore.models.load_model import load_model | ||
|
||
|
||
class CalculateScoresBetweenAllIonmodes: | ||
"""Calculates the true tanimoto scores and scores between all ionmodes""" | ||
def __init__(self, | ||
model_file_name, positive_validation_spectra, negative_validation_spectra): | ||
self.model_file_name = model_file_name | ||
self.postive_validation_spectra = positive_validation_spectra | ||
self.negative_validation_spectra = negative_validation_spectra | ||
self.model = MS2DeepScore(load_model(model_file_name)) | ||
|
||
self.pos_vs_neg_scores = self.get_tanimoto_and_prediction_pairs( | ||
positive_validation_spectra, negative_validation_spectra, label="positive vs negative") | ||
self.pos_vs_pos_scores = self.get_tanimoto_and_prediction_pairs( | ||
positive_validation_spectra, label="positive vs positive") | ||
self.neg_vs_neg_scores = self.get_tanimoto_and_prediction_pairs( | ||
negative_validation_spectra, label="negative vs negative") | ||
# Avoid memory leakage | ||
torch.cuda.empty_cache() | ||
del self.model | ||
|
||
def get_tanimoto_and_prediction_pairs(self, spectra_1, spectra_2=None, label="") -> PredictionsAndTanimotoScores: | ||
symmetric = False | ||
if spectra_2 is None: | ||
spectra_2 = spectra_1 | ||
symmetric = True | ||
if symmetric: | ||
predictions_df = create_embedding_matrix_symmetric(self.model, spectra_1) | ||
else: | ||
predictions_df = create_embedding_matrix_not_symmetric(self.model, spectra_1, spectra_2) | ||
tanimoto_scores_df = calculate_tanimoto_scores_unique_inchikey(spectra_1, spectra_2) | ||
return PredictionsAndTanimotoScores(predictions_df, tanimoto_scores_df, symmetric, label) | ||
|
||
def list_of_predictions_and_tanimoto_scores(self): | ||
return [self.pos_vs_pos_scores, | ||
self.pos_vs_neg_scores, | ||
self.neg_vs_neg_scores, ] |
117 changes: 0 additions & 117 deletions
117
ms2deepscore/benchmarking/calculate_scores_for_validation.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from matplotlib import pyplot as plt | ||
|
||
from ms2deepscore.benchmarking.CalculateScoresBetweenAllIonmodes import CalculateScoresBetweenAllIonmodes | ||
from ms2deepscore.utils import create_evenly_spaced_bins | ||
|
||
|
||
def plot_average_per_bin(scores_between_ionmodes: CalculateScoresBetweenAllIonmodes, nr_of_bins): | ||
bins = create_evenly_spaced_bins(nr_of_bins) | ||
bin_centers = [(bin_borders[0] + bin_borders[1])/2 for bin_borders in bins] | ||
fig, ax = plt.subplots() | ||
|
||
for predictions_and_tanimoto_scores in scores_between_ionmodes.list_of_predictions_and_tanimoto_scores(): | ||
average_predictions = predictions_and_tanimoto_scores.get_average_prediction_per_inchikey_pair() | ||
_, average_per_bin = predictions_and_tanimoto_scores.get_average_per_bin(average_predictions, bins) | ||
ax.plot(bin_centers, average_per_bin, label=predictions_and_tanimoto_scores.label) | ||
|
||
ax.set_xlabel("True chemical similarity") | ||
ax.set_ylabel("Average predicted chemical similarity") | ||
ax.set_xlim(0, 1) | ||
ax.set_ylim(0, 1) | ||
ax.legend() | ||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import Tuple, List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from matplotlib import pyplot as plt | ||
|
||
from ms2deepscore.benchmarking.CalculateScoresBetweenAllIonmodes import CalculateScoresBetweenAllIonmodes | ||
from ms2deepscore.validation_loss_calculation.PredictionsAndTanimotoScores import PredictionsAndTanimotoScores | ||
|
||
|
||
def create_3_heatmaps(pairs: CalculateScoresBetweenAllIonmodes, nr_of_bins): | ||
minimum_y_axis = 0 | ||
maximum_y_axis = 1 | ||
for predictions_and_tanimoto_score in pairs.list_of_predictions_and_tanimoto_scores(): | ||
average_pred_per_inchikey_pair = predictions_and_tanimoto_score.get_average_prediction_per_inchikey_pair() | ||
minimum = average_pred_per_inchikey_pair.min().min() | ||
maximum = average_pred_per_inchikey_pair.max().max() | ||
if minimum < minimum_y_axis: | ||
minimum_y_axis = minimum | ||
if maximum > maximum_y_axis: | ||
maximum_y_axis = maximum | ||
|
||
x_bins = np.linspace(0, 1, nr_of_bins + 1) | ||
y_bins = np.linspace(minimum_y_axis, maximum_y_axis + 0.00001, nr_of_bins + 1) | ||
|
||
# Take the average per bin | ||
pos_pos_normalized_heatmap = create_normalized_heatmap_data(pairs.pos_vs_pos_scores, x_bins, y_bins) | ||
neg_neg_normalized_heatmap = create_normalized_heatmap_data(pairs.neg_vs_neg_scores, x_bins, y_bins) | ||
pos_neg_normalized_heatmap = create_normalized_heatmap_data(pairs.pos_vs_neg_scores, x_bins, y_bins) | ||
|
||
maximum_heatmap_intensity = max(pos_pos_normalized_heatmap.max(), neg_neg_normalized_heatmap.max(), | ||
pos_neg_normalized_heatmap.max()) | ||
|
||
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | ||
axes[0].imshow(neg_neg_normalized_heatmap.T, origin='lower', interpolation='nearest', | ||
cmap="inferno", vmax=maximum_heatmap_intensity, extent=[0, 1, minimum_y_axis, maximum_y_axis]) | ||
axes[0].set_title("Negative vs negative") | ||
axes[1].imshow(pos_pos_normalized_heatmap.T, origin='lower', interpolation='nearest', | ||
cmap="inferno", vmax=maximum_heatmap_intensity, extent=[0, 1, minimum_y_axis, maximum_y_axis]) | ||
axes[1].set_title("Positive vs positive") | ||
im2 = axes[2].imshow(pos_neg_normalized_heatmap.T, origin='lower', interpolation='nearest', | ||
cmap="inferno", vmax=maximum_heatmap_intensity, extent=[0, 1, minimum_y_axis, maximum_y_axis]) | ||
axes[2].set_title("Positive vs negative") | ||
for ax in axes: | ||
ax.set_xlabel("True chemical similarity") | ||
ax.set_ylabel("Predicted chemical similarity") | ||
ax.set_xlim(0, 1) | ||
ax.set_ylim(minimum_y_axis, maximum_y_axis) | ||
|
||
cbar = fig.colorbar(im2, ax=axes, orientation='vertical', fraction=0.02, pad=0.04) | ||
cbar.set_label('Density') # Label for the colorbar | ||
return fig | ||
|
||
|
||
def create_normalized_heatmap_data(prediction_and_tanimoto_scores: PredictionsAndTanimotoScores, | ||
x_bins, y_bins): | ||
average_prediction = \ | ||
prediction_and_tanimoto_scores.get_average_prediction_per_inchikey_pair() | ||
list_of_tanimoto_scores, list_of_average_predictions = convert_dataframes_to_lists_with_matching_pairs( | ||
prediction_and_tanimoto_scores.tanimoto_df, | ||
average_prediction) | ||
heatmap = np.histogram2d(list_of_tanimoto_scores, | ||
list_of_average_predictions, | ||
bins=(x_bins, y_bins))[0] | ||
normalized_heatmap = heatmap / heatmap.sum(axis=1, keepdims=True) | ||
return normalized_heatmap | ||
|
||
|
||
def convert_dataframes_to_lists_with_matching_pairs(tanimoto_df: pd.DataFrame, | ||
average_predictions_per_inchikey_pair: pd.DataFrame | ||
) -> Tuple[List[float], List[float]]: | ||
"""Takes in two dataframes with inchikeys as index and returns two lists with scores, which correspond to pairs""" | ||
predictions = [] | ||
tanimoto_scores = [] | ||
for inchikey_1 in average_predictions_per_inchikey_pair.index: | ||
for inchikey_2 in average_predictions_per_inchikey_pair.columns: | ||
prediction = average_predictions_per_inchikey_pair[inchikey_2][inchikey_1] | ||
# don't include pairs where the prediciton is Nan (this is the case when only a pair against itself is available) | ||
if not np.isnan(prediction): | ||
predictions.append(prediction) | ||
tanimoto_scores.append(tanimoto_df[inchikey_2][inchikey_1]) | ||
return tanimoto_scores, predictions |
Oops, something went wrong.