From de4da1b1e3d682eb77031a83b7d0d44488fbf72c Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Wed, 22 May 2024 11:35:31 +0200 Subject: [PATCH] fixing classes --- src/models/new_train_model.py | 6 +- src/utils.py | 106 ++++++++++++++++++---------------- 2 files changed, 58 insertions(+), 54 deletions(-) diff --git a/src/models/new_train_model.py b/src/models/new_train_model.py index 69821a9..7155f12 100644 --- a/src/models/new_train_model.py +++ b/src/models/new_train_model.py @@ -303,7 +303,7 @@ def train( tf.keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.5, patience=5, min_lr=0.00001 ), - tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10), + tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20), tensorboard_callback, image_logging_callback, ClassificationMetrics( @@ -315,7 +315,7 @@ def train( tf.keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.5, patience=5, min_lr=0.00001 ), - tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10), + tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20), tensorboard_callback, ClassificationMetrics1(test_dataset, log_dir, test_y=test_y), ] @@ -324,7 +324,7 @@ def train( tf.keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.5, patience=5, min_lr=0.00001 ), - tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10), + tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20), tensorboard_callback, ClassificationMetrics3(test_dataset, log_dir, test_y=test_y) ] diff --git a/src/utils.py b/src/utils.py index 0701173..aaf9c0d 100644 --- a/src/utils.py +++ b/src/utils.py @@ -197,8 +197,10 @@ def check_classification( pred *= 100 true *= 100 - pred_label = (pred[:, ind] < threshold).astype(int) # Assuming feature_dim = 1 - true_label = (true[:, ind] < threshold).astype(int) # Adjust index if different + # NOTE now it is 0 for hypo and 1 for hyper + # diffrent from exam project!! + pred_label = (pred[:, ind] > threshold).astype(int) # Assuming feature_dim = 1 + true_label = (true[:, ind] > threshold).astype(int) # Adjust index if different fpr, tpr, _ = roc_curve(true_label, pred_label) roc_auc = auc(fpr, tpr) @@ -238,7 +240,7 @@ def on_epoch_end(self, epoch, logs=None): specificity = tn / (tn + fp) precision = tp / (tp + fp) npv = tn / (tn + fn) - f1 = tp / (tp + 1 / 2 * (fp + fn)) + f1 = 2 * (precision * sensitivity) / (precision + sensitivity) tf.summary.scalar("Accuracy", accuracy, step=epoch) tf.summary.scalar("Sensitivity", sensitivity, step=epoch) tf.summary.scalar("Specificity", specificity, step=epoch) @@ -252,7 +254,7 @@ def on_epoch_end(self, epoch, logs=None): cm = confusion_matrix(true_label, pred_label) figure = plot_confusion_matrix( - cm, class_names=["Hyper", "Hypo"] + cm, class_names=["Hypo", "Hyper"] ) tf.summary.image("Confusion Matrix", self.plot_to_image(figure), step=epoch) @@ -392,7 +394,9 @@ def plot_to_image(self, figure): def check_classification1(true, pred, threshold=0.5): # Assuming true and pred have shape [batch_size, 1] + # 0 for hypo and 1 for hyper pred_label = (pred >= threshold).astype(int) + # true_label = (true >= threshold).astype(int) print("true_label shape:", true_label.shape) print("pred_label shape:", pred_label.shape) @@ -442,7 +446,7 @@ def on_epoch_end(self, epoch, logs=None): tf.summary.image("ROC Curve", self.plot_to_image(figure), step=epoch) cm = confusion_matrix(true_label, pred_label) - figure = plot_confusion_matrix1(cm, class_names=["Hypo", "Hyper"]) + figure = plot_confusion_matrix(cm, class_names=["Hypo", "Hyper"]) tf.summary.image("Confusion Matrix", self.plot_to_image(figure), step=epoch) def plot_to_image(self, figure): @@ -464,28 +468,28 @@ def plot_roc_curve1(fpr, tpr, roc_auc): ax.legend(loc="lower right") return fig -def plot_confusion_matrix1(cm, class_names): - figure = plt.figure(figsize=(8, 8)) - plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues) - plt.title("Confusion matrix") - plt.colorbar() - tick_marks = np.arange(len(class_names)) - plt.xticks(tick_marks, class_names, rotation=45) - plt.yticks(tick_marks, class_names) +# def plot_confusion_matrix1(cm, class_names): +# figure = plt.figure(figsize=(8, 8)) +# plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues) +# plt.title("Confusion matrix") +# plt.colorbar() +# tick_marks = np.arange(len(class_names)) +# plt.xticks(tick_marks, class_names, rotation=45) +# plt.yticks(tick_marks, class_names) - # Normalize the confusion matrix. - cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2) +# # Normalize the confusion matrix. +# cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2) - # Use white text if squares are dark; otherwise black. - threshold = cm.max() / 2.0 - for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): - color = "black" - plt.text(j, i, cm[i, j], horizontalalignment="center", color=color) +# # Use white text if squares are dark; otherwise black. +# threshold = cm.max() / 2.0 +# for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): +# color = "black" +# plt.text(j, i, cm[i, j], horizontalalignment="center", color=color) - # plt.tight_layout() - plt.ylabel("True label") - plt.xlabel("Predicted label") - return figure +# # plt.tight_layout() +# plt.ylabel("True label") +# plt.xlabel("Predicted label") +# return figure def check_classification3(true, pred): @@ -522,16 +526,16 @@ def on_epoch_end(self, epoch, logs=None): cm = confusion_matrix(true_label, pred_label) accuracy = np.trace(cm) / np.sum(cm) precision = np.diag(cm) / np.sum(cm, axis=0) - recall = np.diag(cm) / np.sum(cm, axis=1) - f1 = 2 * precision * recall / (precision + recall) - + sensitivity = np.diag(cm) / np.sum(cm, axis=1) + f1 = 2 * (precision * sensitivity) / (precision + sensitivity) + tf.summary.scalar("Accuracy", accuracy, step=epoch) for i, cls in enumerate(["Hypo", "Norm", "Hyper"]): tf.summary.scalar(f"Precision_{cls}", precision[i], step=epoch) - tf.summary.scalar(f"Recall_{cls}", recall[i], step=epoch) + tf.summary.scalar(f"sensitivity_{cls}", sensitivity[i], step=epoch) tf.summary.scalar(f"F1_{cls}", f1[i], step=epoch) - figure = plot_confusion_matrix3(cm, class_names=["Hypo", "Norm", "Hyper"]) + figure = plot_confusion_matrix(cm, class_names=["Hypo", "Norm", "Hyper"]) tf.summary.image("Confusion Matrix", self.plot_to_image(figure), step=epoch) def plot_to_image(self, figure): @@ -543,25 +547,25 @@ def plot_to_image(self, figure): image = tf.expand_dims(image, 0) return image -def plot_confusion_matrix3(cm, class_names): - figure = plt.figure(figsize=(8, 8)) - plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues) - plt.title("Confusion matrix") - plt.colorbar() - tick_marks = np.arange(len(class_names)) - plt.xticks(tick_marks, class_names, rotation=45) - plt.yticks(tick_marks, class_names) - - # Normalize the confusion matrix. - cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2) - - # Use white text if squares are dark; otherwise black. - threshold = cm.max() / 2.0 - for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): - color = "black" - plt.text(j, i, cm[i, j], horizontalalignment="center", color=color) - - # plt.tight_layout() - plt.ylabel("True label") - plt.xlabel("Predicted label") - return figure +# def plot_confusion_matrix3(cm, class_names): +# figure = plt.figure(figsize=(8, 8)) +# plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues) +# plt.title("Confusion matrix") +# plt.colorbar() +# tick_marks = np.arange(len(class_names)) +# plt.xticks(tick_marks, class_names, rotation=45) +# plt.yticks(tick_marks, class_names) + +# # Normalize the confusion matrix. +# cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2) + +# # Use white text if squares are dark; otherwise black. +# threshold = cm.max() / 2.0 +# for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): +# color = "black" +# plt.text(j, i, cm[i, j], horizontalalignment="center", color=color) + +# # plt.tight_layout() +# plt.ylabel("True label") +# plt.xlabel("Predicted label") +# return figure