-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
197 lines (153 loc) · 6.88 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
import torch.nn.functional as F
from torch.nn import Linear as Lin
from torch_geometric.utils import scatter
from torch_geometric.nn import (
MLP,
PointTransformerConv,
fps,
knn,
knn_graph,
knn_interpolate
)
class TransformerBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.lin_in = Lin(in_channels, in_channels)
self.lin_out = Lin(out_channels, out_channels)
self.pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False)
self.attn_nn = MLP([out_channels, 64, out_channels], norm=None,
plain_last=False)
self.transformer = PointTransformerConv(in_channels, out_channels,
pos_nn=self.pos_nn,
attn_nn=self.attn_nn)
def forward(self, x, pos, edge_index):
x = self.lin_in(x).relu()
x = self.transformer(x, pos, edge_index)
x = self.lin_out(x).relu()
return x
class TransitionDown(torch.nn.Module):
'''
Samples the input point cloud by a ratio percentage to reduce
cardinality and uses an mlp to augment features dimensionnality
'''
def __init__(self, in_channels, out_channels, ratio=0.25, k=16):
super().__init__()
self.k = k
self.ratio = ratio
self.mlp = MLP([in_channels, out_channels], plain_last=False)
def forward(self, x, pos, batch):
# FPS sampling
id_clusters = fps(pos, ratio=self.ratio, batch=batch)
# compute for each cluster the k nearest points
sub_batch = batch[id_clusters] if batch is not None else None
# beware of self loop
id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch,
batch_y=sub_batch)
# transformation of features through a simple MLP
x = self.mlp(x)
# Max pool onto each cluster the features from knn in points
x_out = scatter(x[id_k_neighbor[1]], id_k_neighbor[0], dim=0,
dim_size=id_clusters.size(0), reduce='max')
# keep only the clusters and their max-pooled features
sub_pos, out = pos[id_clusters], x_out
return out, sub_pos, sub_batch
class TransitionUp(torch.nn.Module):
'''
Reduce features dimensionnality and interpolate back to higher
resolution and cardinality
'''
def __init__(self, in_channels, out_channels):
super().__init__()
self.mlp_sub = MLP([in_channels, out_channels], plain_last=False)
self.mlp = MLP([out_channels, out_channels], plain_last=False)
def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None):
# transform low-res features and reduce the number of features
x_sub = self.mlp_sub(x_sub)
# interpolate low-res feats to high-res points
x_interpolated = knn_interpolate(x_sub, pos_sub, pos, k=3,
batch_x=batch_sub, batch_y=batch)
x = self.mlp(x) + x_interpolated
return x
class BodySeg(torch.nn.Module):
def __init__(self, in_channels, out_channels, dim_model, k=16):
super().__init__()
self.k = k
# dummy feature is created if there is none given
in_channels = max(in_channels, 1)
# first block
self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False)
self.transformer_input = TransformerBlock(
in_channels=dim_model[0],
out_channels=dim_model[0],
)
# backbone layers
self.transformers_up = torch.nn.ModuleList()
self.transformers_down = torch.nn.ModuleList()
self.transition_up = torch.nn.ModuleList()
self.transition_down = torch.nn.ModuleList()
for i in range(0, len(dim_model) - 1):
# Add Transition Down block followed by a Point Transformer block
self.transition_down.append(
TransitionDown(in_channels=dim_model[i],
out_channels=dim_model[i + 1], k=self.k))
self.transformers_down.append(
TransformerBlock(in_channels=dim_model[i + 1],
out_channels=dim_model[i + 1]))
# Add Transition Up block followed by Point Transformer block
self.transition_up.append(
TransitionUp(in_channels=dim_model[i + 1],
out_channels=dim_model[i]))
self.transformers_up.append(
TransformerBlock(in_channels=dim_model[i],
out_channels=dim_model[i]))
# summit layers
self.mlp_summit = MLP([dim_model[-1], dim_model[-1]], norm=None,
plain_last=False)
self.transformer_summit = TransformerBlock(
in_channels=dim_model[-1],
out_channels=dim_model[-1],
)
# class score computation
self.mlp_output = MLP([dim_model[0], 64, out_channels], norm=None)
def forward(self, x, pos, batch=None):
# add dummy features in case there is none
if x is None:
x = torch.ones((pos.shape[0], 3)).to(pos.get_device())
out_x = []
out_pos = []
out_batch = []
# first block
x = self.mlp_input(x)
edge_index = knn_graph(pos, k=self.k, batch=batch)
x = self.transformer_input(x, pos, edge_index)
# save outputs for skipping connections
out_x.append(x)
out_pos.append(pos)
out_batch.append(batch)
# backbone down : #reduce cardinality and augment dimensionnality
for i in range(len(self.transformers_down)):
x, pos, batch = self.transition_down[i](x, pos, batch=batch)
edge_index = knn_graph(pos, k=self.k, batch=batch)
x = self.transformers_down[i](x, pos, edge_index)
out_x.append(x)
out_pos.append(pos)
out_batch.append(batch)
# summit
x = self.mlp_summit(x)
edge_index = knn_graph(pos, k=self.k, batch=batch)
x = self.transformer_summit(x, pos, edge_index)
# backbone up : augment cardinality and reduce dimensionnality
n = len(self.transformers_down)
for i in range(n):
x = self.transition_up[-i - 1](x=out_x[-i - 2], x_sub=x,
pos=out_pos[-i - 2],
pos_sub=out_pos[-i - 1],
batch_sub=out_batch[-i - 1],
batch=out_batch[-i - 2])
edge_index = knn_graph(out_pos[-i - 2], k=self.k,
batch=out_batch[-i - 2])
x = self.transformers_up[-i - 1](x, out_pos[-i - 2], edge_index)
# Class score
out = self.mlp_output(x)
return F.log_softmax(out, dim=-1)