Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
florian-huber authored Oct 28, 2024
1 parent 4afe6b6 commit dbd2bd0
Showing 1 changed file with 28 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@


class ValidationLossCalculator:
"""A class to calculate validation loss for MS2DeepScore models based on validation spectra and Tanimoto scores.
"""
def __init__(self,
val_spectrums,
settings: SettingsMS2Deepscore):
"""
Parameters:
-----------
val_spectrums:
A list of spectra used for validation.
settings:
Configuration settings for MS2DeepScore.
"""
self.val_spectrums = val_spectrums
self.score_bins = settings.same_prob_bins
self.settings = settings
Expand All @@ -22,20 +32,31 @@ def __init__(self,
def compute_binned_validation_loss(self,
model: SiameseSpectralModel,
loss_types):
"""Benchmark the model against a validation set.
"""
Compute the validation loss for a model based on binned Tanimoto scores.
Parameters:
-----------
model : SiameseSpectralModel
The Siamese spectral model to be benchmarked.
loss_types : list
A list of loss types to calculate (e.g., 'mse', 'mae').
"""
ms2deepscore_model = MS2DeepScore(model)
ms2ds_scores = create_embedding_matrix_symmetric(ms2deepscore_model, self.val_spectrums)
predictions_and_tanimoto_scores = PredictionsAndTanimotoScores(
predictions_df=ms2ds_scores, tanimoto_df=self.tanimoto_scores, symmetric=True)
predictions_df=ms2ds_scores,
tanimoto_df=self.tanimoto_scores,
symmetric=True
)
losses_per_bin = {}
for loss_type in loss_types:
_, average_loss_per_bin = predictions_and_tanimoto_scores.get_average_loss_per_bin_per_inchikey_pair(
loss_type,
self.settings.same_prob_bins)
self.settings.same_prob_bins
)
losses_per_bin[loss_type] = average_loss_per_bin
average_losses = {}
for loss_type in loss_types:
average_losses[loss_type] = np.mean(losses_per_bin[loss_type])
return average_losses, losses_per_bin

average_losses = {loss_type: np.mean(losses_per_bin[loss_type]) for loss_type in loss_types}

return average_losses, losses_per_bin

0 comments on commit dbd2bd0

Please sign in to comment.