Skip to content

Commit

Permalink
new rescaling
Browse files Browse the repository at this point in the history
  • Loading branch information
GemmaTuron committed Apr 26, 2023
1 parent 8af2372 commit ef4c013
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions zairachem/pool/bagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import joblib
import h5py
import collections
from scipy.special import expit
from sklearn.linear_model import LogisticRegressionCV, LinearRegression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.preprocessing import RobustScaler
from sklearn.preprocessing import RobustScaler, PowerTransformer
from sklearn.metrics import roc_curve, auc, r2_score

from .. import ZairaBase
Expand Down Expand Up @@ -155,8 +156,8 @@ def save(self, filename):


class PoolClassifier(object):
def __init__(self, path, mode="scaling"):
assert mode in ["scaling", "median", "model"]
def __init__(self, path, mode="weighting"):
assert mode in ["weighting", "median", "model"]
self.path = path
if not os.path.exists(self.path):
os.makedirs(self.path, exist_ok=True)
Expand All @@ -168,7 +169,8 @@ def _get_model_filename(self, n):
def _fit_just_median(self, df_X, df_y):
return np.median(np.array(df_X), axis=1)

def _fit_scaling(self, df_X, df_y):
def _fit_weighting(self, df_X, df_y):
y = np.array(df_y).ravel()
cols = list(df_X.columns)
X = np.array(df_X)
p25 = np.percentile(X.ravel(), 25)
Expand All @@ -177,18 +179,21 @@ def _fit_scaling(self, df_X, df_y):
scale = (p25, p50, p75)
for c in cols:
X = np.array(df_X[c]).reshape(-1, 1)
mdl = RobustScaler()
mdl.fit(X)
mdl0 = PowerTransformer()
mdl0.fit(X)
X = mdl0.transform(X)
mdl1 = LogisticRegressionCV()
mdl1.fit(X, y)
filename = self._get_model_filename(c)
joblib.dump(mdl, filename)
joblib.dump((mdl0, mdl1), filename)
filename = self._get_model_filename("overall")
joblib.dump(scale, filename)
filename = self._get_model_filename("weighting")
ws = WeightSchemes(df_X, df_y, "classification")
ws.distance_to_leads()
ws.importance()
ws.save(filename)
return self._predict_scaling(df_X)
return self._predict_weighting(df_X)

def _fit_model(self, df_X, df_y):
y = np.array(df_y).ravel()
Expand All @@ -204,22 +209,19 @@ def _fit_model(self, df_X, df_y):
def _predict_just_median(self, df_X):
return np.median(np.array(df_X), axis=1)

def _predict_scaling(self, df_X):
def _predict_weighting(self, df_X):
cols = list(df_X.columns)
Y_hat = []
for c in cols:
filename = self._get_model_filename(c)
if os.path.exists(filename):
mdl = joblib.load(filename)
mdl0, mdl1 = joblib.load(filename)
X = np.array(df_X[c]).reshape(-1, 1)
y_hat = mdl.transform(X).ravel()
X = mdl0.transform(X)
y_hat = mdl1.predict_proba(X)[:,1]
Y_hat += [y_hat]
Y_hat = np.array(Y_hat).T
filename = self._get_model_filename("overall")
scale = joblib.load(filename)
iqr = scale[-1] - scale[0]
med = scale[1]
Y_hat = Y_hat * iqr + med
filename = self._get_model_filename("weighting")
weights = joblib.load(filename)
wvals = weights["weights"]
Expand Down Expand Up @@ -248,16 +250,16 @@ def _predict_model(self, df_X):
return np.median(Y_hat, axis=1)

def fit(self, df_X, df_y):
if self.mode == "scaling":
return self._fit_scaling(df_X, df_y)
if self.mode == "weighting":
return self._fit_weighting(df_X, df_y)
if self.mode == "median":
return self._fit_just_median(df_X, df_y)
if self.mdoe == "model":
return self._fit_model(df_X, df_y)

def predict(self, df_X):
if self.mode == "scaling":
return self._predict_scaling(df_X)
if self.mode == "weighting":
return self._predict_weighting(df_X)
if self.mode == "median":
return self._predict_just_median(df_X)
if self.mode == "model":
Expand Down

0 comments on commit ef4c013

Please sign in to comment.