Skip to content

Commit

Permalink
Add TorchScript test files (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 20, 2022
1 parent 9ffbdf8 commit 716674a
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 9 deletions.
12 changes: 10 additions & 2 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4202,11 +4202,14 @@ python.Execution = class {
return tensor.dtype.scalar_type();
});
this.registerFunction('ops.prim.is_quantized', function(tensor) {
return tensor && tensor.__quantized__ === true;
return tensor.is_quantized;
});
this.registerFunction('ops.prim.is_cuda', function(/* tensor */) {
return false;
});
this.registerFunction('ops.prim.is_nested', function(tensor) {
return tensor.is_nested;
});
this.registerFunction('ops.prim.unchecked_unwrap_optional', function(value) {
return value;
});
Expand Down Expand Up @@ -5209,7 +5212,12 @@ python.Execution = class {
}
throw new python.Error("Unsupported indices in layout'" + this._indices.__str__() + "'.");
}

get is_quantized() {
return this.__quantized__ === true;
}
get is_nested() {
return this.__nested__ === true;
}
size() {
return this._shape;
}
Expand Down
14 changes: 13 additions & 1 deletion source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,9 @@ pytorch.Execution = class extends python.Execution {
const qualified_name = this.data && this.data.__class__ && this.data.__class__.__module__ && this.data.__class__.__name__ ? this.data.__class__.__module__ + '.' + this.data.__class__.__name__ : '';
return execution.invoke('torch.ClassType', [ qualified_name ]);
}
get qualified_name() {
return this._type().qualified_name();
}
get code_with_constants() {
const const_map = {};
const_map.const_mapping = new Map(Object.entries(execution.builtins.CONSTANTS));
Expand Down Expand Up @@ -1517,6 +1520,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
this._kind = kind;
this._inputs = [];
this._outputs = [];
this._blocks = [];
}
kind() {
return this._kind;
Expand All @@ -1527,6 +1531,9 @@ pytorch.jit.Execution = class extends pytorch.Execution {
outputs() {
return this._outputs;
}
blocks() {
return this._blocks;
}
addInput(value) {
const use = execution.invoke('torch.Use', [ this ]);
value._uses.push(use);
Expand All @@ -1538,6 +1545,11 @@ pytorch.jit.Execution = class extends pytorch.Execution {
this._outputs.push(value);
return value;
}
addBlock() {
const block = execution.invoke('torch.Block' [ this._graph, this ]);
this._blocks.push(block);
return block;
}
});
this.registerType('torch.Value', class {
constructor(node) {
Expand Down Expand Up @@ -2094,7 +2106,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
if (type.startsWith('torch.')) {
overloads = this._types.get('aten::' + type.substring(6));
}
else if (type.startsWith('ops.')) {
else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) {
const path = type.split('.');
if (path.length === 3) {
overloads = this._types.get(path[1] + '::' + path[2]);
Expand Down
7 changes: 4 additions & 3 deletions test/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def _test_onnx_iterate():
def _test_torchscript_transformer():
torch = __import__('torch')
model = torch.nn.Transformer(nhead=16, num_encoder_layers=12)
trace = torch.jit.trace(model, (torch.rand(10, 32, 512), torch.rand(20, 32, 512)))
torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access
netron.serve('transformer', trace)
module = torch.jit.trace(model, (torch.rand(10, 32, 512), torch.rand(20, 32, 512)))
# module = torch.jit.script(model)
torch._C._jit_pass_inline(module.graph) # pylint: disable=protected-access
netron.serve('transformer', module)

def _test_torchscript_resnet34():
torch = __import__('torch')
Expand Down
21 changes: 18 additions & 3 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4985,6 +4985,13 @@
"format": "TorchScript v1.0",
"link": "https://github.com/KinglittleQ/SuperPoint_SLAM"
},
{
"type": "pytorch",
"target": "test.8bit.pth",
"source": "https://github.com/lutzroeder/netron/files/5238524/test.8bit.pth.zip[test.8bit.pth]",
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/546"
},
{
"type": "pytorch",
"target": "traced_fft.pt",
Expand Down Expand Up @@ -5022,10 +5029,18 @@
},
{
"type": "pytorch",
"target": "test.8bit.pth",
"source": "https://github.com/lutzroeder/netron/files/5238524/test.8bit.pth.zip[test.8bit.pth]",
"target": "transformer.pt",
"source": "https://github.com/lutzroeder/netron/files/10271969/transformer.pt.zip[transformer.pt]",
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/546"
"error": "AssertionError: was expecting embedding dimension of 512, but got ?",
"link": "https://github.com/lutzroeder/netron/issues/842"
},
{
"type": "pytorch",
"target": "transformer_traced.pt",
"source": "https://github.com/lutzroeder/netron/files/10271968/transformer_traced.pt.zip[transformer_traced.pt]",
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/842"
},
{
"type": "pytorch",
Expand Down

0 comments on commit 716674a

Please sign in to comment.