forked from Stonesjtu/Pytorch-NCE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
41 lines (30 loc) · 1.24 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Main container for common language model"""
import torch
import torch.nn as nn
from utils import get_mask
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a criterion (decoder and loss function)."""
def __init__(self, ntoken, ninp, nhid, nlayers, criterion, dropout=0.5):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout, batch_first=True)
self.nhid = nhid
self.nlayers = nlayers
self.criterion = criterion
self.reset_parameters()
def reset_parameters(self):
init_range = 0.1
self.encoder.weight.data.uniform_(-init_range, init_range)
def _rnn(self, input):
'''Serves as the encoder and recurrent layer'''
emb = self.drop(self.encoder(input))
output, unused_hidden = self.rnn(emb)
output = self.drop(output)
return output
def forward(self, input, target, length):
mask = get_mask(length)
rnn_output = self._rnn(input)
loss = self.criterion(target, rnn_output)
loss = torch.masked_select(loss, mask)
return loss.mean()