Skip to content

Commit

Permalink
Update python.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 16, 2024
1 parent 2545ad7 commit 7ff178f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
34 changes: 17 additions & 17 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -9100,28 +9100,28 @@ python.Execution = class {
}
LEGACY_deserialize() {
const execution = this._compilation_unit.execution;
const caffe2 = execution.proto.caffe2;
const torch = execution.import('torch');
const stream = this._reader.get_record('model.json');
const buffer = stream.peek();
const decoder = new TextDecoder('utf-8');
const content = decoder.decode(buffer);
const model = JSON.parse(content);
const data = model.mainModule || {};
const queue = [data];
const obj = JSON.parse(content);
const model = execution.proto.torch.ModelDef.decodeJson(obj);
const tensorTypeMap = new Map([
['FLOAT', 'Float'],
['FLOAT16', 'Half'],
['DOUBLE', 'Double'],
['INT8', 'Char'],
['INT32', 'Int'],
['INT64', 'Long']
[caffe2.TensorProto.DataType.FLOAT, 'Float'],
[caffe2.TensorProto.DataType.FLOAT16, 'Half'],
[caffe2.TensorProto.DataType.DOUBLE, 'Double'],
[caffe2.TensorProto.DataType.INT8, 'Char'],
[caffe2.TensorProto.DataType.INT32, 'Int'],
[caffe2.TensorProto.DataType.INT64, 'Long']
]);
const tensor_table = (model.tensors || []).map((constant) => {
const key = constant.data.key;
if (!tensorTypeMap.has(constant.dataType)) {
throw new python.Error(`Unsupported tensor data type '${constant.dataType}'.`);
if (!tensorTypeMap.has(constant.data_type)) {
throw new python.Error(`Unsupported tensor data type '${constant.data_type}'.`);
}
const type = tensorTypeMap.get(constant.dataType);
const type = tensorTypeMap.get(constant.data_type);
const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
const strides = constant.strides ? constant.strides.map((dim) => parseInt(dim, 10)) : null;
const storage_type = execution.resolve(`torch.${type}Storage`);
Expand All @@ -9137,7 +9137,7 @@ python.Execution = class {
storage._set_cdata(data);
}
const tensor = torch._utils._rebuild_tensor(storage, 0, shape, strides);
tensor.name = constant.data.key;
tensor.name = key;
return tensor;
});
execution.builtins.CONSTANTS = {};
Expand All @@ -9152,14 +9152,14 @@ python.Execution = class {
const obj = unpickler.load();
attributes.push(...obj);
}

this._LEGACY_moduleStack = ['__torch__'];
// const module_def = model.mainModule;
const module_def = model.main_module;
for (const tensor of tensor_table) {
this._constant_table.push(tensor);
}
// this.LEGACY_convertModule(module_def);

this.LEGACY_convertModule(module_def);
const data = obj.mainModule || {};
const queue = [data];
while (queue.length > 0) {
const module = queue.shift();
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };
Expand Down
4 changes: 4 additions & 0 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,9 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
}

async read(metadata) {
if (this._entries.has('model.json')) {
pytorch.proto = await this._context.require('./pytorch-proto');
}
const keys = [
'attributes.pkl',
'version',
Expand All @@ -1360,6 +1363,7 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
}
}
this.execution = new pytorch.Execution(null, metadata);
this.execution.proto = pytorch.proto;
for (const event of this._events) {
this.execution.on(event[0], event[1]);
}
Expand Down

0 comments on commit 7ff178f

Please sign in to comment.