diff --git a/source/python.js b/source/python.js index 7cb730a93a..56ddf6283e 100644 --- a/source/python.js +++ b/source/python.js @@ -4202,11 +4202,14 @@ python.Execution = class { return tensor.dtype.scalar_type(); }); this.registerFunction('ops.prim.is_quantized', function(tensor) { - return tensor && tensor.__quantized__ === true; + return tensor.is_quantized; }); this.registerFunction('ops.prim.is_cuda', function(/* tensor */) { return false; }); + this.registerFunction('ops.prim.is_nested', function(tensor) { + return tensor.is_nested; + }); this.registerFunction('ops.prim.unchecked_unwrap_optional', function(value) { return value; }); @@ -5209,7 +5212,12 @@ python.Execution = class { } throw new python.Error("Unsupported indices in layout'" + this._indices.__str__() + "'."); } - + get is_quantized() { + return this.__quantized__ === true; + } + get is_nested() { + return this.__nested__ === true; + } size() { return this._shape; } diff --git a/source/pytorch.js b/source/pytorch.js index d28a93faad..10e7f97358 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1210,6 +1210,9 @@ pytorch.Execution = class extends python.Execution { const qualified_name = this.data && this.data.__class__ && this.data.__class__.__module__ && this.data.__class__.__name__ ? this.data.__class__.__module__ + '.' + this.data.__class__.__name__ : ''; return execution.invoke('torch.ClassType', [ qualified_name ]); } + get qualified_name() { + return this._type().qualified_name(); + } get code_with_constants() { const const_map = {}; const_map.const_mapping = new Map(Object.entries(execution.builtins.CONSTANTS)); @@ -1517,6 +1520,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { this._kind = kind; this._inputs = []; this._outputs = []; + this._blocks = []; } kind() { return this._kind; @@ -1527,6 +1531,9 @@ pytorch.jit.Execution = class extends pytorch.Execution { outputs() { return this._outputs; } + blocks() { + return this._blocks; + } addInput(value) { const use = execution.invoke('torch.Use', [ this ]); value._uses.push(use); @@ -1538,6 +1545,11 @@ pytorch.jit.Execution = class extends pytorch.Execution { this._outputs.push(value); return value; } + addBlock() { + const block = execution.invoke('torch.Block' [ this._graph, this ]); + this._blocks.push(block); + return block; + } }); this.registerType('torch.Value', class { constructor(node) { @@ -2094,7 +2106,7 @@ pytorch.jit.Execution = class extends pytorch.Execution { if (type.startsWith('torch.')) { overloads = this._types.get('aten::' + type.substring(6)); } - else if (type.startsWith('ops.')) { + else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) { const path = type.split('.'); if (path.length === 3) { overloads = this._types.get(path[1] + '::' + path[2]); diff --git a/test/backend.py b/test/backend.py index d9c0ea9bcc..a27bed8178 100755 --- a/test/backend.py +++ b/test/backend.py @@ -35,9 +35,10 @@ def _test_onnx_iterate(): def _test_torchscript_transformer(): torch = __import__('torch') model = torch.nn.Transformer(nhead=16, num_encoder_layers=12) - trace = torch.jit.trace(model, (torch.rand(10, 32, 512), torch.rand(20, 32, 512))) - torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access - netron.serve('transformer', trace) + module = torch.jit.trace(model, (torch.rand(10, 32, 512), torch.rand(20, 32, 512))) + # module = torch.jit.script(model) + torch._C._jit_pass_inline(module.graph) # pylint: disable=protected-access + netron.serve('transformer', module) def _test_torchscript_resnet34(): torch = __import__('torch') diff --git a/test/models.json b/test/models.json index 79f72896af..8c7e173e1c 100644 --- a/test/models.json +++ b/test/models.json @@ -4985,6 +4985,13 @@ "format": "TorchScript v1.0", "link": "https://github.com/KinglittleQ/SuperPoint_SLAM" }, + { + "type": "pytorch", + "target": "test.8bit.pth", + "source": "https://github.com/lutzroeder/netron/files/5238524/test.8bit.pth.zip[test.8bit.pth]", + "format": "TorchScript v1.6", + "link": "https://github.com/lutzroeder/netron/issues/546" + }, { "type": "pytorch", "target": "traced_fft.pt", @@ -5022,10 +5029,18 @@ }, { "type": "pytorch", - "target": "test.8bit.pth", - "source": "https://github.com/lutzroeder/netron/files/5238524/test.8bit.pth.zip[test.8bit.pth]", + "target": "transformer.pt", + "source": "https://github.com/lutzroeder/netron/files/10271969/transformer.pt.zip[transformer.pt]", "format": "TorchScript v1.6", - "link": "https://github.com/lutzroeder/netron/issues/546" + "error": "AssertionError: was expecting embedding dimension of 512, but got ?", + "link": "https://github.com/lutzroeder/netron/issues/842" + }, + { + "type": "pytorch", + "target": "transformer_traced.pt", + "source": "https://github.com/lutzroeder/netron/files/10271968/transformer_traced.pt.zip[transformer_traced.pt]", + "format": "TorchScript v1.6", + "link": "https://github.com/lutzroeder/netron/issues/842" }, { "type": "pytorch",