-
Notifications
You must be signed in to change notification settings - Fork 8
/
cos_tr_mod.py
48 lines (39 loc) · 2.28 KB
/
cos_tr_mod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import ride # isort:skip
from collections import OrderedDict
import continual as co
from datasets import datasets
from models.base import CoModelBase, CoSpatioTemporalBlock
from models.s_tr.s_tr import GcnUnitAttention
class CoSTrMod(
ride.RideModule,
ride.TopKAccuracyMetric(1, 3, 5),
ride.optimizers.SgdOneCycleOptimizer,
datasets.GraphDatasets,
CoModelBase,
):
def __init__(self, hparams):
# Shapes from Dataset:
# num_channels, num_frames, num_vertices, num_skeletons
(C_in, T, V, S) = self.input_shape
A = self.graph.A
def CoGcnUnitAttention(in_channels, out_channels, A, bn_momentum=0.1):
return co.forward_stepping(
GcnUnitAttention(in_channels, out_channels, A, bn_momentum, num_point=V)
)
# fmt: off
self.layers = co.Sequential(OrderedDict([
("layer1", CoSpatioTemporalBlock(C_in, 64, A, padding=0, window_size=T, residual=False)),
("layer2", CoSpatioTemporalBlock(64, 64, A, padding=0, window_size=T - 1 * 8)),
("layer3", CoSpatioTemporalBlock(64, 64, A, padding=0, window_size=T - 2 * 8)),
("layer4", CoSpatioTemporalBlock(64, 64, A, CoGraphConv=CoGcnUnitAttention, padding=0, window_size=T - 3 * 8)),
("layer5", CoSpatioTemporalBlock(64, 128, A, CoGraphConv=CoGcnUnitAttention, padding=0, window_size=T - 4 * 8, stride=1)),
("layer6", CoSpatioTemporalBlock(128, 128, A, CoGraphConv=CoGcnUnitAttention, padding=0, window_size=(T - 4 * 8) / 2 - 1 * 8)),
("layer7", CoSpatioTemporalBlock(128, 128, A, CoGraphConv=CoGcnUnitAttention, padding=0, window_size=(T - 4 * 8) / 2 - 2 * 8)),
("layer8", CoSpatioTemporalBlock(128, 256, A, CoGraphConv=CoGcnUnitAttention, padding=0, window_size=(T - 4 * 8) / 2 - 3 * 8, stride=1)),
("layer9", CoSpatioTemporalBlock(256, 256, A, CoGraphConv=CoGcnUnitAttention, padding=0, window_size=((T - 4 * 8) / 2 - 3 * 8) / 2 - 1 * 8)),
("layer10", CoSpatioTemporalBlock(256, 256, A, CoGraphConv=CoGcnUnitAttention, padding=0, window_size=((T - 4 * 8) / 2 - 3 * 8) / 2 - 2 * 8)),
]))
# fmt: on
# Other layers defined in CoModelBase.on_init_end
if __name__ == "__main__": # pragma: no cover
ride.Main(CoSTrMod).argparse()