forked from gcalbertini/273Kelvin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
backbone2.py
33 lines (25 loc) · 818 Bytes
/
backbone2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torchvision.models as models
import torch.nn as nn
from lightning import train_backbone
class Backbone(nn.Module):
def __init__(self, backbone):
super().__init__()
self.out_channels = 512
self.premodel = backbone
# change this later
for p in self.premodel.parameters():
p.requires_grad = True
def forward(self,x):
out = self.premodel(x)
out = out.unsqueeze(2)
out = out.unsqueeze(3)
return out
def get_backbone(train=False):
if train:
train_backbone()
backbone = models.resnet18(pretrained=None)
backbone.fc = nn.Identity()
checkpoint = torch.load('./resnet18_backbone_weights.ckpt')
backbone.load_state_dict(checkpoint['model_state_dict'])
return Backbone(backbone)