From 2414e0668df34df4fbd6a0caf32c9d6875566bd2 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Fri, 16 Dec 2022 10:26:16 -0500 Subject: [PATCH] Update pytorch.js --- source/pytorch.js | 264 +++++++++++++++++++++++----------------------- 1 file changed, 131 insertions(+), 133 deletions(-) diff --git a/source/pytorch.js b/source/pytorch.js index c6f6f527b5..21ce0a4c4b 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -55,146 +55,146 @@ pytorch.Graph = class { this._outputs = []; this._groups = true; this._name = name || ''; - switch (module.__type__) { - case 'script': { - const traced = module.trace(); - const initializers = new Map(); - const constants = module.execution.builtins.CONSTANTS; - if (constants) { - for (const entry of Object.entries(constants)) { - const name = 'CONSTANTS.' + entry[0]; - const value = entry[1]; - if (pytorch.Utility.isTensor(value)) { - const initializer = new pytorch.Tensor(name, value); - initializers.set(value, initializer); - } - else if (value && value.__class__ && value.__class__.__module__ && value.__class__.__name__) { - const type = value.__class__.__module__ + '.' + value.__class__.__name__; - switch (type) { - case '__torch__.torch.classes.xnnpack.LinearOpContext': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': - for (const entry of Object.entries(value)) { - const key = entry[0]; - const value = entry[1]; - if (pytorch.Utility.isTensor(value)) { - initializers.set(value, new pytorch.Tensor(name + '.' + key, value)); - } + if (module instanceof pytorch.jit.ScriptModule) { + const traced = module.trace(); + const initializers = new Map(); + const constants = module.execution.builtins.CONSTANTS; + if (constants) { + for (const entry of Object.entries(constants)) { + const name = 'CONSTANTS.' + entry[0]; + const value = entry[1]; + if (pytorch.Utility.isTensor(value)) { + initializers.set(value, new pytorch.Tensor(name, value)); + } + else if (value && value.__class__ && value.__class__.__module__ && value.__class__.__name__) { + const type = value.__class__.__module__ + '.' + value.__class__.__name__; + switch (type) { + case '__torch__.torch.classes.xnnpack.LinearOpContext': + case '__torch__.torch.classes.xnnpack.Conv2dOpContext': + case '__torch__.torch.classes.quantized.LinearPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': + for (const entry of Object.entries(value)) { + const key = entry[0]; + const value = entry[1]; + if (pytorch.Utility.isTensor(value)) { + initializers.set(value, new pytorch.Tensor(name + '.' + key, value)); } - break; - default: - throw new pytorch.Error("Unsupported constant context '" + type + "'."); - } - } - else { - throw new pytorch.Error('Unsupported constant.'); + } + break; + default: + throw new pytorch.Error("Unsupported constant context '" + type + "'."); } } - } - const queue = [ module.data ]; - while (queue.length > 0) { - const module = queue.shift(); - if (module.__class__ && module.__class__.__module__ === '__torch__.torch.classes._nnapi' && module.__class__.__name__ === 'Compilation') { - continue; + else { + throw new pytorch.Error('Unsupported constant.'); } - for (const entry of Object.entries(module)) { - const key = entry[0]; - if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') { - const obj = entry[1]; - if (!Array.isArray(obj) && obj === Object(obj)) { - if (pytorch.Utility.isTensor(obj)) { - const parameter = obj; - parameter.__parent__ = module; - if (!parameter.initializer && parameter.storage()) { - if (parameter.__count__ === undefined || parameter.__count__ === 1) { - initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); - } + } + } + const queue = [ module.data ]; + while (queue.length > 0) { + const module = queue.shift(); + if (module.__class__ && module.__class__.__module__ === '__torch__.torch.classes._nnapi' && module.__class__.__name__ === 'Compilation') { + continue; + } + for (const entry of Object.entries(module)) { + const key = entry[0]; + if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') { + const obj = entry[1]; + if (!Array.isArray(obj) && obj === Object(obj)) { + if (pytorch.Utility.isTensor(obj)) { + const parameter = obj; + parameter.__parent__ = module; + if (!parameter.initializer && parameter.storage()) { + if (parameter.__count__ === undefined || parameter.__count__ === 1) { + initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); } } - else if (obj && obj.__class__) { - obj.__parent__ = module; - if (!obj.__id__) { - obj.__id__ = key; - } - queue.push(obj); + } + else if (obj && obj.__class__) { + obj.__parent__ = module; + if (!obj.__id__) { + obj.__id__ = key; } + queue.push(obj); } } } } - if (traced) { - const graph = module.graph; - for (const value of graph.inputs()) { - const identifier = value.unique().toString(); - const name = value.debugName() || identifier; - this._inputs.push(new pytorch.Parameter(name, true, [ - new pytorch.Argument(identifier, null, null) - ])); + } + if (traced) { + const graph = module.graph; + for (const value of graph.inputs()) { + const identifier = value.unique().toString(); + const name = value.debugName() || identifier; + this._inputs.push(new pytorch.Parameter(name, true, [ + new pytorch.Argument(identifier, null, null) + ])); + } + for (const value of graph.outputs()) { + const identifier = value.unique().toString(); + this._outputs.push(new pytorch.Parameter(identifier, true, [ + new pytorch.Argument(identifier, null, null) + ])); + } + for (const node of graph.nodes()) { + if (node === graph.param_node() || + node === graph.return_node()) { + continue; } - for (const value of graph.outputs()) { - const identifier = value.unique().toString(); - this._outputs.push(new pytorch.Parameter(identifier, true, [ - new pytorch.Argument(identifier, null, null) - ])); + if (node.kind() === 'prim::ListConstruct' && + node.outputs().length === 1 && + node.outputs().every((output) => output.uses().length === 1) && + node.inputs().every((input) => pytorch.Utility.isTensor(input.value))) { + continue; } - for (const node of graph.nodes()) { - if (node === graph.param_node() || - node === graph.return_node()) { - continue; - } - if (node.kind() === 'prim::ListConstruct' && - node.outputs().length === 1 && - node.outputs().every((output) => output.uses().length === 1) && - node.inputs().every((input) => pytorch.Utility.isTensor(input.value))) { - continue; - } - if (node.kind() === 'prim::ListUnpack' && - node.inputs().length === 1 && - node.inputs().every((input) => input.uses().length === 1) && - node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) { - continue; - } - const item = { - type: node.kind(), - node: node - }; - this._nodes.push(new pytorch.Node(metadata, '', item, initializers)); + if (node.kind() === 'prim::ListUnpack' && + node.inputs().length === 1 && + node.inputs().every((input) => input.uses().length === 1) && + node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) { + continue; } + const item = { + type: node.kind(), + node: node + }; + this._nodes.push(new pytorch.Node(metadata, '', item, initializers)); } - if (module) { - this._loadScriptModule(metadata, module.data, initializers); - } - break; } - case 'module': { - this._type = (module.__module__ && module.__name__) ? (module.__module__ + '.' + module.__name__) : ''; - this._loadModule(metadata, module, [], []); - break; + if (module) { + this._loadScriptModule(metadata, module.data, initializers); } - case 'weights': { - for (const state_group of module) { - const attributes = state_group.attributes || []; - const inputs = state_group.states.map((parameter) => { - return new pytorch.Parameter(parameter.name, true, - parameter.arguments.map((state) => { - const tensor = new pytorch.Tensor(state.id, pytorch.Utility.toTensor(state.value)); - return new pytorch.Argument(state.id, null, tensor); - })); - }); - const obj = { - name: state_group.name, - type: state_group.type || 'torch.nn.Module', - attributes: attributes, - inputs: inputs, - outputs: [] - }; - this._nodes.push(new pytorch.Node(metadata, '', obj, null)); + } + else { + switch (module.__type__) { + case 'module': { + this._type = (module.__module__ && module.__name__) ? (module.__module__ + '.' + module.__name__) : ''; + this._loadModule(metadata, module, [], []); + break; + } + case 'weights': { + for (const state_group of module) { + const attributes = state_group.attributes || []; + const inputs = state_group.states.map((parameter) => { + return new pytorch.Parameter(parameter.name, true, + parameter.arguments.map((state) => { + const tensor = new pytorch.Tensor(state.id, pytorch.Utility.toTensor(state.value)); + return new pytorch.Argument(state.id, null, tensor); + })); + }); + const obj = { + name: state_group.name, + type: state_group.type || 'torch.nn.Module', + attributes: attributes, + inputs: inputs, + outputs: [] + }; + this._nodes.push(new pytorch.Node(metadata, '', obj, null)); + } + break; + } + default: { + throw new pytorch.Error("Unsupported container type '" + module.__type__ + "'."); } - break; - } - default: { - throw new pytorch.Error("Unsupported container type '" + module.__type__ + "'."); } } } @@ -264,7 +264,9 @@ pytorch.Graph = class { } if (value) { const initializer = new pytorch.Tensor('', value); - inputs.push(new pytorch.Parameter(inputName || key, visible, [ new pytorch.Argument('', null, initializer) ])); + inputs.push(new pytorch.Parameter(inputName || key, visible, [ + new pytorch.Argument('', null, initializer) + ])); } } @@ -1359,7 +1361,9 @@ pytorch.Container.Zip = class extends pytorch.Container { } }; -pytorch.Container.Zip.Script = class { +pytorch.jit = {}; + +pytorch.jit.ScriptModule = class { constructor(entries, execution, location, name) { this.__type__ = 'script'; @@ -1601,7 +1605,7 @@ pytorch.Container.Zip.Json = class extends pytorch.Container.Zip { } }; -pytorch.Container.Zip.Json.Script = class extends pytorch.Container.Zip.Script { +pytorch.Container.Zip.Json.Script = class extends pytorch.jit.ScriptModule { constructor(entries, execution, model) { super(entries, execution); @@ -1657,12 +1661,7 @@ pytorch.Container.Zip.Json.Script = class extends pytorch.Container.Zip.Script { } while (queue.length > 0) { const module = queue.shift(); - if (!module.__class__) { - module.__class__ = { - __module__: 'torch.nn.modules.module', - __name__: 'Module' - }; - } + module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' }; if (module.name) { module.__id__ = module.name; } @@ -1702,7 +1701,6 @@ pytorch.Container.Zip.Json.Script = class extends pytorch.Container.Zip.Script { this._data.forward = module.forward; } } - delete this._model; } get name() { @@ -1748,7 +1746,7 @@ pytorch.Container.Zip.Pickle = class extends pytorch.Container.Zip { } }; -pytorch.Container.Zip.Pickle.Script = class extends pytorch.Container.Zip.Script { +pytorch.Container.Zip.Pickle.Script = class extends pytorch.jit.ScriptModule { constructor(entries, execution, location) { super(entries, execution, location);