From 741f49143e9b805a540cc34c360809ec6f5e8d92 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 30 Nov 2024 16:59:27 -0800 Subject: [PATCH] Add TorchScript test file (#1061) --- source/python.js | 244 ++++++++++++++++++++++-- source/pytorch-metadata.json | 3 + source/pytorch.js | 356 ++++++++++++++--------------------- test/models.json | 18 +- 4 files changed, 392 insertions(+), 229 deletions(-) diff --git a/source/python.js b/source/python.js index 042af0c494..a812bb7817 100644 --- a/source/python.js +++ b/source/python.js @@ -6755,10 +6755,16 @@ python.Execution = class { return this.kind() === rhs.kind(); } isSubtypeOf(rhs) { - if (rhs.kind() === 'OptionalType') { + if (rhs.kind() === 'OptionalType' && this.kind() !== 'OptionalType') { return rhs.getElementType().equals(this); } - return false; + return this.equals(rhs); + } + expect(type) { + if (this instanceof type === false) { + throw new python.Error(`Expected '${type.kind()}' but got '${this.kind()}'.`); + } + return this; } str() { if (this._kind === 'VarType' && this._annotation_str) { @@ -6814,6 +6820,9 @@ python.Execution = class { findAttribute(name) { return this._attributes.get(name); } + getAttribute(name) { + return this._attributes.get(name); + } hasConstant(/* name */) { } methods() { @@ -6848,7 +6857,7 @@ python.Execution = class { super('ListType'); this._elem = elem; } - static get(elem) { + static create(elem) { return new torch.ListType(elem); } getElementType() { @@ -6924,7 +6933,7 @@ python.Execution = class { this._elements = elements; this._schema = schema; } - static get(elements) { + static create(elements) { return new torch.TupleType(elements); } static createNamed(qualified_name, field_names, field_types /*, field_defaults */) { @@ -6981,6 +6990,9 @@ python.Execution = class { equals(rhs) { return this.kind() === rhs.kind(); } + isSubtypeOf(/* rhs */) { + return true; + } str() { return 'NoneType'; } @@ -7144,7 +7156,7 @@ python.Execution = class { this._key = key; this._value = value; } - static get(key, value) { + static create(key, value) { return new torch.DictType(key, value); } getKeyType() { @@ -7372,7 +7384,7 @@ python.Execution = class { L.eat(','); L.whitespace(0); } - real_value = torch.TupleType.get(types); + real_value = torch.TupleType.create(types); fake_value = real_value; } else if (L.value === 'Future') { L.next(); @@ -7415,7 +7427,7 @@ python.Execution = class { const value_type = this.parseType().first; L.expect(')'); alias_info = this.parseAliasAnnotation(); - real_value = torch.DictType.get(key_type, value_type); + real_value = torch.DictType.create(key_type, value_type); fake_value = real_value; } else if (L.eat('Union')) { L.next(); @@ -7454,8 +7466,8 @@ python.Execution = class { while (true) { if (L.kind === '[]') { L.expect('[]'); - fake_value = torch.ListType.get(fake_value); - real_value = torch.ListType.get(real_value); + fake_value = torch.ListType.create(fake_value); + real_value = torch.ListType.create(real_value); let container = this.parseAliasAnnotation(); if (alias_info) { if (!container) { @@ -7524,8 +7536,8 @@ python.Execution = class { L.whitespace(0); let N = null; if (L.eat('[')) { - fake_type = torch.ListType.get(fake_type); - real_type = torch.ListType.get(real_type); + fake_type = torch.ListType.create(fake_type); + real_type = torch.ListType.create(real_type); if (L.kind === '#') { N = Number(L.value); L.next(); @@ -7932,8 +7944,118 @@ python.Execution = class { this._block = new torch.Block(this); this._insert_before = this.return_node(); } - create(kind) { - return new torch.Node(this, kind); + create(kind, ...args) { + let inputs = null; + let num_outputs = 1; + if (args.length === 2 && Array.isArray(args[0]) && typeof args[1] === 'number') { + [inputs, num_outputs] = args; + } else if (args.length === 1) { + if (typeof args[0] === 'number') { + [num_outputs] = args; + } else if (Array.isArray(args[0])) { + [inputs] = args; + } + } + const n = new torch.Node(this, kind); + if (inputs) { + for (const i of inputs) { + n.addInput(i); + } + } + for (let i = 0; i < num_outputs; i++) { + n.addOutput(); + } + return n; + } + createUninitialized(typ) { + const n = this.create('prim::Uninitialized'); + n.output().setType(typ); + return n; + } + createList(contained_type, values) { + const n = this.create('prim::ListConstruct', values); + for (const v of values) { + if (!v.type().isSubtypeOf(contained_type)) { + throw new python.Error('Invalid list item.'); + } + } + n.output().setType(torch.ListType.create(contained_type)); + return n; + } + createListUnpack(v, size) { + const list_type = v.type().expect(torch.ListType); + const elem_type = list_type.getElementType(); + const n = this.create('prim::ListUnpack', [v], 0); + for (let i = 0; i < size; i++) { + n.addOutput().setType(elem_type); + } + return n; + } + createTuple(values, tuple_type) { + if (!tuple_type) { + const types = values.map((v) => v.type()); + tuple_type = torch.TupleType.create(types); + } + const n = this.create('prim::TupleConstruct', values); + n.output().setType(tuple_type); + return n; + } + createTupleUnpack(v) { + const tt = v.type().expect(torch.TupleType); + const n = this.create('prim::TupleUnpack', [v], 0); + for (const element of tt.elements()) { + n.addOutput().setType(element); + } + return n; + } + createTupleIndex(tup, idx, output_type) { + const n = this.create('prim::TupleIndex', [tup, idx]); + n.output().setType(output_type); + return n; + } + createDict(key_type, value_type, keys, values) { + if (keys.length !== values.length) { + throw new python.Error('Invalid dictionary size.'); + } + const n = this.create('prim::DictConstruct'); + const length = keys.length; + for (let i = 0; i < length; i++) { + if (!keys[i].type().isSubtypeOf(key_type)) { + throw new python.Error('Invalid key.'); + } + if (!values[i].type().isSubtypeOf(value_type)) { + throw new python.Error('Invalid value.'); + } + n.addInput(keys[i]); + n.addInput(values[i]); + } + n.output().setType(torch.DictType.create(key_type, value_type)); + return n; + } + createObject(type) { + const node = this.create('prim::CreateObject'); + node.output().setType(type); + return node; + } + createIsInstance(v, types) { + const n = this.create('prim::isinstance', [v], 1); + n.tys_('types', types); + n.output().setType(torch.BoolType.get()); + return n; + } + createSetAttr(obj, field, newValue) { + const n = this.create('prim::SetAttr', [obj, newValue], 0); + n.s_('name', field); + return n; + } + createGetAttr(obj, field) { + const n = this.create('prim::GetAttr', [obj]); + n.s_('name', field); + const classType = obj.type(); + const outputType = classType.getAttribute(field); + n.output().setType(outputType); + n.output().setDebugName(/^[0-9]+$/.test(field) ? `_${field}` : field); + return n; } inputs() { return this._block.inputs(); @@ -7954,7 +8076,69 @@ python.Execution = class { return this._block.addInput(name); } insertNode(node) { - node.insertBefore(this._insert_before); + return node.insertBefore(this._insert_before); + } + insertConstant(val) { + const n = this.create('prim::Constant'); + this.insertNode(n); + let type = null; + if (val === null) { + n.ival_('value', val); + type = torch.NoneType.get(); + } else if (typeof val === 'string') { + n.s_('value', val); + type = torch.StringType.get(); + } else if (Array.isArray(val) && val.every((item) => typeof item === 'string')) { + n.ss_('value', val); + type = torch.ListType.create(torch.StringType.get()); + } else if (typeof val === 'boolean') { + // return value; + n.i_('value', val === true ? 1 : 0); + type = torch.BoolType.get(); + } else if (Number.isInteger(val)) { + n.i_('value', val); + type = torch.IntType.get(); + } else if (typeof val === 'number') { + // return value; + n.f_('value', val); + type = torch.FloatType.get(); + } else { + throw new python.Error(`Unsupported value type '${typeof value}'.`); + } + if (type) { + n.output().setType(type); + } + return n.output(); + } + insertUncheckedCast(v, type) { + const n = this.insertNode(this.create('prim::unchecked_cast', [v])); + n.output().setType(type); + return n.output(); + } + insertToList(v, type) { + let dim = 0; + let ptr = type; + while (ptr instanceof torch.ListType) { + ptr = ptr.getElementType(); + dim += 1; + } + let elem_ty = 0; + if (ptr instanceof torch.IntType) { + elem_ty = 0; + } else if (ptr instanceof torch.FloatType) { + elem_ty = 1; + } else if (ptr instanceof torch.BoolType) { + elem_ty = 2; + } else if (ptr instanceof torch.ComplexType) { + elem_ty = 3; + } else { + throw new python.Error(`Unsupported list type '${type.kind()}'.`); + } + const dim_val = this.insertConstant(dim); + const elem_ty_val = this.insertConstant(elem_ty); + const n = this.insertNode(this.create('prim::tolist', [v, dim_val, elem_ty_val])); + n.output().setType(type); + return n.output(); } insertPoint() { return this._insert_before; @@ -7994,8 +8178,8 @@ python.Execution = class { this.registerType('torch.Block', class { constructor(graph) { this._graph = graph; - this._input = graph.create('prim::Param'); - this._output = graph.create('prim::Return'); + this._input = graph.create('prim::Param', 0); + this._output = graph.create('prim::Return', 0); this._input.next = this._output; this._input.prev = this._output; this._output.next = this._input; @@ -8091,6 +8275,12 @@ python.Execution = class { outputs() { return this._outputs; } + output() { + if (this._outputs.length !== 1) { + throw new python.Error('Node has multiple outputs.'); + } + return this._outputs[0]; + } blocks() { return this._blocks; } @@ -8214,6 +8404,12 @@ python.Execution = class { f(name) { return this._values.get(name)[0]; } + tys_(name, value) { + this._values.set(name, [value, 'tys']); + } + tys(name) { + return this._values.get(name)[0]; + } ival_(name, value) { this._values.set(name, [value, 'ival']); } @@ -8927,7 +9123,23 @@ python.Execution = class { } } } + execution.purge = new Set(); const result = this.data.forward.__call__(args); + const queue = Array.from(execution.purge); + const visited = new Set(); + while (queue.length > 0) { + const node = queue.shift(); + if (visited.has(node)) { + continue; + } + visited.add(node); + if (node.outputs().every((output) => output.uses().length === 0)) { + for (const input of node.inputs()) { + queue.push(input.node()); + } + node.destroy(); + } + } if (Array.isArray(result)) { for (const output of result) { if (isTensor(output)) { diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index 3c18cc047f..de0d60d439 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -6228,6 +6228,9 @@ { "name": "prim::shape(Tensor self) -> int[]" }, + { + "name": "prim::tolist(...) -> ..." + }, { "name": "prim::type(Device self) -> str" }, diff --git a/source/pytorch.js b/source/pytorch.js index 8eab62fe9b..fd11050fd8 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -405,6 +405,7 @@ pytorch.Node = class { case 'i': value = node.i(name); type = 'int64'; break; case 'f': value = node.f(name); type = 'float32'; break; case 'ss': value = node.ss(name); type = 'string[]'; break; + case 'tys': value = node.tys(name).map((ty) => pytorch.Utility.toType(ty)); type = 'type[]'; break; case 'ival': value = node.ival(name); break; default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`); } @@ -1570,7 +1571,7 @@ pytorch.Execution = class extends python.Execution { execution.variable(this.serialized_model_tensor); this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1; const type = new pytorch.nnapi.Graph(this.serialized_model); - const node = execution.graph.create(type); + const node = execution.graph.create(type, 0); execution.graph.insertNode(node); for (const tensor of inputs) { const value = execution.variable(tensor); @@ -1674,41 +1675,6 @@ pytorch.Execution = class extends python.Execution { return this._graph; } - constant(value) { - const torch = this.torch; - const node = this.graph.create('prim::Constant'); - this.graph.insertNode(node); - let type = null; - if (value === null) { - node.ival_('value', value); - type = torch.NoneType.get(); - } else if (typeof value === 'string') { - node.s_('value', value); - type = torch.StringType.get(); - } else if (Array.isArray(value) && value.every((item) => typeof item === 'string')) { - node.ss_('value', value); - type = torch.ListType.get(torch.StringType.get()); - } else if (typeof value === 'boolean') { - // return value; - node.i_('value', value === true ? 1 : 0); - type = torch.BoolType.get(); - } else if (Number.isInteger(value)) { - node.i_('value', value); - type = torch.IntType.get(); - } else if (typeof value === 'number') { - // return value; - node.f_('value', value); - type = torch.FloatType.get(); - } else { - throw new pytorch.Error(`Unsupported value type '${typeof value}'.`); - } - if (type) { - value = node.addOutput(); - value.setType(type); - } - return value; - } - variable(obj, node) { const torch = this.torch; if (this._values.has(obj)) { @@ -1767,7 +1733,7 @@ pytorch.Execution = class extends python.Execution { const value = this.builtins[expr.id]; const entries = Object.entries(value).map(([name, value]) => { if (Array.isArray(value) && value.length > 0 && value.every((item) => typeof item === 'string')) { - value = this.constant(value); + value = this._graph.insertConstant(value); return [name, value]; } return [name, value]; @@ -1815,7 +1781,7 @@ pytorch.Execution = class extends python.Execution { return super.target(expr, context); } - expression(expr, context) { + expression(expr, context, typehint) { if (!this.trace) { return super.expression(expr, context); } @@ -1824,7 +1790,7 @@ pytorch.Execution = class extends python.Execution { switch (expr.type) { case 'Constant': { if (expr.value === true || expr.value === false) { - return this.constant(expr.value); + return this._graph.insertConstant(expr.value); } break; } @@ -1833,7 +1799,7 @@ pytorch.Execution = class extends python.Execution { if (target instanceof ast.Name) { let value = this.expression(expr.value, context); if (typeof value === 'string' || typeof value === 'boolean' || typeof value === 'number') { - value = this.constant(value); + value = this._graph.insertConstant(value); } else if (typeof value !== 'object' && value !== undefined) { throw new pytorch.Error(`Unsupported assignment value type '${typeof value}'.`); } @@ -1848,25 +1814,27 @@ pytorch.Execution = class extends python.Execution { context.target.pop(); if (target.elts.every((item) => item instanceof ast.Name)) { if (value instanceof torch.Value) { - const node = this._graph.create('prim::TupleUnpack'); - node.setSourceRange(expr.location); - this.graph.insertNode(node); - node.addInput(value); - const outputs = []; + let outputs = null; + if (value.type() instanceof torch.TupleType) { + const node = this._graph.createTupleUnpack(value); + node.setSourceRange(expr.location); + this.graph.insertNode(node); + outputs = node.outputs(); + } else if (value.type() instanceof torch.ListType) { + const size = target.elts.length; + const node = this._graph.createListUnpack(value, size); + node.setSourceRange(expr.location); + this.graph.insertNode(node); + outputs = node.outputs(); + } + if (outputs === null) { + throw new pytorch.Error(`Unsupported unpack type '${value.type().kind()}'.`); + } for (let i = 0; i < target.elts.length; i++) { const item = target.elts[i]; - const output = node.addOutput(); - const type = value.type(); - if (type instanceof torch.ListType) { - output.setType(value.type().getElementType()); - } else if (type instanceof torch.TupleType) { - output.setType(type.elements()[i]); - } else { - throw new pytorch.Error(`Unsupported tuple unpack type '${type.kind()}'.`); - } + const output = outputs[i]; output.setDebugName(item.id); context.set(item.id, output); - outputs.push(output); } return outputs; } @@ -1894,8 +1862,10 @@ pytorch.Execution = class extends python.Execution { const func = expr.func; if (func instanceof ast.Name && func.id === 'annotate') { const type = this.type(expr.args[0]); - let value = this.expression(expr.args[1], context); - if (value instanceof torch.Tensor) { + const [, obj] = expr.args; + let value = this.expression(obj, context, type); + if (value instanceof torch.Tensor || + (value instanceof torch.Value && value.type() instanceof torch.TensorType)) { let name = null; if (type instanceof torch.IntType) { name = 'IntImplicit'; @@ -1914,43 +1884,46 @@ pytorch.Execution = class extends python.Execution { const target = new ast.Name('torch'); return this.call(target, name, expr.args.slice(1), context); } - if (value instanceof torch.Value) { - value.setType(type); + if (value instanceof torch.Value && !type.equals(value.type())) { + throw new pytorch.Error('Invalid annotation type hint.'); } if (value === null) { - value = this.constant(value); + value = this._graph.insertConstant(value); value.setType(type); } return value; } if (func instanceof ast.Name && func.id === 'uninitialized') { const type = this.type(expr.args[0]); - const node = this._graph.create('prim::Uninitialized'); + const node = this._graph.createUninitialized(type); node.setSourceRange(expr.location); this.graph.insertNode(node); - const value = node.addOutput(); - value.setType(type); - return value; + return node.output(); } if (func instanceof ast.Name && func.id === 'unchecked_cast') { let value = this.expression(expr.args[1], context); + if (value instanceof torch.Value === false) { // remove + value = this.variable(value); + } const type = this.type(expr.args[0]); - const node = this._graph.create('prim::unchecked_cast'); - this.graph.insertNode(node); - node.addInput(this.variable(value)); - value = node.addOutput(); - value.setType(type); - return value; + return this.graph.insertUncheckedCast(value, type); } if (func instanceof ast.Name && func.id === 'isinstance') { - let value = this.expression(expr.args[1], context); - // const type = this.type(expression.args[0]); - const node = this._graph.create('prim::isinstance'); + const value = this.expression(expr.args[0], context); + let [, types] = expr.args; + if (types instanceof ast.Tuple) { + types = types.elts.map((expr) => this.type(expr)); + } else { + types = [this.type(types)]; + } + const v = this.variable(value); // remove + const node = this._graph.createIsInstance(v, types); this.graph.insertNode(node); - node.addInput(this.variable(value)); - value = node.addOutput(); - value.setType(torch.BoolType.get()); - return value; + return node.output(); + } + if (func.attr === 'tolist' && expr.args.length === 0) { + const target = this.target(func.value, context); + return this.graph.insertToList(target, typehint); } return super.expression(expr, context); } @@ -1965,22 +1938,18 @@ pytorch.Execution = class extends python.Execution { } if (type instanceof torch.ListType) { let index = this.expression(elt, context); - const node = this._graph.create('aten::__getitem__.t'); - this.graph.insertNode(node); - node.addInput(value); if (Number.isInteger(index)) { - index = this.constant(index); + index = this._graph.insertConstant(index); } - node.addInput(index); - const output = node.addOutput(); - output.setType(type.getElementType()); - return output; + const node = this._graph.create('aten::__getitem__.t', [value, index]); + this.graph.insertNode(node); + node.output().setType(type.getElementType()); + return node.output(); } if (type instanceof torch.DictType) { let key = this.expression(elt, context); - const node = this._graph.create('aten::__getitem__.t'); + const node = this._graph.create('aten::__getitem__.t', [value]); this.graph.insertNode(node); - node.addInput(value); if (type.getKeyType() instanceof torch.StringType && typeof key === 'string') { const value = new torch.Value(node); value.value = key; @@ -1991,26 +1960,19 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error(`Unsupported dictionary key type.`); } node.addInput(key); - const output = node.addOutput(); - output.setType(type.getValueType()); - return output; + node.output().setType(type.getValueType()); + return node.output(); } if (type instanceof torch.TupleType) { - const index = this.expression(elt, context); - const node = this._graph.create('prim::TupleIndex'); - this.graph.insertNode(node); - node.addInput(value); - if (index instanceof torch.Value) { - node.addInput(index); - } else if (Number.isInteger(index)) { - const value = this.constant(index); - node.addInput(value); - } else { + let index = this.expression(elt, context); + if (!Number.isInteger(index)) { throw new pytorch.Error(`Unsupported tuple index type.`); } - const output = node.addOutput(); - output.setType(type.elements()[index]); - return output; + const output_type = type.elements()[index]; + index = this._graph.insertConstant(index); + const node = this._graph.createTupleIndex(value, index, output_type); + this.graph.insertNode(node); + return node.output(); } } } @@ -2020,100 +1982,84 @@ pytorch.Execution = class extends python.Execution { const target = this.target(expr.value, context); const attr = expr.attr; if (target instanceof torch.Value && target.type() instanceof torch.ClassType) { - const type = target.type().findAttribute(attr); - const node = this.graph.create('prim::GetAttr'); + const node = this._graph.createGetAttr(target, attr); this.graph.insertNode(node); - node.s_(attr); - node.addInput(target); - const value = node.addOutput(); - value.setType(type); - return value; + return node.output(); } return target[attr]; } case 'List': { const list = expr.elts.map((item) => this.expression(item, context)); if (/* list.length > 0 && */ list.every((item) => item instanceof torch.Value || pytorch.Utility.isTensor(item) || Number.isInteger(item) || typeof item === 'string' || item === null)) { - const node = this._graph.create('prim::ListConstruct'); - this.graph.insertNode(node); - const output = node.addOutput(); + const values = []; + let item_type = null; for (const item of list) { + let value = null; if (item instanceof torch.Value) { - node.addInput(item); - output.setType(torch.ListType.get(item.type())); - } else if (Number.isInteger(item)) { - const value = this.constant(item); - node.addInput(value); - output.setType(torch.ListType.get(torch.IntType.get())); - } else if (typeof item === 'string') { - const value = this.constant(item); - node.addInput(value); - output.setType(torch.ListType.get(torch.StringType.get())); + value = item; + } else if (Number.isInteger(item) || typeof item === 'string' || item === null) { + value = this._graph.insertConstant(item); } else if (pytorch.Utility.isTensor(item)) { - const value = this.variable(item, null); - node.addInput(value); - output.setType(torch.ListType.get(torch.TensorType.get())); + value = this.variable(item, null); } else { - const value = new torch.Value(node); - value.value = item; - node.addInput(value); + throw new pytorch.Error('Unsupported list item type.'); + } + values.push(value); + const type = value.type(); + if (!item_type || item_type.isSubtypeOf(type)) { + item_type = type; } } - return output; + const contained_type = typehint ? typehint.getElementType() : item_type; + const node = this._graph.createList(contained_type, values); + this.graph.insertNode(node); + return node.output(); } break; } case 'Tuple': { const elts = expr.elts.map((expr) => this.expression(expr, context)); - const node = this._graph.create('prim::TupleConstruct'); - node.setSourceRange(expr.location); - this.graph.insertNode(node); - const types = []; - const elements = []; - for (const item of elts) { - if (item instanceof torch.Value) { - node.addInput(item); - types.push(item.type()); - } else if (pytorch.Utility.isTensor(item)) { - const value = this.variable(item, node); - node.addInput(value); - types.push(value.type()); - } else if (item === null || Number.isInteger(item) || typeof item === 'number' || typeof item === 'boolean' || typeof item === 'string') { - const value = this.constant(item); - node.addInput(value); - types.push(value.type()); + const values = []; + for (const elt of elts) { + let value = null; + if (elt instanceof torch.Value) { + value = elt; + } else if (elt === null || Number.isInteger(elt) || typeof elt === 'number' || typeof elt === 'boolean' || typeof elt === 'string') { + value = this._graph.insertConstant(elt); } else { - const value = new torch.Value(node); - value.value = item; - node.addInput(value); - types.push(torch.Type.get()); + throw new pytorch.Error('Unsupported tuple element.'); } - elements.push(item); + values.push(value); } - const value = node.addOutput(); - value.setType(torch.TupleType.get(types)); - return value; + const node = this._graph.createTuple(values); + node.setSourceRange(expr.location); + this.graph.insertNode(node); + return node.output(); } case 'Dict': { - const node = this._graph.create('prim::DictConstruct'); - this.graph.insertNode(node); + const keys = []; + const values = []; let keyType = null; let valueType = null; for (let i = 0; i < expr.keys.length; i++) { const key = this.expression(expr.keys[i], context); const keyValue = this.variable(key, null); - keyType = keyValue.type(); - node.addInput(keyValue); + if (!keyType || keyType.isSubtypeOf(keyValue.type())) { + keyType = keyValue.type(); + } + keys.push(keyValue); const value = this.expression(expr.values[i], context); const valueValue = this.variable(value, null); - valueType = valueValue.type(); - node.addInput(valueValue); - } - const output = node.addOutput(); - if (keyType && valueType) { - output.setType(torch.DictType.get(keyType, valueType)); + if (!valueType || valueType.isSubtypeOf(valueValue.type())) { + valueType = valueValue.type(); + } + values.push(valueValue); } - return output; + const key_type = typehint ? typehint.getKeyType() : keyType; + const value_type = typehint ? typehint.getValueType() : valueType; + const node = this._graph.createDict(key_type, value_type, keys, values); + this.graph.insertNode(node); + return node.output(); } default: { break; @@ -2359,18 +2305,8 @@ pytorch.Execution = class extends python.Execution { } else if (test === false) { statements.splice(i, 1, ...condition.orelse.statements); } - const count = new Map(); for (const node of state) { - if (count.has(node)) { - count.set(node, count.get(node) + 1); - } else { - count.set(node, 1); - } - } - if (count.size > 0 && Array.from(count).every(([node, count]) => node.outputs().length === 1 && node.outputs()[0].uses().length <= count)) { - for (const node of state) { - node.destroy(); - } + this.purge.add(node); } if (test === true || test === false) { continue; @@ -2429,7 +2365,7 @@ pytorch.Execution = class extends python.Execution { return value.type(); }; this.variables(condition, condition); - const node = this._graph.create('prim::If'); + const node = this._graph.create('prim::If', 0); node.setSourceRange(stmt.location); this.graph.insertNode(node); node.addInput(test); @@ -2490,7 +2426,7 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error("Unsupported condition."); } if (stmt instanceof ast.For) { - const node = this._graph.create('prim::Loop'); + const node = this._graph.create('prim::Loop', 0); node.setSourceRange(stmt.location); this.graph.insertNode(node); const loop = stmt; @@ -2498,7 +2434,9 @@ pytorch.Execution = class extends python.Execution { const range = this.expression(loop.iter, context); const variable = loop.target; for (const current of range) { - this.statement({ type: '=', target: variable, expression: { type: 'number', value: current } }, context); + const constant = new ast.Constant(current); + const stmt = new ast.Assign(variable, constant); + this.statement(stmt, context); const value = this.block(loop.body.statements, context); if (value !== undefined) { return value; @@ -2509,7 +2447,7 @@ pytorch.Execution = class extends python.Execution { } } if (stmt instanceof ast.While) { - const node = this._graph.create('prim::Loop'); + const node = this._graph.create('prim::Loop', 0); node.setSourceRange(stmt.location); this.graph.insertNode(node); const test = this.expression(stmt.test, context); @@ -2571,7 +2509,7 @@ pytorch.Execution = class extends python.Execution { switch (expr.value.id) { case 'List': { const type = this.type(elts[0]); - return torch.ListType.get(type); + return torch.ListType.create(type); } case 'Optional': { const type = this.type(elts[0]); @@ -2579,12 +2517,12 @@ pytorch.Execution = class extends python.Execution { } case 'Tuple': { const types = elts.map((expr) => this.type(expr)); - return torch.TupleType.get(types); + return torch.TupleType.create(types); } case 'Dict': { const key = this.type(elts[0]); const value = this.type(elts[1]); - return torch.DictType.get(key, value); + return torch.DictType.create(key, value); } case 'Final': { return this.type(elts[0]); @@ -2602,6 +2540,8 @@ pytorch.Execution = class extends python.Execution { case 'float': return torch.FloatType.get(); case 'number': return torch.NumberType.get(); case 'bool': return torch.BoolType.get(); + case 'list': return torch.Type.get('AnyListType'); + case 'tuple': return torch.Type.get('AnyTupleType'); case 'None': return torch.NoneType.get(); case 'NoneType': return torch.NoneType.get(); default: throw new pytorch.Error(`Unsupported type expression '${expr.value}'.`); @@ -2635,12 +2575,10 @@ pytorch.Execution = class extends python.Execution { if (identifier) { const type = this._resolver.resolveType(identifier); if (type) { - const node = this.graph.create('prim::CreateObject'); + const node = this.graph.createObject(type); node.setSourceRange(location); this.graph.insertNode(node); - const value = node.addOutput(); - value.setType(type); - return value; + return node.output(); } } } @@ -2650,7 +2588,7 @@ pytorch.Execution = class extends python.Execution { if (args.length === 0) { return obj; } - const node = this.graph.create('prim::CallMethod'); + const node = this.graph.create('prim::CallMethod', 0); node.setSourceRange(location); this.graph.insertNode(node); node.s_('name', name); @@ -2676,24 +2614,19 @@ pytorch.Execution = class extends python.Execution { const value = this.variable(arg); node.addInput(value); } - return node.addOutput(); + return node.output(); } const prefix = this.identifier(target); if (prefix && prefix !== 'self' && !prefix.startsWith('self.') && prefix.indexOf('.') !== -1) { const identifier = `${prefix}.${name}`; const type = this._resolver.resolveType(identifier); if (type instanceof torch.TupleType) { - const node = this._graph.create('prim::TupleConstruct'); + const evalArgs = args.map((expression) => this.expression(expression, context)); + const values = evalArgs.map((arg) => this.variable(arg)); + const node = this._graph.createTuple(values, type); node.setSourceRange(location); this.graph.insertNode(node); - const evalArgs = args.map((expression) => this.expression(expression, context)); - for (const arg of evalArgs) { - const value = this.variable(arg); - node.addInput(value); - } - const output = node.addOutput(); - output.setType(type); - return output; + return node.output(); } if (type instanceof torch.ClassType) { const node = this.graph.create('prim::CallMethod'); @@ -2704,14 +2637,14 @@ pytorch.Execution = class extends python.Execution { const value = this.variable(arg); node.addInput(value); } - return node.addOutput(); + return node.output(); } } return super.call(target, name, args, context); } const [schema, evalArgs] = overload; const op = schema.overload_name ? `${schema.name}.${schema.overload_name}` : schema.name; - const node = this._graph.create(op); + const node = this._graph.create(op, 0); node.setSourceRange(location); this.graph.insertNode(node); const referencedParameters = []; @@ -2776,9 +2709,8 @@ pytorch.Execution = class extends python.Execution { value.setType(torch.TensorType.get()); list.addInput(value); } - const output = list.addOutput(); - output.setType(torch.ListType.get(torch.TensorType.get())); - input = output; + list.output().setType(torch.ListType.create(torch.TensorType.get())); + input = list.output(); match = true; } } else { @@ -2925,20 +2857,20 @@ pytorch.Execution = class extends python.Execution { if (!type) { throw new pytorch.Error(); } - type = torch.ListType.get(type); + type = torch.ListType.create(type); break; } default: { if (type instanceof torch.DictType) { const keyType = varTypes.map(type.getKeyType()); const valueType = varTypes.map(type.getValueType()); - type = torch.DictType.get(keyType, valueType); + type = torch.DictType.create(keyType, valueType); } else if (type instanceof torch.TupleType && type.elements().length === 2) { const elements = type.elements().map((type) => varTypes.map(type)); - type = torch.ListType.get(torch.TupleType.get(elements)); + type = torch.ListType.create(torch.TupleType.create(elements)); } else if (type instanceof torch.ListType && type.getElementType() instanceof torch.TupleType) { const elements = type.getElementType().elements().map((type) => varTypes.map(type)); - type = torch.ListType.get(torch.TupleType.get(elements)); + type = torch.ListType.create(torch.TupleType.create(elements)); } else { throw new pytorch.Error(`Unsupported return type '${type.str()}'.`); } @@ -3014,7 +2946,7 @@ pytorch.Execution = class extends python.Execution { (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.IntType)) || (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.ListType && obj.type().getElementType().getElementType() instanceof torch.IntType); case 'SymInt[1]': - return this.isType(obj, torch.IntType.get()) || this.isType(obj, torch.ListType.get(torch.IntType.get())); + return this.isType(obj, torch.IntType.get()) || this.isType(obj, torch.ListType.create(torch.IntType.get())); case 'float': { return obj !== null && (typeof obj === 'number' || obj instanceof Number) || (obj instanceof torch.Value && (obj.type() instanceof torch.FloatType || obj.type() instanceof torch.IntType)); } @@ -3129,9 +3061,9 @@ pytorch.Execution = class extends python.Execution { } else if (Number(value) === value) { return torch.FloatType.get(); } else if (Array.isArray(value) && value.every((item) => Number(item) === item && item % 1 === 0)) { - return torch.ListType.get(torch.IntType.get()); + return torch.ListType.create(torch.IntType.get()); } else if (Array.isArray(value) && value.every((item) => Number(item) === item)) { - return torch.ListType.get(torch.FloatType.get()); + return torch.ListType.create(torch.FloatType.get()); } else if (value instanceof torch.Value) { return value.type(); } @@ -3428,6 +3360,8 @@ pytorch.Utility = class { case 'Layout': return 'Layout'; case 'VarType': return type.annotation_str; case 'NoneType': return 'None'; + case 'AnyListType': return 'list'; + case 'AnyTupleType': return 'tuple'; default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`); } } diff --git a/test/models.json b/test/models.json index 1fd7133912..b6d782d73c 100644 --- a/test/models.json +++ b/test/models.json @@ -5547,7 +5547,7 @@ "target": "fasterrcnn_resnet50_fpn.pt", "source": "https://github.com/lutzroeder/netron/files/7677467/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]", "format": "TorchScript v1.7", - "link": "https://github.com/lutzroeder/netron/issues/689" + "link": "https://github.com/lutzroeder/netron/issues/1061" }, { "type": "pytorch", @@ -5698,6 +5698,13 @@ "format": "PyTorch Package v1.9", "link": "https://github.com/lutzroeder/netron/issues/928" }, + { + "type": "pytorch", + "target": "m4-sWE-0.1B.script.pt", + "source": "https://github.com/user-attachments/files/17967188/m4-sWE-0.1B.script.pt.zip[m4-sWE-0.1B.script.pt]", + "format": "TorchScript v1.6", + "link": "https://github.com/lutzroeder/netron/issues/1061" + }, { "type": "pytorch", "target": "mask_depthwise_conv.pt", @@ -5720,6 +5727,13 @@ "assert": "model.graphs[0].nodes[0].inputs.length == 1", "link": "https://github.com/facebookresearch/kill-the-bits/tree/master/src/models/compressed" }, + { + "type": "pytorch", + "target": "mask_rcnn.pt", + "source": "https://github.com/user-attachments/files/17966950/mask_rcnn.pt.zip[mask_rcnn.pt]", + "format": "TorchScript v1.7", + "link": "https://github.com/lutzroeder/netron/issues/1061" + }, { "type": "pytorch", "target": "mcunet-5fps.pkl", @@ -6070,7 +6084,7 @@ { "type": "pytorch", "target": "pyg_model.pt", - "source": "https://github.com/lutzroeder/netron/files/10369483/pyg_model.zip[pyg_model.pt]", + "source": "https://github.com/user-attachments/files/17969647/pyg_model.pt.zip[pyg_model.pt]", "format": "TorchScript v1.7", "link": "https://github.com/lutzroeder/netron/issues/546" },