Skip to content

Commit

Permalink
Further cleaning and adding docstring to PredictionsAndTanimotoScores.py
Browse files Browse the repository at this point in the history
  • Loading branch information
niekdejonge committed Oct 24, 2024
1 parent dfd3fc0 commit 7aec3ef
Showing 1 changed file with 14 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def __init__(self, predictions_df: pd.DataFrame,
self.symmetric = symmetric
self.label = label

self.check_input_data()
self._check_input_data()
# remove predicitons between the same spectrum
if self.symmetric:
np.fill_diagonal(self.predictions_df.values, np.nan)

def check_input_data(self):
def _check_input_data(self):
"""Checks that the prediction df and tanimoto df have the expected format"""
if not isinstance(self.predictions_df, pd.DataFrame) or not isinstance(self.tanimoto_df, pd.DataFrame):
raise TypeError("Expected a pandas DF as input")
Expand Down Expand Up @@ -61,6 +61,7 @@ def get_average_loss_per_bin_per_inchikey_pair(self,
loss_type: str,
tanimoto_bins: np.ndarray):
"""Calculates the loss per tanimoto bin.
First the average prediction for each inchikey pair is calculated,
by taking the average over all predictions between spectra matching these inchikeys"""
loss_type = loss_type.lower()
Expand All @@ -77,22 +78,28 @@ def get_average_loss_per_bin_per_inchikey_pair(self,
f"rmse, mae, risk_mse and risk_mae")
average_losses_per_inchikey_pair = get_average_per_inchikey_pair(losses_per_spectrum_pair)

bin_content, bounds, average_loss_per_bin = self.get_average_loss_per_bin(average_losses_per_inchikey_pair,
tanimoto_bins)
bin_content, bounds, average_loss_per_bin = self._get_average_loss_per_bin(average_losses_per_inchikey_pair,
tanimoto_bins)
if loss_type == "RMSE":
average_loss_per_bin = [average_loss ** 0.5 for average_loss in average_loss_per_bin]
return bin_content, bounds, average_loss_per_bin

def _get_absolute_error_per_spectrum_pair(self):
"""Calculates the absolute error
Used to get the MAE, but the mean is taken after binning over the tanimoto bins."""
losses = abs(self.predictions_df - self.tanimoto_df)
return losses

def _get_squared_error_per_spectrum_pair(self):
"""Calculates the squared errors
Used to get the MSE or RMSE, but the mean is taken after binning over the tanimoto bins."""
losses = (self.predictions_df - self.tanimoto_df) ** 2
return losses

def _get_risk_aware_squared_error_per_spectrum_pair(self):
"""MSE weighted by target position on scale 0 to 1.
"""Squared error weighted by target position on scale 0 to 1.
"""
errors = self.tanimoto_df - self.predictions_df
errors = np.sign(errors) * errors ** 2
Expand All @@ -103,7 +110,7 @@ def _get_risk_aware_squared_error_per_spectrum_pair(self):
return risk_aware_squared_error

def _get_risk_aware_absolute_error_per_spectrum_pair(self):
"""MAE weighted by target position on scale 0 to 1.
"""Absolute errors weighted by target position on scale 0 to 1.
"""
errors = self.tanimoto_df - self.predictions_df
uppers = self.tanimoto_df * errors
Expand All @@ -113,7 +120,7 @@ def _get_risk_aware_absolute_error_per_spectrum_pair(self):
columns=lowers.columns)
return risk_aware_absolute_error

def get_average_loss_per_bin(self,
def _get_average_loss_per_bin(self,
average_loss_per_inchikey_pair: pd.DataFrame,
ref_score_bins: np.ndarray):
"""Compute average loss per tanimoto score bin
Expand Down

0 comments on commit 7aec3ef

Please sign in to comment.