From b06fe07a81ec52168cb2dda925eef1dd5c355d4d Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Tue, 21 May 2024 12:31:12 +0200 Subject: [PATCH] fixes --- configs/new_train_config.yaml | 2 +- src/models/new_train_model.py | 1 + src/utils.py | 33 +++++++++++++++++++++++++-------- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/configs/new_train_config.yaml b/configs/new_train_config.yaml index b9ce8e7..f958b16 100644 --- a/configs/new_train_config.yaml +++ b/configs/new_train_config.yaml @@ -21,7 +21,7 @@ model_config: cnn_config: n_conv_layers: 6 filters: 320 - kernel_size: 3 + kernel_size: 2 activation: "relu" input_shape: [7, 1] n_dense_layers: 2 diff --git a/src/models/new_train_model.py b/src/models/new_train_model.py index 6f40de3..da157ff 100644 --- a/src/models/new_train_model.py +++ b/src/models/new_train_model.py @@ -123,6 +123,7 @@ def build_model(model_config, model_name): kernel_size=cnn_config["kernel_size"], activation=cnn_config["activation"], input_shape=cnn_config["input_shape"], + padding="same", ) ) diff --git a/src/utils.py b/src/utils.py index e5064e6..506822b 100644 --- a/src/utils.py +++ b/src/utils.py @@ -394,6 +394,9 @@ def check_classification1(true, pred, threshold=0.5): # Assuming true and pred have shape [batch_size, 1] pred_label = (pred >= threshold).astype(int) true_label = (true >= threshold).astype(int) + print("true_label shape:", true_label.shape) + print("pred_label shape:", pred_label.shape) + print("example pred vs true:", pred_label[0], true_label[0]) fpr, tpr, _ = roc_curve(true_label, pred_label) roc_auc = auc(fpr, tpr) @@ -462,14 +465,28 @@ def plot_roc_curve1(fpr, tpr, roc_auc): return fig def plot_confusion_matrix1(cm, class_names): - fig, ax = plt.subplots() - cax = ax.matshow(cm, cmap=plt.cm.Blues) - fig.colorbar(cax) - ax.set_xticklabels([''] + class_names) - ax.set_yticklabels([''] + class_names) - plt.xlabel('Predicted') - plt.ylabel('True') - return fig + figure = plt.figure(figsize=(8, 8)) + plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues) + plt.title("Confusion matrix") + plt.colorbar() + tick_marks = np.arange(len(class_names)) + plt.xticks(tick_marks, class_names, rotation=45) + plt.yticks(tick_marks, class_names) + + # Normalize the confusion matrix. + cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2) + + # Use white text if squares are dark; otherwise black. + threshold = cm.max() / 2.0 + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + color = "black" + plt.text(j, i, cm[i, j], horizontalalignment="center", color=color) + + # plt.tight_layout() + plt.ylabel("True label") + plt.xlabel("Predicted label") + return figure + def check_classification3(true, pred): # Assuming true and pred have shape [batch_size, 3]