Skip to content

Commit

Permalink
Add ablation_demo
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Feb 26, 2024
1 parent 4d74f40 commit 928b64f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions feature_importance/ablation_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -49,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -107,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -117,7 +117,7 @@
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[5], line 26\u001b[0m\n\u001b[0;32m 23\u001b[0m metric_results[\u001b[39m'\u001b[39m\u001b[39mMSE_before_ablation\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m mean_squared_error(y_test, y_pred)\n\u001b[0;32m 25\u001b[0m \u001b[39m# Ablation\u001b[39;00m\n\u001b[1;32m---> 26\u001b[0m score \u001b[39m=\u001b[39m rf_plus_model\u001b[39m.\u001b[39;49mget_mdi_plus_scores(X_test, y_test, lfi\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, lfi_abs \u001b[39m=\u001b[39;49m \u001b[39m\"\u001b[39;49m\u001b[39moutside\u001b[39;49m\u001b[39m\"\u001b[39;49m, sample_split\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m)\n\u001b[0;32m 27\u001b[0m local_fi_score \u001b[39m=\u001b[39m score[\u001b[39m\"\u001b[39m\u001b[39mlfi\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[0;32m 28\u001b[0m ascending \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m \u001b[39m# False for MDI\u001b[39;00m\n",
"Cell \u001b[1;32mIn[12], line 26\u001b[0m\n\u001b[0;32m 23\u001b[0m metric_results[\u001b[39m'\u001b[39m\u001b[39mMSE_before_ablation\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m mean_squared_error(y_test, y_pred)\n\u001b[0;32m 25\u001b[0m \u001b[39m# Ablation\u001b[39;00m\n\u001b[1;32m---> 26\u001b[0m score \u001b[39m=\u001b[39m rf_plus_model\u001b[39m.\u001b[39;49mget_mdi_plus_scores(X_test, y_test, lfi\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, lfi_abs \u001b[39m=\u001b[39;49m \u001b[39m\"\u001b[39;49m\u001b[39moutside\u001b[39;49m\u001b[39m\"\u001b[39;49m, sample_split\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m)\n\u001b[0;32m 27\u001b[0m local_fi_score \u001b[39m=\u001b[39m score[\u001b[39m\"\u001b[39m\u001b[39mlfi\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[0;32m 28\u001b[0m ascending \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m \u001b[39m# False for MDI\u001b[39;00m\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\rf_plus.py:379\u001b[0m, in \u001b[0;36m_RandomForestPlus.get_mdi_plus_scores\u001b[1;34m(self, X, y, scoring_fns, local_scoring_fns, sample_split, mode, version, lfi, lfi_abs)\u001b[0m\n\u001b[0;32m 367\u001b[0m mdi_plus_obj \u001b[39m=\u001b[39m ForestMDIPlus(estimators\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mestimators_,\n\u001b[0;32m 368\u001b[0m transformers\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransformers_,\n\u001b[0;32m 369\u001b[0m scoring_fns\u001b[39m=\u001b[39mscoring_fns,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 376\u001b[0m normalize\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnormalize,\n\u001b[0;32m 377\u001b[0m version\u001b[39m=\u001b[39mversion)\n\u001b[0;32m 378\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmdi_plus_ \u001b[39m=\u001b[39m mdi_plus_obj\n\u001b[1;32m--> 379\u001b[0m mdi_plus_scores \u001b[39m=\u001b[39m mdi_plus_obj\u001b[39m.\u001b[39;49mget_scores(X_array, y, lfi\u001b[39m=\u001b[39;49mlfi,\n\u001b[0;32m 380\u001b[0m lfi_abs\u001b[39m=\u001b[39;49mlfi_abs)\n\u001b[0;32m 381\u001b[0m \u001b[39mif\u001b[39;00m lfi \u001b[39mand\u001b[39;00m local_scoring_fns:\n\u001b[0;32m 382\u001b[0m mdi_plus_lfi \u001b[39m=\u001b[39m mdi_plus_scores[\u001b[39m\"\u001b[39m\u001b[39mlfi\u001b[39m\u001b[39m\"\u001b[39m]\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\mdi_plus.py:126\u001b[0m, in \u001b[0;36mForestMDIPlus.get_scores\u001b[1;34m(self, X, y, lfi, lfi_abs)\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[39m# print(\"IN 'get_scores' METHOD WITHIN THE FOREST MDI PLUS OBJECT\")\u001b[39;00m\n\u001b[0;32m 125\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlfi_abs \u001b[39m=\u001b[39m lfi_abs\n\u001b[1;32m--> 126\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit_importance_scores(X, y)\n\u001b[0;32m 127\u001b[0m \u001b[39mif\u001b[39;00m lfi:\n\u001b[0;32m 128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlocal_scoring_fns:\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\mdi_plus.py:223\u001b[0m, in \u001b[0;36mForestMDIPlus._fit_importance_scores\u001b[1;34m(self, X, y)\u001b[0m\n\u001b[0;32m 208\u001b[0m \u001b[39mfor\u001b[39;00m estimator, transformer, tree_random_state \u001b[39min\u001b[39;00m \\\n\u001b[0;32m 209\u001b[0m \u001b[39mzip\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mestimators, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransformers, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtree_random_states):\n\u001b[0;32m 210\u001b[0m tree_mdi_plus \u001b[39m=\u001b[39m TreeMDIPlus(estimator\u001b[39m=\u001b[39mestimator,\n\u001b[0;32m 211\u001b[0m transformer\u001b[39m=\u001b[39mtransformer,\n\u001b[0;32m 212\u001b[0m scoring_fns\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscoring_fns,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 221\u001b[0m num_iters\u001b[39m=\u001b[39mnum_iters,\n\u001b[0;32m 222\u001b[0m lfi_abs\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlfi_abs)\n\u001b[1;32m--> 223\u001b[0m scores \u001b[39m=\u001b[39m tree_mdi_plus\u001b[39m.\u001b[39;49mget_scores(X, y)\n\u001b[0;32m 224\u001b[0m lfi_matrix_lst\u001b[39m.\u001b[39mappend(tree_mdi_plus\u001b[39m.\u001b[39mlfi_matrix)\n\u001b[0;32m 225\u001b[0m \u001b[39mif\u001b[39;00m scores \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
Expand Down Expand Up @@ -156,7 +156,7 @@
"metric_results['MSE_before_ablation'] = mean_squared_error(y_test, y_pred)\n",
"\n",
"# Ablation\n",
"score = rf_plus_model.get_mdi_plus_scores(X_test, y_test, lfi=True, lfi_abs = \"outside\")\n",
"score = rf_plus_model.get_mdi_plus_scores(X_test, y_test, lfi=True, lfi_abs = \"outside\", sample_split=None)\n",
"local_fi_score = score[\"lfi\"]\n",
"ascending = True # False for MDI\n",
"imp_vals = copy.deepcopy(local_fi_score)\n",
Expand Down

0 comments on commit 928b64f

Please sign in to comment.