From 9a80acaeed5098fbce24c34b440d267e1e28bb41 Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Tue, 21 May 2024 12:40:44 +0200 Subject: [PATCH] debug printout --- src/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.py b/src/utils.py index 4149202..f8e6bb8 100644 --- a/src/utils.py +++ b/src/utils.py @@ -490,6 +490,7 @@ def plot_confusion_matrix1(cm, class_names): def check_classification3(true, pred): # Assuming true and pred have shape [batch_size, 3] + print("true shape:", true.shape) true_label = np.argmax(true, axis=1) pred_label = np.argmax(pred, axis=1) print("true_label shape:", true_label.shape)