diff --git a/source/python.js b/source/python.js index ad8d2aff3a..47c69a3060 100644 --- a/source/python.js +++ b/source/python.js @@ -2444,9 +2444,7 @@ python.Execution = class { } flatten() { const size = this.shape.reduce((a, b) => a * b, 1); - const value = execution.invoke('numpy.ndarray', [ - [size], this.dtype, this.data, this.offset, this.strides, this.order - ]); + const value = new numpy.ndarray([size], this.dtype, this.data, this.offset, this.strides, this.order); value.flags = this.flags; return value; } @@ -5535,6 +5533,9 @@ python.Execution = class { throw new python.Error("Unsupported 'torch.__isnot__' expression type."); }); this.registerFunction('torch.__not__', (value) => { + if (Number.isInteger(value)) { + value = Boolean(value); + } if (typeof value === 'boolean') { return !value; } @@ -7311,11 +7312,14 @@ python.Execution = class { this.registerType('torch.Graph', class { constructor() { this._unique = 1; - this._nodes = []; - this._block = execution.invoke('torch.Block', [this]); + this._all_nodes = []; + this._all_values = []; + this._all_blocks = []; + this._block = new torch.Block(this); + this._insert_before = this.return_node(); } create(kind) { - return execution.invoke('torch.Node', [this, kind]); + return new torch.Node(this, kind); } inputs() { return this._block.inputs(); @@ -7324,8 +7328,7 @@ python.Execution = class { return this._block.outputs(); } nodes() { - return this._nodes; - // return this._block.nodes(); + return this._block.nodes(); } param_node() { return this._block.param_node(); @@ -7336,6 +7339,40 @@ python.Execution = class { addInput(name) { return this._block.addInput(name); } + insertNode(node) { + node.insertBefore(this._insert_before); + } + insertPoint() { + return this._insert_before; + } + setInsertPoint(node) { + if (node instanceof torch.Block) { + node = node.return_node(); + } + this._insert_before = node; + } + get all_nodes() { + return this._all_nodes; + } + freeNode(n) { + const index = this._all_nodes.indexOf(n); + if (index !== -1) { + this._all_nodes.splice(index, 1); + } + } + freeValue(v) { + v.setDebugName(''); + const index = this._all_values.indexOf(v); + if (index !== -1) { + this._all_values.splice(index, 1); + } + } + freeBlock(b) { + const index = this._all_blocks.indexOf(b); + if (index !== -1) { + this._all_blocks.splice(index, 1); + } + } }); this.registerType('torch.Block', class { constructor(graph) { @@ -7343,6 +7380,10 @@ python.Execution = class { this._graph = graph; this._input = graph.create('prim::Param'); this._output = graph.create('prim::Return'); + this._input.next = this._output; + this._input.prev = this._output; + this._output.next = this._input; + this._output.prev = this._input; } param_node() { return this._input; @@ -7356,6 +7397,15 @@ python.Execution = class { outputs() { return this._output.inputs(); } + nodes() { + const nodes = []; + let current = this._input; + do { + nodes.push(current); + current = current.next; + } while (current !== this._input); + return nodes; + } addInput(name) { const value = this._input.addOutput(); value.setDebugName(name || ''); @@ -7365,16 +7415,30 @@ python.Execution = class { this._output.addInput(value); return this.outputs().length - 1; } + destroy() { + this._output.removeAllInputs(); + for (const n of this.nodes()) { + n.destroy(); + } + this._output.destroy(); + this._input.destroy(); + this._graph.freeBlock(this); + } }); this.registerType('torch.Node', class { constructor(graph, kind) { this._graph = graph; - this._graph._nodes.push(this); this._kind = kind; this._values = new Map(); this._inputs = []; this._outputs = []; this._blocks = []; + this._graph.all_nodes.push(this); + this._prev = null; + this._next = null; + } + owningGraph() { + return this._graph; } kind() { return this._kind; @@ -7421,6 +7485,86 @@ python.Execution = class { this._blocks.push(block); return block; } + get prev() { + return this._prev; + } + set prev(value) { + this._prev = value; + } + get next() { + return this._next; + } + set next(value) { + this._next = value; + } + insertBefore(n) { + this.insertAfter(n.prev); + return this; + } + insertAfter(n) { + // this.owning_block_ = n->owningBlock(); + const next = n.next; + n.next = this; + this.prev = n; + this.next = next; + next.prev = this; + // assignTopoPosition(); + } + dropInput(i) { + const input = this._inputs[i]; + const uses = this._inputs[i].uses(); + for (let i = uses.length - 1; i >= 0; i--) { + const use = uses[i]; + if (use.user === this) { + uses.splice(i, 1); + } + } + this._inputs[i] = null; + return input; + } + eraseOutput(i) { + this._op = null; + const v = this._outputs[i]; + this._outputs.splice(i, 1); + this.owningGraph().freeValue(v); + } + eraseBlock(i) { + this._op = null; + const n = this._blocks[i]; + this._blocks.splice(i, 1); + n.destroy(); + } + removeAllInputs() { + for (let i = this._inputs.length - 1; i >= 0; i--) { + this.dropInput(i); + } + this._inputs.splice(0, this._inputs.length); + } + inBlockList() { + return this.next !== null; + } + removeFromList() { + this._owning_block = null; + const next = this.next; + const prev = this.prev; + prev.next = next; + next.prev = prev; + this.next = null; + this.prev = null; + } + destroy() { + while (this.outputs().length > 0) { + this.eraseOutput(this.outputs().length - 1); + } + while (this.blocks().length > 0) { + this.eraseBlock(this.blocks().length - 1); + } + this.removeAllInputs(); + if (this.inBlockList()) { + this.removeFromList(); + } + this._graph.freeNode(this); + } s_(name, value) { this._values.set(name, [value, 's']); } @@ -7622,7 +7766,7 @@ python.Execution = class { const data = buffer.slice(offset, offset + length); storage._set_cdata(data); } - const tensor = execution.invoke('torch._utils._rebuild_tensor', [storage, 0, shape, strides]); + const tensor = torch._utils._rebuild_tensor(storage, 0, shape, strides); tensor.name = constant.data.key; return tensor; }); @@ -7634,7 +7778,7 @@ python.Execution = class { if (this._reader.has_record('attributes.pkl')) { const stream = this._reader.get_record('attributes.pkl'); const buffer = stream.peek(); - const unpickler = execution.invoke('pickle.Unpickler', [buffer]); + const unpickler = new pickle.Unpickler(buffer); const obj = unpickler.load(); attributes.push(...obj); } @@ -9360,7 +9504,7 @@ python.Execution = class { this.registerFunction('torch._inductor.compile_fx.compile_fx'); this.registerFunction('torch_utils.persistence._reconstruct_persistent_obj', (meta) => { const name = `_imported_module_${Math.floor(Math.random() * 10000)}`; - const module = execution.invoke('types.ModuleType', [name]); + const module = new types.ModuleType(name); execution.register('sys').modules.set(name, module); const context = new python.Execution.Context(module, null); execution.exec(meta.get('module_src'), context); @@ -9903,7 +10047,7 @@ python.Execution = class { this.registerType('fastai.basic_train.Learner', class {}); this.registerType('fastai.basic_train.Recorder', class {}); this.registerFunction('fastai.torch_core._fa_rebuild_tensor', (cls, ...args) => { - const tensor = self.invoke('torch._utils._rebuild_tensor_v2', args); + const tensor = torch._utils._rebuild_tensor_v2(...args); return self.invoke(cls, tensor); }); this.registerFunction('fastai.torch_core.trainable_params'); diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index e94bd65e27..840e3fbb32 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -423,13 +423,16 @@ "name": "aten::_dim_arange(Tensor like, int dim) -> Tensor" }, { - "name": "aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.) -> Tensor" + "name": "aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.) -> Tensor", + "category": "Quantization" }, { - "name": "aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1., *, Tensor(a!) out) -> Tensor(a!)" + "name": "aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1., *, Tensor(a!) out) -> Tensor(a!)", + "category": "Quantization" }, { - "name": "aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.) -> (Tensor, Tensor, Tensor)" + "name": "aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.) -> (Tensor, Tensor, Tensor)", + "category": "Quantization" }, { "name": "aten::_make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor" @@ -2159,22 +2162,28 @@ "name": "aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)" }, { - "name": "aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor" + "name": "aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor", + "category": "Quantization" }, { - "name": "aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor" + "name": "aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor", + "category": "Quantization" }, { - "name": "aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor" + "name": "aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor", + "category": "Quantization" }, { - "name": "aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)" + "name": "aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", + "category": "Quantization" }, { - "name": "aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))" + "name": "aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", + "category": "Quantization" }, { - "name": "aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor" + "name": "aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor", + "category": "Quantization" }, { "name": "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor", diff --git a/source/pytorch.js b/source/pytorch.js index 001a94f1c4..459b6560a3 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -176,7 +176,7 @@ pytorch.Graph = class { node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) { continue; } - if (node.kind() === 'prim::Constant' && node.outputs().length === 1 && node.outputs()[0].uses().length === 1) { + if (node.kind() === 'prim::Constant' && node.outputs().length <= 1 && node.outputs()[0].uses().length <= 1) { continue; } this.nodes.push(new pytorch.Node(metadata, null, null, node, initializers, values)); @@ -507,7 +507,7 @@ pytorch.Node = class { argument = new pytorch.Argument(name, input.value, 'string'); } else if (input.node() && input.uses().length === 1 && input.node().kind() === 'prim::Constant') { let [type, value] = getAttribute(input.node(), 'value'); - const valueType = input.node().outputs()[0].type(); + const valueType = input.type(); if (valueType) { type = pytorch.Utility.toType(valueType); if (type === 'boolean') { @@ -576,6 +576,13 @@ pytorch.Node = class { const argument = new pytorch.Argument(name, args); this.outputs.push(argument); } + const blocks = node.blocks(); + for (let i = 0; i < blocks.length; i++) { + const name = `block${i.toString()}`; + const graph = { name: '', nodes: [] }; // new pytorch.Graph(metadata, null, name, blocks[i]); + const argument = new pytorch.Argument(name, graph, 'graph'); + this.inputs.push(argument); + } } else if (pytorch.Utility.isInstance(obj, 'torch.fx.node.Node')) { if (obj.op === 'call_function') { const name = obj.target.name; @@ -1540,7 +1547,8 @@ pytorch.Execution = class extends python.Execution { execution.variable(this.serialized_model_tensor); this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1; const type = new pytorch.nnapi.Graph(this.serialized_model); - const node = execution._graph.create(type); + const node = execution.graph.create(type); + execution.graph.insertNode(node); for (const tensor of inputs) { const value = execution.variable(tensor); node.addInput(value); @@ -1647,6 +1655,7 @@ pytorch.Execution = class extends python.Execution { constant(value) { const torch = this.torch; const node = this.graph.create('prim::Constant'); + this.graph.insertNode(node); let type = null; if (value === null) { node.ival_('value', value); @@ -1815,6 +1824,7 @@ pytorch.Execution = class extends python.Execution { if (target.value.every((item) => item.type === 'id')) { if (value instanceof torch.Value) { const node = this._graph.create('prim::TupleUnpack'); + this.graph.insertNode(node); node.addInput(value); const outputs = []; for (let i = 0; i < target.value.length; i++) { @@ -1835,10 +1845,10 @@ pytorch.Execution = class extends python.Execution { return outputs; } if (target.value.length < value.length) { - throw new python.Error(`ValueError: too many values to unpack (expected ${target.value.length}, actual ${value.length}).`); + throw new pytorch.Error(`ValueError: too many values to unpack (expected ${target.value.length}, actual ${value.length}).`); } if (target.value.length > value.length) { - throw new python.Error(`ValueError: not enough values to unpack (expected ${target.value.length}, actual ${value.length}).`); + throw new pytorch.Error(`ValueError: not enough values to unpack (expected ${target.value.length}, actual ${value.length}).`); } for (let i = 0; i < value.length; i++) { context.set(target.value[i].value, value[i]); @@ -1883,6 +1893,7 @@ pytorch.Execution = class extends python.Execution { if (expression.target.type === 'id' && expression.target.value === 'uninitialized') { const type = this.type(expression.args[0], context); const node = this._graph.create('prim::Uninitialized'); + this.graph.insertNode(node); const value = node.addOutput(); value.setType(type); return value; @@ -1891,6 +1902,7 @@ pytorch.Execution = class extends python.Execution { let value = this.expression(expression.args[1], context); const type = this.type(expression.args[0], context); const node = this._graph.create('prim::unchecked_cast'); + this.graph.insertNode(node); node.addInput(this.variable(value)); value = node.addOutput(); value.setType(type); @@ -1900,6 +1912,7 @@ pytorch.Execution = class extends python.Execution { let value = this.expression(expression.args[1], context); // const type = this.type(expression.args[0]); const node = this._graph.create('prim::isinstance'); + this.graph.insertNode(node); node.addInput(this.variable(value)); value = node.addOutput(); value.setType(torch.BoolType.get()); @@ -1910,6 +1923,7 @@ pytorch.Execution = class extends python.Execution { const target = this.target(expression.target.target, context); // this.expression(expression.target.target, context); if (target instanceof torch.Value && target.type() instanceof torch.ClassType) { const node = this._graph.create('prim::CallMethod'); + this.graph.insertNode(node); const name = this.variable(expression.target.member.value, node); node.addInput(name); const args = expression.args.map((expression) => this.expression(expression, context)); @@ -1934,6 +1948,7 @@ pytorch.Execution = class extends python.Execution { if (type instanceof torch.ListType) { let index = this.expression(expression.arguments.value[0], context); const node = this._graph.create('aten::__getitem__.t'); + this.graph.insertNode(node); node.addInput(target); if (Number.isInteger(index)) { index = this.constant(index); @@ -1946,6 +1961,7 @@ pytorch.Execution = class extends python.Execution { if (type instanceof torch.DictType) { let key = this.expression(expression.arguments.value[0], context); const node = this._graph.create('aten::__getitem__.t'); + this.graph.insertNode(node); node.addInput(target); if (type.getKeyType() instanceof torch.StringType && typeof key === 'string') { const value = new torch.Value(node); @@ -1964,6 +1980,7 @@ pytorch.Execution = class extends python.Execution { if (type instanceof torch.TupleType) { let index = this.expression(expression.arguments.value[0], context); const node = this._graph.create('prim::TupleIndex'); + this.graph.insertNode(node); const value = node.addOutput(); value.setType(type.elements()[index]); node.addInput(target); @@ -1983,6 +2000,7 @@ pytorch.Execution = class extends python.Execution { if (typeof expression.member.value === 'string' && target instanceof torch.Value && target.type() instanceof torch.ClassType) { const type = target.type().findAttribute(expression.member.value); const node = this.graph.create('prim::GetAttr'); + this.graph.insertNode(node); node.s_(expression.member.value); node.addInput(target); const value = node.addOutput(); @@ -1991,12 +2009,13 @@ pytorch.Execution = class extends python.Execution { } return target[expression.member.value]; } - throw new python.Error("Unsupported field expression."); + throw new pytorch.Error('Unsupported field expression.'); } case 'list': { const list = expression.value.map((item) => this.expression(item, context)); if (/* list.length > 0 && */ list.every((item) => pytorch.Utility.isInstance(item, 'torch.Value') || pytorch.Utility.isTensor(item) || Number.isInteger(item) || typeof item === 'string' || item === null)) { const node = this._graph.create('prim::ListConstruct'); + this.graph.insertNode(node); const output = node.addOutput(); for (const item of list) { if (item instanceof torch.Value) { @@ -2027,6 +2046,7 @@ pytorch.Execution = class extends python.Execution { case 'tuple': { const args = expression.value.map((expression) => this.expression(expression, context)); const node = this._graph.create('prim::TupleConstruct'); + this.graph.insertNode(node); const types = []; const elements = []; for (const item of args) { @@ -2073,11 +2093,12 @@ pytorch.Execution = class extends python.Execution { } case 'dict': { const node = this._graph.create('prim::DictConstruct'); + this.graph.insertNode(node); let keyType = null; let valueType = null; for (const pair of expression.value) { if (pair.type !== 'pair') { - throw new python.Error(`Unsupported dict item type '${pair.type}'.`); + throw new pytorch.Error(`Unsupported dict item type '${pair.type}'.`); } const key = this.expression(pair.key, context); const keyValue = this.variable(key, null); @@ -2101,6 +2122,187 @@ pytorch.Execution = class extends python.Execution { return super.expression(expression, context); } + static(expression, context, state) { + const torch = this.torch; + switch (expression.type) { + case 'id': { + switch (expression.value) { + case 'None': return null; + case 'True': return true; + case 'False': return false; + default: { + const value = context.get(expression.value); + if (typeof value === 'number' || typeof value === 'boolean' || typeof value === 'string') { + return value; + } + if (value instanceof torch.Tensor && value.storage() && value.storage().size() !== undefined) { + return value; + } + if (value instanceof torch.Value) { + const node = value.node(); + if (node.kind() === 'prim::Constant') { + state.push(node); + return pytorch.Utility.constant(node, 'value'); + } else if (node.kind() === 'prim::ListConstruct' && node.inputs().every((value) => value instanceof torch.Value && value.node().kind() === 'prim::Constant')) { + state.push(node); + for (const value of node.inputs()) { + state.push(value.node()); + } + return node.inputs().map((value) => pytorch.Utility.constant(value.node(), 'value')); + } else if (node.kind() === 'prim::TupleUnpack') { + const index = node.outputs().indexOf(value); + const input = node.inputs()[0].node(); + if (input.kind() === 'prim::TupleConstruct') { + const value = input.inputs()[index]; + const node = value.node(); + if (node.kind() === 'prim::Constant') { + return pytorch.Utility.constant(node, 'value'); + } + } + } + state.splice(0, state.length); + } + break; + } + } + break; + } + case 'list': { + return expression.value.map((expression) => this.static(expression, context)); + } + case 'string': { + return expression.value.substring(1, expression.value.length - 1); + } + case 'number': { + return Number(expression.value); + } + case '.': { + if (expression.member.type === 'id') { + const target = this.target(expression.target, context); + return target[expression.member.value]; + } + break; + } + case 'call': { + const args = expression.args.map((expression) => this.static(expression, context, state)); + if (args.every((arg) => arg !== undefined)) { + const target = this.target(expression.target, context); + if (typeof target === 'function') { + return target(...args); + } + } + state.splice(0, state.length); + break; + } + default: { + break; + } + } + return undefined; + } + + block(statements, context) { + this.traceIf = false; + if (!this.traceIf) { + return super.block(statements, context); + } + statements = Array.prototype.slice.call(statements); + while (statements.length > 0) { + if (statements.length > 1) { + const containsVariableReference = (queue, value) => { + while (queue.length > 0) { + const obj = queue.shift(); + if (obj && obj.type === 'id' && obj.value === value) { + return true; + } else if (Array.isArray(obj)) { + for (const item of obj) { + if (Array.isArray(item) || (Object(item) === item && item.type)) { + queue.push(item); + } + } + } else if (Object(obj) === obj) { + for (const [key, value] of Object.entries(obj)) { + if (key !== 'identifier') { + if (Array.isArray(value)) { + for (const item of value) { + if (Array.isArray(item) || (Object(item) === item && item.type)) { + queue.push(item); + } + } + } else if (Object(value) === value && value.type) { + queue.push(value); + } + } + } + } + } + return false; + }; + const [assign, condition] = statements; + // _x = + // if _x: + // ... + if (assign.type === '=' && condition.type === 'if' && + assign.target.type === 'id' && condition.test.type === 'id' && + assign.target.value === condition.test.value && + !containsVariableReference(statements.slice(2), condition.test.value)) { + statements.shift(); + statements[0] = { + type: 'if', + test: assign.expression, + body: condition.body, + orelse: condition.orelse + }; + } + } + const [condition] = statements; + if (condition.type === 'if') { + const state = []; + let test = this.static(condition.test, context, state); + const count = new Map(); + for (const node of state) { + if (count.has(node)) { + count.set(node, count.get(node) + 1); + } else { + count.set(node, 1); + } + } + if (count.size > 0 && Array.from(count).every(([node, count]) => node.outputs().length === 1 && node.outputs()[0].uses().length <= count)) { + for (const node of state) { + node.destroy(); + } + } + if (test === null) { + test = false; + } else if (typeof test === 'boolean') { + test = test === true; + } else if (Number.isInteger(test)) { + test = test !== 0; + } else if (typeof test === 'string') { + test = test && test.length > 0; + } + if (test === true) { + statements.shift(); + statements = condition.body.statements.concat(statements); + continue; + } + if (test === false) { + statements.shift(); + statements = condition.orelse.statements.concat(statements); + continue; + } + } + if (statements.length > 0) { + const statement = statements.shift(); + const value = this.statement(statement, context); + if (value !== undefined) { + return value; + } + } + } + return undefined; + } + statement(statement, context) { const torch = this.torch; if (!this.trace) { @@ -2121,27 +2323,54 @@ pytorch.Execution = class extends python.Execution { return undefined; } case 'if': { - const test = this.expression(statement.test, context); - if (test === true || (!this.traceIf && test)) { - const value = this.block(statement.body.statements, context); - if (value !== undefined) { - return value; + if (this.traceIf) { + const test = this.expression(statement.test, context); + if (test instanceof torch.Value && test.type() instanceof torch.BoolType) { + const node = this._graph.create('prim::If'); + this.graph.insertNode(node); + node.addInput(test); + const prev = this._graph.insertPoint(); + const true_block = node.addBlock(); + this._graph.setInsertPoint(true_block); + this.block(statement.body.statements, context); + const false_block = node.addBlock(); + this._graph.setInsertPoint(false_block); + this.block(statement.orelse.statements, context); + this._graph.setInsertPoint(prev); + return undefined; } - return undefined; - } else if (test === false) { - if (statement.orelse) { - const value = this.block(statement.orelse.statements, context); + } else { + const test = this.expression(statement.test, context); + if (test === true || (!this.traceIf && test)) { + const value = this.block(statement.body.statements, context); if (value !== undefined) { return value; } + return undefined; + } else if (test === false) { + if (statement.orelse) { + const value = this.block(statement.orelse.statements, context); + if (value !== undefined) { + return value; + } + } + return undefined; + } else if (test instanceof torch.Value && test.type() instanceof torch.BoolType) { + const node = this._graph.create('prim::If'); + this.graph.insertNode(node); + node.addInput(test); + const prev = this._graph.insertPoint(); + const true_block = node.addBlock(); + this._graph.setInsertPoint(true_block); + this.block(statement.body.statements, context); + const false_block = node.addBlock(); + this._graph.setInsertPoint(false_block); + this.block(statement.orelse.statements, context); + this._graph.setInsertPoint(prev); + return undefined; } - return undefined; - } else if (test instanceof torch.Value && test.type() instanceof torch.BoolType) { - const node = this._graph.create('prim::If'); - node.addInput(test); - return undefined; } - throw new python.Error("Unsupported condition."); + throw new pytorch.Error("Unsupported condition."); } default: { break; @@ -2212,6 +2441,7 @@ pytorch.Execution = class extends python.Execution { const type = this.resolve(identifier); if (type && type.__type__) { const node = this.graph.create('prim::CreateObject'); + this.graph.insertNode(node); const value = node.addOutput(); value.setType(type.__type__); return value; @@ -2224,6 +2454,7 @@ pytorch.Execution = class extends python.Execution { return obj; } const node = this.graph.create('prim::CallMethod'); + this.graph.insertNode(node); node.s_('name', name); node.addInput(obj); const evalArgs = args.map((arg) => this.expression(arg, context)); @@ -2236,9 +2467,10 @@ pytorch.Execution = class extends python.Execution { } const overload = this._overload(target, name, args, context); if (!overload) { - const moduleTarget = this.target(target, context); // this.expression(expression.target.target, context); + const moduleTarget = this.target(target, context); if (moduleTarget instanceof torch.Value && moduleTarget.type() instanceof torch.ClassType) { const node = this.graph.create('prim::CallMethod'); + this.graph.insertNode(node); node.s_('name', name); const evalArgs = args.map((expression) => this.expression(expression, context)); for (const arg of evalArgs) { @@ -2252,6 +2484,7 @@ pytorch.Execution = class extends python.Execution { const [schema, evalArgs] = overload; const op = schema.overload_name ? `${schema.name}.${schema.overload_name}` : schema.name; const node = this._graph.create(op); + this.graph.insertNode(node); const referencedParameters = []; const parameters = schema.arguments; const varTypes = new Map(); @@ -2304,6 +2537,7 @@ pytorch.Execution = class extends python.Execution { match = true; } else { const list = this._graph.create('prim::ListConstruct'); + this.graph.insertNode(node); for (const arg of v) { const tensor = arg; if (tensor) { @@ -2694,12 +2928,12 @@ pytorch.Execution = class extends python.Execution { } _overload(target, name, args, context) { - const moduleName = pytorch.Utility.target(target); - if (!moduleName) { + const torch = this.torch; + const prefix = pytorch.Utility.target(target); + if (!prefix) { return null; } - const torch = this.torch; - const type = name ? `${moduleName}.${name}` : moduleName; + const type = name ? `${prefix}.${name}` : prefix; // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml let op_name = null; if (type.startsWith('torch.')) { @@ -2721,7 +2955,7 @@ pytorch.Execution = class extends python.Execution { let evalArgs = null; overloads = torch._C._jit_get_schemas_for_operator(op_name); if ((!overloads || overloads.length === 0) && type.startsWith('ops.') && !type.startsWith('ops.prim')) { - const module = this.import(moduleName); + const module = this.import(prefix); if (!module || !module[name]) { const schema = new torch.FunctionSchema(op_name, null, [], [], false, false); for (let i = 0; i < args.length; i++) { @@ -2993,6 +3227,18 @@ pytorch.Utility = class { } } + static constant(node, name) { + const kind = node.kindOf(name); + switch (kind) { + case 's': return node.s(name); + case 'i': return node.i(name); + case 'f': return node.f(name); + case 'ss': return node.ss(name); + case 'ival': return node.ival(name); + default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`); + } + } + static isObjectType(type) { switch (type) { case '__torch__.torch.classes.xnnpack.LinearOpContext':