diff --git a/source/pytorch.js b/source/pytorch.js index e1d17b9228..1542927d9e 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1561,11 +1561,17 @@ pytorch.Execution = class extends python.Execution { return left === right; } if (typeof left === 'number' && typeof right === 'number') { + if (isNaN(left) && isNaN(right)) { + return true; + } return left === right; } if (left === undefined || right === undefined) { return true; } + if (Array.isArray(left) && Array.isArray(right)) { + return left.length === right.length && left.every((item, index) => item === right[index]); + } throw new pytorch.Error("Unknown 'torch.eq' expression type."); }); this.registerFunction('torch.floor', function(value) { @@ -2995,6 +3001,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { case 'torch.batch_norm': case 'torch.gelu': case 'torch.relu': + case 'torch.clamp_': case 'torch.hardswish_': { const input = this.expression(args[0], context); if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { @@ -3320,6 +3327,20 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { tensor.resize_([ NaN, NaN, NaN, NaN ]); } } + if (statement.type === '=' && + statement.expression.type === 'call' && statement.expression.arguments.length > 0 && + pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.size', 2)) { + const tensor = this.expression(statement.expression.arguments[0].arguments[0], context); + const dim = this.expression(statement.expression.arguments[0].arguments[1], context); + if (pytorch.Utility.isTensor(tensor) && Number.isInteger(dim)) { + if (tensor.shape === undefined) { + tensor.resize_(Array(dim + 1).fill(NaN)); + } + else if (Array.isArray(tensor.shape) && tensor.shape.length <= dim) { + tensor.resize_(tensor.shape.concat(Array(dim + 1 - tensor.shape.length).fill(NaN))); + } + } + } const value = this.statement(statement, context); if (value !== undefined) { return value; diff --git a/test/models.json b/test/models.json index 3918fba639..058a18d824 100644 --- a/test/models.json +++ b/test/models.json @@ -4922,6 +4922,13 @@ "format": "TorchScript v1.5", "link": "https://github.com/lutzroeder/netron/issues/546" }, + { + "type": "pytorch", + "target": "torchscript_resnet50_fp32.pth", + "source": "https://github.com/lutzroeder/netron/files/7688572/torchscript_resnet50_fp32.pth.zip[torchscript_resnet50_fp32.pth]", + "format": "TorchScript v1.6", + "link": "https://github.com/lutzroeder/netron/issues/842" + }, { "type": "pytorch", "target": "tutorial_bidirectional_recurrent_neural_network.pth",