-
Notifications
You must be signed in to change notification settings - Fork 1
/
gru_set2set.py
90 lines (68 loc) · 3.27 KB
/
gru_set2set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
"""
Created on 3/5/2019
@author: RuihongQiu
"""
import torch
import torch.nn as nn
from torch_scatter import scatter_add
from torch_geometric.utils import softmax
class GRUSet2Set(torch.nn.Module):
r"""The global pooling operator based on iterative content-based attention
from the `"Order Matters: Sequence to sequence for sets"
<https://arxiv.org/abs/1511.06391>`_ paper
.. math::
\mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1})
\alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t)
\mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i
\mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t,
where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice
the dimensionality as the input.
Args:
in_channels (int): Size of each input sample.
processing_steps (int): Number of iterations :math:`T`.
num_layers (int, optional): Number of recurrent layers, *.e.g*, setting
:obj:`num_layers=2` would mean stacking two LSTMs together to form
a stacked LSTM, with the second LSTM taking in outputs of the first
LSTM and computing the final results. (default: :obj:`1`)
"""
def __init__(self, in_channels, processing_steps, num_layers=1):
super(GRUSet2Set, self).__init__()
self.in_channels = in_channels
self.out_channels = 2 * in_channels
self.processing_steps = processing_steps
self.num_layers = num_layers
self.rnn = nn.GRU(self.out_channels, self.in_channels,
num_layers)
self.linear = nn.Linear(in_channels * 3, in_channels)
self.reset_parameters()
def reset_parameters(self):
for weight in self.rnn.parameters():
if len(weight.size()) > 1:
torch.nn.init.orthogonal_(weight.data)
def forward(self, x, batch):
""""""
batch_size = batch.max().item() + 1
h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
x.new_zeros((self.num_layers, batch_size, self.in_channels)))
q_star = x.new_zeros(batch_size, self.out_channels)
sections = torch.bincount(batch)
v_i = torch.split(x, tuple(sections.cpu().numpy())) # split whole x back into graphs G_i
v_n_repeat = tuple(nodes[-1].view(1, -1).repeat(nodes.shape[0], 1) for nodes in
v_i) # repeat |V|_i times for the last node embedding
# x = x * v_n_repeat
for i in range(self.processing_steps):
if i == 0:
q, h = self.rnn(q_star.unsqueeze(0))
else:
q, h = self.rnn(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.in_channels)
# e = self.linear(torch.cat((x, q[batch], torch.cat(v_n_repeat, dim=0)), dim=-1)).sum(dim=-1, keepdim=True)
e = (x * q[batch]).sum(dim=-1, keepdim=True)
a = softmax(e, batch, num_nodes=batch_size)
r = scatter_add(a * x, batch, dim=0, dim_size=batch_size)
q_star = torch.cat([q, r], dim=-1)
return q_star
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)