-
Notifications
You must be signed in to change notification settings - Fork 3
/
model.py
125 lines (106 loc) · 4.06 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from collections import OrderedDict
import copy
# from lightly.models.utils import deactivate_requires_grad
# from lightly.models.utils import update_momentum
class SimSiamModel(nn.Module):
def __init__(self):
super(SimSiamModel, self).__init__()
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
self.backbone.fc = nn.Identity()
self.projector = nn.Sequential(*[
nn.Linear(2048, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Linear(2048, 512)
])
def forward(self, x):
return self.projector(self.backbone(x))
class SimCLRModel(nn.Module):
def __init__(self):
super(SimCLRModel, self).__init__()
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
self.backbone.fc = nn.Identity()
self.projector = nn.Sequential(*[
nn.Linear(2048, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 128)
])
def forward(self, x):
return F.normalize(self.projector(self.backbone(x)))
class DCLWModel(nn.Module):
def __init__(self):
super(DCLWModel, self).__init__()
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
self.backbone.fc = nn.Identity()
self.projector = nn.Sequential(*[
nn.Linear(2048, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 128)
])
def forward(self, x):
return F.normalize(self.projector(self.backbone(x)))
class VICRegModel(nn.Module):
def __init__(self):
super(VICRegModel, self).__init__()
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
self.backbone.fc = nn.Identity()
self.projector = nn.Sequential(*[
nn.Linear(2048, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Linear(4096, 4096)
])
def forward(self, x):
return self.projector(self.backbone(x))
class BarlowModel(nn.Module):
def __init__(self):
super(BarlowModel, self).__init__()
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
self.backbone.fc = nn.Identity()
self.projector = nn.Sequential(*[
nn.Linear(2048, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Linear(4096, 4096)
])
self.bn = nn.BatchNorm1d(4096, affine=False)
def forward(self, x):
return self.bn(self.projector(self.backbone(x)))
class TiCoModel(nn.Module):
def __init__(self):
super(TiCoModel, self).__init__()
self.backbone = torchvision.models.resnet50(pretrained=False)
self.backbone.fc = nn.Identity()
self.projection_head = nn.Sequential(*[
nn.Linear(2048, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Linear(4096, 4096)
])
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
self.deactivate_requires_grad(self.backbone_momentum)
self.deactivate_requires_grad(self.projection_head_momentum)
def forward(self, x):
query = self.backbone(x).flatten(start_dim=1)
query = self.projection_head(query)
return query
def forward_momentum(self, x):
key = self.backbone_momentum(x).flatten(start_dim=1)
key = self.projection_head_momentum(key).detach()
return key
def deactivate_requires_grad(self, module):
for param in module.parameters():
param.requires_grad = False
def update_momentum(self, model, model_ema, m):
for model_ema, model in zip(model_ema.parameters(), model.parameters()):
model_ema.data = model_ema.data * m + model.data * (1.0 - m)
def schedule_momentum(self, iter, max_iter, m=0.99):
return m + (1 - m)*np.sin((np.pi/2)*iter/(max_iter-1))