Skip to content

Commit

Permalink
Update python.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 21, 2024
1 parent 91433ba commit 0c7cb76
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 240 deletions.
46 changes: 30 additions & 16 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -8809,8 +8809,8 @@ python.Execution = class {
});
this.registerType('torch.Value', class {
constructor(node) {
this._unique = node && node._next_unique ? node._next_unique++ : node._graph._next_unique++; // remove always node
this._node = node && node._next_unique ? null : node;
this._unique = node._graph._next_unique++;
this._node = node;
this._uses = [];
}
unique() {
Expand Down Expand Up @@ -9162,13 +9162,13 @@ python.Execution = class {
'init2(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers, int compilation_preference, bool relax_f32_to_f16) -> NoneType',
'run(__torch__.torch.classes._nnapi.Compilation self, Tensor[] inputs, Tensor[] outputs) -> NoneType'
] },
{ name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase' },
{ name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase' },
{ name: '__torch__.torch.classes.quantized.LinearPackedParamsBase' },
{ name: '__torch__.torch.classes.rnn.CellParamsBase' },
{ name: '__torch__.torch.classes.xnnpack.Conv2dOpContext' },
{ name: '__torch__.torch.classes.xnnpack.LinearOpContext' },
{ name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext' },
{ name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups' },
{ name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups' },
{ name: '__torch__.torch.classes.quantized.LinearPackedParamsBase', attributes: 'Tensor weight, Tensor? bias' },
{ name: '__torch__.torch.classes.rnn.CellParamsBase', attributes: 'str type, Tensor[] tensors, float[] doubles, int[] longs, __torch__.torch.classes.quantized.LinearPackedParamsBase[] packed_params' },
{ name: '__torch__.torch.classes.xnnpack.Conv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int[] output_min, int[] output_max' },
{ name: '__torch__.torch.classes.xnnpack.LinearOpContext', attributes: 'Tensor weight, Tensor bias, int[] output_min, int[] output_max' },
{ name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups, int[] output_min, int[] output_max' },
];
for (const known_type of known_types) {
const prefix = new torch.jit.QualifiedName(known_type.name);
Expand All @@ -9179,27 +9179,41 @@ python.Execution = class {
const fn = new torch.jit.BuiltinOpFunction(name, schema);
type.addMethod(fn);
}
if (known_type.attributes) {
const schema = new torch.FunctionSchema(`(${known_type.attributes}) -> ()`);
for (const arg of schema.arguments) {
type.addAttribute(arg.name, arg.real_type);
}
}

this._compilation_unit.register_type(type);
}
if (this._reader.has_record('model.json')) {
return this.LEGACY_deserialize();
}
const constants = this.readArchive('constants');
for (let i = 0; i < constants.length; i++) {
execution.builtins.CONSTANTS[`c${i}`] = constants[i];
let val = constants[i];
if (val && val.__class__ && val.__class__.__module__.startsWith('__torch__.torch.classes.')) {
const type = this._source_importer.resolveType(`${val.__class__.__module__}.${val.__class__.__name__}`);
const obj = torch.ScriptObject.create(type);
obj._ivalue = val;
val = obj;
}
execution.builtins.CONSTANTS[`c${i}`] = val;
}
const obj = this.readArchive('data');
const convertModule = (obj) => {
const convertObject = (obj) => {
if (obj.__class__) {
const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
const type = this._source_importer.loadType(new torch.jit.QualifiedName(name));
const module = new torch.ScriptModule(type, this._compilation_unit);
const module = type.is_module() ? new torch.ScriptModule(type, this._compilation_unit) : new torch.ScriptObject(type);
for (let i = 0; i < type.numAttributes(); i++) {
const k = type.getAttributeName(i);
const t = type.getAttribute(i);
const v = obj[k];
if (t.is_module()) {
module.__setattr__(k, convertModule(v));
if (t instanceof torch.ClassType) {
module.__setattr__(k, convertObject(v));
} else {
if (t instanceof torch.TensorType && v && v.__class__ && v instanceof torch.Tensor === false && v.__class__.__module__ === '__torch__.torch.classes.quantized') {
const name = `${v.__class__.__module__}.${v.__class__.__name__}`;
Expand All @@ -9217,7 +9231,7 @@ python.Execution = class {
}
throw new python.Error('Module class not found.');
};
return convertModule(obj);
return convertObject(obj);
}
LEGACY_deserialize() {
const execution = this._compilation_unit.execution;
Expand Down Expand Up @@ -9740,7 +9754,7 @@ python.Execution = class {
if (!this.forward) {
return null;
}
execution.traceAttr = false;
execution.traceAttr = true;
const args = [];
if (!execution.traceAttr) {
args.push(this); // self
Expand Down
Loading

0 comments on commit 0c7cb76

Please sign in to comment.