forked from rwightman/efficientdet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export.py
74 lines (59 loc) · 2.53 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
import json
import time
import logging
import torch
from pathlib import Path
from effdet import create_model
# turn off JIT optimizations and make layers exportable
import timm.models.layers.config
timm.models.layers.config.set_exportable(True)
timm.models.layers.config.set_no_jit(True)
# turn of JIT optimizations in TIMM module
# -> otherwise the activation functions cannot be exported to ONNX
parser = argparse.ArgumentParser(description='PyTorch ImageNet Exporter')
parser.add_argument('--model', '-m', metavar='MODEL', default='tf_efficientdet_d1',
help='model architecture (default: tf_efficientdet_d1)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-o', '--out_file', default='tmp.onnx',
help='output file path')
def valid_tensor(s):
msg = "Not a valid resolution: '{0}' [CxHxW].".format(s)
try:
q = s.split('x')
if len(q) != 3:
raise argparse.ArgumentTypeError(msg)
return [int(v) for v in q]
except ValueError:
raise argparse.ArgumentTypeError(msg)
parser.add_argument('-r', '--ONNX_resolution', default="3x512x512", type=valid_tensor,
help='ONNX input resolution (default: 3x223x223 [imagenet])')
def export(args):
# creat output dir
Path(args.out_file).parent.mkdir(parents=True, exist_ok=True)
# create model
bench = create_model(
args.model,
bench_task='',
checkpoint_path=args.checkpoint
)
bench.eval()
# make dummy run (really required??)
dummy_input = torch.randn([1]+[3, 512, 512])
bench(dummy_input)
# Export ONNX file
input_names = [ "input:0" ] # this are our standardized in/out nameing (required for runtime)
output_names = [ "output:0", "output:1" ]
print("Exporting ONNX with input resolution of {} to '{}'".format(args.ONNX_resolution,args.out_file))
torch.onnx._export(bench, dummy_input, args.out_file, opset_version=11, keep_initializers_as_inputs=True, output_names=output_names)
#torch.onnx._export(bench, dummy_input, args.out_file, keep_initializers_as_inputs=True, output_names=output_names)
print("Saved to {}".format(args.out_file))
def main():
args = parser.parse_args()
export(args)
if __name__ == '__main__':
main()