-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_add.py
34 lines (26 loc) · 882 Bytes
/
test_add.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import torch.jit
torch.ops.load_library("build/libmyadd.dylib")
class MyAddModule(nn.Module):
def forward(self, input):
return torch.ops.my_ops.my_add(input)
# Test the op in different use cases
input = torch.ones(1, 10)
print("op:", torch.ops.my_ops.my_add(input))
model = MyAddModule()
print("model:", model(input))
model = torch.nn.Sequential(
torch.nn.Linear(10, 10),
MyAddModule()
)
print("sequential:", model(input))
# Now convert to ONNX
def sym_add(g, input):
return g.op("my_ops::my_add", input).setType(input.type())
torch.onnx.register_custom_op_symbolic("my_ops::my_add", sym_add, 1)
torch.onnx.export(model, input, 'build/model.onnx',
input_names=["input"], output_names=["output"])
# And TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save('build/model.torchscript')