-
Notifications
You must be signed in to change notification settings - Fork 0
/
draw.py
102 lines (88 loc) · 3.09 KB
/
draw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os, sys, pickle
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from sklearn import metrics
def draw_lhist(tag, llist, tlist=[], save=None):
if tag=='l':
ran, bins = (0,4), 4
if tag=='y':
ran, bins = None, 50
n = len(llist)
fig = plt.figure(figsize=(4*n, 3))
for i, l in enumerate(llist):
plt.subplot(1,n,i+1)
plt.hist(l, color = '#9467bd', rwidth=0.9,
bins=bins, range=ran, alpha=0.6)
if tlist[i]: plt.title(tlist[i])
if save: plt.savefig(save)
else: plt.show()
plt.close()
def draw_lprocess(history, save=None):
figure=plt.figure()
plt.title('Loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.xticks(history.epoch)
plt.plot(history.history['loss'], 'y', label='t loss', marker='.')
plt.plot(history.history['val_loss'], 'r', label='v loss', marker='.')
plt.legend(loc='upper right')
if save!=None:
plt.savefig(save+'_loss')
else:
plt.show()
if 'accuracy' in history.history:
figure=plt.figure()
plt.title('Accuracy')
plt.xlabel('epoch')
plt.ylabel('acc')
plt.xticks(history.epoch)
plt.plot(history.history['accuracy'], 'y', label='t acc', marker='.')
plt.plot(history.history['val_accuracy'], 'r', label='v acc', marker='.')
plt.legend(loc='lower right')
if save!=None:
plt.savefig(save+'_acc')
else:
plt.show()
def draw_multi_hist(data_list, n_row=1, save=None):
n_data = len(data_list)
print('ndata: ', n_data)
plt.figure(figsize=(int(n_data/n_row)*3, n_row*3))
for i, data in enumerate(data_list):
plt.subplot(n_row, n_data/n_row, i+1)
plt.title(data)
plt.hist(data_list[data], color='#9467bd', rwidth=0.9, bins=3, alpha=0.6)
if save!=None:
plt.savefig(save)
else:
plt.show()
def draw_cm(y_true, y_pred, ax, title='Confusion matrix'):
cm = metrics.confusion_matrix(y_true, y_pred)
cm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm, annot=True, ax = ax)
plt.title(title)
buttom, top = ax.get_ylim()
ax.set_xlabel("Pred")
ax.set_ylabel("True")
ax.set_ylim(buttom+0.5, top-0.5)
def draw_response(sig, bkg, title='Response'):
plt.title(title)
kwargs = dict(histtype='stepfilled', alpha=0.4, density=True, bins=25)
plt.hist(sig, **kwargs, edgecolor='b')
plt.hist(bkg, **kwargs, edgecolor='r')
plt.legend(['awake', 'anaesthesia'], loc='upper center')
def draw_roc(fpr, tpr, ax, title = 'ROC Curve'):
tnr = 1-fpr
auc = metrics.auc(x=tpr, y=tnr)
# plot
roc_curve = Line2D(
xdata=tpr, ydata=tnr,
label="RNN (AUC = {:.3f})".format(auc),
color='darkorange', alpha=0.8, lw=3)
ax.add_line(roc_curve)
ax.set_xlabel('True Positive Rates (Signal Efficiency)',fontsize=12)
ax.set_ylabel('True Negative Rates (Background Rejection)', fontsize=12)
ax.grid()
ax.legend()
ax.set_title(title,fontsize=15)