Skip to content

Commit

Permalink
add wrapper and test
Browse files Browse the repository at this point in the history
  • Loading branch information
florian-huber committed Aug 13, 2023
1 parent 04e8b01 commit 1efd346
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 26 deletions.
41 changes: 37 additions & 4 deletions ms2deepscore/spectrum_pair_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,45 @@
from collections import Counter
import numba
import numpy as np
from matchms import Spectrum
from matchms.filtering import add_fingerprint
from matchms.similarity.vector_similarity_functions import jaccard_index
from scipy.sparse import coo_array, lil_array


def compute_spectrum_pairs(spectrums):
pass
def compute_spectrum_pairs(spectrums,
selection_bins: np.ndarray = np.array([(x/10, x/10 + 0.1) for x in range(0, 10)]),
max_pairs_per_bin: int = 20,
include_diagonal: bool = True,
fix_global_bias: bool = True,
fingerprint_type: str = "daylight",
nbits: int = 2048):
"""Function to compute the compound similarities (Tanimoto) and collect a well-balanced set of pairs.
TODO: describe method and arguments
"""
# pylint: disable=too-many-arguments
spectra_selected, inchikeys14_unique = select_inchi_for_unique_inchikeys(spectrums)
print(f"Selected {len(spectra_selected)} spectra with unique inchikeys (out of {len(spectrums)} spectra)")
# Compute fingerprints using matchms
spectra_selected = [add_fingerprint(s, fingerprint_type, nbits) for s in spectra_selected]

# Ignore missing / not-computed fingerprints
fingerprints = [s.get("fingerprint") for s in spectra_selected]
idx = np.array([i for i, x in enumerate(fingerprints) if x is not None]).astype(int)
if len(idx) == 0:
raise ValueError("No fingerprints could be computed")
if len(idx) < len(fingerprints):
print(f"Successfully generated fingerprints for {len(idx)} of {len(fingerprints)} spectra")
fingerprints = [fingerprints[i] for i in idx]
inchikeys14_unique = [inchikeys14_unique[i] for i in idx]
spectra_selected = [spectra_selected[i] for i in idx]
return compute_jaccard_similarity_matrix_cherrypicking(
np.array(fingerprints),
selection_bins,
max_pairs_per_bin,
include_diagonal,
fix_global_bias)


def select_inchi_for_unique_inchikeys(
Expand Down Expand Up @@ -141,7 +174,7 @@ def compute_jaccard_similarity_matrix_cherrypicking(
if i == j and not include_diagonal:
continue
scores_row[j] = jaccard_index(fingerprints[i, :], fingerprints[j, :])

# Cherrypicking
for bin_number, selection_bin in enumerate(selection_bins):
# Indices of scores within the current bin
Expand All @@ -160,5 +193,5 @@ def compute_jaccard_similarity_matrix_cherrypicking(
scores_i.extend(len(idx_selected) * [i])
scores_j.extend(list(idx_selected))
print(max_pairs_global)

return scores_data, scores_i, scores_j
58 changes: 36 additions & 22 deletions tests/test_spectrum_pair_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from matchms import Spectrum
from ms2deepscore.spectrum_pair_selection import (
compute_jaccard_similarity_matrix_cherrypicking,
compute_spectrum_pairs,
jaccard_similarity_matrix_cherrypicking,
select_inchi_for_unique_inchikeys
)
Expand Down Expand Up @@ -32,6 +33,27 @@ def fingerprints():
], dtype=bool)


@pytest.fixture
def spectrums():
metadata = {"precursor_mz": 101.1,
"inchikey": "ABCABCABCABCAB-nonsense",
"inchi": "InChI=1/C6H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/h2,5,7-10H,1H2/t2-,5+/m0/s1"}
spectrum_1 = Spectrum(mz=np.array([100.]),
intensities=np.array([0.7]),
metadata=metadata)
spectrum_2 = Spectrum(mz=np.array([90.]),
intensities=np.array([0.4]),
metadata=metadata)
spectrum_3 = Spectrum(mz=np.array([90.]),
intensities=np.array([0.4]),
metadata=metadata)
spectrum_4 = Spectrum(mz=np.array([90.]),
intensities=np.array([0.4]),
metadata={"inchikey": 14 * "X",
"inchi": "InChI=1S/C8H10N4O2/c1-10-4-9-6-5(10)7(13)12(3)8(14)11(6)2/h4H,1-3H3"})
return [spectrum_1, spectrum_2, spectrum_3, spectrum_4]


def test_basic_functionality(simple_fingerprints):
matrix = jaccard_similarity_matrix_cherrypicking(simple_fingerprints, random_seed=42)
assert matrix.shape == (4, 4)
Expand Down Expand Up @@ -81,31 +103,23 @@ def test_global_bias_not_possible(fingerprints):
assert (data>0.8).sum() == 8


def test_select_inchi_for_unique_inchikeys():
metadata = {"precursor_mz": 101.1,
"inchikey": "ABCABCABCABCAB-nonsense",
"inchi": "InChI=1/C6H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/h2,5,7-10H,1H2/t2-,5+/m0/s1"}
spectrum_1 = Spectrum(mz=np.array([100.]),
intensities=np.array([0.7]),
metadata=metadata)
spectrum_2 = Spectrum(mz=np.array([90.]),
intensities=np.array([0.4]),
metadata=metadata)
metadata["inchikey"] = "ABCABCABCABCAB-nonsense2"
spectrum_3 = Spectrum(mz=np.array([90.]),
intensities=np.array([0.4]),
metadata=metadata)
spectrum_4 = Spectrum(mz=np.array([90.]),
intensities=np.array([0.4]),
metadata={"inchikey": "ABCABCABCABCAB-nonsense2",
"inchi": "InChI=1/C666H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/h2,5,7-10H,1H2/t2-,5+/m0/s1"})
(spectrums_selected, inchikey14s) = select_inchi_for_unique_inchikeys([spectrum_1, spectrum_2, spectrum_3, spectrum_4])
def test_select_inchi_for_unique_inchikeys(spectrums):
spectrums[2].set("inchikey", "ABCABCABCABCAB-nonsense2")
spectrums[3].set("inchikey", "ABCABCABCABCAB-nonsense3")
(spectrums_selected, inchikey14s) = select_inchi_for_unique_inchikeys(spectrums)
assert inchikey14s == ['ABCABCABCABCAB']
assert spectrums_selected[0].get("inchi").startswith("InChI=1/C6H8O6/")


def test_select_inchi_for_unique_inchikeys_two_inchikeys(spectrums):
# Test for two different inchikeys
spectrum_4.set("inchikey", 14 * "X")
spectrums = [spectrum_1, spectrum_2, spectrum_3, spectrum_4, spectrum_4]
(spectrums_selected, inchikey14s) = select_inchi_for_unique_inchikeys(spectrums)
assert inchikey14s == ['ABCABCABCABCAB', 'XXXXXXXXXXXXXX']
assert [s.get("inchi")[:15] for s in spectrums_selected] == ['InChI=1/C6H8O6/', 'InChI=1/C666H8O']
assert [s.get("inchi")[:15] for s in spectrums_selected] == ['InChI=1/C6H8O6/', 'InChI=1S/C8H10N']


def test_compute_spectrum_pairs(spectrums):
a, b, c = compute_spectrum_pairs(spectrums)
assert b == [0, 0, 1, 1]
assert c == [1, 0, 0, 1]
assert np.allclose(a, [0.1665089877010407, 1.0, 0.1665089877010407, 1.0])

0 comments on commit 1efd346

Please sign in to comment.