-
Notifications
You must be signed in to change notification settings - Fork 1
/
weighted_gat.py
128 lines (106 loc) · 4.81 KB
/
weighted_gat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-
"""
Created on 5/5/2019
@author: RuihongQiu
"""
import torch
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros
from add_self_loops_partial import add_self_loops_partial
class WeightedGATConv(MessagePassing):
r"""The graph attentional operator from the `"Graph Attention Networks"
<https://arxiv.org/abs/1710.10903>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed as
.. math::
\alpha_{i,j} =
\frac{
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
\right)\right)}
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
\right)\right)}.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`False`, the multi-head
attentions are averaged instead of concatenated. (default: :obj:`True`)
negative_slope (float, optional): LeakyReLU angle of the negative
slope. (default: :obj:`0.2`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""
def __init__(self,
in_channels,
out_channels,
heads=1,
concat=True,
negative_slope=0.2,
dropout=0,
bias=True,
weighted=True):
super(WeightedGATConv, self).__init__('add')
self.weighted = weighted
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.weight = Parameter(
torch.Tensor(in_channels, heads * out_channels))
self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels + 1))
if bias and concat:
self.bias = Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
glorot(self.att)
zeros(self.bias)
def forward(self, x, edge_index, edge_attr=None):
""""""
edge_index, edge_attr = add_self_loops_partial(edge_index, edge_attr)
x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
return self.propagate(edge_index, x=x, num_nodes=edge_index.max() + 1, edge_attr=edge_attr)
def message(self, x_i, x_j, edge_index, num_nodes, edge_attr):
# Compute attention coefficients.
if edge_attr is not None:
# alpha = ((torch.cat([x_i, x_j], dim=-1) * self.att) * edge_attr.view(-1, 1, 1)).sum(dim=-1)
alpha = (torch.cat([x_i, x_j, edge_attr.view(-1, 1).repeat(1, x_i.shape[1]).view(-1, x_i.shape[1], 1)],
dim=-1) * self.att).sum(dim=-1)
else:
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, edge_index[0], num_nodes)
# Sample attention coefficients stochastically.
if self.training and self.dropout > 0:
alpha = F.dropout(alpha, p=self.dropout, training=True)
return x_j * alpha.view(-1, self.heads, 1)
def update(self, aggr_out):
if self.concat is True:
aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
else:
aggr_out = aggr_out.mean(dim=1)
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)