From b83cad31b72c962886f73e66e303d8f5fab3e951 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 3 Feb 2022 06:42:40 -0800 Subject: [PATCH] Add TorchScript test file (#842) --- source/python.js | 7 +++++++ source/pytorch.js | 5 +++++ test/models.json | 10 ++++++++++ 3 files changed, 22 insertions(+) diff --git a/source/python.js b/source/python.js index ecd7a75c2f..acb1419e2b 100644 --- a/source/python.js +++ b/source/python.js @@ -3022,7 +3022,9 @@ python.Execution = class { return; } else if (target.type === 'tuple') { + context.target.push(target.value); const value = this.expression(expression.expression, context); + context.target.pop(); if (target.value.every((item) => item.type === 'id')) { if (target.value.length < value.length) { throw new python.Error('ValueError: too many values to unpack (expected ' + target.value.length + ', actual ' + value.length + ').'); @@ -3287,6 +3289,11 @@ python.Execution.Context = class { } return undefined; } + + get target() { + this._target = this._target || []; + return this._target; + } }; python.Utility = class { diff --git a/source/pytorch.js b/source/pytorch.js index 0247191068..547a4c7a4e 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -3146,6 +3146,11 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { case 'torch.unbind': count = args[0].__tuple__ || count; break; + case 'torch.split': + if (context.target.length > 0) { + count = context.target[context.target.length - 1].length; + } + break; } const tensors = []; const outputs = []; diff --git a/test/models.json b/test/models.json index 6870f1a75e..6230028cc7 100644 --- a/test/models.json +++ b/test/models.json @@ -5057,6 +5057,16 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/827" }, + { + "type": "pytorch", + "target": "yolo4_tiny.pt", + "source": "https://github.com/lutzroeder/netron/files/7995416/yolo4_tiny.pt.zip[yolo4_tiny.pt]", + "format": "TorchScript v1.5", + "link": "https://github.com/lutzroeder/netron/issues/842" + }, + + + { "type": "rknn", "target": "autopilot.rknn",