-
Notifications
You must be signed in to change notification settings - Fork 9
/
cifar-mtl.py
124 lines (87 loc) · 3.04 KB
/
cifar-mtl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from models import LeNet
from utils import *
# PATHS
CHECKPOINT = "./checkpoints/cifar-mtl"
# BATCH
BATCH_SIZE = 256
NUM_WORKERS = 4
# SGD
LEARNING_RATE = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
# Step Decay
LR_DROP = 0.5
EPOCHS_DROP = 20
# MISC
EPOCHS = 100
CUDA = True
# Manual seed
SEED = 20
random.seed(SEED)
torch.manual_seed(SEED)
if CUDA:
torch.cuda.manual_seed_all(SEED)
ALL_CLASSES = range(10)
def main():
if not os.path.isdir(CHECKPOINT):
os.makedirs(CHECKPOINT)
print('==> Preparing dataset')
trainloader, testloader = load_CIFAR(batch_size = BATCH_SIZE, num_workers = NUM_WORKERS)
CLASSES = []
AUROCs = []
for t, cls in enumerate(ALL_CLASSES):
print('\nTask: [%d | %d]\n' % (t + 1, len(ALL_CLASSES)))
CLASSES.append(cls)
print("==> Creating model")
model = LeNet(num_classes=len(CLASSES))
if CUDA:
model = model.cuda()
model = nn.DataParallel(model)
cudnn.benchmark = True
print(' Total params: %.2fK' % (sum(p.numel() for p in model.parameters()) / 1000) )
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(),
lr=LEARNING_RATE,
momentum=MOMENTUM,
weight_decay=WEIGHT_DECAY
)
print("==> Learning")
best_loss = 1e10
learning_rate = LEARNING_RATE
for epoch in range(EPOCHS):
# decay learning rate
if (epoch + 1) % EPOCHS_DROP == 0:
learning_rate *= LR_DROP
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
print('Epoch: [%d | %d]' % (epoch + 1, EPOCHS))
train_loss = train(trainloader, model, criterion, CLASSES, CLASSES, optimizer = optimizer, use_cuda = CUDA)
test_loss = train(testloader, model, criterion, CLASSES, CLASSES, test = True, use_cuda = CUDA)
# save model
is_best = test_loss < best_loss
best_loss = min(test_loss, best_loss)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'loss': test_loss,
'optimizer': optimizer.state_dict()
}, CHECKPOINT, is_best)
print("==> Calculating AUROC")
filepath_best = os.path.join(CHECKPOINT, "best.pt")
checkpoint = torch.load(filepath_best)
model.load_state_dict(checkpoint['state_dict'])
auroc = calc_avg_AUROC(model, testloader, CLASSES, CLASSES, CUDA)
print( 'AUROC: {}'.format(auroc) )
AUROCs.append(auroc)
print( '\nAverage Per-task Performance over number of tasks' )
for i, p in enumerate(AUROCs):
print("%d: %f" % (i+1,p))
if __name__ == '__main__':
main()