-
Notifications
You must be signed in to change notification settings - Fork 2
/
plot.py
342 lines (289 loc) · 12.9 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
import os.path as osp
import numpy as np
DIV_LINE_WIDTH = 50
# Global vars for tracking and labeling data at load time.
exp_idx = 0
units = dict()
def plot_data(data, xaxis='Epoch', value="AverageEpRet",
condition="Condition1", smooth=1, paper=False,
hidelegend=False, title=None, savedir=None,
clear_xticks=False, **kwargs):
# special handling for plotting a horizontal line
splits = value.split(',')
value = splits[0]
#if len(splits) > 1:
# y_horiz = float(splits[1])
#else:
# y_horiz = None
y_horiz, ymin, ymax = None, None, None
if len(splits) > 1:
for split in splits[1:]:
if split[0]=='h':
y_horiz = float(split[1:])
elif split[0]=='l':
ymin = float(split[1:])
elif split[0]=='u':
ymax = float(split[1:])
if isinstance(data, list):
# Seive data so only data with value column is present
data = [x for x in data if value in x.columns]
if smooth > 1:
"""
smooth data with moving window average.
that is,
smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
where the "smooth" param is width of that window (2k+1)
"""
y = np.ones(smooth)
for datum in data:
x = np.asarray(datum[value])
z = np.ones(len(x))
smoothed_x = np.convolve(x,y,'same') / np.convolve(z,y,'same')
datum[value] = smoothed_x
if isinstance(data, list):
data = pd.concat(data, ignore_index=True)
font_scale = 1. if paper else 1.5
sns.set(style="darkgrid", font_scale=font_scale)
"""
#sns.set_palette(sns.color_palette('muted'))
sns.set_palette([(0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
#(1.0, 0.4980392156862745, 0.054901960784313725),
(0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
(0.5803921568627451, 0.403921568627451, 0.7411764705882353),
(0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
#(0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
#(0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
#(0.4980392156862745, 0.4980392156862745, 0.4980392156862745),
(0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
(0.09019607843137255, 0.7450980392156863, 0.8117647058823529)]
)
#"""
sns.tsplot(data=data, time=xaxis, value=value, unit="Unit", condition=condition, ci='sd', **kwargs)
"""
If you upgrade to any version of Seaborn greater than 0.8.1, switch from
tsplot to lineplot replacing L29 with:
sns.lineplot(data=data, x=xaxis, y=value, hue=condition, ci='sd', **kwargs)
Changes the colorscheme and the default legend style, though.
"""
plt.legend(loc='best')#.draggable()
"""
For the version of the legend used in the Spinning Up benchmarking page,
swap L38 with:
plt.legend(loc='upper center', ncol=6, handlelength=1,
mode="expand", borderaxespad=0., prop={'size': 13})
"""
xmax = np.max(np.asarray(data[xaxis]))
xscale = xmax > 5e3
if xscale:
# Just some formatting niceness: x-axis scale in scientific notation if max x is large
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
old_ymin, old_ymax = plt.ylim()
if ymin:
plt.ylim(bottom=min(ymin, old_ymin))
if ymax:
plt.ylim(top=max(ymax, old_ymax))
#if title:
# plt.title(title)
if paper:
plt.gcf().set_size_inches(3.85,2.75)
plt.tight_layout(pad=0.5)
else:
plt.tight_layout(pad=0.5)
if y_horiz:
# y, xmin, xmax, colors='k', linestyles='solid', label='',
plt.hlines(y_horiz, 0, xmax, colors='red', linestyles='dashed', label='limit')
fname = osp.join(savedir, title+'_'+value).lower()
if clear_xticks:
x, _ = plt.xticks()
plt.xticks(x, [])
plt.xlabel('')
fname += '_nox'
if savedir is not '':
os.makedirs(savedir, exist_ok=True)
plt.savefig(fname+'.pdf', format='pdf')
if hidelegend:
plt.legend().remove()
if savedir is not '':
plt.savefig(fname + '_nolegend.pdf', format='pdf')
if savedir is not '':
# Separately save legend
h, l = plt.axes().get_legend_handles_labels()
legfig, legax = plt.subplots(figsize=(7.5,0.75))
legax.set_facecolor('white')
leg = legax.legend(h, l, loc='center', ncol=5, handlelength=1.5,
mode="expand", borderaxespad=0., prop={'size': 13})
legax.xaxis.set_visible(False)
legax.yaxis.set_visible(False)
for line in leg.get_lines():
line.set_linewidth(4.0)
plt.tight_layout(pad=0.5)
plt.savefig(osp.join(savedir, title+'_legend.pdf'), format='pdf')
def get_datasets(logdir, condition=None):
"""
Recursively look through logdir for output files produced by
spinup.logx.Logger.
Assumes that any file "progress.txt" is a valid hit.
"""
global exp_idx
global units
datasets = []
for root, _, files in os.walk(logdir):
if 'progress.txt' in files:
exp_name = None
try:
config_path = open(os.path.join(root,'config.json'))
config = json.load(config_path)
if 'exp_name' in config:
exp_name = config['exp_name']
except:
print('No file named config.json')
condition1 = condition or exp_name or 'exp'
condition2 = condition1 + '-' + str(exp_idx)
exp_idx += 1
if condition1 not in units:
units[condition1] = 0
unit = units[condition1]
units[condition1] += 1
try:
exp_data = pd.read_table(os.path.join(root,'progress.txt'))
except:
print('Could not read from %s'%os.path.join(root,'progress.txt'))
continue
performance = 'AverageTestEpRet' if 'AverageTestEpRet' in exp_data else 'AverageEpRet'
exp_data.insert(len(exp_data.columns),'Unit',unit)
exp_data.insert(len(exp_data.columns),'Condition1',condition1)
exp_data.insert(len(exp_data.columns),'Condition2',condition2)
exp_data.insert(len(exp_data.columns),'Performance',exp_data[performance])
datasets.append(exp_data)
return datasets
def get_all_datasets(all_logdirs, legend=None, select=None, exclude=None):
"""
For every entry in all_logdirs,
1) check if the entry is a real directory and if it is,
pull data from it;
2) if not, check to see if the entry is a prefix for a
real directory, and pull data from that.
"""
logdirs = []
for logdir in all_logdirs:
if osp.isdir(logdir) and logdir[-1]=='/':
logdirs += [logdir]
else:
basedir = osp.dirname(logdir)
fulldir = lambda x : osp.join(basedir, x)
prefix = logdir.split('/')[-1]
listdir= os.listdir(basedir)
logdirs += sorted([fulldir(x) for x in listdir if prefix in x])
"""
Enforce selection rules, which check logdirs for certain substrings.
Makes it easier to look at graphs from particular ablations, if you
launch many jobs at once with similar names.
"""
if select is not None:
logdirs = [log for log in logdirs if all(x in log for x in select)]
if exclude is not None:
logdirs = [log for log in logdirs if all(not(x in log) for x in exclude)]
# Verify logdirs
print('Plotting from...\n' + '='*DIV_LINE_WIDTH + '\n')
for logdir in logdirs:
print(logdir)
print('\n' + '='*DIV_LINE_WIDTH)
# Make sure the legend is compatible with the logdirs
assert not(legend) or (len(legend) == len(logdirs)), \
"Must give a legend title for each set of experiments."
# Load data from logdirs
data = []
if legend:
for log, leg in zip(logdirs, legend):
data += get_datasets(log, leg)
else:
for log in logdirs:
data += get_datasets(log)
return data
def make_plots(all_logdirs, legend=None, xaxis=None, values=None, count=False,
font_scale=1.5, smooth=1, select=None, exclude=None, estimator='mean',
paper=False, hidelegend=False, title=None, savedir=None, show=True,
clear_xticks=False):
data = get_all_datasets(all_logdirs, legend, select, exclude)
values = values if isinstance(values, list) else [values]
condition = 'Condition2' if count else 'Condition1'
estimator = getattr(np, estimator) # choose what to show on main curve: mean? max? min?
for value in values:
plt.figure()
plot_data(data, xaxis=xaxis, value=value, condition=condition,
smooth=smooth, estimator=estimator,
paper=paper, hidelegend=hidelegend,
title=title, savedir=savedir,
clear_xticks=clear_xticks)
if show:
plt.show()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('logdir', nargs='*')
parser.add_argument('--legend', '-l', nargs='*')
parser.add_argument('--xaxis', '-x', default='TotalEnvInteracts')
parser.add_argument('--value', '-y', default='Performance', nargs='*')
parser.add_argument('--count', action='store_true')
parser.add_argument('--smooth', '-s', type=int, default=1)
parser.add_argument('--select', nargs='*')
parser.add_argument('--exclude', nargs='*')
parser.add_argument('--est', default='mean')
parser.add_argument('--paper', action='store_true')
parser.add_argument('--hidelegend', '-hl', action='store_true')
parser.add_argument('--title', type=str, default='')
parser.add_argument('--savedir', type=str, default='')
parser.add_argument('--dont_show', action='store_true')
parser.add_argument('--clearx', action='store_true')
args = parser.parse_args()
"""
Args:
logdir (strings): As many log directories (or prefixes to log
directories, which the plotter will autocomplete internally) as
you'd like to plot from.
legend (strings): Optional way to specify legend for the plot. The
plotter legend will automatically use the ``exp_name`` from the
config.json file, unless you tell it otherwise through this flag.
This only works if you provide a name for each directory that
will get plotted. (Note: this may not be the same as the number
of logdir args you provide! Recall that the plotter looks for
autocompletes of the logdir args: there may be more than one
match for a given logdir prefix, and you will need to provide a
legend string for each one of those matches---unless you have
removed some of them as candidates via selection or exclusion
rules (below).)
xaxis (string): Pick what column from data is used for the x-axis.
Defaults to ``TotalEnvInteracts``.
value (strings): Pick what columns from data to graph on the y-axis.
Submitting multiple values will produce multiple graphs. Defaults
to ``Performance``, which is not an actual output of any algorithm.
Instead, ``Performance`` refers to either ``AverageEpRet``, the
correct performance measure for the on-policy algorithms, or
``AverageTestEpRet``, the correct performance measure for the
off-policy algorithms. The plotter will automatically figure out
which of ``AverageEpRet`` or ``AverageTestEpRet`` to report for
each separate logdir.
count: Optional flag. By default, the plotter shows y-values which
are averaged across all results that share an ``exp_name``,
which is typically a set of identical experiments that only vary
in random seed. But if you'd like to see all of those curves
separately, use the ``--count`` flag.
smooth (int): Smooth data by averaging it over a fixed window. This
parameter says how wide the averaging window will be.
select (strings): Optional selection rule: the plotter will only show
curves from logdirs that contain all of these substrings.
exclude (strings): Optional exclusion rule: plotter will only show
curves from logdirs that do not contain these substrings.
"""
make_plots(args.logdir, args.legend, args.xaxis, args.value, args.count,
smooth=args.smooth, select=args.select, exclude=args.exclude,
estimator=args.est, paper=args.paper, hidelegend=args.hidelegend,
title=args.title, savedir=args.savedir, show=not(args.dont_show),
clear_xticks=args.clearx)
if __name__ == "__main__":
main()