From f08b346ae80bf8e29535297dba43612aeb4e0bad Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 7 Dec 2024 16:55:56 -0800 Subject: [PATCH] Update pytorch.js (#1061) --- source/python.js | 1 + source/pytorch.js | 94 +++++++++++++++++++++++++---------------------- 2 files changed, 51 insertions(+), 44 deletions(-) diff --git a/source/python.js b/source/python.js index e1665c2f3a..9c5493b7d8 100644 --- a/source/python.js +++ b/source/python.js @@ -8432,6 +8432,7 @@ python.Execution = class { } setSourceRange(r) { this._source_range = r; + return this; } sourceRange() { return this._source_range; diff --git a/source/pytorch.js b/source/pytorch.js index 196d93a600..50e73c1ec2 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1686,7 +1686,7 @@ pytorch.Execution = class extends python.Execution { } else if (obj instanceof torch.Value) { value = obj; } else { - value = new torch.Value(node ? node : this._graph); + value = new torch.Value(node ? node : this.graph); } if (pytorch.Utility.isTensor(obj)) { value.value = obj; @@ -1781,6 +1781,10 @@ pytorch.Execution = class extends python.Execution { return super.target(expr, context); } + create(kind, loc, n_outputs) { + return this._graph.create(kind, n_outputs).setSourceRange(loc); + } + expression(expr, context, typehint) { if (!this.trace) { return super.expression(expr, context); @@ -1818,13 +1822,13 @@ pytorch.Execution = class extends python.Execution { if (value.type() instanceof torch.TupleType) { const node = this._graph.createTupleUnpack(value); node.setSourceRange(expr.location); - this.graph.insertNode(node); + this._graph.insertNode(node); outputs = node.outputs(); } else if (value.type() instanceof torch.ListType) { const size = target.elts.length; const node = this._graph.createListUnpack(value, size); node.setSourceRange(expr.location); - this.graph.insertNode(node); + this._graph.insertNode(node); outputs = node.outputs(); } if (outputs === null) { @@ -1897,7 +1901,7 @@ pytorch.Execution = class extends python.Execution { const type = this.type(expr.args[0]); const node = this._graph.createUninitialized(type); node.setSourceRange(expr.location); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } if (func instanceof ast.Name && func.id === 'unchecked_cast') { @@ -1906,7 +1910,7 @@ pytorch.Execution = class extends python.Execution { value = this.variable(value); } const type = this.type(expr.args[0]); - return this.graph.insertUncheckedCast(value, type); + return this._graph.insertUncheckedCast(value, type); } if (func instanceof ast.Name && func.id === 'isinstance') { const value = this.expression(expr.args[0], context); @@ -1918,12 +1922,12 @@ pytorch.Execution = class extends python.Execution { } const v = this.variable(value); // remove const node = this._graph.createIsInstance(v, types); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } if (func.attr === 'tolist' && expr.args.length === 0) { const target = this.target(func.value, context); - return this.graph.insertToList(target, typehint); + return this._graph.insertToList(target, typehint); } return super.expression(expr, context); } @@ -1942,14 +1946,14 @@ pytorch.Execution = class extends python.Execution { index = this._graph.insertConstant(index); } const node = this._graph.create('aten::__getitem__.t', [value, index]); - this.graph.insertNode(node); + this._graph.insertNode(node); node.output().setType(type.getElementType()); return node.output(); } if (type instanceof torch.DictType) { let key = this.expression(elt, context); const node = this._graph.create('aten::__getitem__.t', [value]); - this.graph.insertNode(node); + this._graph.insertNode(node); if (type.getKeyType() instanceof torch.StringType && typeof key === 'string') { const value = new torch.Value(node); value.value = key; @@ -1971,7 +1975,7 @@ pytorch.Execution = class extends python.Execution { const output_type = type.elements()[index]; index = this._graph.insertConstant(index); const node = this._graph.createTupleIndex(value, index, output_type); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } } @@ -1983,7 +1987,7 @@ pytorch.Execution = class extends python.Execution { const attr = expr.attr; if (target instanceof torch.Value && target.type() instanceof torch.ClassType) { const node = this._graph.createGetAttr(target, attr); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } return target[attr]; @@ -2012,7 +2016,7 @@ pytorch.Execution = class extends python.Execution { } const contained_type = typehint ? typehint.getElementType() : item_type; const node = this._graph.createList(contained_type, values); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } break; @@ -2033,7 +2037,7 @@ pytorch.Execution = class extends python.Execution { } const node = this._graph.createTuple(values); node.setSourceRange(expr.location); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } case 'Dict': { @@ -2058,7 +2062,7 @@ pytorch.Execution = class extends python.Execution { const key_type = typehint ? typehint.getKeyType() : keyType; const value_type = typehint ? typehint.getValueType() : valueType; const node = this._graph.createDict(key_type, value_type, keys, values); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } default: { @@ -2367,9 +2371,8 @@ pytorch.Execution = class extends python.Execution { return value.type(); }; this.variables(condition, condition); - const node = this._graph.create('prim::If', 0); - node.setSourceRange(stmt.location); - this.graph.insertNode(node); + const node = this.create('prim::If', stmt.location, 0); + this._graph.insertNode(node); node.addInput(test); const prev = this._graph.insertPoint(); const true_block = node.addBlock(); @@ -2428,9 +2431,14 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error("Unsupported condition."); } if (stmt instanceof ast.For) { - const node = this._graph.create('prim::Loop', 0); - node.setSourceRange(stmt.location); - this.graph.insertNode(node); + const range = stmt.location; + const node = this.create('prim::Loop', range, 0); + this._graph.insertNode(node); + const itrs = stmt.iter instanceof ast.Tuple ? stmt.iter.elts : [stmt.iter]; + // const targets = stmt.target instanceof ast.Tuple ? stmt.target.elts : [stmt.target]; + if (itrs.length !== 1) { + throw new pytorch.Error('List of iterables is not supported currently.'); + } const loop = stmt; if (loop.target instanceof ast.Name && loop.iter instanceof ast.Tuple === false) { const range = this.expression(loop.iter, context); @@ -2449,9 +2457,8 @@ pytorch.Execution = class extends python.Execution { } } if (stmt instanceof ast.While) { - const node = this._graph.create('prim::Loop', 0); - node.setSourceRange(stmt.location); - this.graph.insertNode(node); + const node = this._graph.create('prim::Loop', stmt.location, 0); + this._graph.insertNode(node); const test = this.expression(stmt.test, context); if (test) { const value = this.block(stmt.body, context); @@ -2577,9 +2584,9 @@ pytorch.Execution = class extends python.Execution { if (identifier) { const type = this._resolver.resolveType(identifier); if (type) { - const node = this.graph.createObject(type); + const node = this._graph.createObject(type); node.setSourceRange(location); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } } @@ -2590,9 +2597,9 @@ pytorch.Execution = class extends python.Execution { if (args.length === 0) { return obj; } - const node = this.graph.create('prim::CallMethod', 0); + const node = this._graph.create('prim::CallMethod', 0); node.setSourceRange(location); - this.graph.insertNode(node); + this._graph.insertNode(node); node.s_('name', name); node.addInput(obj); const evalArgs = args.map((arg) => this.expression(arg, context)); @@ -2608,8 +2615,8 @@ pytorch.Execution = class extends python.Execution { if (!overload) { const moduleTarget = this.target(target, context); if (moduleTarget instanceof torch.Value && moduleTarget.type() instanceof torch.ClassType) { - const node = this.graph.create('prim::CallMethod'); - this.graph.insertNode(node); + const node = this._graph.create('prim::CallMethod'); + this._graph.insertNode(node); node.s_('name', name); const evalArgs = args.map((expression) => this.expression(expression, context)); for (const arg of evalArgs) { @@ -2627,12 +2634,12 @@ pytorch.Execution = class extends python.Execution { const values = evalArgs.map((arg) => this.variable(arg)); const node = this._graph.createTuple(values, type); node.setSourceRange(location); - this.graph.insertNode(node); + this._graph.insertNode(node); return node.output(); } if (type instanceof torch.ClassType) { - const node = this.graph.create('prim::CallMethod'); - this.graph.insertNode(node); + const node = this._graph.create('prim::CallMethod'); + this._graph.insertNode(node); node.s_('name', name); const evalArgs = args.map((expression) => this.expression(expression, context)); for (const arg of evalArgs) { @@ -2646,9 +2653,8 @@ pytorch.Execution = class extends python.Execution { } const [schema, evalArgs] = overload; const op = schema.overload_name ? `${schema.name}.${schema.overload_name}` : schema.name; - const node = this._graph.create(op, 0); - node.setSourceRange(location); - this.graph.insertNode(node); + const node = this.create(op, location, 0); + this._graph.insertNode(node); const referencedParameters = []; const parameters = schema.arguments; const varTypes = new Map(); @@ -2666,7 +2672,7 @@ pytorch.Execution = class extends python.Execution { let index = 0; while (position < evalArgs.length) { if (index >= parameters.length) { - if (schema.name.startsWith('_caffe2::') || schema.is_vararg) { + if (schema.is_vararg) { break; } throw new pytorch.Error('Invalid parameter length.'); @@ -2694,25 +2700,25 @@ pytorch.Execution = class extends python.Execution { } else if (type instanceof torch.ListType && type.getElementType() instanceof torch.TensorType) { const v = evalArgs[position]; if ((v instanceof torch.Value && v.type() instanceof torch.ListType && v.type().getElementType() instanceof torch.TensorType) || - (Array.isArray(v) && v.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) { + (v === null || Array.isArray(v) && v.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) { position++; if (v instanceof torch.Value) { input = v; match = true; } else { - const list = this._graph.create('prim::ListConstruct'); - this.graph.insertNode(node); - for (const arg of v) { + const values = []; + for (const arg of v || []) { const tensor = arg; if (tensor) { tensor.__count__ = (tensor.__count__ || 0) + 1; } const value = this.variable(tensor); value.setType(torch.TensorType.get()); - list.addInput(value); + values.push(value); } - list.output().setType(torch.ListType.create(torch.TensorType.get())); - input = list.output(); + const node = this._graph.createList(torch.TensorType.get(), values); + this._graph.insertNode(node); + input = node.output(); match = true; } } else { @@ -3146,7 +3152,7 @@ pytorch.Execution = class extends python.Execution { let index = 0; while (position < evalArgs.length) { if (index >= parameters.length) { - next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg; + next = !schema.is_vararg; break; } const arg = parameters[index];