-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
114 lines (104 loc) · 3.92 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from matplotlib import pyplot as plt
from matplotlib.tri import Triangulation
import numpy as np
def extractValues(M, N, layer):
# create north, east, south, west arrays from QTable
valuesN = np.reshape(layer[:, 0], (M, N))
valuesE = np.reshape(layer[:, 1], (M, N))
valuesS = np.reshape(layer[:, 2], (M, N))
valuesW = np.reshape(layer[:, 3], (M, N))
return [valuesN, valuesE, valuesS, valuesW]
def triangulation_for_triheatmap(M, N):
xv, yv = np.meshgrid(np.arange(-0.5, M), np.arange(-0.5, N)) # vertices of the little squares
xc, yc = np.meshgrid(np.arange(0, M), np.arange(0, N)) # centers of the little squares
x = np.concatenate([xv.ravel(), xc.ravel()])
y = np.concatenate([yv.ravel(), yc.ravel()])
cstart = (M + 1) * (N + 1) # indices of the centers
trianglesN = [(i + j * (M + 1), i + 1 + j * (M + 1), cstart + i + j * M)
for j in range(N) for i in range(M)]
trianglesE = [(i + 1 + j * (M + 1), i + 1 + (j + 1) * (M + 1), cstart + i + j * M)
for j in range(N) for i in range(M)]
trianglesS = [(i + 1 + (j + 1) * (M + 1), i + (j + 1) * (M + 1), cstart + i + j * M)
for j in range(N) for i in range(M)]
trianglesW = [(i + (j + 1) * (M + 1), i + j * (M + 1), cstart + i + j * M)
for j in range(N) for i in range(M)]
return [Triangulation(x, y, triangles) for triangles in [trianglesN, trianglesE, trianglesS, trianglesW]]
def plotQTable(QTable, frame, display=True, save=False, path = ''):
values = extractValues(5, 5, QTable[frame*25:(frame+1)*25])
if np.min(values) == np.max(values):
return
triangul = triangulation_for_triheatmap(5, 5)
[plt.Normalize(-0.5, 1) for _ in range(4)]
_, ax = plt.subplots(figsize=(7, 7))
[ax.tripcolor(t, val.ravel(), cmap='RdYlGn', vmin=np.min(values), vmax=np.max(values), ec='white') for t, val in zip(triangul, values)]
for val, dir in zip(values, [(-1, 0), (0, 1), (1, 0), (0, -1)]):
for i in range(5):
for j in range(5):
v = val[j, i]
if np.min(values) == np.max(values):
nv = 0.6
else:
nv = (v - np.min(values)) / abs(np.max(values) - np.min(values)) # normalize value to determine text color
ax.text(i + 0.3 * dir[1], j + 0.3 * dir[0], f'{v:.2f}', color='k' if 0.2 < nv < 0.8 else 'w', ha='center', va='center')
# fig.colorbar(imgs[0], ax=ax)
plt.xticks(range(5), range(1, 6))
plt.yticks(range(5), range(1, 6))
ax.invert_yaxis()
ax.margins(x=0, y=0)
ax.set_aspect('equal', 'box') # square cells
# put QTable description in title
description = ''
if frame >= 4:
# only report dropoff locations that are full
pd = [int(x) for x in bin(frame-4+16)[3:]]
if pd[3] == 1: # first
description += 'First'
if pd[2] == 1: # second
if description != '':
description += ', second'
else:
description += 'Second'
if pd[1] == 1: # third
if description != '':
description += ', third'
else:
description += 'Third'
if pd[0] == 1: # fourth
if description != '':
description += ', fourth'
else:
description += 'Fourth'
if description == '':
description += 'No'
description += ' dropoff(s) full'
else:
# only report pickup locations that are empty
pd = [int(x) for x in bin(frame+4)[3:]]
if pd[1] == 1:
description += 'First'
if pd[0] == 1:
if description != '':
description += ', second'
else:
description += 'Second'
if description == '':
description += 'No'
description += ' pickup(s) empty'
ax.set(title=f'QTable Frame {frame+1}\n{description}')
plt.tight_layout()
if display == True:
plt.show()
if save == True:
plt.savefig(f'{path}/frame{frame+1}')
plt.close()
def plotLineGraph(arr, env, display=True, save=False, path = '', title = ''):
_, ax = plt.subplots()
ax.plot(arr, linestyle='-', color='red', marker='.')
ax.set_ylim(min(0, min(arr)))
ax.set(title=title)
plt.grid(b=True, axis='y', linestyle='--')
if display == True:
plt.show()
if save == True:
plt.savefig(path)
plt.close()