forked from deepinsight/insightface
-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
75 lines (68 loc) · 2.63 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import timm
class GazeModel(pl.LightningModule):
def __init__(self, backbone, epoch):
super().__init__()
self.save_hyperparameters()
self.backbone = timm.create_model(backbone, num_classes=481*2*3)
self.epoch = epoch
#self.loss = nn.MSELoss(reduction='mean')
self.loss = nn.L1Loss(reduction='mean')
#self.hard_mining = False
self.hard_mining = False
self.num_face = 1103
self.num_eye = 481*2
def forward(self, x):
# use forward for inference/predictions
y = self.backbone(x)
return y
def cal_loss(self, y_hat, y, hm=False):
bs = y.size(0)
y_hat = y_hat.view( (bs,-1,3) )
loss = torch.abs(y_hat - y) #(B,K,3)
loss[:,:,2] *= 0.5
if hm:
loss = torch.mean(loss, dim=(1,2)) #(B,)
loss, _ = torch.topk(loss, k=int(bs*0.25), largest=True)
#B = len(loss)
#S = int(B*0.5)
#loss, _ = torch.sort(loss, descending=True)
#loss = loss[:S]
loss = torch.mean(loss) * 20.0
return loss
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = self.cal_loss(y_hat, y, self.hard_mining)
self.log('train_loss', loss, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = self.cal_loss(y_hat, y)
self.log('val_loss', loss, on_step=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = self.cal_loss(y_hat, y)
self.log('test_loss', loss)
def configure_optimizers(self):
#return torch.optim.Adam(self.parameters(), lr=0.0002)
opt = torch.optim.SGD(self.parameters(), lr = 0.1, momentum=0.9, weight_decay = 0.0005)
epoch_steps = [int(self.epoch*0.4), int(self.epoch*0.7), int(self.epoch*0.9)]
print('epoch_steps:', epoch_steps)
def lr_step_func(epoch):
return 0.1 ** len([m for m in epoch_steps if m <= epoch])
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer=opt, lr_lambda=lr_step_func)
lr_scheduler = {
'scheduler': scheduler,
'name': 'learning_rate',
'interval':'epoch',
'frequency': 1}
return [opt], [lr_scheduler]