diff --git a/source/pytorch.js b/source/pytorch.js index 6a89bee168..b15a5ce5cc 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -57,7 +57,7 @@ pytorch.Graph = class { return values.get(name); }; const createNode = (groups, key, obj, args, output) => { - let type = obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : '?'; + let type = obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : 'Module'; if (type === 'torch.jit._script.RecursiveScriptModule' && obj._c && obj._c.qualified_name) { type = obj._c.qualified_name; } @@ -768,6 +768,10 @@ pytorch.Container = class { if (mobile) { return mobile; } + const index = pytorch.Container.Index.open(context); + if (index) { + return index; + } const executorch = pytorch.Container.ExecuTorch.open(context); if (executorch) { return executorch; @@ -1144,6 +1148,73 @@ pytorch.Container.Zip = class extends pytorch.Container { } }; +pytorch.Container.Index = class extends pytorch.Container { + + static open(context) { + const obj = context.peek('json'); + if (obj && obj.weight_map) { + const entries = Object.entries(obj.weight_map); + if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.bin'))) { + return new pytorch.Container.Index(context, entries); + } + } + return null; + } + + constructor(context, entries) { + super(); + this._context = context; + this._entries = entries; + this._format = 'PyTorch'; + } + + async read(metadata) { + const weight_map = new Map(this._entries); + const keys = new Set(weight_map.keys()); + const files = Array.from(new Set(weight_map.values())); + const contexts = await Promise.all(files.map((name) => this._context.fetch(name))); + const execution = new pytorch.jit.Execution(null, metadata); + for (const event of this._events) { + execution.on(event[0], event[1]); + } + const torch = execution.__import__('torch'); + const readers = contexts.map((context) => { + const entries = context.peek('zip'); + return new pytorch.jit.StreamReader(entries); + }); + const formats = new Set(readers.map((reader) => { + const version = reader.version(); + return pytorch.Utility.format('PyTorch', version); + })); + if (formats.size === 1) { + this._format = formats.values().next().value; + } + const shards = readers.map((reader) => { + const entries = new Map(reader.getAllRecords().map((key) => [ key, reader.getRecord(key) ])); + return torch.load(entries); + }); + const entries = new Map(); + for (const shard of shards) { + for (const [key, value] of Object.entries(shard)) { + if (keys.has(key)) { + entries.set(key, value); + } + } + } + this._modules = pytorch.Utility.findWeights(entries); + delete this._context; + delete this._entries; + } + + get format() { + return this._format; + } + + get modules() { + return this._modules; + } +}; + pytorch.Execution = class extends python.Execution { constructor(sources) { @@ -2858,8 +2929,7 @@ pytorch.jit.ScriptModuleDeserializer = class { this._tensor_dir_prefix = tensor_dir_prefix || ''; this._source_importer = new pytorch.jit.SourceImporter( this._compilation_unit, this._constants_table, - new pytorch.jit.SourceLoader(this._reader, this._code_prefix), - reader.version()); + new pytorch.jit.SourceLoader(this._reader, this._code_prefix), reader.version()); } deserialize() { diff --git a/source/view.js b/source/view.js index 3bb8b8ea97..2d7c277181 100644 --- a/source/view.js +++ b/source/view.js @@ -5166,7 +5166,7 @@ view.ModelFactoryService = class { this._extensions = new Set([ '.zip', '.tar', '.tar.gz', '.tgz', '.gz' ]); this._factories = []; this.register('./server', [ '.netron']); - this.register('./pytorch', [ '.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf', '.jit', '.pte' ], [ '.model' ]); + this.register('./pytorch', [ '.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf', '.jit', '.pte', '.bin.index.json' ], [ '.model' ]); this.register('./onnx', [ '.onnx', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', 'onnxmodel', 'ngf', 'json' ]); this.register('./mxnet', [ '.json', '.params' ], [ '.mar']); this.register('./coreml', [ '.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb', '.pbtxt' ], [ '.mlpackage' ]);