From 8dccbd87421973a3bf26ea207b8fc3e144d1d334 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 23 Oct 2024 19:45:30 -0700 Subject: [PATCH] Update pytorch.js (#1061) --- source/python.js | 7 +++- source/pytorch.js | 96 +++++++++++++++++++++++++++++++++++------------ source/view.js | 10 ++--- 3 files changed, 82 insertions(+), 31 deletions(-) diff --git a/source/python.js b/source/python.js index fb8c3d21ec..05ba396541 100644 --- a/source/python.js +++ b/source/python.js @@ -6168,7 +6168,7 @@ python.Execution = class { kind() { return this._kind; } - annotation_str() { + get annotation_str() { return this._annotation_str; } equals(/* rhs */) { @@ -6357,7 +6357,10 @@ python.Execution = class { return torch.AnyType.value; } str() { - return 'AnyType'; + return 'Any'; + } + __str__() { + return 'Any'; } }); this.registerType('torch.NoneType', class extends torch.Type { diff --git a/source/pytorch.js b/source/pytorch.js index ca8839daf1..df3c76160b 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -447,6 +447,20 @@ pytorch.Node = class { module = null; } } + const mapTensor = (input) => { + let initializer = null; + let identifier = input.unique().toString(); + if (input.value) { + const value = input.value; + const hide = value.__parent__ ? value.__parent__.__hide__ : true; + initializer = hide ? initializers.get(value) : null; + identifier = initializer ? initializer.name : identifier; + } + if (initializer) { + return new pytorch.Value(identifier, null, null, initializer); + } + return values.map(identifier); + }; for (let i = 0; i < inputs.length; i++) { const input = inputs[i]; const arg = schema && schema.arguments && i < schema.arguments.length ? schema.arguments[i] : null; @@ -480,16 +494,23 @@ pytorch.Node = class { argument = new pytorch.Argument(name, input.value, type || 'attribute'); } else if (pytorch.Utility.isInstance(input.type(), 'torch.ListType')) { if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && - input.node().inputs().every((value) => pytorch.Utility.isInstance(value, 'torch.Value') || pytorch.Utility.isInstance(value.type(), 'torch.IntType') || pytorch.Utility.isInstance(value.type(), 'torch.FloatType') || pytorch.Utility.isInstance(value.type(), 'torch.StringType') || pytorch.Utility.isInstance(value.type(), 'torch.ComplexType') || pytorch.Utility.isInstance(value.type(), 'torch.TensorType'))) { + input.node().inputs().every((value) => pytorch.Utility.isInstance(value, 'torch.Value') || pytorch.Utility.isInstance(value.type(), 'torch.IntType') || pytorch.Utility.isInstance(value.type(), 'torch.FloatType') || pytorch.Utility.isInstance(value.type(), 'torch.StringType') || pytorch.Utility.isInstance(value.type(), 'torch.ComplexType') || pytorch.Utility.isInstance(value.type(), 'torch.TensorType'))) { const list = input.node().inputs(); const args = list.map((value) => { + if (pytorch.Utility.isTensor(value.value)) { + return mapTensor(value); + } + if (value.uses().length === 1 && value.node().kind() === 'prim::Constant') { + return getAttribute(value.node(), 'value')[1]; + } if (value.uses().length === 1 && value.node() === input.node() && value.value !== undefined) { return value.value; } const identifier = value.unique().toString(); return values.map(identifier); }); - argument = new pytorch.Argument(name, args, pytorch.Utility.toType(input.type())); + const type = list.every((value) => (pytorch.Utility.isTensor(value.value)) || value.value === null) ? null : pytorch.Utility.toType(input.type()); + argument = new pytorch.Argument(name, args, type); } else { const identifier = input.unique().toString(); argument = new pytorch.Argument(name, [values.map(identifier)]); @@ -497,7 +518,14 @@ pytorch.Node = class { } else if (pytorch.Utility.isInstance(input.type(), 'torch.StringType') && typeof input.value === 'string') { argument = new pytorch.Argument(name, input.value, 'string'); } else if (input.node() && input.uses().length === 1 && input.node().kind() === 'prim::Constant') { - const [type, value] = getAttribute(input.node(), 'value'); + let [type, value] = getAttribute(input.node(), 'value'); + const valueType = input.node().outputs()[0].type(); + if (valueType) { + type = pytorch.Utility.toType(valueType); + if (type === 'boolean') { + value = Boolean(value); + } + } argument = new pytorch.Argument(name, value, type || 'attribute'); } else { const identifier = input.unique().toString(); @@ -2357,6 +2385,14 @@ pytorch.jit.Execution = class extends pytorch.Execution { } const torch = this.torch; switch (expression.type) { + case 'id': { + switch (expression.value) { + case 'True': return this.constant(true); + case 'False': return this.constant(false); + default: break; + } + return super.expression(expression, context); + } case '=': { const target = expression.target; if (target.type === 'id') { @@ -2559,15 +2595,11 @@ pytorch.jit.Execution = class extends pytorch.Execution { node.addInput(item); output.setType(torch.ListType.get(item.type())); } else if (Number.isInteger(item)) { - const value = new torch.Value(node); - value.value = item; - value.setType(torch.IntType.get()); + const value = this.constant(item); node.addInput(value); output.setType(torch.ListType.get(torch.IntType.get())); } else if (typeof item === 'string') { - const value = new torch.Value(node); - value.value = item; - value.setType(torch.StringType.get()); + const value = this.constant(item); node.addInput(value); output.setType(torch.ListType.get(torch.StringType.get())); } else if (pytorch.Utility.isTensor(item)) { @@ -2681,16 +2713,27 @@ pytorch.jit.Execution = class extends pytorch.Execution { return undefined; } case 'if': { - if (!this.traceIf) { - return super.statement(statement, context); - } - /* const test = this.expression(statement.test, context); - const n = this._graph.create('prim::If'); - const true_block = n.addBlock(); - const false_block = n.addBlock(); - */ - return undefined; + if (test instanceof torch.Value) { + const node = this._graph.create('prim::If'); + node.addInput(test); + } + if (test === true || test) { + const value = this.block(statement.body.statements, context); + if (value !== undefined) { + return value; + } + return undefined; + } else if (test === false) { + if (statement.orelse) { + const value = this.block(statement.orelse.statements, context); + if (value !== undefined) { + return value; + } + } + return undefined; + } + throw new python.Error("Unsupported condition."); } default: { break; @@ -2806,7 +2849,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { const varTypes = new Map(); varTypes.map = function(type) { if (type.kind() === 'VarType') { - const key = type.annotation_str(); + const key = type.annotation_str; if (!varTypes.has(key)) { throw new pytorch.Error(`Unknown var type '${key}'.`); } @@ -2895,7 +2938,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { if (match) { node.addInput(input); if (type.kind() === 'VarType') { - const key = type.annotation_str(); + const key = type.annotation_str; if (input instanceof torch.Value && input.type()) { varTypes.set(key, input.type()); } else if (input instanceof torch.Value && Number.isInteger(input.value)) { @@ -2904,7 +2947,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { // throw new pytorch.Error("Unknown value type 't'."); } if (type instanceof torch.ListType && type.getElementType().kind() === 'VarType') { - const key = type.getElementType().annotation_str(); + const key = type.getElementType().annotation_str; if (input instanceof torch.Value && input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.ListType) { varTypes.set(key, input.type().getElementType().getElementType()); } else if (input instanceof torch.Value && input.type() instanceof torch.ListType) { @@ -2921,7 +2964,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { } } if (type instanceof torch.DictType && type.getValueType().kind() === 'VarType') { - const key = type.getValueType().annotation_str(); + const key = type.getValueType().annotation_str; if (input instanceof torch.Value && input.type() instanceof torch.DictType) { varTypes.set(key, input.type().getValueType()); } else if (input.value && Object.values(input.value).every((item) => pytorch.Utility.isTensor(item))) { @@ -2931,7 +2974,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { } } if (type instanceof torch.ListType && type.getElementType() instanceof torch.TupleType && type.getElementType().elements().length === 2 && type.getElementType().elements()[1].kind() === 'VarType') { - const key = type.getElementType().elements()[1].annotation_str(); + const key = type.getElementType().elements()[1].annotation_str; if (input instanceof torch.Value && input.type() instanceof torch.ListType && input.type().getElementType() instanceof torch.TupleType) { const elements = input.type().getElementType().elements(); if (elements.length === 2) { @@ -3341,6 +3384,9 @@ pytorch.jit.Execution = class extends pytorch.Execution { case 'Tensor[]': return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) || (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType); + case 'Tensor?[]': + return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) || + (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.OptionalType && obj.type().getElementType().getElementType() instanceof torch.TensorType); case 'Scalar': return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) || @@ -3424,6 +3470,8 @@ pytorch.jit.Execution = class extends pytorch.Execution { return false; case 'complex': return obj instanceof torch.Value && obj.type() instanceof torch.ComplexType; + case 'Any': + return true; case 'Any[]': if (Array.isArray(obj)) { return true; @@ -4450,7 +4498,7 @@ pytorch.Utility = class { return `boolean`; } if (pytorch.Utility.isInstance(type, 'torch.TensorType')) { - return `Tensor`; + return `tensor`; } if (pytorch.Utility.isInstance(type, 'torch.TupleType')) { return `(${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')})`; diff --git a/source/view.js b/source/view.js index 145cb19534..80715b896f 100644 --- a/source/view.js +++ b/source/view.js @@ -2913,13 +2913,13 @@ view.ArgumentView = class extends view.Control { if (argument.type === 'attribute') { this._source = 'attribute'; } - if (argument.type === 'tensor') { - value = [{ type: value.type, initializer: value }]; - } else if (argument.type === 'tensor[]') { + if (argument.type === 'tensor' || argument.type === 'tensor?') { + value = [value === null ? value : { type: value.type, initializer: value }]; + } else if (argument.type === 'tensor[]' || argument.type === 'tensor?[]') { value = value.map((value) => value === null ? value : { type: value.type, initializer: value }); } this._source = typeof type === 'string' && !type.endsWith('*') ? 'attribute' : this._source; - if (this._source === 'attribute' && type !== 'tensor' && type !== 'tensor[]') { + if (this._source === 'attribute' && type !== 'tensor' && type !== 'tensor?' && type !== 'tensor[]' && type !== 'tensor?[]') { this._source = 'attribute'; const item = new view.PrimitiveView(context, argument); this._items.push(item); @@ -2929,7 +2929,7 @@ view.ArgumentView = class extends view.Control { } else { const values = value; for (const value of values) { - const emit = values.length === 1 && value.initializer; + const emit = values.length === 1 && value && value.initializer; const target = emit ? argument : value; if (value === null) { const item = new view.TextView(this._view, null);