-
Notifications
You must be signed in to change notification settings - Fork 0
/
accuracy_loss_plotter.py
45 lines (30 loc) · 1.14 KB
/
accuracy_loss_plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import matplotlib.pyplot as plt
from matplotlib import style
style.use("ggplot")
model_name = "model-1638307554" # grab whichever model name you want here. We could also just reference the MODEL_NAME if you're in a notebook still.
def create_acc_loss_graph(model_name):
contents = open("model.log", "r").read().split("\n")
times = []
accuracies = []
losses = []
val_accs = []
val_losses = []
for c in contents:
if model_name in c:
name, timestamp, acc, loss, val_acc, val_loss, epoch = c.split(",")
times.append(float(timestamp))
accuracies.append(float(acc))
losses.append(float(loss))
val_accs.append(float(val_acc))
val_losses.append(float(val_loss))
fig = plt.figure()
ax1 = plt.subplot2grid((2,1), (0,0))
ax2 = plt.subplot2grid((2,1), (1,0), sharex=ax1)
ax1.plot(times, accuracies, label="acc")
ax1.plot(times, val_accs, label="val_acc")
ax1.legend(loc=2)
ax2.plot(times,losses, label="loss")
ax2.plot(times,val_losses, label="val_loss")
ax2.legend(loc=2)
plt.show()
create_acc_loss_graph(model_name)