From 1bd858456caab587dafca876e50b5b92499a09cc Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Fri, 13 Oct 2023 12:39:26 +0200 Subject: [PATCH] new losses comparisons --- src/models/eval_baseline_models.py | 30 ++++++++++++++++++++++++++++++ src/models/eval_saved_model.py | 15 +++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/models/eval_baseline_models.py b/src/models/eval_baseline_models.py index b680056..7a7103b 100644 --- a/src/models/eval_baseline_models.py +++ b/src/models/eval_baseline_models.py @@ -378,6 +378,21 @@ def check_classification( mae = mae * scale_paper rmse = rmse * scale_paper print(f"for GP paper scaled MAE: {mae}, paper scaled RMSE: {rmse}") + + # now we do the same MAE and RMSE but only for the last point + mae = mean_absolute_error(true_label_gp[:, -1], pred_label_gp[:, -1]) + rmse = mean_squared_error(true_label_gp[:, -1], pred_label_gp[:, -1], squared=False) + print(f"last point MAE: {mae}, last point RMSE: {rmse}") + + # get back to original scale + mae = mae * 57.941 + rmse = rmse * 57.941 + print(f"last point scaled MAE: {mae}, last point scaled RMSE: {rmse}") + + # get values for comparison with other paper + mae = mae * scale_paper + rmse = rmse * scale_paper + print(f"paper scaled last point MAE: {mae}, paper scaled last point RMSE: {rmse}") plot_beautiful_fig_gp( test_x[:3], @@ -442,6 +457,21 @@ def check_classification( rmse = rmse * scale_paper print(f"for SVM paper scaled MAE: {mae}, paper scaled RMSE: {rmse}") + # now we do the same MAE and RMSE but only for the last point + mae = mean_absolute_error(true_label_svm[:, -1], pred_label_svm[:, -1]) + rmse = mean_squared_error(true_label_svm[:, -1], pred_label_svm[:, -1], squared=False) + print(f"last point MAE: {mae}, last point RMSE: {rmse}") + + # get back to original scale + mae = mae * 57.941 + rmse = rmse * 57.941 + print(f"last point scaled MAE: {mae}, last point scaled RMSE: {rmse}") + + # get values for comparison with other paper + mae = mae * scale_paper + rmse = rmse * scale_paper + print(f"paper scaled last point MAE: {mae}, paper scaled last point RMSE: {rmse}") + # Use your existing function to plot results plot_beautiful_fig( test_x[:3], diff --git a/src/models/eval_saved_model.py b/src/models/eval_saved_model.py index d1d8044..00af7e3 100644 --- a/src/models/eval_saved_model.py +++ b/src/models/eval_saved_model.py @@ -187,6 +187,21 @@ def eval(): rmse = rmse * scale_paper print(f"for {log_name} paper scaled MAE: {mae}, paper scaled RMSE: {rmse}") + # now we do the same MAE and RMSE but only for the last point + mae = mean_absolute_error(new_test_y[:, -1], pred_y[:, -1]) + rmse = mean_squared_error(new_test_y[:, -1], pred_y[:, -1], squared=False) + print(f"for {log_name} last point MAE: {mae}, last point RMSE: {rmse}") + + # get back to original scale + mae = mae * 57.941 + rmse = rmse * 57.941 + print(f"for {log_name} last point scaled MAE: {mae}, last point scaled RMSE: {rmse}") + + # get values for comparison with other paper + mae = mae * scale_paper + rmse = rmse * scale_paper + print(f"for {log_name} paper scaled last point MAE: {mae}, paper scaled last point RMSE: {rmse}") + new_test_x = new_test_x.reshape(-1, 7) plot_beautiful_fig(new_test_x[:3], new_test_y[:3], pred_y[:3], "new_test", log_dir, 144.98, 57.94)