diff --git a/src/torchscript-metadata.json b/src/torchscript-metadata.json index f189d8f374b..1673972d971 100644 --- a/src/torchscript-metadata.json +++ b/src/torchscript-metadata.json @@ -44,6 +44,27 @@ ] } }, + { + "name": "conv2d_relu", + "schema": { + "category": "Layer", + "inputs": [ + { "name": "input" }, + { "name": "packed_weight" } + ], + "attributes": [ + { "name": "stride", "type": "int64[]", "default": 1 }, + { "name": "padding", "type": "int64[]", "default": 0 }, + { "name": "dilation", "type": "int64[]", "default": 1 }, + { "name": "groups", "type": "int64", "default": 1 }, + { "name": "output_scale", "type": "float64" }, + { "name": "output_zero_point", "type": "int64" } + ], + "outputs": [ + { "name": "output" } + ] + } + }, { "name": "max_pool2d", "schema": { diff --git a/src/torchscript.js b/src/torchscript.js index a62697fd7b2..f948541c987 100644 --- a/src/torchscript.js +++ b/src/torchscript.js @@ -1181,6 +1181,12 @@ torchscript.Container = class { } throw new torchscript.Error('Unknown expression type.'); }); + this._functionTable.set('torch.jit._pickle.build_boollist', function(data) { + return data; + }); + this._functionTable.set('torch.jit._pickle.build_doublelist', function(data) { + return data; + }); this._functionTable.set('torch.jit._pickle.build_intlist', function(data) { return data; }); @@ -1852,8 +1858,7 @@ torchscript.GraphContext = class { _nodeExpression(expression, target) { if (expression.type == 'call' && (target.type == 'id' || target.type == 'tuple')) { let name = torchscript.Utility.target(expression.target); - let namespace = 'torch.'; - if (name.startsWith(namespace)) { + if (name.startsWith('torch.') || name.startsWith('ops.quantized')) { let inputs = []; let outputs = []; let args = expression.arguments; @@ -1956,12 +1961,30 @@ torchscript.GraphContext = class { let attributes = []; while (args.length > 0) { let attributeExpression = args[0]; - if (attributeExpression.type == 'list') { - for (let i = 0; i < attributeExpression.value.length; i++) { - attributeExpression.value[i] = this._attributeExpression(attributeExpression.value[i]); + if (this._isCall(attributeExpression, 'int', [ {} ]) || + this._isCall(attributeExpression, 'float', [ {} ])) { + const tensor = this._evaluateExpression(attributeExpression.arguments[0]); + if (tensor && tensor.size && tensor.size.length === 1 && tensor.size[0] === 1 && + tensor.storage && tensor.storage.data) { + const dataView = new DataView(tensor.storage.data.buffer, tensor.storage.byteOffset, tensor.storage.byteLength); + switch (tensor.dataType) { + case 'float32': { + attributes.push(dataView.getFloat32(0, true)); + break; + } + case 'int32': { + attributes.push(dataView.getInt32(0, true)); + break; + } + } + args.shift(); + continue; } } let intExpression = this._attributeExpression(attributeExpression); + if (intExpression.type == 'list' && intExpression.value.every((item) => item.type === 'number')) { + intExpression = intExpression.value.map((item) => item.value); + } if (intExpression) { attributeExpression = intExpression; } @@ -1977,7 +2000,7 @@ torchscript.GraphContext = class { } } this._nodes.push({ - name: name.substring(namespace.length), + name: name.split('.').pop(), attributes: attributes, inputs: inputs, outputs: outputs @@ -2019,7 +2042,7 @@ torchscript.GraphContext = class { _attributeExpression(expression) { if (expression.type == 'id') { if (this._state[expression.value]) { - return this._state[expression.value]; + return this._evaluateExpression(this._state[expression.value]); } } return this._evaluateExpression(expression); @@ -2074,6 +2097,11 @@ torchscript.GraphContext = class { this._state[target.value] = expression; return true; } + // _0 = torch.len(...) + if (this._isCall(expression, 'torch.len', [ {} ])) { + this._state[target.value] = expression; + return true; + } // _output_size = torch.list_with_default([7, 7], torch.size(x0)) if (this._isCall(expression, 'torch.list_with_default', [ {}, {} ])) { this._state[target.value] = expression; @@ -2108,8 +2136,10 @@ torchscript.GraphContext = class { return true; } const valueExpression = this._evaluateExpression(expression); - if (valueExpression.type === 'number' || this._isBooleanLiteral(valueExpression)) { - this._state[target.value] = expression; + if (valueExpression.type === 'number' || + this._isBooleanLiteral(valueExpression) || + (valueExpression.type === 'list' && valueExpression.value.every((item) => item.type == 'number'))) { + this._state[target.value] = valueExpression; return true; } // _aux = None @@ -2143,14 +2173,19 @@ torchscript.GraphContext = class { } } } - /* - if (target.type === 'tuple' && target.value.every((item) => item.type === 'id')) { + if (target.type === 'tuple' && + target.value.every((item) => item.type === 'id')) { // _30, _31, = _24 if (expression.type === 'id' && this._state[expression.value]) { - debugger; + const valueExpression = this._state[expression.value]; + if (valueExpression.type === 'list' && target.value.length === valueExpression.value.length) { + for (let i = 0; i < target.value.length; i++) { + this._state[target.value[i].value] = valueExpression.value[i]; + } + return true; + } } } - */ } return false; } @@ -2355,11 +2390,19 @@ torchscript.GraphContext = class { if (typeof value === 'number') { return { type: 'number', value: value }; } + if (Array.isArray(value) && value.every((item) => typeof item === 'number')) { + const array = value; + return { type: 'list', value: array.map((item) => { return { type: 'number', value: item }; }) }; + } if (torchscript.Utility.isTensor(value)) { return value; } } } + if (expression.type === 'list') { + const value = expression.value.map((item) => this._evaluateExpression(item)); + return { type: 'list', value: value }; + } // int(x) if (this._isCall(expression, 'int', [ {} ])) { return this._evaluateExpression(expression.arguments[0]); diff --git a/test/models.json b/test/models.json index 4f0f8af6546..60608f63d6d 100644 --- a/test/models.json +++ b/test/models.json @@ -5143,75 +5143,74 @@ }, { "type": "torchscript", - "target": "mnist_linear_torchscript_1.pt", - "source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_1.pt]", + "target": "inception_v3.pt", + "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", - "link": "https://github.com/lutzroeder/netron/issues/281" + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "torchscript", - "target": "mnist_linear_torchscript_2.pt", - "source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_2.pt]", + "target": "inception_v3_traced.pt", + "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", - "link": "https://github.com/lutzroeder/netron/issues/281" + "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "torchscript", - "target": "mobilenet_quantized_scripted_925.pt", - "source": "https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/assets/mobilenet_quantized_scripted_925.pt?raw=true", - "error": "Unknown statement '{\"location\":\" at 374:8\",\"type\":\"=\",\"target\":{\"type\":\"id\",\"value\":\"_24\",\"keyword\":false},\"expression\":{\"location\":\" at 374:14\",\"type\":\".\",\"target\":{\"type\":\"id\",\"value\":\"_21\",\"keyword\":false},\"member\":{\"type\":\"id\",\"value\":\"stride\",\"keyword\":false}}}' in 'mobilenet_quantized_scripted_925.pt'.", + "target": "junction_mlp_vehicle_model.pt", + "source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/junction_mlp_vehicle_model.pt?raw=true", "format": "TorchScript v1", - "link": "https://github.com/pytorch/android-demo-app" + "link": "https://github.com/ApolloAuto/apollo" }, { "type": "torchscript", - "target": "model-reddit16-f140225004_2.pt1", - "source": "https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/assets/model-reddit16-f140225004_2.pt1?raw=true", - "error": "Unknown statement '{\"location\":\" at 132:17\",\"type\":\"=\",\"target\":{\"location\":\" at 132:11\",\"type\":\"tuple\",\"value\":[{\"type\":\"id\",\"value\":\"_32\",\"keyword\":false},{\"type\":\"id\",\"value\":\"_33\",\"keyword\":false}]},\"expression\":{\"type\":\"id\",\"value\":\"hx\",\"keyword\":false}}' in 'model-reddit16-f140225004_2.pt1'.", + "target": "lane_scanning_vehicle_model.pt", + "source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/lane_scanning_vehicle_model.pt?raw=true", "format": "TorchScript v1", - "link": "https://github.com/pytorch/android-demo-app" + "link": "https://github.com/ApolloAuto/apollo" }, { "type": "torchscript", - "target": "inception_v3.pt", + "target": "mobilenet_v2.pt", "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "torchscript", - "target": "inception_v3_traced.pt", - "script": [ "${root}/tools/pytorch", "sync install zoo" ], + "target": "mnist_linear_torchscript_1.pt", + "source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_1.pt]", "format": "TorchScript v1", - "link": "https://pytorch.org/docs/stable/torchvision/models.html" + "link": "https://github.com/lutzroeder/netron/issues/281" }, { "type": "torchscript", - "target": "junction_mlp_vehicle_model.pt", - "source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/junction_mlp_vehicle_model.pt?raw=true", + "target": "mnist_linear_torchscript_2.pt", + "source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_2.pt]", "format": "TorchScript v1", - "link": "https://github.com/ApolloAuto/apollo" + "link": "https://github.com/lutzroeder/netron/issues/281" }, { "type": "torchscript", - "target": "lane_scanning_vehicle_model.pt", - "source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/lane_scanning_vehicle_model.pt?raw=true", + "target": "mobilenet_quantized_scripted_925.pt", + "source": "https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/assets/mobilenet_quantized_scripted_925.pt?raw=true", "format": "TorchScript v1", - "link": "https://github.com/ApolloAuto/apollo" + "link": "https://github.com/pytorch/android-demo-app" }, { "type": "torchscript", - "target": "mobilenet_v2.pt", + "target": "mobilenet_v2_traced.pt", "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { "type": "torchscript", - "target": "mobilenet_v2_traced.pt", - "script": [ "${root}/tools/pytorch", "sync install zoo" ], + "target": "model-reddit16-f140225004_2.pt1", + "source": "https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/assets/model-reddit16-f140225004_2.pt1?raw=true", + "error": "Unknown statement '{\"location\":\" at 132:17\",\"type\":\"=\",\"target\":{\"location\":\" at 132:11\",\"type\":\"tuple\",\"value\":[{\"type\":\"id\",\"value\":\"_32\",\"keyword\":false},{\"type\":\"id\",\"value\":\"_33\",\"keyword\":false}]},\"expression\":{\"type\":\"id\",\"value\":\"hx\",\"keyword\":false}}' in 'model-reddit16-f140225004_2.pt1'.", "format": "TorchScript v1", - "link": "https://pytorch.org/docs/stable/torchvision/models.html" + "link": "https://github.com/pytorch/android-demo-app" }, { "type": "torchscript",