Skip to content

Commit

Permalink
Update Ablation
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Feb 21, 2024
1 parent b22c2f4 commit 7caf5c4
Show file tree
Hide file tree
Showing 4 changed files with 549 additions and 416 deletions.
5 changes: 3 additions & 2 deletions feature_importance/01_run_importance_local_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def compare_estimators(estimators: List[ModelConfig],
feature_importance_list = []

# loop over model estimators
for model in tqdm(estimators, leave=False):
for model in estimators:
est = model.cls(**model.kwargs)

# get kwargs for all fi_ests
Expand Down Expand Up @@ -110,7 +110,7 @@ def compare_estimators(estimators: List[ModelConfig],

# loop over fi estimators
seed = np.random.randint(0, 100000)
for fi_est in fi_ests:
for fi_est in tqdm(fi_ests):
metric_results = {
'model': model.name,
'fi': fi_est.name,
Expand Down Expand Up @@ -140,6 +140,7 @@ def compare_estimators(estimators: List[ModelConfig],
metric_results[f'MSE_after_ablation_{i+1}'] = mean_squared_error(y_test, est.predict(ablation_X_test))
end = time.time()
metric_results['ablation_time'] = end - start
metric_results['test_size'] = X_test.shape[0]
print(f"data_size: {X_test.shape[0]}, fi: {fi_est.name}, done with time: {end - start}")

# initialize results with metadata and metric results
Expand Down
Loading

0 comments on commit 7caf5c4

Please sign in to comment.