Skip to content

Commit

Permalink
Fix bug in permutation_local
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Jan 13, 2024
1 parent a4320c3 commit ae5a3d9
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions feature_importance/scripts/competing_methods_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,10 @@ def permutation_local(X, y, fit, num_permutations=100):
y_pred_permuted = fit.predict(X_permuted)

# Calculate MSE for each sample
mse_values = mean_squared_error(y, y_pred_permuted)
for i in range(num_samples):
lpi[i, k] += (y[i]-y_pred_permuted[i])**2

# Store MSE values in the array
lpi[:, k] += mse_values

# Average MSE values across permutations for each sample
lpi[:, k] /= num_permutations
lpi /= num_permutations

# Convert the array to a DataFrame
result_table = pd.DataFrame(lpi, columns=[f'Feature_{i}' for i in range(num_features)])
Expand Down

0 comments on commit ae5a3d9

Please sign in to comment.