Skip to content

Commit

Permalink
code linting
Browse files Browse the repository at this point in the history
  • Loading branch information
florian-huber committed Oct 29, 2024
1 parent 21ccb84 commit 8bf9a7e
Showing 1 changed file with 63 additions and 31 deletions.
94 changes: 63 additions & 31 deletions ms2deepscore/train_new_model/inchikey_pair_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,37 @@ def select_compound_pairs_wrapper(
if settings.random_seed is not None:
np.random.seed(settings.random_seed)

fingerprints, inchikeys14_unique = compute_fingerprints_for_training(spectra, settings.fingerprint_type,
settings.fingerprint_nbits)
fingerprints, inchikeys14_unique = compute_fingerprints_for_training(
spectra,
settings.fingerprint_type,
settings.fingerprint_nbits
)

available_pairs_per_bin_matrix, available_scores_per_bin_matrix = compute_jaccard_similarity_per_bin(
fingerprints,
settings.max_pairs_per_bin,
settings.same_prob_bins,
settings.include_diagonal)

aimed_nr_of_pairs_per_bin = determine_aimed_nr_of_pairs_per_bin(available_pairs_per_bin_matrix,
settings,
nr_of_inchikeys=len(inchikeys14_unique))

pair_frequency_matrixes = balanced_selection_of_pairs_per_bin(available_pairs_per_bin_matrix,
settings.max_pair_resampling,
aimed_nr_of_pairs_per_bin)

selected_pairs_per_bin = convert_to_selected_pairs_list(pair_frequency_matrixes, available_pairs_per_bin_matrix,
available_scores_per_bin_matrix, inchikeys14_unique)
settings.include_diagonal
)

aimed_nr_of_pairs_per_bin = determine_aimed_nr_of_pairs_per_bin(
available_pairs_per_bin_matrix,
settings,
nr_of_inchikeys=len(inchikeys14_unique)
)

pair_frequency_matrixes = balanced_selection_of_pairs_per_bin(
available_pairs_per_bin_matrix,
settings.max_pair_resampling,
aimed_nr_of_pairs_per_bin
)

selected_pairs_per_bin = convert_to_selected_pairs_list(
pair_frequency_matrixes,
available_pairs_per_bin_matrix,
available_scores_per_bin_matrix,
inchikeys14_unique
)
return [pair for pairs in selected_pairs_per_bin for pair in pairs]


Expand Down Expand Up @@ -79,7 +91,7 @@ def compute_fingerprints_for_training(

# Compute fingerprints using matchms
spectra_selected = [add_fingerprint(s, fingerprint_type, nbits) \
if s.get("fingerprint") is None else s for s in spectra_selected]
if s.get("fingerprint") is None else s for s in spectra_selected]

# Ignore missing / not-computed fingerprints
fingerprints = [s.get("fingerprint") for s in tqdm(spectra_selected,
Expand Down Expand Up @@ -142,22 +154,40 @@ def compute_jaccard_similarity_per_bin(


def determine_aimed_nr_of_pairs_per_bin(available_pairs_per_bin_matrix, settings, nr_of_inchikeys):
"""Determines the aimed_nr_of_pairs_per_bin.
If the settings given are higher than the highest possible number of pairs it is lowered to that"""
"""Calculate the target number of pairs per bin based on available pairs and given settings.
This function determines the desired number of pairs per bin based on the `average_pairs_per_bin`
setting and the total number of InChIKeys. If the calculated number exceeds the maximum possible
number of pairs, it is adjusted to the feasible limit.
Parameters:
-----------
available_pairs_per_bin_matrix
A matrix containing the available number of pairs for each bin.
settings:
Settings object containing configuration options.
Required attributes:
- max_pair_resampling: Factor to resample pairs per bin.
- average_pairs_per_bin: Desired average number of pairs per bin.
nr_of_inchikeys:
The total number of InChIKeys.
"""

# Select the nr_of_pairs_per_bin to use
# Get the number of available pairs per bin
nr_of_available_pairs_per_bin = get_nr_of_available_pairs_in_bin(available_pairs_per_bin_matrix)
lowest_max_number_of_pairs = min(nr_of_available_pairs_per_bin) * settings.max_pair_resampling
print(f"The available nr of pairs per bin are: {nr_of_available_pairs_per_bin}")

# Calculate initial target number of pairs per bin
aimed_nr_of_pairs_per_bin = settings.average_pairs_per_bin * nr_of_inchikeys
if lowest_max_number_of_pairs < aimed_nr_of_pairs_per_bin:
print(f"Warning: The average_pairs_per_bin: {settings.average_pairs_per_bin} cannot be reached, "
f"since this would require "
f"{settings.average_pairs_per_bin} * {nr_of_inchikeys} = {aimed_nr_of_pairs_per_bin} pairs."
f"But one of the bins has only {lowest_max_number_of_pairs} available"
f"Instead the lowest number of available pairs in a bin times the resampling is used, "
f"which is: {lowest_max_number_of_pairs}")
print(f"Warning: The target average_pairs_per_bin ({settings.average_pairs_per_bin}) cannot be reached, "
f"as it requires {aimed_nr_of_pairs_per_bin} pairs. However, one of the bins has only "
f"{lowest_max_number_of_pairs} pairs available. The number will be adjusted to this limit.")
aimed_nr_of_pairs_per_bin = lowest_max_number_of_pairs

return aimed_nr_of_pairs_per_bin


Expand Down Expand Up @@ -191,16 +221,17 @@ def balanced_selection_of_pairs_per_bin(
nr_of_pairs_per_bin:
The number of pairs that should be sampled for each tanimoto bin.
"""

inchikey_count = np.zeros(available_pairs_per_bin_matrix.shape[1])
pair_frequency_matrixes = []
for pairs_in_bin in available_pairs_per_bin_matrix:
pair_frequencies, inchikey_count = select_balanced_pairs(pairs_in_bin,
inchikey_count,
nr_of_pairs_per_bin,
max_pair_resampling,
)
pair_frequencies, inchikey_count = select_balanced_pairs(
pairs_in_bin,
inchikey_count,
nr_of_pairs_per_bin,
max_pair_resampling,
)
pair_frequency_matrixes.append(pair_frequencies)

pair_frequency_matrixes = np.array(pair_frequency_matrixes)
pair_frequency_matrixes[pair_frequency_matrixes == 2 * max_pair_resampling] = 0
return pair_frequency_matrixes
Expand Down Expand Up @@ -357,7 +388,8 @@ def select_balanced_pairs(available_pairs_for_bin_matrix: np.ndarray,


def get_nr_of_available_pairs_in_bin(selected_pairs_per_bin_matrix: np.ndarray) -> List[int]:
"""Calculates the number of unique pairs available per bin, discarding duplicated (inverted) pairs"""
"""Calculates the number of unique pairs available per bin, discarding duplicated (inverted) pairs.
"""
nr_of_unique_pairs_per_bin = []
for bin_idx in tqdm(range(selected_pairs_per_bin_matrix.shape[0]),
desc="Determining number of available pairs per bin"):
Expand Down

0 comments on commit 8bf9a7e

Please sign in to comment.