Skip to content

Commit

Permalink
Update pytorch.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 16, 2022
1 parent 463b272 commit 2414e06
Showing 1 changed file with 131 additions and 133 deletions.
264 changes: 131 additions & 133 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,146 +55,146 @@ pytorch.Graph = class {
this._outputs = [];
this._groups = true;
this._name = name || '';
switch (module.__type__) {
case 'script': {
const traced = module.trace();
const initializers = new Map();
const constants = module.execution.builtins.CONSTANTS;
if (constants) {
for (const entry of Object.entries(constants)) {
const name = 'CONSTANTS.' + entry[0];
const value = entry[1];
if (pytorch.Utility.isTensor(value)) {
const initializer = new pytorch.Tensor(name, value);
initializers.set(value, initializer);
}
else if (value && value.__class__ && value.__class__.__module__ && value.__class__.__name__) {
const type = value.__class__.__module__ + '.' + value.__class__.__name__;
switch (type) {
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
for (const entry of Object.entries(value)) {
const key = entry[0];
const value = entry[1];
if (pytorch.Utility.isTensor(value)) {
initializers.set(value, new pytorch.Tensor(name + '.' + key, value));
}
if (module instanceof pytorch.jit.ScriptModule) {
const traced = module.trace();
const initializers = new Map();
const constants = module.execution.builtins.CONSTANTS;
if (constants) {
for (const entry of Object.entries(constants)) {
const name = 'CONSTANTS.' + entry[0];
const value = entry[1];
if (pytorch.Utility.isTensor(value)) {
initializers.set(value, new pytorch.Tensor(name, value));
}
else if (value && value.__class__ && value.__class__.__module__ && value.__class__.__name__) {
const type = value.__class__.__module__ + '.' + value.__class__.__name__;
switch (type) {
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
for (const entry of Object.entries(value)) {
const key = entry[0];
const value = entry[1];
if (pytorch.Utility.isTensor(value)) {
initializers.set(value, new pytorch.Tensor(name + '.' + key, value));
}
break;
default:
throw new pytorch.Error("Unsupported constant context '" + type + "'.");
}
}
else {
throw new pytorch.Error('Unsupported constant.');
}
break;
default:
throw new pytorch.Error("Unsupported constant context '" + type + "'.");
}
}
}
const queue = [ module.data ];
while (queue.length > 0) {
const module = queue.shift();
if (module.__class__ && module.__class__.__module__ === '__torch__.torch.classes._nnapi' && module.__class__.__name__ === 'Compilation') {
continue;
else {
throw new pytorch.Error('Unsupported constant.');
}
for (const entry of Object.entries(module)) {
const key = entry[0];
if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') {
const obj = entry[1];
if (!Array.isArray(obj) && obj === Object(obj)) {
if (pytorch.Utility.isTensor(obj)) {
const parameter = obj;
parameter.__parent__ = module;
if (!parameter.initializer && parameter.storage()) {
if (parameter.__count__ === undefined || parameter.__count__ === 1) {
initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter));
}
}
}
const queue = [ module.data ];
while (queue.length > 0) {
const module = queue.shift();
if (module.__class__ && module.__class__.__module__ === '__torch__.torch.classes._nnapi' && module.__class__.__name__ === 'Compilation') {
continue;
}
for (const entry of Object.entries(module)) {
const key = entry[0];
if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') {
const obj = entry[1];
if (!Array.isArray(obj) && obj === Object(obj)) {
if (pytorch.Utility.isTensor(obj)) {
const parameter = obj;
parameter.__parent__ = module;
if (!parameter.initializer && parameter.storage()) {
if (parameter.__count__ === undefined || parameter.__count__ === 1) {
initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter));
}
}
else if (obj && obj.__class__) {
obj.__parent__ = module;
if (!obj.__id__) {
obj.__id__ = key;
}
queue.push(obj);
}
else if (obj && obj.__class__) {
obj.__parent__ = module;
if (!obj.__id__) {
obj.__id__ = key;
}
queue.push(obj);
}
}
}
}
if (traced) {
const graph = module.graph;
for (const value of graph.inputs()) {
const identifier = value.unique().toString();
const name = value.debugName() || identifier;
this._inputs.push(new pytorch.Parameter(name, true, [
new pytorch.Argument(identifier, null, null)
]));
}
if (traced) {
const graph = module.graph;
for (const value of graph.inputs()) {
const identifier = value.unique().toString();
const name = value.debugName() || identifier;
this._inputs.push(new pytorch.Parameter(name, true, [
new pytorch.Argument(identifier, null, null)
]));
}
for (const value of graph.outputs()) {
const identifier = value.unique().toString();
this._outputs.push(new pytorch.Parameter(identifier, true, [
new pytorch.Argument(identifier, null, null)
]));
}
for (const node of graph.nodes()) {
if (node === graph.param_node() ||
node === graph.return_node()) {
continue;
}
for (const value of graph.outputs()) {
const identifier = value.unique().toString();
this._outputs.push(new pytorch.Parameter(identifier, true, [
new pytorch.Argument(identifier, null, null)
]));
if (node.kind() === 'prim::ListConstruct' &&
node.outputs().length === 1 &&
node.outputs().every((output) => output.uses().length === 1) &&
node.inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
continue;
}
for (const node of graph.nodes()) {
if (node === graph.param_node() ||
node === graph.return_node()) {
continue;
}
if (node.kind() === 'prim::ListConstruct' &&
node.outputs().length === 1 &&
node.outputs().every((output) => output.uses().length === 1) &&
node.inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
continue;
}
if (node.kind() === 'prim::ListUnpack' &&
node.inputs().length === 1 &&
node.inputs().every((input) => input.uses().length === 1) &&
node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
continue;
}
const item = {
type: node.kind(),
node: node
};
this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
if (node.kind() === 'prim::ListUnpack' &&
node.inputs().length === 1 &&
node.inputs().every((input) => input.uses().length === 1) &&
node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
continue;
}
const item = {
type: node.kind(),
node: node
};
this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
}
if (module) {
this._loadScriptModule(metadata, module.data, initializers);
}
break;
}
case 'module': {
this._type = (module.__module__ && module.__name__) ? (module.__module__ + '.' + module.__name__) : '';
this._loadModule(metadata, module, [], []);
break;
if (module) {
this._loadScriptModule(metadata, module.data, initializers);
}
case 'weights': {
for (const state_group of module) {
const attributes = state_group.attributes || [];
const inputs = state_group.states.map((parameter) => {
return new pytorch.Parameter(parameter.name, true,
parameter.arguments.map((state) => {
const tensor = new pytorch.Tensor(state.id, pytorch.Utility.toTensor(state.value));
return new pytorch.Argument(state.id, null, tensor);
}));
});
const obj = {
name: state_group.name,
type: state_group.type || 'torch.nn.Module',
attributes: attributes,
inputs: inputs,
outputs: []
};
this._nodes.push(new pytorch.Node(metadata, '', obj, null));
}
else {
switch (module.__type__) {
case 'module': {
this._type = (module.__module__ && module.__name__) ? (module.__module__ + '.' + module.__name__) : '';
this._loadModule(metadata, module, [], []);
break;
}
case 'weights': {
for (const state_group of module) {
const attributes = state_group.attributes || [];
const inputs = state_group.states.map((parameter) => {
return new pytorch.Parameter(parameter.name, true,
parameter.arguments.map((state) => {
const tensor = new pytorch.Tensor(state.id, pytorch.Utility.toTensor(state.value));
return new pytorch.Argument(state.id, null, tensor);
}));
});
const obj = {
name: state_group.name,
type: state_group.type || 'torch.nn.Module',
attributes: attributes,
inputs: inputs,
outputs: []
};
this._nodes.push(new pytorch.Node(metadata, '', obj, null));
}
break;
}
default: {
throw new pytorch.Error("Unsupported container type '" + module.__type__ + "'.");
}
break;
}
default: {
throw new pytorch.Error("Unsupported container type '" + module.__type__ + "'.");
}
}
}
Expand Down Expand Up @@ -264,7 +264,9 @@ pytorch.Graph = class {
}
if (value) {
const initializer = new pytorch.Tensor('', value);
inputs.push(new pytorch.Parameter(inputName || key, visible, [ new pytorch.Argument('', null, initializer) ]));
inputs.push(new pytorch.Parameter(inputName || key, visible, [
new pytorch.Argument('', null, initializer)
]));
}
}

Expand Down Expand Up @@ -1359,7 +1361,9 @@ pytorch.Container.Zip = class extends pytorch.Container {
}
};

pytorch.Container.Zip.Script = class {
pytorch.jit = {};

pytorch.jit.ScriptModule = class {

constructor(entries, execution, location, name) {
this.__type__ = 'script';
Expand Down Expand Up @@ -1601,7 +1605,7 @@ pytorch.Container.Zip.Json = class extends pytorch.Container.Zip {
}
};

pytorch.Container.Zip.Json.Script = class extends pytorch.Container.Zip.Script {
pytorch.Container.Zip.Json.Script = class extends pytorch.jit.ScriptModule {

constructor(entries, execution, model) {
super(entries, execution);
Expand Down Expand Up @@ -1657,12 +1661,7 @@ pytorch.Container.Zip.Json.Script = class extends pytorch.Container.Zip.Script {
}
while (queue.length > 0) {
const module = queue.shift();
if (!module.__class__) {
module.__class__ = {
__module__: 'torch.nn.modules.module',
__name__: 'Module'
};
}
module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' };
if (module.name) {
module.__id__ = module.name;
}
Expand Down Expand Up @@ -1702,7 +1701,6 @@ pytorch.Container.Zip.Json.Script = class extends pytorch.Container.Zip.Script {
this._data.forward = module.forward;
}
}
delete this._model;
}

get name() {
Expand Down Expand Up @@ -1748,7 +1746,7 @@ pytorch.Container.Zip.Pickle = class extends pytorch.Container.Zip {
}
};

pytorch.Container.Zip.Pickle.Script = class extends pytorch.Container.Zip.Script {
pytorch.Container.Zip.Pickle.Script = class extends pytorch.jit.ScriptModule {

constructor(entries, execution, location) {
super(entries, execution, location);
Expand Down

0 comments on commit 2414e06

Please sign in to comment.