Skip to content

Commit

Permalink
add pair generator in new SelectedCompoundPairs class
Browse files Browse the repository at this point in the history
  • Loading branch information
florian-huber committed Aug 14, 2023
1 parent fd6376c commit 0528d3c
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 2 deletions.
57 changes: 56 additions & 1 deletion ms2deepscore/spectrum_pair_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from matchms import Spectrum
from matchms.filtering import add_fingerprint
from matchms.similarity.vector_similarity_functions import jaccard_index
from scipy.sparse import coo_array, lil_array
from scipy.sparse import coo_array


def compute_spectrum_pairs(spectrums,
Expand Down Expand Up @@ -195,3 +195,58 @@ def select_inchi_for_unique_inchikeys(
spectra_selected.append(list_of_spectra[ID].clone())

return spectra_selected, inchikeys14_unique


class SelectedCompoundPairs:
"""Class to store sparse ("cherrypicked") compound pairs and their respective scores.
This is meant to be used with the results of the `compute_spectrum_pairs()` function.
The therein selected (cherrypicked) scores are stored similar to a list-of-lists format.
"""
def __init__(self, coo_array, inchikeys):
self._scores = []
self._cols = []

self._idx_to_inchikey = {idx: key for idx, key in enumerate(inchikeys)}
self._inchikey_to_idx = {key: idx for idx, key in enumerate(inchikeys)}

for row_idx in self._idx_to_inchikey.keys():
row_mask = (coo_array.row == row_idx)
self._cols.append(coo_array.col[row_mask])
self._scores.append(coo_array.data[row_mask])

# Initialize counter for each column
self._row_generator_index = np.zeros(len(self._idx_to_inchikey), dtype=int)

def shuffle(self):
"""Shuffle all scores for all inchikeys."""
for i in range(len(self._scores)):
self._shuffle_row(i)

def _shuffle_row(self, row_index):
"""Shuffle the column and scores of row with row_index."""
permutation = np.random.permutation(len(self._cols[row_index]))
self._cols[row_index] = self._cols[row_index][permutation]
self._scores[row_index] = self._scores[row_index][permutation]

def next_pair_for_inchikey(self, inchikey):
row_idx = self._inchikey_to_idx[inchikey]

# Retrieve the next pair
col_idx = self._cols[row_idx][self._row_generator_index[row_idx]]
score = self._scores[row_idx][self._row_generator_index[row_idx]]

# Update the counter, wrapping around if necessary
self._row_generator_index[row_idx] += 1
if self._row_generator_index[row_idx] >= len(self._cols[row_idx]):
self._row_generator_index[row_idx] = 0

return score, self._idx_to_inchikey[col_idx]

@property
def scores(self):
return self._scores

def __str__(self):
return f"SelectedCompoundPairs with {len(self._scores)} columns."
61 changes: 60 additions & 1 deletion tests/test_spectrum_pair_selection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy as np
import pytest
from scipy.sparse import coo_array
from matchms import Spectrum
from ms2deepscore.spectrum_pair_selection import (
compute_jaccard_similarity_matrix_cherrypicking,
compute_spectrum_pairs,
jaccard_similarity_matrix_cherrypicking,
select_inchi_for_unique_inchikeys
select_inchi_for_unique_inchikeys,
SelectedCompoundPairs
)


Expand Down Expand Up @@ -54,6 +56,15 @@ def spectrums():
return [spectrum_1, spectrum_2, spectrum_3, spectrum_4]


@pytest.fixture
def dummy_data():
data = np.array([0.5, 0.7, 0.3, 0.9, 0.8, 0.2, 0.6, 1.0])
row = np.array([0, 1, 1, 1, 3, 4, 5, 6]) # 2 missing on purpose
col = np.array([0, 2, 3, 4, 6, 5, 2, 6])
inchikeys = ["Inchikey0", "Inchikey1", "Inchikey2", "Inchikey3", "Inchikey4", "Inchikey5", "Inchikey6"]
return data, row, col, inchikeys


def test_basic_functionality(simple_fingerprints):
matrix = jaccard_similarity_matrix_cherrypicking(simple_fingerprints, random_seed=42)
assert matrix.shape == (4, 4)
Expand Down Expand Up @@ -136,3 +147,51 @@ def test_compute_spectrum_pairs_vary_parameters(spectrums):
assert scores.shape == (2, 2)
assert len(scores.row) == 4
assert np.allclose(scores.data, [1.0, 1.0, 1.0, 1.0])


# Test SelectedCompoundPairs class
def test_SCP_initialization(dummy_data):
data, row, col, inchikeys = dummy_data
coo = coo_array((data, (row, col)))
scp = SelectedCompoundPairs(coo, inchikeys)

assert len(scp._cols) == len(inchikeys)
assert len(scp._scores) == len(inchikeys)
assert scp._idx_to_inchikey[0] == "Inchikey0"
assert scp._inchikey_to_idx["Inchikey0"] == 0


def test_SCP_shuffle(dummy_data):
data, row, col, inchikeys = dummy_data
coo = coo_array((data, (row, col)))
scp = SelectedCompoundPairs(coo, inchikeys)

original_cols = [r.copy() for r in scp._cols]
original_scores = [s.copy() for s in scp._scores]

scp.shuffle()

# Check that the data has been shuffled
assert not all(np.array_equal(o, n) for o, n in zip(original_cols, scp._cols))
assert not all(np.array_equal(o, n) for o, n in zip(original_scores, scp._scores))


def test_SCP_next_pair_for_inchikey(dummy_data):
data, row, col, inchikeys = dummy_data
coo = coo_array((data, (row, col)))
scp = SelectedCompoundPairs(coo, inchikeys)

score, inchikey2 = scp.next_pair_for_inchikey("Inchikey1")
assert score == 0.7
assert inchikey2 == "Inchikey2"

score, inchikey2 = scp.next_pair_for_inchikey("Inchikey1")
assert scp._row_generator_index[1] == 2
assert score == 0.3
assert inchikey2 == "Inchikey3"

# Test wrap around
scp._row_generator_index[1] = 0
score, inchikey2 = scp.next_pair_for_inchikey("Inchikey1")
assert score == 0.7
assert inchikey2 == "Inchikey2"

0 comments on commit 0528d3c

Please sign in to comment.