-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
54 lines (45 loc) · 1.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-
"""
Created on 5/4/2019
@author: RuihongQiu
"""
import numpy as np
import logging
import torch.nn as nn
def forward(model, loader, device, writer, epoch, top_k=20, optimizer=None, train_flag=True):
if train_flag:
model.train()
else:
model.eval()
hit, mrr = [], []
mean_loss = 0.0
updates_per_epoch = len(loader)
for i, batch in enumerate(loader):
if train_flag:
optimizer.zero_grad()
scores = model(batch.to(device))
targets = batch.y - 1
loss = model.loss_function(scores, targets)
if train_flag:
loss.backward()
# nn.utils.clip_grad_norm(model.parameters(), 0.5)
optimizer.step()
writer.add_scalar('loss/train_batch_loss', loss.item(), epoch * updates_per_epoch + i)
else:
sub_scores = scores.topk(top_k)[1] # batch * top_k
for score, target in zip(sub_scores.detach().cpu().numpy(), targets.detach().cpu().numpy()):
hit.append(np.isin(target, score))
if len(np.where(score == target)[0]) == 0:
mrr.append(0)
else:
mrr.append(1 / (np.where(score == target)[0][0] + 1))
mean_loss += loss / batch.num_graphs
if train_flag:
writer.add_scalar('loss/train_loss', mean_loss.item(), epoch)
else:
writer.add_scalar('loss/test_loss', mean_loss.item(), epoch)
hit = np.mean(hit) * 100
mrr = np.mean(mrr) * 100
writer.add_scalar('index/hit', hit, epoch)
writer.add_scalar('index/mrr', mrr, epoch)
return hit, mrr