Skip to content

Commit

Permalink
make code cleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed May 13, 2024
1 parent a7fd722 commit a7070d2
Show file tree
Hide file tree
Showing 16 changed files with 11,092 additions and 46,965 deletions.
2 changes: 1 addition & 1 deletion feature_importance/01_ablation_regression_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#SBATCH --partition=yugroup

source activate mdi
command="01_run_ablation_regression.py --nreps 1 --config mdi_local.real_data_regression --split_seed ${1} --ignore_cache --create_rmd --result_name diabetes_regr"
command="01_run_ablation_regression.py --nreps 1 --config mdi_local.real_data_regression --split_seed ${1} --ignore_cache --create_rmd --result_name diabetes_regr_new"

# Execute the command
python $command
4 changes: 2 additions & 2 deletions feature_importance/01_ablation_script.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash

slurm_script="01_ablation_regression_script.sh"
slurm_script="01_ablation_classification_script.sh"

for rep in {1..2}
for rep in {1..5}
do
sbatch $slurm_script $rep # Submit SLURM job using the specified script
done
128 changes: 60 additions & 68 deletions feature_importance/01_run_ablation_classification.py

Large diffs are not rendered by default.

117 changes: 54 additions & 63 deletions feature_importance/01_run_ablation_regression.py

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Binary file modified feature_importance/diabetes_classification_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified feature_importance/diabetes_classification_test_subset_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified feature_importance/diabetes_classification_test_subset_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified feature_importance/diabetes_classification_train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

FI_ESTIMATORS = [
[FIModelConfig('TreeSHAP_RF', tree_shap_evaluation_RF, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LFI_fit_on_inbag_RF', LFI_evaluation_RF_MDI, model_type='tree', splitting_strategy = "train-test", other_params={"include_raw":False, "fit_on":"inbag", "prediction_model": Ridge(alpha=1e-6)})],
[FIModelConfig('LFI_fit_on_OOB_RF', LFI_evaluation_RF_OOB, model_type='tree', splitting_strategy = "train-test", other_params={"fit_on":"oob"})],
[FIModelConfig('LFI_evaluate_on_all_RF_plus', LFI_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LFI_evaluate_on_oob_RF_plus', LFI_evaluation_RF_plus_OOB, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LFI_fit_on_inbag_RF', LFI_evaluation_RF_MDI_classification, model_type='tree', splitting_strategy = "train-test", ascending = False, other_params={"include_raw":False, "fit_on":"inbag", "prediction_model": Ridge(alpha=1e-6)})],
[FIModelConfig('LFI_fit_on_OOB_RF', LFI_evaluation_RF_OOB, model_type='tree', splitting_strategy = "train-test", ascending = False, other_params={"fit_on":"oob"})],
[FIModelConfig('LFI_evaluate_on_all_RF_plus', LFI_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_evaluate_on_oob_RF_plus', LFI_evaluation_RF_plus_OOB, model_type='tree', splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('Kernel_SHAP_RF_plus', kernel_shap_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LIME_RF_plus', lime_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
]
19 changes: 19 additions & 0 deletions feature_importance/fi_config/mdi_local/real_data_regression/dgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,31 @@
"data_name": "diabetes_regr",
"sample_row_n": None
}
# X_PARAMS_DICT = {
# "source": "imodels",
# "data_name": "satellite_image",
# "sample_row_n": None
# }
# X_PARAMS_DICT = {
# "source": "openml",
# "task_id": 359946,
# "sample_row_n": None
# }

Y_DGP = sample_real_data_y
Y_PARAMS_DICT = {
"source": "imodels",
"data_name": "diabetes_regr"
}
# Y_PARAMS_DICT = {
# "source": "imodels",
# "data_name": "satellite_image"
# }
# Y_PARAMS_DICT = {
# "source": "openml",
# "task_id": 359946
# }

# vary one parameter
VARY_PARAM_NAME = "sample_row_n"
VARY_PARAM_VALS = {"keep_all_rows": None}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

FI_ESTIMATORS = [
[FIModelConfig('TreeSHAP_RF', tree_shap_evaluation_RF, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LFI_fit_on_inbag_RF', LFI_evaluation_RF_MDI, model_type='tree', splitting_strategy = "train-test", other_params={"include_raw":False, "fit_on":"inbag", "prediction_model": Ridge(alpha=1e-6)})],
[FIModelConfig('LFI_fit_on_OOB_RF', LFI_evaluation_RF_OOB, model_type='tree', splitting_strategy = "train-test", other_params={"fit_on":"oob"})],
[FIModelConfig('LFI_evaluate_on_all_RF_plus', LFI_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LFI_evaluate_on_oob_RF_plus', LFI_evaluation_RF_plus_OOB, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LFI_fit_on_inbag_RF', LFI_evaluation_RF_MDI, model_type='tree', splitting_strategy = "train-test", ascending = False, other_params={"include_raw":False, "fit_on":"inbag", "prediction_model": Ridge(alpha=1e-6)})],
[FIModelConfig('LFI_fit_on_OOB_RF', LFI_evaluation_RF_OOB, model_type='tree', splitting_strategy = "train-test", ascending = False, other_params={"fit_on":"oob"})],
[FIModelConfig('LFI_evaluate_on_all_RF_plus', LFI_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_evaluate_on_oob_RF_plus', LFI_evaluation_RF_plus_OOB, model_type='tree', splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('Kernel_SHAP_RF_plus', kernel_shap_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LIME_RF_plus', lime_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
]
72 changes: 39 additions & 33 deletions feature_importance/real_data_ablation_visulization_new.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -15145,7 +15145,7 @@
"13 2 "
]
},
"execution_count": 3,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -15156,7 +15156,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand All @@ -15182,23 +15182,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'Column not found: ablation_time'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m# Print the ablation time\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m averages \u001b[39m=\u001b[39m combined_df\u001b[39m.\u001b[39;49mgroupby(\u001b[39m'\u001b[39;49m\u001b[39mfi\u001b[39;49m\u001b[39m'\u001b[39;49m)[\u001b[39m'\u001b[39;49m\u001b[39mablation_time\u001b[39;49m\u001b[39m'\u001b[39;49m]\u001b[39m.\u001b[39mmean()\u001b[39m.\u001b[39mreset_index()\n\u001b[1;32m 3\u001b[0m \u001b[39mprint\u001b[39m(averages)\n",
"File \u001b[0;32m/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/pandas/core/groupby/generic.py:1415\u001b[0m, in \u001b[0;36mDataFrameGroupBy.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1406\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(key, \u001b[39mtuple\u001b[39m) \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(key) \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m 1407\u001b[0m \u001b[39m# if len == 1, then it becomes a SeriesGroupBy and this is actually\u001b[39;00m\n\u001b[1;32m 1408\u001b[0m \u001b[39m# valid syntax, so don't raise warning\u001b[39;00m\n\u001b[1;32m 1409\u001b[0m warnings\u001b[39m.\u001b[39mwarn(\n\u001b[1;32m 1410\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mIndexing with multiple keys (implicitly converted to a tuple \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1411\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mof keys) will be deprecated, use a list instead.\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 1412\u001b[0m \u001b[39mFutureWarning\u001b[39;00m,\n\u001b[1;32m 1413\u001b[0m stacklevel\u001b[39m=\u001b[39mfind_stack_level(),\n\u001b[1;32m 1414\u001b[0m )\n\u001b[0;32m-> 1415\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__getitem__\u001b[39;49m(key)\n",
"File \u001b[0;32m/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/pandas/core/base.py:248\u001b[0m, in \u001b[0;36mSelectionMixin.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 247\u001b[0m \u001b[39mif\u001b[39;00m key \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobj:\n\u001b[0;32m--> 248\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mKeyError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mColumn not found: \u001b[39m\u001b[39m{\u001b[39;00mkey\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 249\u001b[0m subset \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobj[key]\n\u001b[1;32m 250\u001b[0m ndim \u001b[39m=\u001b[39m subset\u001b[39m.\u001b[39mndim\n",
"\u001b[0;31mKeyError\u001b[0m: 'Column not found: ablation_time'"
]
}
],
"outputs": [],
"source": [
"# # Print the ablation time\n",
"# averages = combined_df.groupby('fi')['ablation_time'].mean().reset_index()\n",
Expand All @@ -15207,16 +15193,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"#################### Change the following according to the dataset ####################\n",
"task = \"regression\" #\"regression\" \"classification\"\n",
"########################################################################################\n",
"methods_rf = [\"TreeSHAP_RF\", \"LFI_fit_on_inbag_RF\", \"LFI_fit_on_OOB_RF\", \"LFI_evaluate_on_all_RF_plus\", \"LFI_evaluate_on_oob_RF_plus\",\n",
"methods = [\"TreeSHAP_RF\", \"LFI_fit_on_inbag_RF\", \"LFI_fit_on_OOB_RF\", \"LFI_evaluate_on_all_RF_plus\", \"LFI_evaluate_on_oob_RF_plus\",\n",
" \"Kernel_SHAP_RF_plus\", \"LIME_RF_plus\"]\n",
"methods_rf_plus = [\"Kernel_SHAP_RF_plus\",\"LFI_with_raw_RF_plus\", \"LIME_RF_plus\"]\n",
"n_testsize = combined_df[['train_size', 'test_size']].drop_duplicates()\n",
"num_features = combined_df['num_features'].drop_duplicates()[0]\n",
"metrics = {\"regression\": [\"MSE\", \"R_2\"], \"classification\": [\"AUROC\",\"AUPRC\", \"F1\"]}"
Expand All @@ -15226,19 +15211,40 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: RF\n",
"MSE before ablation: 3161.100213831893\n",
"R2 before ablation: 0.4470820005009606\n",
"\n",
"Model: RF_plus\n",
"MSE before ablation: 2943.1783568546443\n",
"R2 before ablation: 0.4856215513988536\n",
"\n"
"Model: RF\n"
]
},
{
"ename": "KeyError",
"evalue": "'test_all_mse'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/pandas/core/indexes/base.py:3803\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 3802\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 3803\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_engine\u001b[39m.\u001b[39;49mget_loc(casted_key)\n\u001b[1;32m 3804\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n",
"File \u001b[0;32m/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/pandas/_libs/index.pyx:138\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/pandas/_libs/index.pyx:165\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5745\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5753\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
"\u001b[0;31mKeyError\u001b[0m: 'test_all_mse'",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[16], line 13\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[39mfor\u001b[39;00m model, group_df \u001b[39min\u001b[39;00m grouped:\n\u001b[1;32m 12\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mModel:\u001b[39m\u001b[39m\"\u001b[39m, model)\n\u001b[0;32m---> 13\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mMSE before ablation:\u001b[39m\u001b[39m\"\u001b[39m, group_df[\u001b[39m\"\u001b[39;49m\u001b[39mtest_all_mse\u001b[39;49m\u001b[39m\"\u001b[39;49m]\u001b[39m.\u001b[39mmean())\n\u001b[1;32m 14\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mR2 before ablation:\u001b[39m\u001b[39m\"\u001b[39m, group_df[\u001b[39m\"\u001b[39m\u001b[39mtest_all_r2\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39mmean())\n\u001b[1;32m 15\u001b[0m \u001b[39mprint\u001b[39m()\n",
"File \u001b[0;32m/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/pandas/core/frame.py:3805\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3803\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcolumns\u001b[39m.\u001b[39mnlevels \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m 3804\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 3805\u001b[0m indexer \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcolumns\u001b[39m.\u001b[39;49mget_loc(key)\n\u001b[1;32m 3806\u001b[0m \u001b[39mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 3807\u001b[0m indexer \u001b[39m=\u001b[39m [indexer]\n",
"File \u001b[0;32m/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/pandas/core/indexes/base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 3803\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_engine\u001b[39m.\u001b[39mget_loc(casted_key)\n\u001b[1;32m 3804\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n\u001b[0;32m-> 3805\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mKeyError\u001b[39;00m(key) \u001b[39mfrom\u001b[39;00m \u001b[39merr\u001b[39;00m\n\u001b[1;32m 3806\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mTypeError\u001b[39;00m:\n\u001b[1;32m 3807\u001b[0m \u001b[39m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3808\u001b[0m \u001b[39m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3809\u001b[0m \u001b[39m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3810\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_check_indexing_error(key)\n",
"\u001b[0;31mKeyError\u001b[0m: 'test_all_mse'"
]
}
],
Expand Down
Loading

0 comments on commit a7070d2

Please sign in to comment.