Skip to content

Commit

Permalink
TorchScript 1.3 prototype (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 6, 2019
1 parent a0bb598 commit 849bf65
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/pickle.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pickle.Unpickler = class {
break;
case pickle.OpCode.TUPLE:
items = stack;
stack = marker .pop();
stack = marker.pop();
stack.push(items);
break;
case pickle.OpCode.SETITEM:
Expand Down
51 changes: 26 additions & 25 deletions src/torchscript.js
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,9 @@ torchscript.ModelFactory = class {
container.prefix = version.name.substring(0, version.name.length - 7);
let find = (name) => {
let entry = container.entries.find((entry) => entry.name == container.prefix + name);
if (entry) {
return entry.data;
}
return null;
return entry ? entry.data : null;
}
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
container.version = version.data;
container.attributes = find('attribtues.pkl');
container.constants = find('constants.pkl');
Expand Down Expand Up @@ -299,34 +297,37 @@ torchscript.Graph = class {
}
}
}
/*
if (container.data) {
let queue = [ container.data ];
while (queue.length > 0) {
let module = queue.shift();
if (module.parameters) {
for (let parameter of module.parameters) {
if (parameter.tensorId) {
let tensorId = parseInt(parameter.tensorId, 10);
parameter.initializer = container.tensors[tensorId];
if (parameter.outputs && parameter.outputs.length == 1) {
container.parameters[parameter.outputs[0]] = parameter;
}
}
}
}
for (let key of Object.keys(module)) {
if (key !== '__type__' && key !== '__parent__') {
let submodule = module[key];
if (submodule === Object(submodule)) {
submodule.__parent__ = module;
queue.push(submodule);
let obj = module[key];
if (!Array.isArray(obj) && obj === Object(obj)) {
if (obj && obj.__type__ && obj.__type__.endsWith('Tensor')) {
// debugger;
}
else {
obj.__parent__ = module;
queue.push(obj);
}
}
/* if (module.parameters) {
for (let parameter of module.parameters) {
if (parameter.tensorId) {
let tensorId = parseInt(parameter.tensorId, 10);
parameter.initializer = container.tensors[tensorId];
if (parameter.outputs && parameter.outputs.length == 1) {
container.parameters[parameter.outputs[0]] = parameter;
}
}
}
*/
}
}
}
}
*/

if (context) {
for (let input of context.inputs) {
Expand Down Expand Up @@ -468,8 +469,8 @@ torchscript.Node = class {
for (let argument of input) {
let parameter = container.parameters[argument.id];
if (parameter) {
if (parameter.module && (module == null || module == parameter.module)) {
module = parameter.module;
if (parameter.__module__ && (module == null || module == parameter.__module__)) {
module = parameter.__module__;
count++;
}
else {
Expand Down Expand Up @@ -1389,7 +1390,7 @@ torchscript.GraphContext = class {
let targetModule = this._module(expression.target);
if (targetModule && targetModule.parameters) {
for (let parameter of targetModule.parameters) {
parameter.module = targetModule;
parameter.__module__ = targetModule;
if (parameter.name === expression.member.value) {
parameter.outputs = parameter.outputs || [];
parameter.outputs.push(target.value);
Expand All @@ -1398,7 +1399,7 @@ torchscript.GraphContext = class {
}
targetModule.unresolvedParameters = targetModule.unresolvedParameters || [];
for (let unresolvedParameter of targetModule.unresolvedParameters) {
unresolvedParameter.module = targetModule;
unresolvedParameter.__module__ = targetModule;
if (unresolvedParameter.name === expression.member.value) {
unresolvedParameter.outputs = unresolvedParameter.outputs || [];
unresolvedParameter.outputs.push(target.value);
Expand Down
13 changes: 10 additions & 3 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -3703,7 +3703,7 @@
{
"type": "pytorch",
"target": "mnist_linear.ckpt",
"source": "https://github.com/lutzroeder/netron/files/3269075/mnist_linear_torchscript.zip[mnist_linear.ckpt]",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear.ckpt]",
"format": "PyTorch",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
Expand Down Expand Up @@ -5006,8 +5006,15 @@
},
{
"type": "torchscript",
"target": "mnist_linear_torchscript.pt",
"source": "https://github.com/lutzroeder/netron/files/3269075/mnist_linear_torchscript.zip[mnist_linear_torchscript.pt]",
"target": "mnist_linear_torchscript_1.pt",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_1.pt]",
"format": "TorchScript v1",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
{
"type": "torchscript",
"target": "mnist_linear_torchscript_2.pt",
"source": "https://github.com/lutzroeder/netron/files/3585288/mnist_linear_torchscript.zip[mnist_linear_torchscript_2.pt]",
"format": "TorchScript v1",
"link": "https://github.com/lutzroeder/netron/issues/281"
},
Expand Down

0 comments on commit 849bf65

Please sign in to comment.