diff --git a/source/pytorch.js b/source/pytorch.js index f8fd7b5fba..9ba770ea16 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -18,7 +18,19 @@ pytorch.ModelFactory = class { } filter(context, type) { - return (context.type !== 'pytorch.export' && context.type !== 'pytorch.index') || type !== 'pytorch.zip'; + if (context.type === 'pytorch.export' && type === 'pytorch.zip') { + return false; + } + if (context.type === 'pytorch.index' && type === 'pytorch.zip') { + return false; + } + if (context.type === 'pytorch.model.json' && type === 'pytorch.data.pkl') { + return false; + } + if (context.type === 'pytorch.model.json' && type === 'pickle') { + return false; + } + return true; } async open(context) { @@ -28,6 +40,9 @@ pytorch.ModelFactory = class { context.error(new pytorch.Error(`Unknown type name '${name}'.`), false); }); await target.read(metadata); + if (!target.format || !target.modules) { + throw new pytorch.Error("Container not implemented."); + } return new pytorch.Model(metadata, target); } }; @@ -719,6 +734,7 @@ pytorch.Container = class { pytorch.Container.data_pkl, pytorch.Container.torch_utils, pytorch.Container.Mobile, + pytorch.Container.ModelJson, pytorch.Container.Index, pytorch.Container.ExportedProgram, pytorch.Container.ExecuTorch, @@ -742,14 +758,6 @@ pytorch.Container = class { on(event, callback) { this._events.push([event, callback]); } - - get format() { - throw new pytorch.Error('Container format not implemented.'); - } - - get modules() { - throw new pytorch.Error('Container modules not implemented.'); - } }; pytorch.Container.Tar = class extends pytorch.Container { @@ -769,6 +777,7 @@ pytorch.Container.Tar = class extends pytorch.Container { } async read() { + this.format = 'PyTorch v0.1.1'; const execution = new pytorch.Execution(); for (const event of this._events) { execution.on(event[0], event[1]); @@ -776,19 +785,11 @@ pytorch.Container.Tar = class extends pytorch.Container { const torch = execution.__import__('torch'); const obj = torch.load(this.entries); delete this.entries; - this._modules = pytorch.Utility.findWeights(obj); - if (!this._modules) { + this.modules = pytorch.Utility.findWeights(obj); + if (!this.modules) { throw new pytorch.Error('File does not contain root module or state dictionary.'); } } - - get format() { - return 'PyTorch v0.1.1'; - } - - get modules() { - return this._modules; - } }; pytorch.Container.Pickle = class extends pytorch.Container { @@ -809,6 +810,7 @@ pytorch.Container.Pickle = class extends pytorch.Container { } async read() { + this.format = 'PyTorch v0.1.10'; const data = this.stream.length < 0x7ffff000 ? this.stream.peek() : this.stream; delete this.stream; const execution = new pytorch.Execution(); @@ -817,15 +819,7 @@ pytorch.Container.Pickle = class extends pytorch.Container { } const torch = execution.__import__('torch'); const obj = torch.load(data); - this._modules = pytorch.Utility.find(obj); - } - - get format() { - return 'PyTorch v0.1.10'; - } - - get modules() { - return this._modules; + this.modules = pytorch.Utility.find(obj); } }; @@ -843,7 +837,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container { if (pytorch.Utility.isTensor(obj)) { return new pytorch.Container.data_pkl('tensor', obj); } - if (Array.isArray(obj) && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) { + if (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) { return new pytorch.Container.data_pkl('tensor', obj); } if (obj instanceof Map) { @@ -869,44 +863,40 @@ pytorch.Container.data_pkl = class extends pytorch.Container { constructor(type, data) { super(); - this.type = 'pytorch.data_pkl'; + this.type = 'pytorch.data.pkl'; this._type = type; this._data = data; } - get format() { - return 'PyTorch Pickle'; - } - - get modules() { + async read() { + this.format = 'PyTorch Pickle'; switch (this._type) { case 'module': { if (this._data) { - this._modules = pytorch.Utility.findModule(this._data); + this.modules = pytorch.Utility.findModule(this._data); delete this._data; } - if (!this._modules) { + if (!this.modules) { throw new pytorch.Error('File does not contain root module or state dictionary.'); } - return this._modules; + return this.modules; } case 'tensor': case 'tensor[]': case 'tensor<>': { if (this._data) { - this._modules = pytorch.Utility.findWeights(this._data); + this.modules = pytorch.Utility.findWeights(this._data); delete this._data; } - if (!this._modules) { + if (!this.modules) { throw new pytorch.Error('File does not contain root module or state dictionary.'); } - return this._modules; + return this.modules; } default: { throw new pytorch.Error("PyTorch standalone 'data.pkl' not supported."); } } - } }; @@ -936,17 +926,10 @@ pytorch.Container.torch_utils = class extends pytorch.Container { } async read() { - this._modules = pytorch.Utility.find(this.obj); + this.format = 'PyTorch torch_utils'; + this.modules = pytorch.Utility.find(this.obj); delete this.obj; } - - get format() { - return 'PyTorch torch_utils'; - } - - get modules() { - return this._modules; - } }; pytorch.Container.Mobile = class extends pytorch.Container { @@ -968,7 +951,7 @@ pytorch.Container.Mobile = class extends pytorch.Container { async read(metadata) { pytorch.mobile = await this.context.require('./pytorch-schema'); pytorch.mobile = pytorch.mobile.torch.jit.mobile; - this._modules = new Map(); + this.modules = new Map(); const execution = new pytorch.jit.Execution(null, metadata); for (const event in this._events) { execution.on(event[0], event[1]); @@ -977,22 +960,14 @@ pytorch.Container.Mobile = class extends pytorch.Container { const torch = execution.__import__('torch'); const module = torch.jit.jit_module_from_flatbuffer(stream); const version = module._c._bytecode_version.toString(); - this._format = pytorch.Utility.format('PyTorch Mobile', version); + this.format = pytorch.Utility.format('PyTorch Mobile', version); if (module && module.forward) { - this._modules = new Map([['', module]]); + this.modules = new Map([['', module]]); } else { - this._modules = pytorch.Utility.find(module); + this.modules = pytorch.Utility.find(module); } delete this.context; } - - get format() { - return this._format; - } - - get modules() { - return this._modules; - } }; pytorch.Container.ExecuTorch = class extends pytorch.Container { @@ -1034,18 +1009,7 @@ pytorch.Container.Zip = class extends pytorch.Container { } const records = new Map(Array.from(entries).map(([name, value]) => [name.substring(prefix), value])); if (records.has('model.json')) { - try { - const stream = records.get('model.json'); - const buffer = stream.peek(); - const decoder = new TextDecoder('utf-8'); - const content = decoder.decode(buffer); - const model = JSON.parse(content); - if (model.mainModule) { - return new pytorch.Container.Zip(entries, model); - } - } catch { - // continue regardless of error - } + return null; } if (records.has('data.pkl')) { return new pytorch.Container.Zip(entries); @@ -1057,12 +1021,11 @@ pytorch.Container.Zip = class extends pytorch.Container { return null; } - constructor(entries, model) { + constructor(entries) { super(); this.type = 'pytorch.zip'; // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md this._entries = entries; - this._model = model; } async read(metadata) { @@ -1072,43 +1035,91 @@ pytorch.Container.Zip = class extends pytorch.Container { } const torch = execution.__import__('torch'); const reader = new torch.PyTorchFileReader(this._entries); - const torchscript = this._model ? true : reader.has_record('constants.pkl'); - if (this._model) { - this._producer = this._model && this._model.producerName ? this._model.producerName + (this._model.producerVersion ? ` v${this._model.producerVersion}` : '') : ''; - this._format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0'; - } else { - const name = torchscript ? 'TorchScript' : 'PyTorch'; - const version = reader.version(); - this._format = pytorch.Utility.format(name, version); - } + const torchscript = reader.has_record('constants.pkl'); + const name = torchscript ? 'TorchScript' : 'PyTorch'; + const version = reader.version(); + this.format = pytorch.Utility.format(name, version); if (torchscript) { const module = torch.jit.load(reader); execution.trace = true; if (module.data && module.data.forward) { - this._modules = new Map([['', module]]); + this.modules = new Map([['', module]]); } else { - this._modules = pytorch.Utility.find(module.data); + this.modules = pytorch.Utility.find(module.data); } } else { const records = reader.get_all_records().map((key) => [key, reader.get_record(key)]); const entries = new Map(records); const module = torch.load(entries); - this._modules = pytorch.Utility.find(module); + this.modules = pytorch.Utility.find(module); } delete this._model; delete this._entries; } - get format() { - return this._format; + get producer() { + return this._producer || ''; } +}; - get modules() { - return this._modules; +pytorch.Container.ModelJson = class extends pytorch.Container { + + static open(context) { + const identifier = context.identifier; + if (identifier === 'model.json') { + const model = context.peek('json'); + if (model && model.mainModule) { + const entries = new Map(); + entries.set('model.json', context.stream); + return new pytorch.Container.ModelJson(context, entries, model); + } + } + return null; } - get producer() { - return this._producer || ''; + constructor(context, entries, model) { + super(); + this.type = 'pytorch.model.json'; + this._context = context; + this._entries = entries; + this._model = model; + } + + async read(metadata) { + const keys = [ + 'attributes.pkl', + 'version', + ...this._model.tensors.filter((tensor) => tensor && tensor.data && tensor.data.key).map((tensor) => tensor.data.key) + ]; + if (this._model.mainModule.torchscriptArena && this._model.mainModule.torchscriptArena.key) { + keys.push(this._model.mainModule.torchscriptArena.key); + } + const values = await Promise.all(keys.map((name) => this._context.fetch(name).then((context) => context.stream).catch(() => null))); + for (let i = 0; i < keys.length; i++) { + if (values[i]) { + this._entries.set(keys[i], values[i]); + } + } + const execution = new pytorch.jit.Execution(null, metadata); + for (const event of this._events) { + execution.on(event[0], event[1]); + } + const torch = execution.__import__('torch'); + const reader = new torch.PyTorchFileReader(this._entries); + if (this._model && this._model.producerName) { + this.producer = this._model.producerName + (this._model.producerVersion ? ` v${this._model.producerVersion}` : ''); + } + this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0'; + const module = torch.jit.load(reader); + execution.trace = true; + if (module.data && module.data.forward) { + this.modules = new Map([['', module]]); + } else { + this.modules = pytorch.Utility.find(module.data); + } + delete this._context; + delete this._model; + delete this._entries; } }; @@ -1130,10 +1141,10 @@ pytorch.Container.Index = class extends pytorch.Container { this.type = 'pytorch.index'; this.context = context; this._entries = entries; - this._format = 'PyTorch'; } async read(metadata) { + this.format = 'PyTorch'; const weight_map = new Map(this._entries); const keys = new Set(weight_map.keys()); const files = Array.from(new Set(weight_map.values())); @@ -1152,7 +1163,7 @@ pytorch.Container.Index = class extends pytorch.Container { return pytorch.Utility.format('PyTorch', version); })); if (formats.size === 1) { - this._format = formats.values().next().value; + this.format = formats.values().next().value; } const shards = archives.map((entries) => { return torch.load(entries); @@ -1165,18 +1176,10 @@ pytorch.Container.Index = class extends pytorch.Container { } } } - this._modules = pytorch.Utility.findWeights(entries); + this.modules = pytorch.Utility.findWeights(entries); delete this.context; delete this._entries; } - - get format() { - return this._format; - } - - get modules() { - return this._modules; - } }; pytorch.Container.ExportedProgram = class extends pytorch.Container { @@ -1197,7 +1200,7 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container { } async read() { - this._format = 'PyTorch Export'; + this.format = 'PyTorch Export'; const serialized_state_dict = await this._fetch('serialized_state_dict.pt') || await this._fetch('serialized_state_dict.json'); const serialized_constants = await this._fetch('serialized_constants.pt') || await this._fetch('serialized_constants.json'); const f = new Map(); @@ -1229,14 +1232,6 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container { throw new pytorch.Error(`'torch.export' not supported.`); } - get format() { - return this._format; - } - - get modules() { - return this._modules; - } - async _fetch(name) { try { const context = await this._context.fetch(name); @@ -1692,8 +1687,8 @@ pytorch.Execution = class extends python.Execution { __setattr__(name, value) { if (this._initializing) { super.__setattr__(name, value); - } else if (this._modules.has(name)) { - this._modules.set(name, value); + } else if (this.modules.has(name)) { + this.modules.set(name, value); } else if (this._c.hasattr(name)) { this._c.setattr(name, value); } else { @@ -1704,8 +1699,8 @@ pytorch.Execution = class extends python.Execution { if (this._initializing) { return super.__getattr__(name); } - if (this._modules.has(name)) { - return this._modules.get(name); + if (this.modules.has(name)) { + return this.modules.get(name); } if (this._c.hasattr(name)) { return this._c.getattr(name); @@ -3039,10 +3034,12 @@ pytorch.jit.ScriptModuleDeserializer = class { const storage = new storage_type(size); const itemsize = storage.dtype.itemsize(); const stream = this._reader.get_record(key); - const buffer = stream.peek(); - const length = size * itemsize; - const data = buffer.slice(offset, offset + length); - storage._set_cdata(data); + if (stream) { + const buffer = stream.peek(); + const length = size * itemsize; + const data = buffer.slice(offset, offset + length); + storage._set_cdata(data); + } const tensor = execution.invoke('torch._utils._rebuild_tensor', [storage, 0, shape, strides]); tensor.name = constant.data.key; return tensor; @@ -3094,6 +3091,9 @@ pytorch.jit.ScriptModuleDeserializer = class { } const arena = data.torchscriptArena; if (arena && arena.key && arena.key.startsWith('code/')) { + if (!this._reader.has_record(arena.key)) { + throw new pytorch.Error(`File '${arena.key}' not found.`); + } const file = arena.key.substring('code/'.length); const name = file.replace(/\.py$/, '').split('/').join('.'); const module = execution.import(name); @@ -3115,6 +3115,9 @@ pytorch.jit.ScriptModuleDeserializer = class { readArchiveAndTensors(archive_name, pickle_prefix, tensor_prefix, type_resolver, obj_loader, device, stream_reader, type_parser, storage_context) { const picklename = `${pickle_prefix + archive_name}.pkl`; const stream = stream_reader.get_record(picklename); + if (!stream) { + throw new pytorch.Error(`File '${picklename}' is not found.`); + } const buffer = stream.peek(); const tensor_dir_path = tensor_prefix ? tensor_prefix : `${archive_name}/`; const read_record = (name) => { @@ -3292,8 +3295,8 @@ pytorch.Container.Package = class extends pytorch.Container { const torch = execution.__import__('torch'); const reader = new torch.PyTorchFileReader(this.entries); const version = reader.version(); - this._format = pytorch.Utility.format('PyTorch Package', version); - this._modules = new Map(); + this.format = pytorch.Utility.format('PyTorch Package', version); + this.modules = new Map(); const records = reader.get_all_records().filter((name) => { if (!name.startsWith('.data/') && !name.endsWith('.py')) { const stream = reader.get_record(name); @@ -3324,19 +3327,11 @@ pytorch.Container.Package = class extends pytorch.Container { for (const entry of entries) { const module = importer.load_pickle(entry[0], entry[1]); const key = `${entry[0].replace(/\./, '/')}/${entry[1]}`; - this._modules.set(key, module); + this.modules.set(key, module); } } delete this.entries; } - - get format() { - return this._format; - } - - get modules() { - return this._modules; - } }; pytorch.MemoryFormat = { diff --git a/source/view.js b/source/view.js index fc6a3608ad..b07540f11e 100644 --- a/source/view.js +++ b/source/view.js @@ -5459,7 +5459,7 @@ view.ModelFactoryService = class { this._patterns = new Set(['.zip', '.tar', '.tar.gz', '.tgz', '.gz']); this._factories = []; this.register('./server', ['.netron']); - this.register('./pytorch', ['.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf', '.jit', '.pte', '.bin.index.json', 'serialized_exported_program.json'], ['.model', '.pt2']); + this.register('./pytorch', ['.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf', '.jit', '.pte', '.bin.index.json', 'serialized_exported_program.json', 'model.json'], ['.model', '.pt2']); this.register('./onnx', ['.onnx', '.onnx.data', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', '.ngf', '.json', '.bin', 'onnxmodel']); this.register('./tflite', ['.tflite', '.lite', '.tfl', '.bin', '.pb', '.tmfile', '.h5', '.model', '.json', '.txt', '.dat', '.nb', '.ckpt']); this.register('./mxnet', ['.json', '.params'], ['.mar']); diff --git a/test/models.json b/test/models.json index 1b59774eac..1cd61625d4 100644 --- a/test/models.json +++ b/test/models.json @@ -5460,30 +5460,16 @@ { "type": "pytorch", "target": "pedestrian_interaction_position_embedding.pt", - "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/pedestrian_interaction_position_embedding.pt", + "source": "https://github.com/user-attachments/files/16120032/pedestrian_interaction_position_embedding.pt.zip[pedestrian_interaction_position_embedding.pt]", "format": "TorchScript v1.0", - "link": "https://github.com/ApolloAuto/apollo" - }, - { - "type": "pytorch", - "target": "pedestrian_interaction_prediction_layer.pt", - "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/pedestrian_interaction_prediction_layer.pt", - "format": "TorchScript v1.0", - "link": "https://github.com/ApolloAuto/apollo" + "link": "https://github.com/lutzroeder/netron/issues/842" }, { "type": "pytorch", "target": "pedestrian_interaction_single_lstm.pt", - "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/pedestrian_interaction_single_lstm.pt", + "source": "https://github.com/user-attachments/files/16120033/pedestrian_interaction_single_lstm.pt.zip[pedestrian_interaction_single_lstm.pt]", "format": "TorchScript v1.0", - "link": "https://github.com/ApolloAuto/apollo" - }, - { - "type": "pytorch", - "target": "pedestrian_interaction_social_embedding.pt", - "source": "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/pedestrian_interaction_social_embedding.pt", - "format": "TorchScript v1.1", - "link": "https://github.com/ApolloAuto/apollo" + "link": "https://github.com/lutzroeder/netron/issues/842" }, { "type": "pytorch",