-
Notifications
You must be signed in to change notification settings - Fork 172
/
cdac.py
275 lines (219 loc) · 9.06 KB
/
cdac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import numpy as np
from functools import partial
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import LambdaLR
from dassl.data import DataManager
from dassl.optim import build_optimizer
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import ReverseGrad
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms.transforms import build_transform
def custom_scheduler(iter, max_iter=None, alpha=10, beta=0.75, init_lr=0.001):
"""Custom LR Annealing
https://arxiv.org/pdf/1409.7495.pdf
"""
if max_iter is None:
return init_lr
return (1 + float(iter / max_iter) * alpha)**(-1.0 * beta)
class AAC(nn.Module):
def forward(self, sim_mat, prob_u, prob_us):
P = prob_u.matmul(prob_us.t())
loss = -(
sim_mat * torch.log(P + 1e-7) +
(1.-sim_mat) * torch.log(1. - P + 1e-7)
)
return loss.mean()
class Prototypes(nn.Module):
def __init__(self, fdim, num_classes, temp=0.05):
super().__init__()
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
self.temp = temp
self.revgrad = ReverseGrad()
def forward(self, x, reverse=False):
if reverse:
x = self.revgrad(x)
x = F.normalize(x, p=2, dim=1)
out = self.prototypes(x)
out = out / self.temp
return out
@TRAINER_REGISTRY.register()
class CDAC(TrainerXU):
"""Cross Domain Adaptive Clustering.
https://arxiv.org/pdf/2104.09415.pdf
"""
def __init__(self, cfg):
self.rampup_coef = cfg.TRAINER.CDAC.RAMPUP_COEF
self.rampup_iters = cfg.TRAINER.CDAC.RAMPUP_ITRS
self.lr_multi = cfg.TRAINER.CDAC.CLASS_LR_MULTI
self.topk = cfg.TRAINER.CDAC.TOPK_MATCH
self.p_thresh = cfg.TRAINER.CDAC.P_THRESH
self.aac_criterion = AAC()
super().__init__(cfg)
def check_cfg(self, cfg):
assert len(
cfg.TRAINER.CDAC.STRONG_TRANSFORMS
) > 0, "Strong augmentations are necessary to run CDAC"
assert cfg.DATALOADER.K_TRANSFORMS == 2, "CDAC needs two strong augmentations of the same image."
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.CDAC.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = self.dm.train_loader_x
self.train_loader_u = self.dm.train_loader_u
self.val_loader = self.dm.val_loader
self.test_loader = self.dm.test_loader
self.num_classes = self.dm.num_classes
self.lab2cname = self.dm.lab2cname
def build_model(self):
cfg = self.cfg
# Custom LR Scheduler for CDAC
if self.cfg.TRAIN.COUNT_ITER == "train_x":
self.num_batches = len(self.train_loader_x)
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
self.num_batches = len(self.len_train_loader_u)
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
self.num_batches = min(
len(self.train_loader_x), len(self.train_loader_u)
)
self.max_iter = self.max_epoch * self.num_batches
print("Max Iterations: %d" % self.max_iter)
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
custom_lr_F = partial(
custom_scheduler, max_iter=self.max_iter, init_lr=cfg.OPTIM.LR
)
self.sched_F = LambdaLR(self.optim_F, custom_lr_F)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building C")
self.C = Prototypes(self.F.fdim, self.num_classes)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
# Multiply the learning rate of C by lr_multi
for group_param in self.optim_C.param_groups:
group_param['lr'] *= self.lr_multi
custom_lr_C = partial(
custom_scheduler,
max_iter=self.max_iter,
init_lr=cfg.OPTIM.LR * self.lr_multi
)
self.sched_C = LambdaLR(self.optim_C, custom_lr_C)
self.register_model("C", self.C, self.optim_C, self.sched_C)
def assess_y_pred_quality(self, y_pred, y_true, mask):
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
keep_rate = mask.sum() / mask.numel()
output = {
"acc_thre": acc_thre,
"acc_raw": acc_raw,
"keep_rate": keep_rate
}
return output
def forward_backward(self, batch_x, batch_u):
current_itr = self.epoch * self.num_batches + self.batch_idx
input_x, label_x, input_u, input_us, input_us2, label_u = self.parse_batch_train(
batch_x, batch_u
)
# Paper Reference Eq. 2 - Supervised Loss
feat_x = self.F(input_x)
logit_x = self.C(feat_x)
loss_x = F.cross_entropy(logit_x, label_x)
self.model_backward_and_update(loss_x)
feat_u = self.F(input_u)
feat_us = self.F(input_us)
feat_us2 = self.F(input_us2)
# Paper Reference Eq.3 - Adversarial Adaptive Loss
logit_u = self.C(feat_u, reverse=True)
logit_us = self.C(feat_us, reverse=True)
prob_u, prob_us = F.softmax(logit_u, dim=1), F.softmax(logit_us, dim=1)
# Get similarity matrix s_ij
sim_mat = self.get_similarity_matrix(feat_u, self.topk, self.device)
aac_loss = (-1. * self.aac_criterion(sim_mat, prob_u, prob_us))
# Paper Reference Eq. 4 - Pseudo label Loss
logit_u = self.C(feat_u)
logit_us = self.C(feat_us)
logit_us2 = self.C(feat_us2)
prob_u, prob_us, prob_us2 = F.softmax(
logit_u, dim=1
), F.softmax(
logit_us, dim=1
), F.softmax(
logit_us2, dim=1
)
prob_u = prob_u.detach()
max_probs, max_idx = torch.max(prob_u, dim=-1)
mask = max_probs.ge(self.p_thresh).float()
p_u_stats = self.assess_y_pred_quality(max_idx, label_u, mask)
pl_loss = (
F.cross_entropy(logit_us2, max_idx, reduction='none') * mask
).mean()
# Paper Reference Eq. 8 - Consistency Loss
cons_multi = self.sigmoid_rampup(
current_itr=current_itr, rampup_itr=self.rampup_iters
) * self.rampup_coef
cons_loss = cons_multi * F.mse_loss(prob_us, prob_us2)
loss_u = aac_loss + pl_loss + cons_loss
self.model_backward_and_update(loss_u)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
"aac_loss": aac_loss.item(),
"pl_loss": pl_loss.item(),
"cons_loss": cons_loss.item(),
"p_u_pred_acc": p_u_stats["acc_raw"],
"p_u_pred_acc_thre": p_u_stats["acc_thre"],
"p_u_pred_keep": p_u_stats["keep_rate"]
}
# Update LR after every iteration as mentioned in the paper
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
input_u = batch_u["img"][0]
input_us = batch_u["img2"][0]
input_us2 = batch_u["img2"][1]
label_u = batch_u["label"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_us = input_us.to(self.device)
input_us2 = input_us2.to(self.device)
label_u = label_u.to(self.device)
return input_x, label_x, input_u, input_us, input_us2, label_u
def model_inference(self, input):
return self.C(self.F(input))
@staticmethod
def get_similarity_matrix(feat, topk, device):
feat_d = feat.detach()
feat_d = torch.sort(
torch.argsort(feat_d, dim=1, descending=True)[:, :topk], dim=1
)[0]
sim_mat = torch.zeros((feat_d.shape[0], feat_d.shape[0])).to(device)
for row in range(feat_d.shape[0]):
sim_mat[row, torch.all(feat_d == feat_d[row, :], dim=1)] = 1
return sim_mat
@staticmethod
def sigmoid_rampup(current_itr, rampup_itr):
"""Exponential Rampup
https://arxiv.org/abs/1610.02242
"""
if rampup_itr == 0:
return 1.0
else:
var = np.clip(current_itr, 0.0, rampup_itr)
phase = 1.0 - var/rampup_itr
return float(np.exp(-5.0 * phase * phase))