-
Notifications
You must be signed in to change notification settings - Fork 13
/
utils.py
73 lines (61 loc) · 2.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
This file contains code for some utility functions, such as code for saving CSV logs.
Code credits to Xifeng Guo: `https://github.com/XifengGuo/CapsNet-Keras`
"""
import numpy as np
from matplotlib import pyplot as plt
import csv
import math
def plot_log(filename, show=True):
# load data
keys = []
values = []
with open(filename, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
if keys == []:
for key, value in row.items():
keys.append(key)
values.append(float(value))
continue
for _, value in row.items():
values.append(float(value))
values = np.reshape(values, newshape=(-1, len(keys)))
values[:,0] += 1
fig = plt.figure(figsize=(4,6))
fig.subplots_adjust(top=0.95, bottom=0.05, right=0.95)
fig.add_subplot(211)
for i, key in enumerate(keys):
if key.find('loss') >= 0 and not key.find('val') >= 0: # training loss
plt.plot(values[:, 0], values[:, i], label=key)
plt.legend()
plt.title('Training loss')
fig.add_subplot(212)
for i, key in enumerate(keys):
if key.find('acc') >= 0: # acc
plt.plot(values[:, 0], values[:, i], label=key)
plt.legend()
plt.title('Training and validation accuracy')
# fig.savefig('result/log.png')
if show:
plt.show()
def combine_images(generated_images, height=None, width=None):
num = generated_images.shape[0]
if width is None and height is None:
width = int(math.sqrt(num))
height = int(math.ceil(float(num)/width))
elif width is not None and height is None: # height not given
height = int(math.ceil(float(num)/width))
elif height is not None and width is None: # width not given
width = int(math.ceil(float(num)/height))
shape = generated_images.shape[1:3]
image = np.zeros((height*shape[0], width*shape[1]),
dtype=generated_images.dtype)
for index, img in enumerate(generated_images):
i = int(index/width)
j = index % width
image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
img[:, :, 0]
return image
if __name__=="__main__":
plot_log('result/log.csv')