Skip to content

Commit

Permalink
Add TorchScript test file (#842) (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 8, 2021
1 parent 941048a commit 96a5eee
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
2 changes: 1 addition & 1 deletion source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -2800,7 +2800,7 @@ python.Execution = class {
break;
}
case 'var': {
context.set(statement.name, undefined);
context.set(statement.name, statement.initializer ? this.expression(statement.initializer, context) : undefined);
break;
}
case '=': {
Expand Down
40 changes: 38 additions & 2 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,10 @@ pytorch.Execution = class extends python.Execution {
}
throw new pytorch.Error("Unknown 'torch.ge' expression type.");
});
this.registerFunction('torch.is_floating_point', function(tensor) {
const type = tensor.dtype.scalar_type();
return (type === 5 || type === 6 || type === 7);
});
this.registerFunction('torch.jit._pickle.build_boollist', function(data) {
return data;
});
Expand Down Expand Up @@ -1748,7 +1752,7 @@ pytorch.Execution = class extends python.Execution {
throw new pytorch.Error('Slicing only supports step=1');
}
start = Math.max(0, start >= 0 ? start : l.length + start);
end = Math.min(l.length, end);
end = Math.min(l.length, end || Number.MAX_SAFE_INTEGER);
return l.slice(start, end);
});
this.registerFunction('torch.sub', function(left, right) {
Expand Down Expand Up @@ -2973,6 +2977,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
}
case 'torch.mean':
case 'torch.mul':
case 'torch.div':
case 'torch.batch_norm':
case 'torch.gelu':
case 'torch.relu':
Expand All @@ -2983,7 +2988,8 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
}
break;
}
case 'torch.add': {
case 'torch.add':
case 'torch.sub': {
const input = this.expression(args[0], context);
if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
parameter.resize_(input.size());
Expand All @@ -2996,6 +3002,13 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
}
break;
}
case 'torch.select': {
const input = this.expression(args[0], context);
if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
parameter.resize_(Array(input.size().length - 1).fill(NaN));
}
break;
}
case 'torch.layer_norm': {
const input = this.expression(args[0], context);
const normalized_shape = this.expression(args[1], context);
Expand Down Expand Up @@ -3176,6 +3189,29 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
tensor.resize_(Array(number).fill(NaN));
}
}
// val = torch.slice(torch.size(img), -2)
// if torch.eq(torch.len(val), 2):
// pass
// else:
// ops.prim.RaiseException("AssertionError: ")
if (assign.type === '=' &&
condition.type === 'if' &&
pytorch.Utility.isCall(assign.expression, 'torch.slice', 2) &&
pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.size', 1) &&
pytorch.Utility.isCall(condition.condition, 'torch.eq', 2) &&
pytorch.Utility.isCall(condition.condition.arguments[0], 'torch.len', 1) &&
pytorch.Utility.isEqual(condition.condition.arguments[0].arguments[0], assign.target) &&
condition.else.statements.length == 1 &&
pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) {
const tensor = this.expression(assign.expression.arguments[0].arguments[0], context);
if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
const start = this.expression(assign.expression.arguments[1], context);
const value = this.expression(condition.condition.arguments[1], context);
if (Number.isInteger(start) && Number.isInteger(value)) {
tensor.resize_(Array(value - start).fill(NaN));
}
}
}
}
if (statements.length > 1) {
const size = statements[0];
Expand Down
11 changes: 9 additions & 2 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4326,8 +4326,8 @@
{
"type": "pytorch",
"target": "fasterrcnn_resnet50_fpn.pt",
"source": "https://github.com/lutzroeder/netron/files/6040364/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]",
"error": "Unsupported function 'torch.full' in 'fasterrcnn_resnet50_fpn.pt'.",
"source": "https://github.com/lutzroeder/netron/files/7677467/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]",
"error": "Unknown torch.add expression type in 'fasterrcnn_resnet50_fpn.pt'.",
"link": "https://github.com/lutzroeder/netron/issues/689"
},
{
Expand Down Expand Up @@ -4859,6 +4859,13 @@
"source": "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth",
"format": "PyTorch v0.1.1"
},
{
"type": "pytorch",
"target": "ssdlite320_mobilenet_v3_large.pt",
"source": "https://github.com/lutzroeder/netron/files/7677468/ssdlite320_mobilenet_v3_large.pt.zip[ssdlite320_mobilenet_v3_large.pt]",
"error": "l.slice is not a function in 'ssdlite320_mobilenet_v3_large.pt'.",
"link": "https://github.com/lutzroeder/netron/issues/842"
},
{
"type": "pytorch",
"target": "superpoint_v1.pth",
Expand Down

0 comments on commit 96a5eee

Please sign in to comment.