Skip to content

Commit

Permalink
Updtae ablation demo
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Mar 8, 2024
1 parent f8b56b9 commit 0ffcef9
Showing 1 changed file with 254 additions and 0 deletions.
254 changes: 254 additions & 0 deletions feature_importance/ablation_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,260 @@
" return result_table"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"n = 200\n",
"d = 10\n",
"mean = [[0]*5 + [0]*5, [10]*5 + [0]*5]\n",
"scale = [[1]*10,[1]*10]\n",
"s = 5\n",
"X = sample_normal_X_subgroups(n, d, mean, scale)\n",
"beta = np.concatenate((np.ones(s), np.zeros(d-s)))\n",
"y = np.matmul(X, beta)\n",
"split_seed = 0\n",
"X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(X, y, \"train-test\", split_seed)\n",
"\n",
"rf_regressor = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=331)\n",
"rf_regressor.fit(X_train, y_train)\n",
"seed = 0\n",
"rf_plus_model = RandomForestPlusRegressor(rf_model=copy.deepcopy(rf_regressor), include_raw=False)\n",
"rf_plus_model = RandomForestPlusRegressor(rf_model=rf_regressor, include_raw=False)\n",
"rf_plus_model.fit(X_train, y_train)\n",
"\n",
"score = rf_plus_model.get_mdi_plus_scores(X_test, y_test, lfi=True, lfi_abs = \"outside\", sample_split=None)\n",
"local_fi_score = score[\"lfi\"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3.440747</td>\n",
" <td>4.255933</td>\n",
" <td>6.700399</td>\n",
" <td>8.639199</td>\n",
" <td>5.333437</td>\n",
" <td>0.325618</td>\n",
" <td>0.640087</td>\n",
" <td>0.223689</td>\n",
" <td>0.185920</td>\n",
" <td>0.176673</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2.805189</td>\n",
" <td>4.695955</td>\n",
" <td>6.887698</td>\n",
" <td>5.888381</td>\n",
" <td>5.774552</td>\n",
" <td>0.462473</td>\n",
" <td>0.569851</td>\n",
" <td>0.166908</td>\n",
" <td>0.252426</td>\n",
" <td>0.145043</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2.994209</td>\n",
" <td>4.769220</td>\n",
" <td>7.218987</td>\n",
" <td>5.956899</td>\n",
" <td>5.627507</td>\n",
" <td>0.449225</td>\n",
" <td>0.525465</td>\n",
" <td>0.166205</td>\n",
" <td>0.216060</td>\n",
" <td>0.127154</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.183443</td>\n",
" <td>4.734916</td>\n",
" <td>8.226171</td>\n",
" <td>4.774889</td>\n",
" <td>7.257146</td>\n",
" <td>0.282170</td>\n",
" <td>0.363478</td>\n",
" <td>0.157185</td>\n",
" <td>0.224982</td>\n",
" <td>0.166350</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.262282</td>\n",
" <td>4.380160</td>\n",
" <td>7.201394</td>\n",
" <td>5.638692</td>\n",
" <td>5.665000</td>\n",
" <td>0.589775</td>\n",
" <td>0.425661</td>\n",
" <td>0.138957</td>\n",
" <td>0.197739</td>\n",
" <td>0.126664</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61</th>\n",
" <td>4.045584</td>\n",
" <td>4.820155</td>\n",
" <td>6.939586</td>\n",
" <td>5.668030</td>\n",
" <td>7.662810</td>\n",
" <td>0.294419</td>\n",
" <td>0.379058</td>\n",
" <td>0.130060</td>\n",
" <td>0.160509</td>\n",
" <td>0.136078</td>\n",
" </tr>\n",
" <tr>\n",
" <th>62</th>\n",
" <td>3.239955</td>\n",
" <td>3.826604</td>\n",
" <td>6.492881</td>\n",
" <td>5.832585</td>\n",
" <td>5.394387</td>\n",
" <td>0.666766</td>\n",
" <td>0.847282</td>\n",
" <td>0.292280</td>\n",
" <td>0.235026</td>\n",
" <td>0.101330</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63</th>\n",
" <td>4.852265</td>\n",
" <td>3.986563</td>\n",
" <td>6.360245</td>\n",
" <td>7.147750</td>\n",
" <td>5.026830</td>\n",
" <td>0.525546</td>\n",
" <td>0.372760</td>\n",
" <td>0.084053</td>\n",
" <td>0.113026</td>\n",
" <td>0.166653</td>\n",
" </tr>\n",
" <tr>\n",
" <th>64</th>\n",
" <td>3.000221</td>\n",
" <td>4.311225</td>\n",
" <td>6.570626</td>\n",
" <td>5.680667</td>\n",
" <td>5.084522</td>\n",
" <td>0.680960</td>\n",
" <td>0.919951</td>\n",
" <td>0.111115</td>\n",
" <td>0.261413</td>\n",
" <td>0.163289</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65</th>\n",
" <td>3.974615</td>\n",
" <td>4.496878</td>\n",
" <td>6.546694</td>\n",
" <td>7.366003</td>\n",
" <td>5.269932</td>\n",
" <td>0.339395</td>\n",
" <td>0.458521</td>\n",
" <td>0.099534</td>\n",
" <td>0.238012</td>\n",
" <td>0.151942</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>66 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 \\\n",
"0 3.440747 4.255933 6.700399 8.639199 5.333437 0.325618 0.640087 \n",
"1 2.805189 4.695955 6.887698 5.888381 5.774552 0.462473 0.569851 \n",
"2 2.994209 4.769220 7.218987 5.956899 5.627507 0.449225 0.525465 \n",
"3 4.183443 4.734916 8.226171 4.774889 7.257146 0.282170 0.363478 \n",
"4 3.262282 4.380160 7.201394 5.638692 5.665000 0.589775 0.425661 \n",
".. ... ... ... ... ... ... ... \n",
"61 4.045584 4.820155 6.939586 5.668030 7.662810 0.294419 0.379058 \n",
"62 3.239955 3.826604 6.492881 5.832585 5.394387 0.666766 0.847282 \n",
"63 4.852265 3.986563 6.360245 7.147750 5.026830 0.525546 0.372760 \n",
"64 3.000221 4.311225 6.570626 5.680667 5.084522 0.680960 0.919951 \n",
"65 3.974615 4.496878 6.546694 7.366003 5.269932 0.339395 0.458521 \n",
"\n",
" 7 8 9 \n",
"0 0.223689 0.185920 0.176673 \n",
"1 0.166908 0.252426 0.145043 \n",
"2 0.166205 0.216060 0.127154 \n",
"3 0.157185 0.224982 0.166350 \n",
"4 0.138957 0.197739 0.126664 \n",
".. ... ... ... \n",
"61 0.130060 0.160509 0.136078 \n",
"62 0.292280 0.235026 0.101330 \n",
"63 0.084053 0.113026 0.166653 \n",
"64 0.111115 0.261413 0.163289 \n",
"65 0.099534 0.238012 0.151942 \n",
"\n",
"[66 rows x 10 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"local_fi_score"
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand Down

0 comments on commit 0ffcef9

Please sign in to comment.