diff --git a/src/mlnet.js b/src/mlnet.js index 74a50e6dce..8779dd4690 100644 --- a/src/mlnet.js +++ b/src/mlnet.js @@ -272,7 +272,7 @@ mlnet.TensorType = class { constructor(codec) { - mlnet.TensorType._map = mlnet.TensorType.map || new Map([ + mlnet.TensorType._map = mlnet.TensorType._map || new Map([ [ 'Boolean', 'boolean' ], [ 'Single', 'float32' ], [ 'Double', 'float64' ], diff --git a/src/torchscript-metadata.json b/src/torchscript-metadata.json index 184055e9b9..53a137ccce 100644 --- a/src/torchscript-metadata.json +++ b/src/torchscript-metadata.json @@ -689,5 +689,21 @@ { "name": "output" } ] } + }, + { + "name": "flatten", + "schema": { + "category": "Shape", + "attributes": [ + { "name": "start_dim", "type": "int64" }, + { "name": "end_dim", "type": "int64" } + ], + "inputs": [ + { "name": "inputs" } + ], + "outputs": [ + { "name": "output" } + ] + } } ] diff --git a/src/torchscript.js b/src/torchscript.js index e738676617..476949ba25 100644 --- a/src/torchscript.js +++ b/src/torchscript.js @@ -12,11 +12,11 @@ var zip = zip || require('./zip'); torchscript.ModelFactory = class { match(context) { - var identifier = context.identifier; - var extension = identifier.split('.').pop().toLowerCase(); + let identifier = context.identifier; + let extension = identifier.split('.').pop().toLowerCase(); if (extension == 'pt' || extension == 'pth' || extension == 'pkl' || extension == 'h5' || extension == 't7' || extension == 'dms' || extension == 'model' || extension == 'ckpt' || identifier.endsWith('.pth.tar')) { - if (torchscript.ModelFactory._openContainer(context)) { + if (torchscript.ModelFactory._openContainer(context.entries)) { return true; } } @@ -26,22 +26,28 @@ torchscript.ModelFactory = class { open(context, host) { return host.require('./python').then((python) => { return host.require('./pickle').then((pickle) => { - var identifier = context.identifier; + let identifier = context.identifier; try { - var container = torchscript.ModelFactory._openContainer(context); - if (container.attributes) { - container.attributes = new pickle.Unpickler(container.attributes.data).load((name, args) => { - return { type: name, args: args[0] }; - }); - } + let container = torchscript.ModelFactory._openContainer(context.entries); container.identifier = identifier; + container.constants = torchscript.ModelFactory._unpickle(host, identifier, pickle, container.constants, torchscript.ModelFactory._storage(container, 'constants')); + container.constants = (container.constants || []).map((tensor) => new torchscript.Tensor('pickle', tensor)); + container.data = torchscript.ModelFactory._unpickle(host, identifier, pickle, container.data, torchscript.ModelFactory._storage(container, 'data')); + container.attributes = torchscript.ModelFactory._unpickle(host, identifier, pickle, container.attributes, null); + let textDecoder = new TextDecoder('utf-8'); + if (container.version) { + container.version = JSON.parse(textDecoder.decode(container.version)); + } + if (container.model) { + container.model = JSON.parse(textDecoder.decode(container.model)); + } return torchscript.Metadata.open(host).then((metadata) => { try { return new torchscript.Model(metadata, host, python, container); } catch (error) { host.exception(error, false); - var message = error && error.message ? error.message : error.toString(); + let message = error && error.message ? error.message : error.toString(); message = message.endsWith('.') ? message.substring(0, message.length - 1) : message; throw new torchscript.Error(message + " in '" + identifier + "'."); } @@ -49,7 +55,7 @@ torchscript.ModelFactory = class { } catch (error) { host.exception(error, false); - var message = error && error.message ? error.message : error.toString(); + let message = error && error.message ? error.message : error.toString(); message = message.endsWith('.') ? message.substring(0, message.length - 1) : message; return Promise.reject(new torchscript.Error(message + " in '" + identifier + "'.")); } @@ -57,40 +63,170 @@ torchscript.ModelFactory = class { }); } - static _openContainer(context) { - let entries = context.entries; + static _openContainer(entries) { if (entries && entries.length > 0) { - var container = { }; - container.version = entries.find((entry) => entry.name == 'version' || entry.name.endsWith('/version')); - if (container.version) { - container.prefix = container.version.name.substring(0, container.version.name.length - 7); - container.attributes = entries.find((entry) => entry.name == container.prefix + 'attributes.pkl'); - container.model = entries.find((entry) => entry.name == container.prefix + 'model.json'); + let container = { }; + let version = entries.find((entry) => entry.name == 'version' || entry.name.endsWith('/version')); + if (version) { container.entries = entries; - if (container.version && container.model) { + container.prefix = version.name.substring(0, version.name.length - 7); + let find = (name) => { + let entry = container.entries.find((entry) => entry.name == container.prefix + name); + if (entry) { + return entry.data; + } + return null; + } + container.version = version.data; + container.attributes = find('attribtues.pkl'); + container.constants = find('constants.pkl'); + container.data = find('data.pkl'); + container.model = find('model.json'); + if (container.version && (container.model || container.data)) { return container; } } } return null; } + + static _storage(container, dirname) { + let map = new Map(); + let prefix = container.prefix + dirname + '/'; + for (let entry of container.entries) { + if (entry.name.startsWith(prefix)) { + let key = entry.name.substring(prefix.length); + map.set(key, entry.data); + } + } + return map; + } + + static _unpickle(host, identifier, pickle, data, storage_map) { + if (!data) { + return null; + } + let functionTable = {}; + functionTable['collections.OrderedDict'] = function(args) { + let obj = []; + obj.__setitem__ = function(key, value) { + obj.push({ key: key, value: value }); + }; + if (args) { + for (let arg of args) { + obj.__setitem__(arg[0], arg[1]); + } + } + return obj; + }; + functionTable['torch._utils._rebuild_tensor_v2'] = function (storage, storage_offset, size, stride, requires_grad, backward_hooks) { + return { + __type__: storage.__type__.replace('Storage', 'Tensor'), + storage: storage, + storage_offset: storage_offset, + size: size, + stride: stride, + requires_grad:requires_grad, + backward_hooks: backward_hooks + }; + }; + let constructorTable = {}; + constructorTable['torch.ByteStorage'] = function (size) { + this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8'; + }; + constructorTable['torch.CharStorage'] = function (size) { + this.size = size; this.dataTypeSize = 1; this.dataType = 'int8'; + }; + constructorTable['torch.ShortStorage'] = function (size) { + this.size = size; this.dataTypeSize = 2; this.dataType = 'int16'; + }; + constructorTable['torch.IntStorage'] = function (size) { + this.size = size; this.dataTypeSize = 4; this.dataType = 'int32'; + }; + constructorTable['torch.LongStorage'] = function (size) { + this.size = size; this.dataTypeSize = 8; this.dataType = 'int64'; + }; + constructorTable['torch.HalfStorage'] = function (size) { + this.size = size; this.dataTypeSize = 2; this.dataType = 'float16'; + }; + constructorTable['torch.FloatStorage'] = function (size) { + this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; + }; + constructorTable['torch.DoubleStorage'] = function (size) { + this.size = size; this.dataTypeSize = 8; this.dataType = 'float64'; + }; + let function_call = (name, args) => { + let func = functionTable[name]; + if (func) { + return func.apply(null, args); + } + let obj = { __type__: name }; + let constructor = constructorTable[name]; + if (constructor) { + constructor.apply(obj, args); + } + else if (!name.startsWith('__torch__.')) { + host.exception(new torchscript.Error("Unknown function '" + name + "' in '" + identifier + "'."), false); + } + return obj; + }; + let deserialized_objects = new Map(); + let persistent_load = (saved_id) => { + let typename = saved_id.shift(); + if (typename !== 'storage') { + throw new torchscript.Error("Unknown persistent load type '" + typename + "'."); + } + let data_type = saved_id.shift(); + let root_key = saved_id.shift(); + saved_id.shift(); // location + let size = saved_id.shift(); + let storage = null; + if (deserialized_objects.has(root_key)) { + storage = deserialized_objects.get(root_key); + } + else { + storage = function_call(data_type, [ size ]); + storage.data = storage_map.get(root_key); + deserialized_objects[root_key] = storage; + } + let view_metadata = saved_id.shift(); + if (view_metadata) { + let view_key = view_metadata.shift(); + view_metadata.shift(); // view_offset + view_metadata.shift(); // view_size + let view = deserialized_objects[view_key]; + if (!view) { + view = null; // storage.slice(view_offset, view_offset + view_size); + deserialized_objects[view_key] = view; + } + return view; + } + return storage; + }; + return new pickle.Unpickler(data).load(function_call, persistent_load); + } }; torchscript.Model = class { constructor(metadata, host, python, container) { - var textDecoder = new TextDecoder('utf-8'); - var model = JSON.parse(textDecoder.decode(container.model.data)); - var version = JSON.parse(textDecoder.decode(container.version.data)); - this._format = 'TorchScript v' + version.toString(); - if (model.producerName) { - this._producer = model.producerName; - if (model.producerVersion) { - this._producer = this._producer + ' v' + model.producerVersion; + this._format = 'TorchScript v' + container.version.toString(); + if (container.model) { + if (container.producerName) { + this._producer = container.producerName; + if (container.producerVersion) { + this._producer = this._producer + ' v' + container.producerVersion; + } } + container.tensors = container.model.tensors.map((tensor) => { + let key = container.prefix + tensor.data.key; + let entry = container.entries.find((entry) => entry.name == key); + return new torchscript.Tensor('json', { tensor: tensor, data: entry.data }); + }); + container.constants = container.tensors; } this._graphs = []; - this._graphs.push(new torchscript.Graph(metadata, host, python, container, model.mainModule, model.tensors)); + this._graphs.push(new torchscript.Graph(metadata, host, python, container)); } get format() { @@ -108,73 +244,118 @@ torchscript.Model = class { torchscript.Graph = class { - constructor(metadata, host, python, container, mainModule, tensors) { - this._name = mainModule.name; + constructor(metadata, host, python, container) { + if (container.model && container.model.mainModule) { + this._name = container.model.mainModule.name; + } this._inputs = []; this._outputs = []; this._nodes = []; - container.tensors = tensors.map((tensor) => new torchscript.Tensor(tensor, container)); - - var context = null; + let mainModule = null; + let context = null; try { - context = new torchscript.GraphContext(container, python, mainModule); + let script = ''; + let className = null; + if (container.model && container.model.mainModule) { + mainModule = container.model.mainModule; + script = mainModule.torchscriptArena.key; + } + else if (container.data) { + mainModule = container.data; + let typeName = mainModule.__type__.split('.'); + className = typeName.pop(); + script = 'code/' + typeName.join('/') + '.py'; + } + context = new torchscript.GraphContext(container, python, mainModule, script, className); } catch (error) { - var message = error && error.message ? error.message : error.toString(); + let message = error && error.message ? error.message : error.toString(); message = message.endsWith('.') ? message.substring(0, message.length - 1) : message; host.exception(new torchscript.Error(message + " in '" + container.identifier + "'."), false); } container.parameters = {}; - var queue = [ mainModule ]; - while (queue.length > 0) { - var module = queue.shift(); - if (module.parameters) { - for (var parameter of module.parameters) { - if (parameter.tensorId) { - var tensorId = parseInt(parameter.tensorId, 10); - parameter.initializer = container.tensors[tensorId]; - if (parameter.outputs && parameter.outputs.length == 1) { - container.parameters[parameter.outputs[0]] = parameter; + if (container.model && container.model.mainModule) { + let queue = [ container.model.mainModule ]; + while (queue.length > 0) { + let module = queue.shift(); + if (module.parameters) { + for (let parameter of module.parameters) { + if (parameter.tensorId) { + let tensorId = parseInt(parameter.tensorId, 10); + parameter.initializer = container.tensors[tensorId]; + if (parameter.outputs && parameter.outputs.length == 1) { + container.parameters[parameter.outputs[0]] = parameter; + } } } } + if (module.submodules) { + for (let submodule of module.submodules) { + submodule.__parent__ = module; + queue.push(submodule); + } + } } - if (module.submodules) { - for (var submodule of module.submodules) { - submodule.parent = module; - queue.push(submodule); + } + /* + if (container.data) { + let queue = [ container.data ]; + while (queue.length > 0) { + let module = queue.shift(); + if (module.parameters) { + for (let parameter of module.parameters) { + if (parameter.tensorId) { + let tensorId = parseInt(parameter.tensorId, 10); + parameter.initializer = container.tensors[tensorId]; + if (parameter.outputs && parameter.outputs.length == 1) { + container.parameters[parameter.outputs[0]] = parameter; + } + } + } + } + for (let key of Object.keys(module)) { + if (key !== '__type__' && key !== '__parent__') { + let submodule = module[key]; + if (submodule === Object(submodule)) { + submodule.__parent__ = module; + queue.push(submodule); + } + } } } } + */ if (context) { - for (var input of context.inputs) { + for (let input of context.inputs) { this._inputs.push(new torchscript.Parameter(input, true, [ new torchscript.Argument(input, null, null) ])); } - for (var output of context.outputs) { + for (let output of context.outputs) { this._outputs.push(new torchscript.Parameter(output, true, [ new torchscript.Argument(output, null, null) ])); } - for (var node of context.nodes) { + for (let node of context.nodes) { this._nodes.push(new torchscript.Node(metadata, container, null, node)); } } - this._loadModule(metadata, container, mainModule); + if (container.model) { + this._loadModule(metadata, container, mainModule); + } } _loadModule(metadata, container, module) { if (module.parameters && module.parameters.length > 0 && !module.hide) { - var node = new torchscript.Node(metadata, container, module, null); + let node = new torchscript.Node(metadata, container, module, null); this._nodes.push(node); } if (module.submodules) { - for (var submodule of module.submodules) { + for (let submodule of module.submodules) { this._loadModule(metadata, container, submodule); } } @@ -258,14 +439,10 @@ torchscript.Node = class { this._inputs = []; this._outputs = []; - var input = null; - var argument = null; - var parameter = null; - if (module) { this._operator = 'Module'; if (module.parameters) { - for (parameter of module.parameters) { + for (let parameter of module.parameters) { this._inputs.push(new torchscript.Parameter(parameter.name, true, [ new torchscript.Argument('', null, parameter.initializer || null) ])); @@ -282,14 +459,14 @@ torchscript.Node = class { this._operator = node.name; this._name = ''; - var schema = metadata.getSchema(this._operator); + let schema = metadata.getSchema(this._operator); module = null; - var match = true; - var count = 0; - for (input of node.inputs) { - for (argument of input) { - parameter = container.parameters[argument.id]; + let match = true; + let count = 0; + for (let input of node.inputs) { + for (let argument of input) { + let parameter = container.parameters[argument.id]; if (parameter) { if (parameter.module && (module == null || module == parameter.module)) { module = parameter.module; @@ -307,9 +484,9 @@ torchscript.Node = class { } if (module && module.parameters.length == count && match) { module.hide = true; - for (input of node.inputs) { - for (argument of input) { - parameter = container.parameters[argument.id]; + for (let input of node.inputs) { + for (let argument of input) { + let parameter = container.parameters[argument.id]; if (parameter && parameter.initializer) { argument.initializer = parameter.initializer; } @@ -320,8 +497,8 @@ torchscript.Node = class { module = null; } - for (var inputIndex = 0; inputIndex < node.inputs.length; inputIndex++) { - var inputName = inputIndex.toString(); + for (let inputIndex = 0; inputIndex < node.inputs.length; inputIndex++) { + let inputName = inputIndex.toString(); if (schema && schema.inputs && schema.inputs.length > inputIndex) { inputName = schema.inputs[inputIndex].name; } @@ -330,8 +507,8 @@ torchscript.Node = class { )); } - for (var outputIndex = 0; outputIndex < node.outputs.length; outputIndex++) { - var outputName = outputIndex.toString(); + for (let outputIndex = 0; outputIndex < node.outputs.length; outputIndex++) { + let outputName = outputIndex.toString(); if (schema && schema.outputs && schema.outputs.length > outputIndex) { outputName = schema.outputs[outputIndex].name; } @@ -340,10 +517,10 @@ torchscript.Node = class { ])); } - for (var attributeIndex = 0; attributeIndex < node.attributes.length; attributeIndex++) { - var attributeSchema = null; - var attributeName = attributeIndex.toString(); - var attributeValue = node.attributes[attributeIndex]; + for (let attributeIndex = 0; attributeIndex < node.attributes.length; attributeIndex++) { + let attributeSchema = null; + let attributeName = attributeIndex.toString(); + let attributeValue = node.attributes[attributeIndex]; if (attributeValue && attributeValue.type === '=' && attributeValue.target.type == 'identifier') { attributeName = attributeValue.target.value; attributeValue = attributeValue.expression; @@ -363,10 +540,10 @@ torchscript.Node = class { if (module) { if (module.name) { - var current = module; + let current = module; this._name = current.name; - while (current.parent != null) { - current = current.parent; + while (current.__parent__ != null) { + current = current.__parent__; this._name = [ current.name, this._name ].join('.') } } @@ -386,12 +563,12 @@ torchscript.Node = class { } get category() { - var schema = this._metadata.getSchema(this._operator); + let schema = this._metadata.getSchema(this._operator); return (schema && schema.category) ? schema.category : ''; } get documentation() { - var schema = this._metadata.getSchema(this._operator); + let schema = this._metadata.getSchema(this._operator); if (schema) { schema = JSON.parse(JSON.stringify(schema)); schema.name = this._operator; @@ -399,21 +576,21 @@ torchscript.Node = class { schema.description = marked(schema.description); } if (schema.attributes) { - for (var attribute of schema.attributes) { + for (let attribute of schema.attributes) { if (attribute.description) { attribute.description = marked(attribute.description); } } } if (schema.inputs) { - for (var input of schema.inputs) { + for (let input of schema.inputs) { if (input.description) { input.description = marked(input.description); } } } if (schema.outputs) { - for (var output of schema.outputs) { + for (let output of schema.outputs) { if (output.description) { output.description = marked(output.description); } @@ -491,7 +668,7 @@ torchscript.Attribute = class { case 'int64[]': if (this._value.type == 'list' && this._value.value.every((item) => item.type === 'number')) { this._value = this._value.value.map((item) => { - var number = parseInt(item.value, 10); + let number = parseInt(item.value, 10); if (!Number.isNaN(item.value - number)) { return number; } @@ -536,12 +713,27 @@ torchscript.Attribute = class { torchscript.Tensor = class { - constructor(tensor, container) { - this._type = new torchscript.TensorType(tensor.dataType, new torchscript.TensorShape(tensor.dims)); - var key = container.prefix + tensor.data.key; - var entry = container.entries.find((entry) => entry.name == key); - this._name = tensor.data.key; - this._data = entry.data; + constructor(format, data) { + switch (format) { + case 'json': + torchscript.Tensor._dataTypeMap = torchscript.Tensor._dataTypeMap || new Map([ + [ 'FLOAT', 'float32' ], + [ 'DOUBLE', 'float64' ], + [ 'INT32', 'int32' ], + [ 'INT64', 'int64' ] + ]); + if (!torchscript.Tensor._dataTypeMap.has(data.tensor.dataType)) { + throw new torchscript.Error("Unknown tensor data type '" + data.tensor.dataType + "'."); + } + this._type = new torchscript.TensorType(torchscript.Tensor._dataTypeMap.get(data.tensor.dataType), new torchscript.TensorShape(data.tensor.dims)); + this._name = data.tensor.data.key; + this._data = data.data; + break; + case 'pickle': + this._type = new torchscript.TensorType(data.storage.dataType, new torchscript.TensorShape(data.size)); + this._data = data.storage.data; + break; + } this._littleEndian = true; } @@ -562,7 +754,7 @@ torchscript.Tensor = class { } get value() { - var context = this._context(); + let context = this._context(); if (context.state) { return null; } @@ -571,17 +763,17 @@ torchscript.Tensor = class { } toString() { - var context = this._context(); + let context = this._context(); if (context.state) { return ''; } context.limit = 10000; - var value = this._decode(context, 0); + let value = this._decode(context, 0); return torchscript.Tensor._stringify(value, '', ' '); } _context() { - var context = {}; + let context = {}; context.state = null; context.index = 0; context.count = 0; @@ -607,14 +799,14 @@ torchscript.Tensor = class { } _decode(context, dimension) { - var results = []; - var dimensions = context.dimensions; + let results = []; + let dimensions = context.dimensions; if (dimensions.length == 0) { dimensions = [ 1 ]; } - var size = dimensions[dimension]; + let size = dimensions[dimension]; if (dimension == dimensions.length - 1) { - for (var i = 0; i < size; i++) { + for (let i = 0; i < size; i++) { if (context.count > context.limit) { results.push('...'); return results; @@ -665,7 +857,7 @@ torchscript.Tensor = class { } } else { - for (var j = 0; j < size; j++) { + for (let j = 0; j < size; j++) { if (context.count > context.limit) { results.push('...'); return results; @@ -681,9 +873,9 @@ torchscript.Tensor = class { static _stringify(value, indentation, indent) { if (Array.isArray(value)) { - var result = []; + let result = []; result.push(indentation + '['); - var items = value.map((item) => torchscript.Tensor._stringify(item, indentation + indent, indent)); + let items = value.map((item) => torchscript.Tensor._stringify(item, indentation + indent, indent)); if (items.length > 0) { result.push(items.join(',\n')); } @@ -712,13 +904,7 @@ torchscript.Tensor = class { torchscript.TensorType = class { constructor(dataType, shape) { - switch(dataType) { - case 'FLOAT': this._dataType = 'float32'; break; - case 'DOUBLE': this._dataType = 'float64'; break; - case 'INT32': this._dataType = 'int32'; break; - case 'INT64': this._dataType = 'int64'; break; - default: throw new torchscript.Error("Unknown tensor data type '" + dataType + "'."); - } + this._dataType = dataType; this._shape = shape; } @@ -774,9 +960,9 @@ torchscript.Metadata = class { this._map = {}; this._attributeCache = {}; if (data) { - var items = JSON.parse(data); + let items = JSON.parse(data); if (items) { - for (var item of items) { + for (let item of items) { if (item.name && item.schema) { this._map[item.name] = item.schema; } @@ -790,12 +976,12 @@ torchscript.Metadata = class { } getAttributeSchema(operator, name) { - var map = this._attributeCache[operator]; + let map = this._attributeCache[operator]; if (!map) { map = {}; - var schema = this.getSchema(operator); + let schema = this.getSchema(operator); if (schema && schema.attributes && schema.attributes.length > 0) { - for (var attribute of schema.attributes) { + for (let attribute of schema.attributes) { map[attribute.name] = attribute; } } @@ -807,7 +993,7 @@ torchscript.Metadata = class { torchscript.GraphContext = class { - constructor(container, python, mainModule) { + constructor(container, python, mainModule, script, className) { this._container = container; this._mainModule = mainModule; @@ -820,29 +1006,34 @@ torchscript.GraphContext = class { this._argumentMap = {}; this._numToTensorMap = {}; - if (mainModule.torchscriptArena && mainModule.torchscriptArena.key) { - var codeKey = container.prefix + mainModule.torchscriptArena.key; - var codeEntries = container.entries.filter((e) => e.name === codeKey); + if (script) { + let codeKey = container.prefix + script; + let codeEntries = container.entries.filter((e) => e.name === codeKey); if (codeEntries.length == 1) { - var codeEntry = codeEntries[0]; - var textDecoder = new TextDecoder('utf-8'); - var code = textDecoder.decode(codeEntry.data); - var reader = new python.Parser(code); - var program = reader.parse(); - var method = program.body.find((statement) => statement.type == 'def' && statement.name == 'forward'); + let codeEntry = codeEntries[0]; + let textDecoder = new TextDecoder('utf-8'); + let code = textDecoder.decode(codeEntry.data); + let reader = new python.Parser(code); + let program = reader.parse(); + let statements = program.body; + if (className) { + let block = statements.find((statment) => statment.type == 'class' && statment.name == className); + statements = block.body.statements; + } + let method = statements.find((statement) => statement.type == 'def' && statement.name == 'forward'); if (method) { this._body = method.body.statements; - var methodParameters = method.parameters; + let methodParameters = method.parameters; if (methodParameters.length > 0 && methodParameters[0].name == 'self') { methodParameters.shift(); } - for (var parameter of methodParameters) { + for (let parameter of methodParameters) { this._parameter(parameter); } if (this._body.length >= 2) { - var returnStatement = this._body[this._body.length - 1]; - var assignStatement = this._body[this._body.length - 2]; + let returnStatement = this._body[this._body.length - 1]; + let assignStatement = this._body[this._body.length - 2]; if (returnStatement.type == 'return' && returnStatement.expression.type == 'identifier' && assignStatement.target.type == 'identifier' && @@ -855,7 +1046,7 @@ torchscript.GraphContext = class { } while (this._body.length > 0) { - var statement = this._body.shift(); + let statement = this._body.shift(); if (this._attributeStatement(statement)) { continue; } @@ -891,13 +1082,13 @@ torchscript.GraphContext = class { } _parameter(parameter) { - var type = parameter.parameterType; + let type = parameter.parameterType; if (type.type == 'type' && type.value == 'Tuple' && type.arguments && type.arguments.length > 0) { if (this._body.length > 0) { - var statement = this._body[0]; + let statement = this._body[0]; if (statement.expression.type == 'identifier' && statement.expression.value == parameter.name) { if (statement.type === '=' && statement.target.type === 'tuple') { - for (var input of statement.target.value) { + for (let input of statement.target.value) { if (input) { this._inputs.push(input.value); } @@ -914,7 +1105,7 @@ torchscript.GraphContext = class { _returnStatement(statement) { if (statement.type == 'return') { - var variable = this._variable(); + let variable = this._variable(); if (this._nodeExpression(statement.expression, variable)) { this._outputs.push(variable.value); return true; @@ -924,8 +1115,8 @@ torchscript.GraphContext = class { return true; } if (statement.expression.type == 'tuple') { - var outputs = []; - for (var expression of statement.expression.value) { + let outputs = []; + for (let expression of statement.expression.value) { variable = this._variable(); if (this._nodeExpression(expression, variable)) { outputs.push(variable.value); @@ -946,17 +1137,17 @@ torchscript.GraphContext = class { _nodeExpression(expression, target) { if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'tuple')) { - var name = this._name(expression.target); - var namespace = 'torch.'; + let name = this._name(expression.target); + let namespace = 'torch.'; if (name.startsWith(namespace)) { - var node = {}; + let node = {}; node.name = name.substring(namespace.length); node.inputs = []; node.outputs = []; node.attributes = []; - var args = expression.arguments; + let args = expression.arguments; while (args.length > 0) { - var argument = args[0]; + let argument = args[0]; argument = this._moduleTensor(argument); if (argument.type == 'identifier' && this._argumentMap[argument.value]) { argument = this._argumentMap[argument.value]; @@ -971,9 +1162,9 @@ torchscript.GraphContext = class { continue; } if (argument.type == 'list') { - var list = []; - for (var input of argument.value) { - var variable = this._variable(); + let list = []; + for (let input of argument.value) { + let variable = this._variable(); if (this._nodeExpression(input, variable)) { list.push({ id: variable.value }); } @@ -1003,7 +1194,7 @@ torchscript.GraphContext = class { if (argument.type == '=') { break; } - variable = this._variable(); + let variable = this._variable(); if (this._nodeExpression(argument, variable)) { node.inputs.push([ { id: variable.value } ]); args.shift(); @@ -1019,9 +1210,9 @@ torchscript.GraphContext = class { argument.target.value == 'CONSTANTS' && argument.member.type == 'identifier' && argument.member.value.startsWith('c')) { - var constantId = [ argument.target.value, argument.member.value ].join('.'); - var constantIndex = parseInt(argument.member.value.substring(1), 10); - var constantTensor = this._container.tensors[constantIndex]; + let constantId = [ argument.target.value, argument.member.value ].join('.'); + let constantIndex = parseInt(argument.member.value.substring(1), 10); + let constantTensor = this._container.constants[constantIndex]; node.inputs.push([ { id: constantId, initializer: constantTensor } ]); args.shift(); continue; @@ -1030,11 +1221,11 @@ torchscript.GraphContext = class { } while (args.length > 0) { if (args[0].type == 'list') { - for (var i = 0; i < args[0].value.length; i++) { + for (let i = 0; i < args[0].value.length; i++) { args[0].value[i] = this._attributeExpression(args[0].value[i]); } } - var intExpression = this._attributeExpression(args[0]); + let intExpression = this._attributeExpression(args[0]); if (intExpression) { args[0] = intExpression; } @@ -1045,7 +1236,7 @@ torchscript.GraphContext = class { node.outputs.push(target.value); } if (target.type == 'tuple') { - for (var identifier of target.value) { + for (let identifier of target.value) { node.outputs.push(identifier.value); } } @@ -1076,7 +1267,7 @@ torchscript.GraphContext = class { expression.target.value == 'int' && expression.arguments.length == 1) { - var replace = this._attributeExpression(expression.arguments[0]); + let replace = this._attributeExpression(expression.arguments[0]); if (replace) { return replace; } @@ -1090,7 +1281,7 @@ torchscript.GraphContext = class { if (statement.expression.type == 'call' && this._name(statement.expression.target) == 'ops.prim.NumToTensor' && statement.expression.arguments.length == 1) { - var size = statement.expression.arguments[0]; + let size = statement.expression.arguments[0]; if (size.type == 'call' && size.arguments.length == 2 && this._name(size.target) == 'torch.size' && @@ -1100,7 +1291,7 @@ torchscript.GraphContext = class { return true; } if (size.type == 'identifier') { - var duplicate1 = this._numToTensorMap[size.value]; + let duplicate1 = this._numToTensorMap[size.value]; if (duplicate1) { this._numToTensorMap[statement.target.value] = duplicate1; return true; @@ -1120,7 +1311,7 @@ torchscript.GraphContext = class { statement.expression.target.value == 'int' && statement.expression.arguments.length == 1 && statement.expression.arguments[0].type == 'identifier') { - var duplicate2 = this._numToTensorMap[statement.expression.arguments[0].value]; + let duplicate2 = this._numToTensorMap[statement.expression.arguments[0].value]; if (duplicate2) { this._numToTensorMap[statement.target.value] = duplicate2; return true; @@ -1131,8 +1322,8 @@ torchscript.GraphContext = class { } _module(expression) { - var module; - var submodule; + let module; + let submodule; if (expression.type === '.') { module = this._module(expression.target); if (module && module.submodules) { @@ -1142,6 +1333,9 @@ torchscript.GraphContext = class { } } } + if (module[expression.member.value]) { + return module[expression.member.value]; + } } if (expression.type == 'call' && expression.target.type == 'identifier' && expression.target.value == 'getattr' && expression.arguments.length == 2) { @@ -1149,11 +1343,14 @@ torchscript.GraphContext = class { if (!module) { return null; } - var name = null; + let name = null; if (expression.arguments[1].type == 'string') { name = expression.arguments[1].value.substring(1, expression.arguments[1].value.length - 1); } if (module) { + if (module[name]) { + return module[name]; + } for (submodule of module.submodules) { if (submodule.name === name) { return submodule; @@ -1176,8 +1373,8 @@ torchscript.GraphContext = class { _moduleStatement(statement) { if (statement.type == '=' && statement.target.type === 'identifier') { - var moduleName = statement.target.value; - var module = this._module(statement.expression); + let moduleName = statement.target.value; + let module = this._module(statement.expression); if (module) { this._moduleMap[moduleName] = module; return true; @@ -1189,9 +1386,9 @@ torchscript.GraphContext = class { _argumentExpression(expression, target) { expression = this._moduleTensor(expression); if (expression.type === '.' && expression.member.type == 'identifier') { - var targetModule = this._module(expression.target); + let targetModule = this._module(expression.target); if (targetModule && targetModule.parameters) { - for (var parameter of targetModule.parameters) { + for (let parameter of targetModule.parameters) { parameter.module = targetModule; if (parameter.name === expression.member.value) { parameter.outputs = parameter.outputs || []; @@ -1200,7 +1397,7 @@ torchscript.GraphContext = class { } } targetModule.unresolvedParameters = targetModule.unresolvedParameters || []; - for (var unresolvedParameter of targetModule.unresolvedParameters) { + for (let unresolvedParameter of targetModule.unresolvedParameters) { unresolvedParameter.module = targetModule; if (unresolvedParameter.name === expression.member.value) { unresolvedParameter.outputs = unresolvedParameter.outputs || []; diff --git a/test/models.json b/test/models.json index 34ef41122e..300e17757e 100644 --- a/test/models.json +++ b/test/models.json @@ -4980,7 +4980,6 @@ "target": "alexnet.pt", "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", - "error": "Unsupported file content for extension '.pt' in 'alexnet.pt'.", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { @@ -5003,7 +5002,6 @@ "link": "https://pytorch.org/docs/stable/torchvision/models.html", "render": "skip", "script": [ "${root}/tools/pytorch", "sync install zoo" ], - "error": "Unsupported file content for extension '.pt' in 'densenet121.pt'.", "status": "script" }, { @@ -5018,7 +5016,6 @@ "target": "inception_v3.pt", "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", - "error": "Unsupported file content for extension '.pt' in 'inception_v3.pt'.", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { @@ -5040,7 +5037,6 @@ "target": "mobilenet_v2.pt", "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", - "error": "Unsupported file content for extension '.pt' in 'mobilenet_v2.pt'.", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { @@ -5084,7 +5080,6 @@ "target": "resnet18.pt", "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", - "error": "Unsupported file content for extension '.pt' in 'resnet18.pt'.", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { @@ -5092,7 +5087,6 @@ "target": "resnet50.pt", "script": [ "${root}/tools/pytorch", "sync install zoo" ], "format": "TorchScript v1", - "error": "Unsupported file content for extension '.pt' in 'resnet50.pt'.", "link": "https://pytorch.org/docs/stable/torchvision/models.html" }, { @@ -5100,7 +5094,6 @@ "target": "squeezenet1_1.pt", "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": [ "${root}/tools/pytorch", "sync install zoo" ], - "error": "Unsupported file content for extension '.pt' in 'squeezenet1_1.pt'.", "status": "script" }, { @@ -5129,7 +5122,6 @@ "target": "vgg16.pt", "link": "https://pytorch.org/docs/stable/torchvision/models.html", "script": [ "${root}/tools/pytorch", "sync install zoo" ], - "error": "Unsupported file content for extension '.pt' in 'vgg16.pt'.", "status": "script" } ] \ No newline at end of file