Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-vaselli committed May 21, 2024
1 parent d54be2c commit b06fe07
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
2 changes: 1 addition & 1 deletion configs/new_train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ model_config:
cnn_config:
n_conv_layers: 6
filters: 320
kernel_size: 3
kernel_size: 2
activation: "relu"
input_shape: [7, 1]
n_dense_layers: 2
Expand Down
1 change: 1 addition & 0 deletions src/models/new_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def build_model(model_config, model_name):
kernel_size=cnn_config["kernel_size"],
activation=cnn_config["activation"],
input_shape=cnn_config["input_shape"],
padding="same",
)
)

Expand Down
33 changes: 25 additions & 8 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ def check_classification1(true, pred, threshold=0.5):
# Assuming true and pred have shape [batch_size, 1]
pred_label = (pred >= threshold).astype(int)
true_label = (true >= threshold).astype(int)
print("true_label shape:", true_label.shape)
print("pred_label shape:", pred_label.shape)
print("example pred vs true:", pred_label[0], true_label[0])

fpr, tpr, _ = roc_curve(true_label, pred_label)
roc_auc = auc(fpr, tpr)
Expand Down Expand Up @@ -462,14 +465,28 @@ def plot_roc_curve1(fpr, tpr, roc_auc):
return fig

def plot_confusion_matrix1(cm, class_names):
fig, ax = plt.subplots()
cax = ax.matshow(cm, cmap=plt.cm.Blues)
fig.colorbar(cax)
ax.set_xticklabels([''] + class_names)
ax.set_yticklabels([''] + class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
return fig
figure = plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)

# Normalize the confusion matrix.
cm = np.around(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis], decimals=2)

# Use white text if squares are dark; otherwise black.
threshold = cm.max() / 2.0
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
color = "black"
plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

# plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
return figure


def check_classification3(true, pred):
# Assuming true and pred have shape [batch_size, 3]
Expand Down

0 comments on commit b06fe07

Please sign in to comment.