Skip to content

Commit

Permalink
Update TorchScript test files (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 8, 2024
1 parent b4e6833 commit a805f35
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 36 deletions.
232 changes: 228 additions & 4 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
[
{
"name": "__torch__.torch.classes.rnn.CellParamsBase",
"inputs": [
{ "name": "type", "type": "string" },
{ "name": "tensors", "type": "Tensor[]" },
{ "name": "doubles", "type": "float64[]" },
{ "name": "longs", "type": "int64[]" },
{ "name": "packed_params", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase[]" }
]
},
{
"name": "__torch__.torch.classes.xnnpack.Conv2dOpContext",
"inputs": [
Expand Down Expand Up @@ -10400,17 +10410,94 @@
]
},
{
"name": "aten::quantized_lstm",
"name": "aten::quantized_gru.data",
"category": "Layer",
"inputs": [
{ "name": "data", "type": "Tensor" },
{ "name": "batch_sizes", "type": "Tensor" },
{ "name": "hx", "type": "Tensor" },
{ "name": "params", "type": "__torch__.torch.classes.rnn.CellParamsBase[]" },
{ "name": "has_biases", "type": "boolean" },
{ "name": "num_layers", "type": "int64" },
{ "name": "dropout", "type": "float32" },
{ "name": "train", "type": "boolean" },
{ "name": "bidirectional", "type": "boolean" }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" }
]
},
{
"name": "aten::quantized_gru.data_legacy",
"category": "Layer",
"inputs": [
{ "name": "data", "type": "Tensor" },
{ "name": "batch_sizes", "type": "Tensor" },
{ "name": "hx", "type": "Tensor" },
{ "name": "params", "type": "Tensor[]" },
{ "name": "has_biases", "type": "boolean" },
{ "name": "num_layers", "type": "int64" },
{ "name": "dropout", "type": "float32" },
{ "name": "train", "type": "boolean" },
{ "name": "bidirectional", "type": "boolean" }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" }
]
},
{
"name": "aten::quantized_gru.input",
"category": "Layer",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor[]" },
{ "name": "hx", "type": "Tensor" },
{ "name": "params", "type": "__torch__.torch.classes.rnn.CellParamsBase[]" },
{ "name": "has_biases", "type": "boolean" },
{ "name": "num_layers", "type": "int64" },
{ "name": "dropout", "type": "float32" },
{ "name": "train", "type": "boolean" },
{ "name": "bidirectional", "type": "boolean" },
{ "name": "batch_first", "type": "boolean" }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" }
]
},
{
"name": "aten::quantized_gru.input_legacy",
"category": "Layer",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor" },
{ "name": "params", "type": "Tensor[]" },
{ "name": "has_biases", "type": "boolean" },
{ "name": "num_layers", "type": "int64" },
{ "name": "dropout", "type": "float32" },
{ "name": "train", "type": "boolean" },
{ "name": "bidirectional", "type": "boolean" },
{ "name": "batch_first", "type": "boolean" },
{ "name": "batch_first", "type": "boolean" }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" }
]
},
{
"name": "aten::quantized_lstm.data",
"category": "Layer",
"inputs": [
{ "name": "data", "type": "Tensor" },
{ "name": "batch_sizes", "type": "Tensor" },
{ "name": "hx", "type": "Tensor[]" },
{ "name": "params", "type": "__torch__.torch.classes.rnn.CellParamsBase[]" },
{ "name": "has_biases", "type": "boolean" },
{ "name": "num_layers", "type": "int64" },
{ "name": "dropout", "type": "float32" },
{ "name": "train", "type": "boolean" },
{ "name": "bidirectional", "type": "boolean" },
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null },
{ "name": "use_dynamic", "type": "boolean", "default": false }
],
Expand All @@ -10421,7 +10508,8 @@
]
},
{
"name": "aten::quantized_lstm.data",
"name": "aten::quantized_lstm.data_legacy",
"category": "Layer",
"inputs": [
{ "name": "data", "type": "Tensor" },
{ "name": "batch_sizes", "type": "Tensor" },
Expand All @@ -10441,6 +10529,50 @@
{ "type": "Tensor" }
]
},
{
"name": "aten::quantized_lstm.input",
"category": "Layer",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor[]" },
{ "name": "params", "type": "__torch__.torch.classes.rnn.CellParamsBase[]" },
{ "name": "has_biases", "type": "boolean" },
{ "name": "num_layers", "type": "int64" },
{ "name": "dropout", "type": "float32" },
{ "name": "train", "type": "boolean" },
{ "name": "bidirectional", "type": "boolean" },
{ "name": "batch_first", "type": "boolean" },
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null },
{ "name": "use_dynamic", "type": "boolean", "default": false }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" },
{ "type": "Tensor" }
]
},
{
"name": "aten::quantized_lstm.input_legacy",
"category": "Layer",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor[]" },
{ "name": "params", "type": "Tensor[]" },
{ "name": "has_biases", "type": "boolean" },
{ "name": "num_layers", "type": "int64" },
{ "name": "dropout", "type": "float32" },
{ "name": "train", "type": "boolean" },
{ "name": "bidirectional", "type": "boolean" },
{ "name": "batch_first", "type": "boolean" },
{ "name": "dtype", "type": "ScalarType", "optional": true, "default": null },
{ "name": "use_dynamic", "type": "boolean", "default": false }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" },
{ "type": "Tensor" }
]
},
{
"name": "aten::quantized_lstm_cell",
"inputs": [
Expand Down Expand Up @@ -15742,6 +15874,41 @@
{ "name": "Y", "type": "Tensor" }
]
},
{
"name": "quantized::make_quantized_cell_params",
"inputs": [
{ "name": "w_ih", "type": "Tensor" },
{ "name": "w_hh", "type": "Tensor" },
{ "name": "b_ih", "type": "Tensor" },
{ "name": "b_hh", "type": "Tensor" }
],
"outputs": [
{ "type": "__torch__.torch.classes.rnn.CellParamsBase" }
]
},
{
"name": "quantized::make_quantized_cell_params_dynamic",
"inputs": [
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "bias_ih", "type": "Tensor" },
{ "name": "bias_hh", "type": "Tensor" },
{ "name": "reduce_range", "type": "boolean", "default": false }
],
"outputs": [
{ "type": "__torch__.torch.classes.rnn.CellParamsBase" }
]
},
{
"name": "quantized::make_quantized_cell_params_fp16",
"inputs": [
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" }
],
"outputs": [
{ "type": "__torch__.torch.classes.rnn.CellParamsBase" }
]
},
{
"name": "quantized::mul",
"inputs": [
Expand Down Expand Up @@ -15968,6 +16135,63 @@
{ "type": "Tensor" }
]
},
{
"name": "quantized::quantized_gru_cell_dynamic",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor" },
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "b_ih", "type": "Tensor" },
{ "name": "b_hh", "type": "Tensor" }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "quantized::quantized_lstm_cell_dynamic",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor[]" },
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "bias_ih", "type": "Tensor" },
{ "name": "bias_hh", "type": "Tensor" }
],
"outputs": [
{ "type": "Tensor" },
{ "type": "Tensor" }
]
},
{
"name": "quantized::quantized_rnn_relu_cell_dynamic",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor" },
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "b_ih", "type": "Tensor" },
{ "name": "b_hh", "type": "Tensor" }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "quantized::quantized_rnn_tanh_cell_dynamic",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor" },
{ "name": "w_ih", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "w_hh", "type": "__torch__.torch.classes.quantized.LinearPackedParamsBase" },
{ "name": "b_ih", "type": "Tensor" },
{ "name": "b_hh", "type": "Tensor" }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "quantized::relu6",
"category": "Activation",
Expand Down
53 changes: 23 additions & 30 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pytorch.Graph = class {
return submodules;
};
const loadScriptModule = (module, initializers) => {
if (module) {
if (module && !pytorch.Utility.isObject(module)) {
if (pytorch.Graph._getParameters(module).size > 0 && !module.__hide__) {
const item = { module };
this.nodes.push(new pytorch.Node(metadata, '', item, initializers, values));
Expand Down Expand Up @@ -527,13 +527,21 @@ pytorch.Node = class {
const input = inputs[i];
const schema = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i] : null;
const name = schema && schema.name ? schema.name : i.toString();
const type = schema && schema.type ? schema.type : null;
let type = schema && schema.type ? schema.type : null;
let array = false;
if (type && type.endsWith('[]')) {
array = true;
type = type.slice(0, -2);
}
let argument = null;
if (pytorch.Utility.isObjectType(type)) {
const obj = input.value;
if (initializers.has(obj)) {
if (!array && initializers.has(obj)) {
const node = new pytorch.Node(metadata, group, { name, type, obj }, initializers, values);
argument = new pytorch.Argument(name, node, 'object');
} else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
const node = obj.map((obj) => new pytorch.Node(metadata, group, { name, type, obj }, initializers, values));
argument = new pytorch.Argument(name, node, 'object[]');
} else {
const identifier = input.unique().toString();
const value = values.map(identifier);
Expand Down Expand Up @@ -1799,6 +1807,11 @@ pytorch.jit.Execution = class extends pytorch.Execution {
[this.weight, this.bias] = state;
}
});
this.registerType('__torch__.torch.classes.rnn.CellParamsBase', class {
__setstate__(state) {
[this.type, this.tensors, this.doubles, this.longs, this.packed_params] = state;
}
});
this.registerType('__torch__.torch.classes.xnnpack.Conv2dOpContext', class {
__setstate__(state) {
[this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups, this.output_min, this.output_max] = state;
Expand Down Expand Up @@ -2137,33 +2150,9 @@ pytorch.jit.Execution = class extends pytorch.Execution {
} else {
copyArgs.shift();
copyEvalArgs.shift();
switch (parameter.type) {
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': {
const value = this.variable(argument);
value.value = argument;
node.addInput(value);
/*
for (const [, value] of Object.entries(argument)) {
if (pytorch.Utility.isTensor(value)) {
const tensor = value;
referencedParameters.push(tensor);
}
}
*/
break;
}
default: {
const value = this.variable(argument);
node.addInput(value);
value.value = argument;
break;
}
}
const value = this.variable(argument);
node.addInput(value);
value.value = argument;
}
}
}
Expand Down Expand Up @@ -2416,6 +2405,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.rnn.CellParamsBase':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': {
Expand Down Expand Up @@ -2607,6 +2597,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
case '__torch__.torch.classes.rnn.CellParamsBase':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
Expand Down Expand Up @@ -3390,6 +3381,8 @@ pytorch.Utility = class {
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
case '__torch__.torch.classes.rnn.CellParamsBase':
case '__torch__.torch.classes.rnn.CellParamsBase[]':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
Expand Down
Loading

0 comments on commit a805f35

Please sign in to comment.