-
Notifications
You must be signed in to change notification settings - Fork 107
/
model.py
57 lines (42 loc) · 2.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch import nn
from torch.nn import init
from transformers import BertModel, RobertaModel
class MatchSum(nn.Module):
def __init__(self, candidate_num, encoder, hidden_size=768):
super(MatchSum, self).__init__()
self.hidden_size = hidden_size
self.candidate_num = candidate_num
if encoder == 'bert':
self.encoder = BertModel.from_pretrained('bert-base-uncased')
else:
self.encoder = RobertaModel.from_pretrained('roberta-base')
def forward(self, text_id, candidate_id, summary_id):
batch_size = text_id.size(0)
pad_id = 0 # for BERT
if text_id[0][0] == 0:
pad_id = 1 # for RoBERTa
# get document embedding
input_mask = ~(text_id == pad_id)
out = self.encoder(text_id, attention_mask=input_mask)[0] # last layer
doc_emb = out[:, 0, :]
assert doc_emb.size() == (batch_size, self.hidden_size) # [batch_size, hidden_size]
# get summary embedding
input_mask = ~(summary_id == pad_id)
out = self.encoder(summary_id, attention_mask=input_mask)[0] # last layer
summary_emb = out[:, 0, :]
assert summary_emb.size() == (batch_size, self.hidden_size) # [batch_size, hidden_size]
# get summary score
summary_score = torch.cosine_similarity(summary_emb, doc_emb, dim=-1)
# get candidate embedding
candidate_num = candidate_id.size(1)
candidate_id = candidate_id.view(-1, candidate_id.size(-1))
input_mask = ~(candidate_id == pad_id)
out = self.encoder(candidate_id, attention_mask=input_mask)[0]
candidate_emb = out[:, 0, :].view(batch_size, candidate_num, self.hidden_size) # [batch_size, candidate_num, hidden_size]
assert candidate_emb.size() == (batch_size, candidate_num, self.hidden_size)
# get candidate score
doc_emb = doc_emb.unsqueeze(1).expand_as(candidate_emb)
score = torch.cosine_similarity(candidate_emb, doc_emb, dim=-1) # [batch_size, candidate_num]
assert score.size() == (batch_size, candidate_num)
return {'score': score, 'summary_score': summary_score}