Skip to content

Commit

Permalink
Add PyTorch .index.json support (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 30, 2023
1 parent 5979389 commit bce23f3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
76 changes: 73 additions & 3 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pytorch.Graph = class {
return values.get(name);
};
const createNode = (groups, key, obj, args, output) => {
let type = obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : '?';
let type = obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : 'Module';
if (type === 'torch.jit._script.RecursiveScriptModule' && obj._c && obj._c.qualified_name) {
type = obj._c.qualified_name;
}
Expand Down Expand Up @@ -768,6 +768,10 @@ pytorch.Container = class {
if (mobile) {
return mobile;
}
const index = pytorch.Container.Index.open(context);
if (index) {
return index;
}
const executorch = pytorch.Container.ExecuTorch.open(context);
if (executorch) {
return executorch;
Expand Down Expand Up @@ -1144,6 +1148,73 @@ pytorch.Container.Zip = class extends pytorch.Container {
}
};

pytorch.Container.Index = class extends pytorch.Container {

static open(context) {
const obj = context.peek('json');
if (obj && obj.weight_map) {
const entries = Object.entries(obj.weight_map);
if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.bin'))) {
return new pytorch.Container.Index(context, entries);
}
}
return null;
}

constructor(context, entries) {
super();
this._context = context;
this._entries = entries;
this._format = 'PyTorch';
}

async read(metadata) {
const weight_map = new Map(this._entries);
const keys = new Set(weight_map.keys());
const files = Array.from(new Set(weight_map.values()));
const contexts = await Promise.all(files.map((name) => this._context.fetch(name)));
const execution = new pytorch.jit.Execution(null, metadata);
for (const event of this._events) {
execution.on(event[0], event[1]);
}
const torch = execution.__import__('torch');
const readers = contexts.map((context) => {
const entries = context.peek('zip');
return new pytorch.jit.StreamReader(entries);
});
const formats = new Set(readers.map((reader) => {
const version = reader.version();
return pytorch.Utility.format('PyTorch', version);
}));
if (formats.size === 1) {
this._format = formats.values().next().value;
}
const shards = readers.map((reader) => {
const entries = new Map(reader.getAllRecords().map((key) => [ key, reader.getRecord(key) ]));
return torch.load(entries);
});
const entries = new Map();
for (const shard of shards) {
for (const [key, value] of Object.entries(shard)) {
if (keys.has(key)) {
entries.set(key, value);
}
}
}
this._modules = pytorch.Utility.findWeights(entries);
delete this._context;
delete this._entries;
}

get format() {
return this._format;
}

get modules() {
return this._modules;
}
};

pytorch.Execution = class extends python.Execution {

constructor(sources) {
Expand Down Expand Up @@ -2858,8 +2929,7 @@ pytorch.jit.ScriptModuleDeserializer = class {
this._tensor_dir_prefix = tensor_dir_prefix || '';
this._source_importer = new pytorch.jit.SourceImporter(
this._compilation_unit, this._constants_table,
new pytorch.jit.SourceLoader(this._reader, this._code_prefix),
reader.version());
new pytorch.jit.SourceLoader(this._reader, this._code_prefix), reader.version());
}

deserialize() {
Expand Down
2 changes: 1 addition & 1 deletion source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -5166,7 +5166,7 @@ view.ModelFactoryService = class {
this._extensions = new Set([ '.zip', '.tar', '.tar.gz', '.tgz', '.gz' ]);
this._factories = [];
this.register('./server', [ '.netron']);
this.register('./pytorch', [ '.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf', '.jit', '.pte' ], [ '.model' ]);
this.register('./pytorch', [ '.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf', '.jit', '.pte', '.bin.index.json' ], [ '.model' ]);
this.register('./onnx', [ '.onnx', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', 'onnxmodel', 'ngf', 'json' ]);
this.register('./mxnet', [ '.json', '.params' ], [ '.mar']);
this.register('./coreml', [ '.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb', '.pbtxt' ], [ '.mlpackage' ]);
Expand Down

0 comments on commit bce23f3

Please sign in to comment.