Skip to content

Commit

Permalink
Add TorchScript test file (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 9, 2021
1 parent 4aede5f commit 6b0ac6a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
21 changes: 21 additions & 0 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1561,11 +1561,17 @@ pytorch.Execution = class extends python.Execution {
return left === right;
}
if (typeof left === 'number' && typeof right === 'number') {
if (isNaN(left) && isNaN(right)) {
return true;
}
return left === right;
}
if (left === undefined || right === undefined) {
return true;
}
if (Array.isArray(left) && Array.isArray(right)) {
return left.length === right.length && left.every((item, index) => item === right[index]);
}
throw new pytorch.Error("Unknown 'torch.eq' expression type.");
});
this.registerFunction('torch.floor', function(value) {
Expand Down Expand Up @@ -2995,6 +3001,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
case 'torch.batch_norm':
case 'torch.gelu':
case 'torch.relu':
case 'torch.clamp_':
case 'torch.hardswish_': {
const input = this.expression(args[0], context);
if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
Expand Down Expand Up @@ -3320,6 +3327,20 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
tensor.resize_([ NaN, NaN, NaN, NaN ]);
}
}
if (statement.type === '=' &&
statement.expression.type === 'call' && statement.expression.arguments.length > 0 &&
pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.size', 2)) {
const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
const dim = this.expression(statement.expression.arguments[0].arguments[1], context);
if (pytorch.Utility.isTensor(tensor) && Number.isInteger(dim)) {
if (tensor.shape === undefined) {
tensor.resize_(Array(dim + 1).fill(NaN));
}
else if (Array.isArray(tensor.shape) && tensor.shape.length <= dim) {
tensor.resize_(tensor.shape.concat(Array(dim + 1 - tensor.shape.length).fill(NaN)));
}
}
}
const value = this.statement(statement, context);
if (value !== undefined) {
return value;
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4922,6 +4922,13 @@
"format": "TorchScript v1.5",
"link": "https://github.com/lutzroeder/netron/issues/546"
},
{
"type": "pytorch",
"target": "torchscript_resnet50_fp32.pth",
"source": "https://github.com/lutzroeder/netron/files/7688572/torchscript_resnet50_fp32.pth.zip[torchscript_resnet50_fp32.pth]",
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/842"
},
{
"type": "pytorch",
"target": "tutorial_bidirectional_recurrent_neural_network.pth",
Expand Down

0 comments on commit 6b0ac6a

Please sign in to comment.