forked from tfunck/minc_keras
-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_metrics.py
37 lines (32 loc) · 1.36 KB
/
plot_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import json
import os
from os.path import splitext, basename
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
def plot_loss(metric, history_fn,model_fn, report_dir):
with open(history_fn, 'r') as fp: history=json.load(fp)
training_fn=report_dir +os.sep +splitext(basename(model_fn))[0] +'_metric_plot.png'
epoch_num = range(len(history[metric]))
plt.clf()
plt.plot(epoch_num, np.array(history[metric]), label='Training Accuracy')
plt.plot(epoch_num, np.array(history['val_'+metric]), label="Validation Accuracy")
plt.legend( loc="upper right", ncol=1, prop={'size':8})
plt.legend(shadow=True)
plt.xlabel("Training Epoch Number")
plt.ylabel("Metric")
plt.savefig(training_fn, bbox_inches="tight", dpi=500, width=1000)
plt.close()
training_fn=report_dir +os.sep +splitext(basename(model_fn))[0] +'_loss_plot.png'
plt.clf()
plt.plot(epoch_num, np.array(history['loss']), label='Training Loss')
plt.plot(epoch_num, np.array(history['val_loss']), label="Validation Loss")
plt.legend( loc="upper right", ncol=1, prop={'size':8})
plt.legend(shadow=True)
plt.xlabel("Training Epoch Number")
plt.ylabel("Loss")
plt.tight_layout()
plt.savefig(training_fn, bbox_inches="tight", dpi=500, width=1000)
plt.close()
print('Model training plot written to ',training_fn )