Skip to content

Commit

Permalink
Update to TorchScript 1.3 prototype (lutzroeder#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder authored and Tee Jung committed Nov 7, 2019
1 parent c2fe7a1 commit 561b174
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 53 deletions.
21 changes: 21 additions & 0 deletions src/torchscript-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,27 @@
]
}
},
{
"name": "conv2d_relu",
"schema": {
"category": "Layer",
"inputs": [
{ "name": "input" },
{ "name": "packed_weight" }
],
"attributes": [
{ "name": "stride", "type": "int64[]", "default": 1 },
{ "name": "padding", "type": "int64[]", "default": 0 },
{ "name": "dilation", "type": "int64[]", "default": 1 },
{ "name": "groups", "type": "int64", "default": 1 },
{ "name": "output_scale", "type": "float64" },
{ "name": "output_zero_point", "type": "int64" }
],
"outputs": [
{ "name": "output" }
]
}
},
{
"name": "max_pool2d",
"schema": {
Expand Down
118 changes: 93 additions & 25 deletions src/torchscript.js
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,12 @@ torchscript.Container = class {
}
throw new torchscript.Error('Unknown expression type.');
});
this._functionTable.set('torch.jit._pickle.build_boollist', function(data) {
return data;
});
this._functionTable.set('torch.jit._pickle.build_doublelist', function(data) {
return data;
});
this._functionTable.set('torch.jit._pickle.build_intlist', function(data) {
return data;
});
Expand Down Expand Up @@ -1852,8 +1858,7 @@ torchscript.GraphContext = class {
_nodeExpression(expression, target) {
if (expression.type == 'call' && (target.type == 'id' || target.type == 'tuple')) {
let name = torchscript.Utility.target(expression.target);
let namespace = 'torch.';
if (name.startsWith(namespace)) {
if (name.startsWith('torch.') || name.startsWith('ops.quantized')) {
let inputs = [];
let outputs = [];
let args = expression.arguments;
Expand Down Expand Up @@ -1956,12 +1961,30 @@ torchscript.GraphContext = class {
let attributes = [];
while (args.length > 0) {
let attributeExpression = args[0];
if (attributeExpression.type == 'list') {
for (let i = 0; i < attributeExpression.value.length; i++) {
attributeExpression.value[i] = this._attributeExpression(attributeExpression.value[i]);
if (this._isCall(attributeExpression, 'int', [ {} ]) ||
this._isCall(attributeExpression, 'float', [ {} ])) {
const tensor = this._evaluateExpression(attributeExpression.arguments[0]);
if (tensor && tensor.size && tensor.size.length === 1 && tensor.size[0] === 1 &&
tensor.storage && tensor.storage.data) {
const dataView = new DataView(tensor.storage.data.buffer, tensor.storage.byteOffset, tensor.storage.byteLength);
switch (tensor.dataType) {
case 'float32': {
attributes.push(dataView.getFloat32(0, true));
break;
}
case 'int32': {
attributes.push(dataView.getInt32(0, true));
break;
}
}
args.shift();
continue;
}
}
let intExpression = this._attributeExpression(attributeExpression);
if (intExpression.type == 'list' && intExpression.value.every((item) => item.type === 'number')) {
intExpression = intExpression.value.map((item) => item.value);
}
if (intExpression) {
attributeExpression = intExpression;
}
Expand All @@ -1977,7 +2000,7 @@ torchscript.GraphContext = class {
}
}
this._nodes.push({
name: name.substring(namespace.length),
name: name.split('.').pop(),
attributes: attributes,
inputs: inputs,
outputs: outputs
Expand Down Expand Up @@ -2019,7 +2042,7 @@ torchscript.GraphContext = class {
_attributeExpression(expression) {
if (expression.type == 'id') {
if (this._state[expression.value]) {
return this._state[expression.value];
return this._evaluateExpression(this._state[expression.value]);
}
}
return this._evaluateExpression(expression);
Expand All @@ -2045,6 +2068,7 @@ torchscript.GraphContext = class {
}
}
}

// _stride_3 = torch._unwrap_optional(_3)
// _stride_3 = ops.prim.unchecked_unwrap_optional(_127)
if (this._isCall(expression, 'torch._unwrap_optional', [ {} ]) ||
Expand Down Expand Up @@ -2074,6 +2098,11 @@ torchscript.GraphContext = class {
this._state[target.value] = expression;
return true;
}
// _0 = torch.len(...)
if (this._isCall(expression, 'torch.len', [ {} ])) {
this._state[target.value] = expression;
return true;
}
// _output_size = torch.list_with_default([7, 7], torch.size(x0))
if (this._isCall(expression, 'torch.list_with_default', [ {}, {} ])) {
this._state[target.value] = expression;
Expand Down Expand Up @@ -2102,20 +2131,20 @@ torchscript.GraphContext = class {
this._state[target.value] = expression;
return true;
}
// _6 = [torch.mul(self.lstm_depth, 2), torch.size(token_emb, 0), self.lstm_width]
if (expression.type === '[]') {
this._state[target.value] = expression;
return true;
}
const valueExpression = this._evaluateExpression(expression);
if (valueExpression.type === 'number' || this._isBooleanLiteral(valueExpression)) {
this._state[target.value] = expression;
if (valueExpression.type === 'number' ||
this._isBooleanLiteral(valueExpression) ||
valueExpression.type === 'tuple' ||
(valueExpression.type === 'list' && valueExpression.value.every((item) => item.type == 'number'))) {
this._state[target.value] = valueExpression;
return true;
}
// _aux = None
if (expression.type === 'id' && expression.value === 'None') {
this._state[target.value] = expression;
return true;
if (expression.type === 'id') {
// _aux = None
if (expression.value === 'None') {
this._state[target.value] = expression;
return true;
}
}
// _0 = <boolean expression>
const booleanExpression = this._evaluateBooleanExpression(expression);
Expand Down Expand Up @@ -2143,14 +2172,20 @@ torchscript.GraphContext = class {
}
}
}
/*
if (target.type === 'tuple' && target.value.every((item) => item.type === 'id')) {
if (target.type === 'tuple' &&
target.value.every((item) => item.type === 'id')) {
// _30, _31, = _24
if (expression.type === 'id' && this._state[expression.value]) {
debugger;
const valueExpression = this._state[expression.value];
if ((valueExpression.type === 'list' || valueExpression.type === 'tuple') &&
target.value.length === valueExpression.value.length) {
for (let i = 0; i < target.value.length; i++) {
this._state[target.value[i].value] = valueExpression.value[i];
}
return true;
}
}
}
*/
}
return false;
}
Expand Down Expand Up @@ -2281,6 +2316,18 @@ torchscript.GraphContext = class {
this._state[target.value] = expression;
return true;
}
// output = (result)[0]
if (expression.type === '[]' &&
expression.target.type === 'id' &&
expression.arguments.value.length === 1 &&
expression.arguments.value[0].type === 'number') {
const arrayExpression = this._state[expression.target.value];
if (arrayExpression.type === 'tuple') {
const index = Number(expression.arguments.value[0].value);
this._state[target.value] = arrayExpression.value[index];
return true;
}
}
}
// _4, _5 = False, _3
if (statement.type === '=' &&
Expand All @@ -2290,9 +2337,22 @@ torchscript.GraphContext = class {
for (let i = 0; i < statement.target.value.length; i++) {
const target = statement.target.value[i];
const expression = statement.expression.value[i];
if (target.type == 'id' && expression.type == 'id') {
this._state[target.value] = expression;
continue;
if (target.type == 'id') {
if (this._isBooleanLiteral(expression)) {
this._state[target.value] = expression;
continue;
}
if (expression.type === 'id') {
const tensorExpression = this._state[expression.value];
if (torchscript.Utility.isTensor(tensorExpression)) {
this._state[target.value] = tensorExpression;
continue;
}
if (tensorExpression.type === 'tuple' && tensorExpression.value.every((item) => item.type === 'id' && (item.value === 'zeros'|| item.value === 'empty'))) {
this._state[target.value] = tensorExpression;
continue;
}
}
}
if (this._argumentExpression(expression, target)) {
continue;
Expand Down Expand Up @@ -2355,11 +2415,19 @@ torchscript.GraphContext = class {
if (typeof value === 'number') {
return { type: 'number', value: value };
}
if (Array.isArray(value) && value.every((item) => typeof item === 'number')) {
const array = value;
return { type: 'list', value: array.map((item) => { return { type: 'number', value: item }; }) };
}
if (torchscript.Utility.isTensor(value)) {
return value;
}
}
}
if (expression.type === 'list') {
const value = expression.value.map((item) => this._evaluateExpression(item));
return { type: 'list', value: value };
}
// int(x)
if (this._isCall(expression, 'int', [ {} ])) {
return this._evaluateExpression(expression.arguments[0]);
Expand Down
55 changes: 27 additions & 28 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5143,75 +5143,74 @@
},
{
"type": "torchscript",
"target": "mnist_linear_torchscript_1.pt",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_1.pt]",
"target": "inception_v3.pt",
"script": [ "${root}/tools/pytorch", "sync install zoo" ],
"format": "TorchScript v1",
"link": "https://github.com/lutzroeder/netron/issues/281"
"link": "https://pytorch.org/docs/stable/torchvision/models.html"
},
{
"type": "torchscript",
"target": "mnist_linear_torchscript_2.pt",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_2.pt]",
"target": "inception_v3_traced.pt",
"script": [ "${root}/tools/pytorch", "sync install zoo" ],
"format": "TorchScript v1",
"link": "https://github.com/lutzroeder/netron/issues/281"
"link": "https://pytorch.org/docs/stable/torchvision/models.html"
},
{
"type": "torchscript",
"target": "mobilenet_quantized_scripted_925.pt",
"source": "https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/assets/mobilenet_quantized_scripted_925.pt?raw=true",
"error": "Unknown statement '{\"location\":\" at 374:8\",\"type\":\"=\",\"target\":{\"type\":\"id\",\"value\":\"_24\",\"keyword\":false},\"expression\":{\"location\":\" at 374:14\",\"type\":\".\",\"target\":{\"type\":\"id\",\"value\":\"_21\",\"keyword\":false},\"member\":{\"type\":\"id\",\"value\":\"stride\",\"keyword\":false}}}' in 'mobilenet_quantized_scripted_925.pt'.",
"target": "junction_mlp_vehicle_model.pt",
"source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/junction_mlp_vehicle_model.pt?raw=true",
"format": "TorchScript v1",
"link": "https://github.com/pytorch/android-demo-app"
"link": "https://github.com/ApolloAuto/apollo"
},
{
"type": "torchscript",
"target": "model-reddit16-f140225004_2.pt1",
"source": "https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/assets/model-reddit16-f140225004_2.pt1?raw=true",
"error": "Unknown statement '{\"location\":\" at 132:17\",\"type\":\"=\",\"target\":{\"location\":\" at 132:11\",\"type\":\"tuple\",\"value\":[{\"type\":\"id\",\"value\":\"_32\",\"keyword\":false},{\"type\":\"id\",\"value\":\"_33\",\"keyword\":false}]},\"expression\":{\"type\":\"id\",\"value\":\"hx\",\"keyword\":false}}' in 'model-reddit16-f140225004_2.pt1'.",
"target": "lane_scanning_vehicle_model.pt",
"source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/lane_scanning_vehicle_model.pt?raw=true",
"format": "TorchScript v1",
"link": "https://github.com/pytorch/android-demo-app"
"link": "https://github.com/ApolloAuto/apollo"
},
{
"type": "torchscript",
"target": "inception_v3.pt",
"target": "mobilenet_v2.pt",
"script": [ "${root}/tools/pytorch", "sync install zoo" ],
"format": "TorchScript v1",
"link": "https://pytorch.org/docs/stable/torchvision/models.html"
},
{
"type": "torchscript",
"target": "inception_v3_traced.pt",
"script": [ "${root}/tools/pytorch", "sync install zoo" ],
"target": "mnist_linear_torchscript_1.pt",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_1.pt]",
"format": "TorchScript v1",
"link": "https://pytorch.org/docs/stable/torchvision/models.html"
"link": "https://github.com/lutzroeder/netron/issues/281"
},
{
"type": "torchscript",
"target": "junction_mlp_vehicle_model.pt",
"source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/junction_mlp_vehicle_model.pt?raw=true",
"target": "mnist_linear_torchscript_2.pt",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_2.pt]",
"format": "TorchScript v1",
"link": "https://github.com/ApolloAuto/apollo"
"link": "https://github.com/lutzroeder/netron/issues/281"
},
{
"type": "torchscript",
"target": "lane_scanning_vehicle_model.pt",
"source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/lane_scanning_vehicle_model.pt?raw=true",
"target": "mobilenet_quantized_scripted_925.pt",
"source": "https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/assets/mobilenet_quantized_scripted_925.pt?raw=true",
"format": "TorchScript v1",
"link": "https://github.com/ApolloAuto/apollo"
"link": "https://github.com/pytorch/android-demo-app"
},
{
"type": "torchscript",
"target": "mobilenet_v2.pt",
"target": "mobilenet_v2_traced.pt",
"script": [ "${root}/tools/pytorch", "sync install zoo" ],
"format": "TorchScript v1",
"link": "https://pytorch.org/docs/stable/torchvision/models.html"
},
{
"type": "torchscript",
"target": "mobilenet_v2_traced.pt",
"script": [ "${root}/tools/pytorch", "sync install zoo" ],
"target": "model-reddit16-f140225004_2.pt1",
"source": "https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/assets/model-reddit16-f140225004_2.pt1?raw=true",
"error": "Unknown function argument in 'model-reddit16-f140225004_2.pt1'.",
"format": "TorchScript v1",
"link": "https://pytorch.org/docs/stable/torchvision/models.html"
"link": "https://github.com/pytorch/android-demo-app"
},
{
"type": "torchscript",
Expand Down

0 comments on commit 561b174

Please sign in to comment.