Skip to content

Commit

Permalink
move functions, linting
Browse files Browse the repository at this point in the history
  • Loading branch information
florian-huber committed Aug 11, 2023
1 parent 40f6b52 commit 9fb69d0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 34 deletions.
39 changes: 39 additions & 0 deletions ms2deepscore/spectrum_pair_selection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,48 @@
from typing import List, Tuple
from collections import Counter
import numba
import numpy as np
from matchms.similarity.vector_similarity_functions import jaccard_index
from scipy.sparse import coo_array, lil_array


def compute_spectrum_pairs(spectrums):
pass


def select_inchi_for_unique_inchikeys(
list_of_spectra: List['Spectrum']
) -> Tuple[List['Spectrum'], List[str]]:
"""Select spectra with most frequent inchi for unique inchikeys.
Method needed to calculate Tanimoto scores.
"""
# Extract inchi's and inchikeys from spectra metadata
inchikeys_list = [s.get("inchikey") for s in list_of_spectra]
inchi_list = [s.get("inchi") for s in list_of_spectra]

inchi_array = np.array(inchi_list)
inchikeys14_array = np.array([x[:14] for x in inchikeys_list])

# Find unique inchikeys
inchikeys14_unique = sorted(set(inchikeys14_array))

spectra_selected = []
for inchikey14 in inchikeys14_unique:
# Indices of matching inchikeys
idx = np.where(inchikeys14_array == inchikey14)[0]

# Find the most frequent inchi for the inchikey
most_common_inchi = Counter(inchi_array[idx]).most_common(1)[0][0]

# ID of the spectrum with the most frequent inchi
ID = idx[np.where(inchi_array[idx] == most_common_inchi)[0][0]]

spectra_selected.append(list_of_spectra[ID].clone())

return spectra_selected, inchikeys14_unique


def jaccard_similarity_matrix_cherrypicking(
fingerprints: np.ndarray,
selection_bins: np.ndarray = np.array([(x/10, x/10 + 0.1) for x in range(0, 10)]),
Expand Down
34 changes: 1 addition & 33 deletions ms2deepscore/train_new_model/calculate_tanimoto_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
jaccard_similarity_matrix
from rdkit import Chem
from tqdm import tqdm
from ..spectrum_pair_selection import select_inchi_for_unique_inchikeys


def calculate_tanimoto_scores_unique_inchikey(
Expand Down Expand Up @@ -44,39 +45,6 @@ def calculate_tanimoto_scores_unique_inchikey(
return tanimoto_df


def select_inchi_for_unique_inchikeys(
list_of_spectra: List[Spectrum],
) -> (List[Spectrum], List[str]):
""" "Select spectra with most frequent inchi for unique inchikeys
Method needed to calculate tanimoto scores"""
# Select all inchi's and inchikeys from spectra metadata
inchikeys_list = []
inchi_list = []
for s in list_of_spectra:
inchikeys_list.append(s.get("inchikey"))
inchi_list.append(s.get("inchi"))
inchi_array = np.array(inchi_list)
inchikeys14_array = np.array([x[:14] for x in inchikeys_list])

# Select unique inchikeys
inchikeys14_unique = sorted(list({x[:14] for x in inchikeys_list}))

spectra_with_most_frequent_inchi_per_unique_inchikey = []
for inchikey14 in inchikeys14_unique:
# Select inchis for inchikey14
idx = np.where(inchikeys14_array == inchikey14)[0]
inchis_for_inchikey14 = [list_of_spectra[i].get("inchi") for i in idx]
# Select the most frequent inchi per inchikey
inchi = Counter(inchis_for_inchikey14).most_common(1)[0][0]
# Store the ID of the spectrum with the most frequent inchi
ID = idx[np.where(inchi_array[idx] == inchi)[0][0]]
spectra_with_most_frequent_inchi_per_unique_inchikey.append(
list_of_spectra[ID].clone()
)
return spectra_with_most_frequent_inchi_per_unique_inchikey, inchikeys14_unique


def calculate_tanimoto_scores_from_smiles(
list_of_smiles_1: List[str], list_of_smiles_2: List[str]
) -> np.ndarray:
Expand Down
29 changes: 28 additions & 1 deletion tests/test_spectrum_pair_selection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import numpy as np
import pytest
from matchms import Spectrum
from ms2deepscore.spectrum_pair_selection import (
compute_jaccard_similarity_matrix_cherrypicking,
jaccard_similarity_matrix_cherrypicking)
jaccard_similarity_matrix_cherrypicking,
select_inchi_for_unique_inchikey
)


@pytest.fixture
Expand Down Expand Up @@ -76,3 +79,27 @@ def test_global_bias_not_possible(fingerprints):
data = np.array(data)
assert (data <= 0.5).sum() == ((data>0.5) & (data<=0.8)).sum() == 16
assert (data>0.8).sum() == 8


def test_select_inchi_for_unique_inchikey():
#ms2ds_binner = SpectrumBinner(100, mz_min=0.0, mz_max=100.0, peak_scaling=1.0)
spectrum_1 = Spectrum(mz=np.array([100.]),
intensities=np.array([0.7]),
metadata={"inchikey": "ABCABCABCABCAB-nonsense",
"inchi": "InChI=1/C6H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/h2,5,7-10H,1H2/t2-,5+/m0/s1"})
spectrum_2 = Spectrum(mz=np.array([90.]),
intensities=np.array([0.4]),
metadata={"inchikey": "ABCABCABCABCAB-nonsense",
"inchi": "InChI=1/C6H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/h2,5,7-10H,1H2/t2-,5+/m0/s1"})
spectrum_3 = Spectrum(mz=np.array([90.]),
intensities=np.array([0.4]),
metadata={"inchikey": "ABCABCABCABCAB-nonsense2",
"inchi": "InChI=1/C666H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/h2,5,7-10H,1H2/t2-,5+/m0/s1"})

select_inchi_for_unique_inchikey([spectrum_1, spectrum_2, spectrum_3])
assert ms2ds_binner.known_bins == [10, 40, 50, 90, 100], "Expected different known bins."
assert len(binned_spectrums) == 2, "Expected 2 binned spectrums."
assert binned_spectrums[0].binned_peaks == {0: 0.7, 2: 0.2, 4: 0.1}, \
"Expected different binned spectrum."
assert binned_spectrums[0].get("inchikey") == "test_inchikey_01", \
"Expected different inchikeys."

0 comments on commit 9fb69d0

Please sign in to comment.