From 96a5eeef45b8fb78fd2eae2e32ee7fa40389e0ca Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 8 Dec 2021 12:38:02 -0500 Subject: [PATCH] Add TorchScript test file (#842) (#851) --- source/python.js | 2 +- source/pytorch.js | 40 ++++++++++++++++++++++++++++++++++++++-- test/models.json | 11 +++++++++-- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/source/python.js b/source/python.js index 2a3d3648541..4d0c5f7c5b1 100644 --- a/source/python.js +++ b/source/python.js @@ -2800,7 +2800,7 @@ python.Execution = class { break; } case 'var': { - context.set(statement.name, undefined); + context.set(statement.name, statement.initializer ? this.expression(statement.initializer, context) : undefined); break; } case '=': { diff --git a/source/pytorch.js b/source/pytorch.js index c6b5e06e2e4..bfb3e538e5d 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1604,6 +1604,10 @@ pytorch.Execution = class extends python.Execution { } throw new pytorch.Error("Unknown 'torch.ge' expression type."); }); + this.registerFunction('torch.is_floating_point', function(tensor) { + const type = tensor.dtype.scalar_type(); + return (type === 5 || type === 6 || type === 7); + }); this.registerFunction('torch.jit._pickle.build_boollist', function(data) { return data; }); @@ -1748,7 +1752,7 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error('Slicing only supports step=1'); } start = Math.max(0, start >= 0 ? start : l.length + start); - end = Math.min(l.length, end); + end = Math.min(l.length, end || Number.MAX_SAFE_INTEGER); return l.slice(start, end); }); this.registerFunction('torch.sub', function(left, right) { @@ -2973,6 +2977,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { } case 'torch.mean': case 'torch.mul': + case 'torch.div': case 'torch.batch_norm': case 'torch.gelu': case 'torch.relu': @@ -2983,7 +2988,8 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { } break; } - case 'torch.add': { + case 'torch.add': + case 'torch.sub': { const input = this.expression(args[0], context); if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { parameter.resize_(input.size()); @@ -2996,6 +3002,13 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { } break; } + case 'torch.select': { + const input = this.expression(args[0], context); + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + parameter.resize_(Array(input.size().length - 1).fill(NaN)); + } + break; + } case 'torch.layer_norm': { const input = this.expression(args[0], context); const normalized_shape = this.expression(args[1], context); @@ -3176,6 +3189,29 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { tensor.resize_(Array(number).fill(NaN)); } } + // val = torch.slice(torch.size(img), -2) + // if torch.eq(torch.len(val), 2): + // pass + // else: + // ops.prim.RaiseException("AssertionError: ") + if (assign.type === '=' && + condition.type === 'if' && + pytorch.Utility.isCall(assign.expression, 'torch.slice', 2) && + pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.size', 1) && + pytorch.Utility.isCall(condition.condition, 'torch.eq', 2) && + pytorch.Utility.isCall(condition.condition.arguments[0], 'torch.len', 1) && + pytorch.Utility.isEqual(condition.condition.arguments[0].arguments[0], assign.target) && + condition.else.statements.length == 1 && + pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) { + const tensor = this.expression(assign.expression.arguments[0].arguments[0], context); + if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) { + const start = this.expression(assign.expression.arguments[1], context); + const value = this.expression(condition.condition.arguments[1], context); + if (Number.isInteger(start) && Number.isInteger(value)) { + tensor.resize_(Array(value - start).fill(NaN)); + } + } + } } if (statements.length > 1) { const size = statements[0]; diff --git a/test/models.json b/test/models.json index 7b2cd363a18..a6986917408 100644 --- a/test/models.json +++ b/test/models.json @@ -4326,8 +4326,8 @@ { "type": "pytorch", "target": "fasterrcnn_resnet50_fpn.pt", - "source": "https://github.com/lutzroeder/netron/files/6040364/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]", - "error": "Unsupported function 'torch.full' in 'fasterrcnn_resnet50_fpn.pt'.", + "source": "https://github.com/lutzroeder/netron/files/7677467/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]", + "error": "Unknown torch.add expression type in 'fasterrcnn_resnet50_fpn.pt'.", "link": "https://github.com/lutzroeder/netron/issues/689" }, { @@ -4859,6 +4859,13 @@ "source": "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth", "format": "PyTorch v0.1.1" }, + { + "type": "pytorch", + "target": "ssdlite320_mobilenet_v3_large.pt", + "source": "https://github.com/lutzroeder/netron/files/7677468/ssdlite320_mobilenet_v3_large.pt.zip[ssdlite320_mobilenet_v3_large.pt]", + "error": "l.slice is not a function in 'ssdlite320_mobilenet_v3_large.pt'.", + "link": "https://github.com/lutzroeder/netron/issues/842" + }, { "type": "pytorch", "target": "superpoint_v1.pth",