Skip to content

Commit

Permalink
fixing classes
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-vaselli committed May 22, 2024
1 parent b46c254 commit de4da1b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 54 deletions.
6 changes: 3 additions & 3 deletions src/models/new_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def train(
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.5, patience=5, min_lr=0.00001
),
tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10),
tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20),
tensorboard_callback,
image_logging_callback,
ClassificationMetrics(
Expand All @@ -315,7 +315,7 @@ def train(
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.5, patience=5, min_lr=0.00001
),
tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10),
tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20),
tensorboard_callback,
ClassificationMetrics1(test_dataset, log_dir, test_y=test_y),
]
Expand All @@ -324,7 +324,7 @@ def train(
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.5, patience=5, min_lr=0.00001
),
tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10),
tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20),
tensorboard_callback,
ClassificationMetrics3(test_dataset, log_dir, test_y=test_y)
]
Expand Down
106 changes: 55 additions & 51 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,10 @@ def check_classification(
pred *= 100
true *= 100

pred_label = (pred[:, ind] < threshold).astype(int) # Assuming feature_dim = 1
true_label = (true[:, ind] < threshold).astype(int) # Adjust index if different
# NOTE now it is 0 for hypo and 1 for hyper
# diffrent from exam project!!
pred_label = (pred[:, ind] > threshold).astype(int) # Assuming feature_dim = 1
true_label = (true[:, ind] > threshold).astype(int) # Adjust index if different

fpr, tpr, _ = roc_curve(true_label, pred_label)
roc_auc = auc(fpr, tpr)
Expand Down Expand Up @@ -238,7 +240,7 @@ def on_epoch_end(self, epoch, logs=None):
specificity = tn / (tn + fp)
precision = tp / (tp + fp)
npv = tn / (tn + fn)
f1 = tp / (tp + 1 / 2 * (fp + fn))
f1 = 2 * (precision * sensitivity) / (precision + sensitivity)
tf.summary.scalar("Accuracy", accuracy, step=epoch)
tf.summary.scalar("Sensitivity", sensitivity, step=epoch)
tf.summary.scalar("Specificity", specificity, step=epoch)
Expand All @@ -252,7 +254,7 @@ def on_epoch_end(self, epoch, logs=None):

cm = confusion_matrix(true_label, pred_label)
figure = plot_confusion_matrix(
cm, class_names=["Hyper", "Hypo"]
cm, class_names=["Hypo", "Hyper"]
)
tf.summary.image("Confusion Matrix", self.plot_to_image(figure), step=epoch)

Expand Down Expand Up @@ -392,7 +394,9 @@ def plot_to_image(self, figure):

def check_classification1(true, pred, threshold=0.5):
# Assuming true and pred have shape [batch_size, 1]
# 0 for hypo and 1 for hyper
pred_label = (pred >= threshold).astype(int)
#
true_label = (true >= threshold).astype(int)
print("true_label shape:", true_label.shape)
print("pred_label shape:", pred_label.shape)
Expand Down Expand Up @@ -442,7 +446,7 @@ def on_epoch_end(self, epoch, logs=None):
tf.summary.image("ROC Curve", self.plot_to_image(figure), step=epoch)

cm = confusion_matrix(true_label, pred_label)
figure = plot_confusion_matrix1(cm, class_names=["Hypo", "Hyper"])
figure = plot_confusion_matrix(cm, class_names=["Hypo", "Hyper"])
tf.summary.image("Confusion Matrix", self.plot_to_image(figure), step=epoch)

def plot_to_image(self, figure):
Expand All @@ -464,28 +468,28 @@ def plot_roc_curve1(fpr, tpr, roc_auc):
ax.legend(loc="lower right")
return fig

def plot_confusion_matrix1(cm, class_names):
figure = plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
# def plot_confusion_matrix1(cm, class_names):
# figure = plt.figure(figsize=(8, 8))
# plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
# plt.title("Confusion matrix")
# plt.colorbar()
# tick_marks = np.arange(len(class_names))
# plt.xticks(tick_marks, class_names, rotation=45)
# plt.yticks(tick_marks, class_names)

# Normalize the confusion matrix.
cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2)
# # Normalize the confusion matrix.
# cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2)

# Use white text if squares are dark; otherwise black.
threshold = cm.max() / 2.0
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
color = "black"
plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)
# # Use white text if squares are dark; otherwise black.
# threshold = cm.max() / 2.0
# for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
# color = "black"
# plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

# plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
return figure
# # plt.tight_layout()
# plt.ylabel("True label")
# plt.xlabel("Predicted label")
# return figure


def check_classification3(true, pred):
Expand Down Expand Up @@ -522,16 +526,16 @@ def on_epoch_end(self, epoch, logs=None):
cm = confusion_matrix(true_label, pred_label)
accuracy = np.trace(cm) / np.sum(cm)
precision = np.diag(cm) / np.sum(cm, axis=0)
recall = np.diag(cm) / np.sum(cm, axis=1)
f1 = 2 * precision * recall / (precision + recall)

sensitivity = np.diag(cm) / np.sum(cm, axis=1)
f1 = 2 * (precision * sensitivity) / (precision + sensitivity)
tf.summary.scalar("Accuracy", accuracy, step=epoch)
for i, cls in enumerate(["Hypo", "Norm", "Hyper"]):
tf.summary.scalar(f"Precision_{cls}", precision[i], step=epoch)
tf.summary.scalar(f"Recall_{cls}", recall[i], step=epoch)
tf.summary.scalar(f"sensitivity_{cls}", sensitivity[i], step=epoch)
tf.summary.scalar(f"F1_{cls}", f1[i], step=epoch)

figure = plot_confusion_matrix3(cm, class_names=["Hypo", "Norm", "Hyper"])
figure = plot_confusion_matrix(cm, class_names=["Hypo", "Norm", "Hyper"])
tf.summary.image("Confusion Matrix", self.plot_to_image(figure), step=epoch)

def plot_to_image(self, figure):
Expand All @@ -543,25 +547,25 @@ def plot_to_image(self, figure):
image = tf.expand_dims(image, 0)
return image

def plot_confusion_matrix3(cm, class_names):
figure = plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)

# Normalize the confusion matrix.
cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2)

# Use white text if squares are dark; otherwise black.
threshold = cm.max() / 2.0
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
color = "black"
plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

# plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
return figure
# def plot_confusion_matrix3(cm, class_names):
# figure = plt.figure(figsize=(8, 8))
# plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
# plt.title("Confusion matrix")
# plt.colorbar()
# tick_marks = np.arange(len(class_names))
# plt.xticks(tick_marks, class_names, rotation=45)
# plt.yticks(tick_marks, class_names)

# # Normalize the confusion matrix.
# cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2)

# # Use white text if squares are dark; otherwise black.
# threshold = cm.max() / 2.0
# for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
# color = "black"
# plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

# # plt.tight_layout()
# plt.ylabel("True label")
# plt.xlabel("Predicted label")
# return figure

0 comments on commit de4da1b

Please sign in to comment.