Skip to content

Commit

Permalink
new losses comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-vaselli committed Oct 13, 2023
1 parent 3683f07 commit 1bd8584
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/models/eval_baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,21 @@ def check_classification(
mae = mae * scale_paper
rmse = rmse * scale_paper
print(f"for GP paper scaled MAE: {mae}, paper scaled RMSE: {rmse}")

# now we do the same MAE and RMSE but only for the last point
mae = mean_absolute_error(true_label_gp[:, -1], pred_label_gp[:, -1])
rmse = mean_squared_error(true_label_gp[:, -1], pred_label_gp[:, -1], squared=False)
print(f"last point MAE: {mae}, last point RMSE: {rmse}")

# get back to original scale
mae = mae * 57.941
rmse = rmse * 57.941
print(f"last point scaled MAE: {mae}, last point scaled RMSE: {rmse}")

# get values for comparison with other paper
mae = mae * scale_paper
rmse = rmse * scale_paper
print(f"paper scaled last point MAE: {mae}, paper scaled last point RMSE: {rmse}")

plot_beautiful_fig_gp(
test_x[:3],
Expand Down Expand Up @@ -442,6 +457,21 @@ def check_classification(
rmse = rmse * scale_paper
print(f"for SVM paper scaled MAE: {mae}, paper scaled RMSE: {rmse}")

# now we do the same MAE and RMSE but only for the last point
mae = mean_absolute_error(true_label_svm[:, -1], pred_label_svm[:, -1])
rmse = mean_squared_error(true_label_svm[:, -1], pred_label_svm[:, -1], squared=False)
print(f"last point MAE: {mae}, last point RMSE: {rmse}")

# get back to original scale
mae = mae * 57.941
rmse = rmse * 57.941
print(f"last point scaled MAE: {mae}, last point scaled RMSE: {rmse}")

# get values for comparison with other paper
mae = mae * scale_paper
rmse = rmse * scale_paper
print(f"paper scaled last point MAE: {mae}, paper scaled last point RMSE: {rmse}")

# Use your existing function to plot results
plot_beautiful_fig(
test_x[:3],
Expand Down
15 changes: 15 additions & 0 deletions src/models/eval_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ def eval():
rmse = rmse * scale_paper
print(f"for {log_name} paper scaled MAE: {mae}, paper scaled RMSE: {rmse}")

# now we do the same MAE and RMSE but only for the last point
mae = mean_absolute_error(new_test_y[:, -1], pred_y[:, -1])
rmse = mean_squared_error(new_test_y[:, -1], pred_y[:, -1], squared=False)
print(f"for {log_name} last point MAE: {mae}, last point RMSE: {rmse}")

# get back to original scale
mae = mae * 57.941
rmse = rmse * 57.941
print(f"for {log_name} last point scaled MAE: {mae}, last point scaled RMSE: {rmse}")

# get values for comparison with other paper
mae = mae * scale_paper
rmse = rmse * scale_paper
print(f"for {log_name} paper scaled last point MAE: {mae}, paper scaled last point RMSE: {rmse}")

new_test_x = new_test_x.reshape(-1, 7)

plot_beautiful_fig(new_test_x[:3], new_test_y[:3], pred_y[:3], "new_test", log_dir, 144.98, 57.94)
Expand Down

0 comments on commit 1bd8584

Please sign in to comment.