-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from ersilia-os/dev
ZairaChem v1
- Loading branch information
Showing
19 changed files
with
1,018 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,14 +31,17 @@ python3 -m pip install autogluon.tabular[all]==0.5.2 | |
python3 -m pip install "xgboost==1.3.3" | ||
python3 -m pip install "SQLAlchemy<1.4.0" | ||
|
||
# install zairachem | ||
# install extra dependencies | ||
python3 -m pip install git+https://github.com/chembl/[email protected] | ||
python3 -m pip install -q -U keras-tuner==1.1.3 | ||
|
||
# install ersilia | ||
python3 -m pip install git+https://github.com/ersilia-os/ersilia.git | ||
ersilia --help | ||
|
||
# install ersilia compound embedding | ||
python3 -m pip install git+https://github.com/ersilia-os/compound-embedding-lite.git | ||
|
||
# install isaura | ||
python3 -m pip install git+https://github.com/ersilia-os/isaura.git@ce293244ad0bdd6d7d4f796d2a84b17208a87b56 | ||
|
||
|
@@ -51,5 +54,11 @@ python3 -m pip install git+https://github.com/ersilia-os/lazy-qsar.git | |
# install melloddy-tuner | ||
python3 -m pip install git+https://github.com/melloddy/[email protected] | ||
|
||
# install tabpfn | ||
python3 -m pip install tabpfn==0.1.8 | ||
|
||
# install imblearn | ||
python3 -m pip install imbalanced-learn==0.10.1 | ||
|
||
# install zairachem | ||
python3 -m pip install -e . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import numpy as np | ||
from lol import LOL | ||
import random | ||
import collections | ||
from tabpfn import TabPFNClassifier | ||
from imblearn.combine import SMOTETomek | ||
from imblearn.over_sampling import KMeansSMOTE | ||
from imblearn.under_sampling import EditedNearestNeighbours | ||
import joblib | ||
|
||
|
||
class TabPFNBinaryClassifier(object): | ||
def __init__(self, device="cpu", N_ensemble_configurations=4): | ||
self.device = device | ||
self.N_ensemble_configurations = N_ensemble_configurations | ||
self.max_samples = 1000 | ||
|
||
def _get_balanced_datasets(self, X, y): | ||
try: | ||
smp = SMOTETomek(sampling_strategy="auto") | ||
X_0, y_0 = smp.fit_resample(X, y) | ||
except: | ||
X_0, y_0 = X, y | ||
try: | ||
smp = KMeansSMOTE(sampling_strategy="auto") | ||
X_1, y_1 = smp.fit_resample(X, y) | ||
except: | ||
X_1, y_1 = X, y | ||
try: | ||
smp = EditedNearestNeighbours(sampling_strategy="auto") | ||
X_2, y_2 = smp.fit_resample(X, y) | ||
except: | ||
X_2, y_2 = X, y | ||
results = [(X_0, y_0), (X_1, y_1), (X_2, y_2)] | ||
return results | ||
|
||
def _cap_samples(self, X, y): | ||
if X.shape[0] <= self.max_samples: | ||
return [(X, y)] | ||
idxs = [i for i in range(X.shape[0])] | ||
R = [] | ||
for _ in range(3): | ||
smp_idxs = random.sample(idxs, self.max_samples) | ||
X_, y_ = X[smp_idxs], y[smp_idxs] | ||
if np.sum(y_) == 0: | ||
continue | ||
R += [(X_, y_)] | ||
return R | ||
|
||
def _get_ensemble(self, X, y): | ||
R = [] | ||
for X_0, y_0 in self._get_balanced_datasets(X, y): | ||
for X_1, y_1 in self._cap_samples(X_0, y_0): | ||
R += [(X_1, y_1)] | ||
return R | ||
|
||
def fit(self, X, y): | ||
self.reducer = LOL(n_components=100) | ||
self.reducer.fit(X, y) | ||
X = self.reducer.transform(X) | ||
self.ensemble = self._get_ensemble(X, y) | ||
|
||
def predict_proba(self, X): | ||
model = TabPFNClassifier( | ||
device=self.device, N_ensemble_configurations=self.N_ensemble_configurations | ||
) | ||
X = self.reducer.transform(X) | ||
R = [] | ||
for X_tr, y_tr in self.ensemble: | ||
# print(X_tr.shape, np.sum(y_tr)) | ||
model.fit(X_tr, y_tr) | ||
R += [model.predict_proba(X)[:, 1]] | ||
model.remove_models_from_memory() | ||
R = np.array(R).T | ||
y_h1 = np.mean(R, axis=1) | ||
y_h0 = 1 - y_h1 | ||
y_h = np.array([y_h0, y_h1]).T | ||
return y_h | ||
|
||
def save(self, file_name): | ||
data = { | ||
"device": self.device, | ||
"N_ensemble_configurations": self.N_ensemble_configurations, | ||
"reducer": self.reducer, | ||
"ensemble": self.ensemble, | ||
} | ||
joblib.dump(data, file_name) | ||
|
||
def load(self, file_name): | ||
data = joblib.load(file_name) | ||
model = TabPFNBinaryClassifier( | ||
device=data["device"], | ||
N_ensemble_configurations=data["N_ensemble_configurations"], | ||
) | ||
model.ensemble = data["ensemble"] | ||
model.reducer = data["reducer"] | ||
return TabPFNClassifierArtifact(model, 0.5) | ||
|
||
|
||
class Binarizer(object): | ||
def __init__(self, threshold): | ||
self.threshold = threshold | ||
|
||
def binarize(self, y_hat): | ||
y_bin = [] | ||
for y in y_hat: | ||
if y > self.threshold: | ||
y_bin += [1] | ||
else: | ||
y_bin += [0] | ||
return np.array(y_bin, dtype=np.uint8) | ||
|
||
|
||
class TabPFNClassifierArtifact(object): | ||
def __init__(self, model, threshold): | ||
self.model = model | ||
self.threshold = threshold | ||
if threshold is not None: | ||
self.binarizer = Binarizer(self.threshold) | ||
else: | ||
self.binarizer = None | ||
|
||
def predict_proba(self, X): | ||
return self.model.predict_proba(X)[:, 1] | ||
|
||
def predict(self, X): | ||
if self.binarizer is not None: | ||
y_hat = self.predict_proba(X) | ||
y_bin = self.binarizer.binarize(y_hat) | ||
else: | ||
y_bin = self.model.predict(X) | ||
return y_bin | ||
|
||
def run(self, X, y=None): | ||
results = collections.OrderedDict() | ||
results["main"] = { | ||
"idxs": None, | ||
"y": y, | ||
"y_hat": self.predict_proba(X), | ||
"b_hat": self.predict(X), | ||
} | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
import pandas as pd | ||
import h5py | ||
|
||
from eosce.models import ErsiliaCompoundEmbeddings | ||
from ..utils.matrices import Hdf5 | ||
from .. import ZairaBase | ||
|
||
from ..setup import SMILES_COLUMN | ||
from ..vars import DATA_SUBFOLDER, DATA_FILENAME, DESCRIPTORS_SUBFOLDER | ||
|
||
EOSCE_FILE_NAME = "eosce.h5" | ||
|
||
|
||
class EosceEmbedder(ZairaBase): | ||
def __init__(self): | ||
ZairaBase.__init__(self) | ||
self.model = ErsiliaCompoundEmbeddings() | ||
|
||
def calculate(self, smiles_list, output_h5): | ||
X = self.model.transform(smiles_list) | ||
if output_h5 is None: | ||
return X | ||
keys = ["key-{0}".format(i) for i in range(len(smiles_list))] | ||
features = ["feat-{0}".format(i) for i in range(X.shape[1])] | ||
inputs = smiles_list | ||
with h5py.File(output_h5, "w") as f: | ||
f.create_dataset("Keys", data=keys) | ||
f.create_dataset("Features", data=features) | ||
f.create_dataset("Inputs", data=inputs) | ||
f.create_dataset("Values", data=X) | ||
|
||
|
||
class EosceLoader(ZairaBase): | ||
def __init__(self): | ||
ZairaBase.__init__(self) | ||
self.path = self.get_output_dir() | ||
|
||
def open(self, eos_id): | ||
path = os.path.join(self.path, DESCRIPTORS_SUBFOLDER, eos_id, EOSCE_FILE_NAME) | ||
return Hdf5(path) | ||
|
||
|
||
class EosceDescriptors(ZairaBase): | ||
def __init__(self): | ||
ZairaBase.__init__(self) | ||
self.path = self.get_output_dir() | ||
self.input_csv = os.path.join(self.path, DATA_SUBFOLDER, DATA_FILENAME) | ||
self.smiles_list = self._get_smiles_list() | ||
|
||
def _get_smiles_list(self): | ||
df = pd.read_csv(self.input_csv) | ||
return list(df[SMILES_COLUMN]) | ||
|
||
def output_h5_filename(self): | ||
path = os.path.join(self.path, DESCRIPTORS_SUBFOLDER) | ||
os.makedirs(path, exist_ok=True) | ||
return os.path.join(path, EOSCE_FILE_NAME) | ||
|
||
def run(self): | ||
output_h5 = self.output_h5_filename() | ||
ref = EosceEmbedder() | ||
ref.calculate(self.smiles_list, output_h5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ESTIMATORS_FAMILY_SUBFOLDER = "ersilia_embedding" |
Oops, something went wrong.