Skip to content

Commit

Permalink
update simulation utils for survival
Browse files Browse the repository at this point in the history
  • Loading branch information
tiffanymtang committed Sep 15, 2023
1 parent 9215c5f commit 1f9a083
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 19 deletions.
22 changes: 22 additions & 0 deletions feature_importance/fi_config/mdi_plus_survival/test/dgp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
sys.path.append("../..")
from feature_importance.scripts.simulations_util import *

import sksurv
from sksurv import datasets


X_DGP = load_X_data
X_PARAMS_DICT = {
"data_fn": sksurv.datasets.load_aids
}
Y_DGP = load_y_data
Y_PARAMS_DICT = {
"data_fn": sksurv.datasets.load_aids
}

VARY_PARAM_NAME = None
VARY_PARAM_VALS = None
# VARY_PARAM_NAME = ["corrupt_size", "sample_row_n"]
# VARY_PARAM_VALS = {"corrupt_size": {"0": 0, "0.01": 0.005, "0.025": 0.0125, "0.05": 0.025},
# "sample_row_n": {"100": 100, "250": 250, "472": 472}}
31 changes: 31 additions & 0 deletions feature_importance/fi_config/mdi_plus_survival/test/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import copy
from feature_importance.util import ModelConfig, FIModelConfig
from feature_importance.scripts.competing_methods import tree_mdi_plus, tree_mdi, tree_mdi_OOB, tree_mda, tree_shap
from sksurv.ensemble import RandomSurvivalForest

from imodels.importance.rf_plus import RandomForestPlusSurvival
from imodels.importance.ppms import CoxnetSurvivalPPM



rf_model = RandomSurvivalForest(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=42)
ppm_model = CoxnetSurvivalPPM()
ESTIMATORS = [
# [ModelConfig('RSF', RandomSurvivalForest, model_type='rsf',
# other_params={'n_estimators': 100, 'min_samples_leaf': 5, 'max_features': 0.33, 'random_state': 27})],
[ModelConfig('RSF+', RandomForestPlusSurvival, model_type='rsf+',
other_params={'rf_model': copy.deepcopy(rf_model),
'prediction_model': copy.deepcopy(ppm_model)})],
]

FI_ESTIMATORS = [
[FIModelConfig('MDI+', tree_mdi_plus, model_type='rsf+', splitting_strategy='train-test',
other_params={'refit': False, 'sample_split': None,
'mdiplus_kwargs': {'sample_split': None}})],
# [FIModelConfig('MDI+_ridge_loo_mae', tree_mdi_plus, model_type='tree', ascending=False, other_params={'scoring_fns': mean_absolute_error})],
# [FIModelConfig('MDI+_Huber_loo_huber_loss', tree_mdi_plus, model_type='tree', ascending=False, other_params={'prediction_model': RobustRegressorPPM(), 'scoring_fns': huber_loss})],
# [FIModelConfig('MDI', tree_mdi, model_type='tree')],
# [FIModelConfig('MDI-oob', tree_mdi_OOB, model_type='tree')],
# [FIModelConfig('MDA', tree_mda, model_type='tree')],
# [FIModelConfig('TreeSHAP', tree_shap, model_type='tree')]
]
57 changes: 38 additions & 19 deletions feature_importance/scripts/competing_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import numpy as np
import sklearn.base
from sklearn.base import RegressorMixin, ClassifierMixin
from sksurv.base import SurvivalAnalysisMixin
from functools import reduce

import shap
from imodels.importance.rf_plus import RandomForestPlusRegressor, RandomForestPlusClassifier
from imodels.importance.rf_plus import RandomForestPlusRegressor, \
RandomForestPlusClassifier, RandomForestPlusSurvival
from feature_importance.scripts.mdi_oob import MDI_OOB
from feature_importance.scripts.mda import MDA

Expand All @@ -21,9 +23,10 @@ def tree_mdi_plus_ensemble(X, y, fit, scoring_fns="auto", **kwargs):
the column names are used in the output
:param y: ndarray of shape (n_samples, n_targets)
The observed responses.
:param rf_model: scikit-learn random forest object or None
The RF model to be used for interpretation. If None, then a new
RandomForestRegressor or RandomForestClassifier is instantiated.
:param fit: scikit-learn random forest object, RandomForestPlus object, or None
The RF(+) model to be used for interpretation. If None, then a new
RandomForestPlus object is instantiated.
:param scoring_fns: list of scoring functions to use for MDI+ scoring
:param kwargs: additional arguments to pass to
RandomForestPlusRegressor or RandomForestPlusClassifier class.
:return: dataframe - [Var, Importance]
Expand All @@ -35,6 +38,8 @@ def tree_mdi_plus_ensemble(X, y, fit, scoring_fns="auto", **kwargs):
RFPlus = RandomForestPlusRegressor
elif isinstance(fit, ClassifierMixin):
RFPlus = RandomForestPlusClassifier
elif isinstance(fit, SurvivalAnalysisMixin):
RFPlus = RandomForestPlusSurvival
else:
raise ValueError("Unknown task.")

Expand Down Expand Up @@ -65,7 +70,8 @@ def tree_mdi_plus_ensemble(X, y, fit, scoring_fns="auto", **kwargs):
return mdi_plus_ranks_df


def tree_mdi_plus(X, y, fit, scoring_fns="auto", return_stability_scores=False, **kwargs):
def tree_mdi_plus(X, y, fit, scoring_fns="auto", refit=True, mdiplus_kwargs=None,
return_stability_scores=False, **kwargs):
"""
Wrapper around MDI+ object to get feature importance scores
Expand All @@ -74,26 +80,36 @@ def tree_mdi_plus(X, y, fit, scoring_fns="auto", return_stability_scores=False,
the column names are used in the output
:param y: ndarray of shape (n_samples, n_targets)
The observed responses.
:param rf_model: scikit-learn random forest object or None
The RF model to be used for interpretation. If None, then a new
RandomForestRegressor or RandomForestClassifier is instantiated.
:param kwargs: additional arguments to pass to
RandomForestPlusRegressor or RandomForestPlusClassifier class.
:param fit: scikit-learn random forest object, RandomForestPlus object, or None
The RF(+) model to be used for interpretation. If None, then a new
RandomForestPlus object is instantiated.
:param scoring_fns: list of scoring functions to use for MDI+ scoring
:param refit: whether to refit the model
:param return_stability_scores: whether to return stability scores
:param mdiplus_kwargs: kwargs to pass to RandomForestPlus.get_mdi_plus_scores()
:param kwargs: additional arguments to pass to RandomForestPlus* class.
:return: dataframe - [Var, Importance]
Var: variable name
Importance: MDI+ score
"""

if isinstance(fit, RegressorMixin):
RFPlus = RandomForestPlusRegressor
elif isinstance(fit, ClassifierMixin):
RFPlus = RandomForestPlusClassifier
if refit:
if isinstance(fit, RegressorMixin):
RFPlus = RandomForestPlusRegressor
elif isinstance(fit, ClassifierMixin):
RFPlus = RandomForestPlusClassifier
elif isinstance(fit, SurvivalAnalysisMixin):
RFPlus = RandomForestPlusSurvival
else:
raise ValueError("Unknown task.")
rf_plus_model = RFPlus(rf_model=fit, **kwargs)
rf_plus_model.fit(X, y)
else:
raise ValueError("Unknown task.")
rf_plus_model = RFPlus(rf_model=fit, **kwargs)
rf_plus_model.fit(X, y)
rf_plus_model = fit
try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, scoring_fns=scoring_fns)
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(
X=X, y=y, scoring_fns=scoring_fns, **mdiplus_kwargs
)
if return_stability_scores:
stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
except ValueError as e:
Expand All @@ -106,7 +122,10 @@ def tree_mdi_plus(X, y, fit, scoring_fns="auto", return_stability_scores=False,
stability_scores = None
else:
raise
mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
if isinstance(rf_plus_model, SurvivalAnalysisMixin):
mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_["cindex_ipcw"]
else:
mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
if return_stability_scores:
mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)

Expand Down
11 changes: 11 additions & 0 deletions feature_importance/scripts/simulations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,17 @@ def entropy_y(X, c=3, return_support=False):
return y


def load_X_data(data_fn):
X = data_fn()[0]
categorical_columns = X.select_dtypes(include=['category']).columns
X = X.drop(columns=categorical_columns)
return X


def load_y_data(X, data_fn, return_support=False):
return data_fn()[1], None, None


class IndexedArray(np.ndarray):
def __new__(cls, input_array, index=None):
obj = np.asarray(input_array).view(cls)
Expand Down

0 comments on commit 1f9a083

Please sign in to comment.