diff --git a/test/backend.py b/test/backend.py index 3537a17b30..d9c0ea9bcc 100755 --- a/test/backend.py +++ b/test/backend.py @@ -19,7 +19,7 @@ def _test_onnx(): model = onnx.load(file) netron.serve(None, model) -def _test_onnx_list(): +def _test_onnx_iterate(): folder = os.path.join(test_data_dir, 'onnx') for item in os.listdir(folder): file = os.path.join(folder, item) @@ -32,29 +32,67 @@ def _test_onnx_list(): address = netron.serve(file, model, verbosity='quiet') netron.stop(address) -def _test_torchscript(): +def _test_torchscript_transformer(): + torch = __import__('torch') + model = torch.nn.Transformer(nhead=16, num_encoder_layers=12) + trace = torch.jit.trace(model, (torch.rand(10, 32, 512), torch.rand(20, 32, 512))) + torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access + netron.serve('transformer', trace) + +def _test_torchscript_resnet34(): torch = __import__('torch') torchvision = __import__('torchvision') + model = torchvision.models.resnet34() # model = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT) # model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.DEFAULT) - model = torchvision.models.resnet34() state_dict = torch.load(os.path.join(test_data_dir, 'pytorch', 'resnet34-333f7ec4.pth')) model.load_state_dict(state_dict) - args = torch.zeros([1, 3, 224, 224]) - trace = torch.jit.trace(model, args, strict=True) - # graph, _ = torch.jit._get_trace_graph(model, args) # pylint: disable=protected-access - # torch.onnx._optimize_trace(graph, torch.onnx.OperatorExportTypes.ONNX) - # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'fasterrcnn_resnet50_fpn.pt')) - # torch.backends.quantized.engine = 'qnnpack' - # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'd2go.pt')) - # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'mobilenetv2-quant_full-nnapi.pt')) - # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'inception_v3_traced.pt')) - # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'netron_issue_920.pt')) - # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'bert-base-uncased.pt')) - # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'UNet.pt')) - torch._C._jit_pass_inline(trace.graph) + trace = torch.jit.trace(model, torch.zeros([1, 3, 224, 224]), strict=True) + torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access netron.serve('resnet34', trace) +def _test_torchscript_quantized(): + torch = __import__('torch') + __import__('torchvision') + torch.backends.quantized.engine = 'qnnpack' + trace = torch.jit.load(os.path.join(test_data_dir, 'pytorch', 'd2go.pt')) + torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access + netron.serve('d2go', trace) + +def _test_torchscript_inception_v3(): + torch = __import__('torch') + trace = torch.jit.load(os.path.join(test_data_dir, 'pytorch', 'inception_v3_traced.pt')) + torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access + netron.serve('inception_v3', trace) + +def _test_torchscript_scalar(): + torch = __import__('torch') + trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'netron_issue_920.pt')) + # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'UNet.pt')) + torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access + netron.serve('inception_v3', trace) + +def _test_torchscript_tuple(): + torch = __import__('torch') + __import__('torchvision') + trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'fasterrcnn_resnet50_fpn.pt')) + torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access + netron.serve('inception_v3', trace) + +def _test_torchscript_nnapi(): + torch = __import__('torch') + trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'mobilenetv2-quant_full-nnapi.pt')) + torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access + netron.serve('inception_v3', trace) + # _test_onnx() -_test_torchscript() -# _test_onnx_list() +# _test_onnx_iterate() + +# _test_torchscript() +# _test_torchscript_quantized() +# _test_torchscript_resnet34() +# _test_torchscript_inception_v3() +# _test_torchscript_scalar() +# _test_torchscript_tuple() +# _test_torchscript_nnapi() +_test_torchscript_transformer()