-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_full_cifar10.py
112 lines (103 loc) · 5.18 KB
/
main_full_cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import apex.amp as amp
import torchattacks
import torch
import torch.nn as nn
from utils import evaluate_standard, evaluate_adv, get_loaders, save_best
from utils_cifar10 import getModel
from config import hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--train', default='standard', type=str, choices=['adv', 'standard'], help="perform adversarial training or standard training")
parser.add_argument('--lr-schedule', default='cyclic', type=str, choices=['cyclic', 'multistep'])
parser.add_argument('--lr-min', default=0., type=float)
parser.add_argument('--lr-max', default=0.2, type=float)
parser.add_argument('--weight-decay', default=5e-4, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--opt-level', default='O2', type=str, choices=['O0', 'O1', 'O2'],
help='O0 is FP32 training, O1 is Mixed Precision, and O2 is "Almost FP16" Mixed Precision')
parser.add_argument('--loss-scale', default='1.0', type=str, choices=['1.0', 'dynamic'],
help='If loss_scale is "dynamic", adaptively adjust the loss scale over time')
parser.add_argument('--master-weights', action='store_true',
help='Maintain FP32 master weights to accompany any FP16 model weights, not applicable for O1 opt level')
parser.add_argument('--ite', default=0, type=int, help="The repetition ID of experiments")
parser.add_argument('--modelID', default=0, type=int, help="The ID of models")
return parser.parse_args()
def main():
args = get_args()
dataName = args.dataName
modelID = args.modelID
model, modelName = getModel(modelID)
save_model = f"full-{args.train}-{args.ite}.pt"
parameters = hyperparameters(dataName, modelName)
train_loader, _, _, _, val_loader = get_loaders(parameters.data_dir, parameters.batch_size, dataName)
save_model_name = parameters.save_model_root + save_model
opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay)
amp_args = dict(opt_level=args.opt_level, loss_scale=args.loss_scale, verbosity=False)
if args.opt_level == 'O2':
amp_args['master_weights'] = args.master_weights
model, opt = amp.initialize(model, opt, **amp_args)
criterion = nn.CrossEntropyLoss()
lr_steps = parameters.epochs * len(train_loader)
if args.lr_schedule == 'cyclic':
scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=args.lr_min, max_lr=args.lr_max,
step_size_up=lr_steps / 2, step_size_down=lr_steps / 2)
elif args.lr_schedule == 'multistep':
scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[lr_steps / 2, lr_steps * 3 / 4], gamma=0.1)
# Training
best_acc = 0
for epoch in range(parameters.epochs):
model.train()
if args.train == "standard":
for X, y in train_loader:
X, y = X.to(device), y.to(device)
output = model(X)
loss = criterion(output, y)
opt.zero_grad()
with amp.scale_loss(loss, opt) as scaled_loss:
scaled_loss.backward()
opt.step()
scheduler.step()
elif args.train == "adv":
atk = torchattacks.PGD(model, eps=parameters.epsilon, alpha=parameters.alpha, steps=parameters.attack_iters, random_start=True)
for X, y in train_loader:
X, y = X.to(device), y.to(device)
adv_data = atk(X, y)
output_adv = model(adv_data)
loss = criterion(output_adv, y)
opt.zero_grad()
with amp.scale_loss(loss, opt) as scaled_loss:
scaled_loss.backward()
opt.step()
scheduler.step()
else:
print("wrong training type")
if args.train == "standard":
val_acc = evaluate_standard(val_loader, model)
if val_acc >= best_acc:
checkpoint = {
'state_dict': model.state_dict(),
'optimizer': opt.state_dict(),
'metric_best': val_acc,
'epoch_best': epoch,
'current_stage': 0,
'candidate_index': 0
}
best_acc = val_acc
save_best(checkpoint, save_model_name)
elif args.train == "adv":
val_attack_acc = evaluate_adv(val_loader, model, "pgd", dataName)
if val_attack_acc >= best_acc:
checkpoint = {
'state_dict': model.state_dict(),
'optimizer': opt.state_dict(),
'metric_best': val_attack_acc,
'epoch_best': epoch,
'current_stage': 0,
'candidate_index': 0
}
best_acc = val_attack_acc
save_best(checkpoint, save_model_name)
if __name__ == "__main__":
main()