-
Notifications
You must be signed in to change notification settings - Fork 42
/
models.py
47 lines (34 loc) · 1.5 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
__author__ = "Stefan Weißenberger and Johannes Gasteiger"
__license__ = "MIT"
from typing import List
import torch
from torch.nn import ModuleList, Dropout, ReLU
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, InMemoryDataset
class GCN(torch.nn.Module):
def __init__(self,
dataset: InMemoryDataset,
hidden: List[int] = [64],
dropout: float = 0.5):
super(GCN, self).__init__()
num_features = [dataset.data.x.shape[1]] + hidden + [dataset.num_classes]
layers = []
for in_features, out_features in zip(num_features[:-1], num_features[1:]):
layers.append(GCNConv(in_features, out_features))
self.layers = ModuleList(layers)
self.reg_params = list(layers[0].parameters())
self.non_reg_params = list([p for l in layers[1:] for p in l.parameters()])
self.dropout = Dropout(p=dropout)
self.act_fn = ReLU()
def reset_parameters(self):
for layer in self.layers:
layer.reset_parameters()
def forward(self, data: Data):
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
for i, layer in enumerate(self.layers):
x = layer(x, edge_index, edge_weight=edge_attr)
if i == len(self.layers) - 1:
break
x = self.act_fn(x)
x = self.dropout(x)
return torch.nn.functional.log_softmax(x, dim=1)