Skip to content

Commit

Permalink
Update dgps: Correct typo; Update models.py: include no raw features;…
Browse files Browse the repository at this point in the history
… Update competing_methods: adapt to Zach changes
  • Loading branch information
zyliang2001 committed Jan 21, 2024
1 parent 61a2332 commit 1f23f3e
Show file tree
Hide file tree
Showing 4 changed files with 410 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
# VARY_PARAM_VALS = {"100": 100, "250": 250, "500": 500, "1000": 1000}

# vary two parameters in a grid
VARY_PARAM_NAME = ["heritability", "sample_row_n"]
### "sample_row_n"
VARY_PARAM_NAME = ["heritability", "n"]
VARY_PARAM_VALS = {"heritability": {"0.1": 0.1, "0.2": 0.2, "0.4": 0.4, "0.8": 0.8},
"n": {"100": 100, "250": 250, "500": 500, "1000": 1000}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
]

FI_ESTIMATORS = [
[FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, ascending = False, model_type='tree')],
[FIModelConfig('MDI_sub_stumps', MDI_local_sub_stumps, ascending = False, model_type='tree')],
[FIModelConfig('MDI_all_stumps_with_raw', MDI_local_all_stumps, ascending = False, model_type='tree')],
[FIModelConfig('MDI_sub_stumps_with_raw', MDI_local_sub_stumps, ascending = False, model_type='tree')],
[FIModelConfig('MDI_all_stumps_without_raw', MDI_local_all_stumps, ascending = False, model_type='tree', include_raw=False)],
[FIModelConfig('MDI_sub_stumps_without_raw', MDI_local_sub_stumps, ascending = False, model_type='tree', include_raw=False)],
[FIModelConfig('TreeSHAP', tree_shap_local, model_type='tree')],
[FIModelConfig('Permutation', permutation_local, model_type='tree')],
[FIModelConfig('LIME', lime_local, model_type='tree')],
]
]

# [FIModelConfig('Permutation', permutation_local, model_type='tree')],
341 changes: 330 additions & 11 deletions feature_importance/local_MDI_plus_visulization.ipynb

Large diffs are not rendered by default.

165 changes: 71 additions & 94 deletions feature_importance/scripts/competing_methods_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,76 +15,6 @@
from sklearn.metrics import r2_score, mean_absolute_error, accuracy_score, roc_auc_score, mean_squared_error


def neg_mae(y_true, y_pred, **kwargs):
"""
Evaluates negative mean absolute error
"""
return -mean_absolute_error(y_true, y_pred, **kwargs)

def tree_shap_local(X, y, fit):
"""
Compute average treeshap value across observations
:param X: design matrix
:param y: response
:param fit: fitted model of interest (tree-based)
:return: dataframe of shape: (n_samples, n_features)
"""
explainer = shap.TreeExplainer(fit)
shap_values = explainer.shap_values(X, check_additivity=False)
if sklearn.base.is_classifier(fit):
def add_abs(a, b):
return abs(a) + abs(b)
results = reduce(add_abs, shap_values)
else:
results = abs(shap_values)
result_table = pd.DataFrame(results, columns=[f'Feature_{i}' for i in range(X.shape[1])])
# results = results.mean(axis=0)
# results = pd.DataFrame(data=results, columns=['importance'])
# # Use column names from dataframe if possible
# if isinstance(X, pd.DataFrame):
# results.index = X.columns
# results.index.name = 'var'
# results.reset_index(inplace=True)

return result_table

def permutation_local(X, y, fit, num_permutations=100):
"""
Compute local permutation importance for each feature and sample.
:param X: design matrix
:param y: response
:param fit: fitted model of interest (tree-based)
:num_permutations: Number of permutations for each feature (default is 100)
:return: dataframe of shape: (n_samples, n_features)
"""

# Get the number of samples and features
num_samples, num_features = X.shape

# Initialize array to store local permutation importance
lpi = np.zeros((num_samples, num_features))

# For each feature
for k in range(num_features):
# Permute X_k num_permutations times
for b in range(num_permutations):
X_permuted = X.copy()
X_permuted[:, k] = np.random.permutation(X[:, k])

# Feed permuted data through the fitted model
y_pred_permuted = fit.predict(X_permuted)

# Calculate MSE for each sample
for i in range(num_samples):
lpi[i, k] += (y[i]-y_pred_permuted[i])**2

lpi /= num_permutations

# Convert the array to a DataFrame
result_table = pd.DataFrame(lpi, columns=[f'Feature_{i}' for i in range(num_features)])

return result_table

def MDI_local_sub_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=False, **kwargs):
"""
Compute local MDI importance for each feature and sample.
Expand All @@ -106,26 +36,19 @@ def MDI_local_sub_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
rf_plus_model.fit(X, y)

try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error, version = "zach")
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error, version = "sub", lfi=True)["lfi"]
if return_stability_scores:
raise NotImplementedError
stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
except ValueError as e:
if str(e) == 'Transformer representation was empty for all trees.':
mdi_plus_scores = np.zeros((num_samples, num_features)) ### Not sure if this is right
# mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
# if isinstance(X, pd.DataFrame):
# mdi_plus_scores.index = X.columns
# mdi_plus_scores.index.name = 'var'
# mdi_plus_scores.reset_index(inplace=True)
mdi_plus_scores = np.zeros((num_samples, num_features))
stability_scores = None
else:
raise
# mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
# if return_stability_scores:
# mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)
result = mdi_plus_scores["local"].values

# Convert the array to a DataFrame
result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])
result_table = pd.DataFrame(mdi_plus_scores, columns=[f'Feature_{i}' for i in range(num_features)])

return result_table

Expand Down Expand Up @@ -159,32 +82,26 @@ def MDI_local_all_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
rf_plus_model = RFPlus(rf_model=fit, **kwargs)
rf_plus_model.fit(X, y)
try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error, version = "tiffany")
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error, version = "all", lfi=True)["lfi"]
if return_stability_scores:
raise NotImplementedError
stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
except ValueError as e:
if str(e) == 'Transformer representation was empty for all trees.':
mdi_plus_scores = np.zeros((num_samples, num_features)) ### Not sure if this is right
# mdi_plus_scores = pd.DataFrame(data=np.zeros(X.shape[1]), columns=['importance'])
# if isinstance(X, pd.DataFrame):
# mdi_plus_scores.index = X.columns
# mdi_plus_scores.index.name = 'var'
# mdi_plus_scores.reset_index(inplace=True)
mdi_plus_scores = np.zeros((num_samples, num_features))
stability_scores = None
else:
raise
# mdi_plus_scores["prediction_score"] = rf_plus_model.prediction_score_
# if return_stability_scores:
# mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)
result = mdi_plus_scores["local"].values

# Convert the array to a DataFrame
result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])
result_table = pd.DataFrame(mdi_plus_scores, columns=[f'Feature_{i}' for i in range(num_features)])

return result_table

def lime_local(X, y, fit):
"""
Compute LIME local importance for each feature and sample.
Larger values indicate more important features.
:param X: design matrix
:param y: response
:param fit: fitted model of interest (tree-based)
Expand All @@ -205,4 +122,64 @@ def lime_local(X, y, fit):
# Convert the array to a DataFrame
result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])

return result_table

def tree_shap_local(X, y, fit):
"""
Compute average treeshap value across observations.
Larger absolute values indicate more important features.
:param X: design matrix
:param y: response
:param fit: fitted model of interest (tree-based)
:return: dataframe of shape: (n_samples, n_features)
"""
explainer = shap.TreeExplainer(fit)
shap_values = explainer.shap_values(X, check_additivity=False)
if sklearn.base.is_classifier(fit):
# Shape values are returned as a list of arrays, one for each class
def add_abs(a, b):
return abs(a) + abs(b)
results = reduce(add_abs, shap_values)
else:
results = abs(shap_values)
result_table = pd.DataFrame(results, columns=[f'Feature_{i}' for i in range(X.shape[1])])

return result_table

def permutation_local(X, y, fit, num_permutations=100):
"""
Compute local permutation importance for each feature and sample.
Larger values indicate more important features.
:param X: design matrix
:param y: response
:param fit: fitted model of interest (tree-based)
:num_permutations: Number of permutations for each feature (default is 100)
:return: dataframe of shape: (n_samples, n_features)
"""

# Get the number of samples and features
num_samples, num_features = X.shape

# Initialize array to store local permutation importance
lpi = np.zeros((num_samples, num_features))

# For each feature
for k in range(num_features):
# Permute X_k num_permutations times
for b in range(num_permutations):
X_permuted = X.copy()
X_permuted[:, k] = np.random.permutation(X[:, k])

# Feed permuted data through the fitted model
y_pred_permuted = fit.predict(X_permuted)

# Calculate MSE for each sample
for i in range(num_samples):
lpi[i, k] += (y[i]-y_pred_permuted[i])**2

lpi /= num_permutations

# Convert the array to a DataFrame
result_table = pd.DataFrame(lpi, columns=[f'Feature_{i}' for i in range(num_features)])

return result_table

0 comments on commit 1f23f3e

Please sign in to comment.