Skip to content

Commit

Permalink
Add MDI to competing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Jan 14, 2024
1 parent b519455 commit cd3baac
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
]

FI_ESTIMATORS = [
# [FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, model_type='tree')],
[FIModelConfig('MDI_sub_stumps', MDI_local_sub_stumps, model_type='tree')],
[FIModelConfig('TreeSHAP', tree_shap_local, model_type='tree')],
[FIModelConfig('Permutation', permutation_local, model_type='tree')],
[FIModelConfig('LIME', lime_local, model_type='tree')],
[FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, model_type='tree')],
[FIModelConfig('MDI_sub_stumps', MDI_local_sub_stumps, model_type='tree')],
]
57 changes: 36 additions & 21 deletions feature_importance/scripts/competing_methods_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def permutation_local(X, y, fit, num_permutations=100):

return result_table

##########To Do for Zach: Please add the implementation of local MDI and MDI+ below##########
def MDI_local_sub_stumps(X, y, fit):
def MDI_local_sub_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=False, **kwargs):
"""
Compute local MDI importance for each feature and sample.
:param X: design matrix
Expand All @@ -106,9 +105,24 @@ def MDI_local_sub_stumps(X, y, fit):
rf_plus_model = RFPlus(rf_model=fit, **kwargs)
rf_plus_model.fit(X, y)

mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X, y, scoring_fns={"r2_score": _fast_r2_score, "negative_mae": neg_mae}, local_scoring_fns=True)
result = mdi_plus_scores["local"]["negative_mae"]

try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error)
if return_stability_scores:
stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
except ValueError as e:
if str(e) == 'Transformer representation was empty for all trees.':
mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
if isinstance(X, pd.DataFrame):
mdi_plus_scores.index = X.columns
mdi_plus_scores.index.name = 'var'
mdi_plus_scores.reset_index(inplace=True)
stability_scores = None
else:
raise
# mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
# if return_stability_scores:
# mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)
result = mdi_plus_scores["local"]
# Convert the array to a DataFrame
result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])

Expand Down Expand Up @@ -139,27 +153,28 @@ def MDI_local_all_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
RFPlus = RandomForestPlusClassifier
else:
raise ValueError("Unknown task.")

rf_plus_model = RFPlus(rf_model=fit, **kwargs)
rf_plus_model.fit(X, y)
# try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error)
result = mdi_plus_scores["local"]
# if return_stability_scores:
# stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
# except ValueError as e:
# if str(e) == 'Transformer representation was empty for all trees.':
# mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
# if isinstance(X, pd.DataFrame):
# mdi_plus_scores.index = X.columns
# mdi_plus_scores.index.name = 'var'
# mdi_plus_scores.reset_index(inplace=True)
# stability_scores = None
# else:
# raise
try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error)
if return_stability_scores:
stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
except ValueError as e:
if str(e) == 'Transformer representation was empty for all trees.':
mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
if isinstance(X, pd.DataFrame):
mdi_plus_scores.index = X.columns
mdi_plus_scores.index.name = 'var'
mdi_plus_scores.reset_index(inplace=True)
stability_scores = None
else:
raise
# mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
# if return_stability_scores:
# mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)

result = mdi_plus_scores["local"]
# Convert the array to a DataFrame
result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])

return result_table
Expand Down

0 comments on commit cd3baac

Please sign in to comment.