From 64adbf762e9a0e4097c0f9a3ac347147a144e650 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Tue, 21 Dec 2021 13:07:41 -0500 Subject: [PATCH] Update onnx.js --- source/onnx.js | 335 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 225 insertions(+), 110 deletions(-) diff --git a/source/onnx.js b/source/onnx.js index 468253382a..8c66d5da75 100644 --- a/source/onnx.js +++ b/source/onnx.js @@ -343,7 +343,7 @@ onnx.Model = class { this._graphs = []; if (model && model.graph) { const graphMetadata = new onnx.GraphMetadata(metadata, imports); - const context = new onnx.Context(graphMetadata, imageFormat); + const context = new onnx.ModelContext(graphMetadata, imageFormat); for (const func of model.functions || []) { context.metadata.add(new onnx.Function(context, func)); } @@ -441,18 +441,18 @@ onnx.Graph = class { this._name = graph.name || null; this._description = graph.doc_string || ''; - const tensors = onnx.Utility.createTensors(graph.node); + context = new onnx.GraphContext(context, graph.node); for (const initializer of graph.initializer) { - const tensor = tensors.map(initializer.name); + const tensor = context.tensor(initializer.name); tensor.initializer = new onnx.Tensor(context, initializer, 'Initializer'); } for (const sparse_initializer of graph.sparse_initializer) { - const tensor = tensors.map(sparse_initializer.values.name); + const tensor = context.tensor(sparse_initializer.values.name); tensor.initializer = new onnx.Tensor(context, sparse_initializer, 'Sparse Initializer'); } for (const tensor_annotation of graph.quantization_annotation || []) { - const tensor = tensors.map(tensor_annotation.tensor_name); + const tensor = context.tensor(tensor_annotation.tensor_name); const annotation = {}; for (const pair of tensor_annotation.quant_parameter_tensor_names) { annotation[pair.key] = pair.value; @@ -460,41 +460,33 @@ onnx.Graph = class { tensor.annotation = annotation; } for (const valueInfo of graph.value_info) { - const tensor = tensors.map(valueInfo.name); + const tensor = context.tensor(valueInfo.name); tensor.type = context.createType(valueInfo.type); tensor.description = valueInfo.doc_string; } graph.input = graph.input.map((valueInfo) => { - const tensor = tensors.map(valueInfo.name); + const tensor = context.tensor(valueInfo.name); tensor.type = context.createType(valueInfo.type); tensor.description = valueInfo.doc_string; return tensor; }); graph.output = graph.output.map((valueInfo) => { - const tensor = tensors.map(valueInfo.name); + const tensor = context.tensor(valueInfo.name); tensor.type = context.createType(valueInfo.type); tensor.description = valueInfo.doc_string; return tensor; }); new onnx.Inference(graph.node, graph.output); - const args = new Map(); - args.map = function(name) { - if (!this.has(name)) { - const tensor = tensors.map(name); - const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; - this.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description)); - } - return this.get(name); - }; - this._nodes = onnx.Utility.createNodes(context, graph.node, graph.input, graph.output, tensors, args); + context.push(graph.node, graph.input, graph.output); + this._nodes = context.pop(); for (const input of graph.input) { - const argument = args.map(input.name); + const argument = context.argument(input.name); if (!argument.initializer) { this._inputs.push(new onnx.Parameter(input.name, [ argument ])); } } for (const output of graph.output) { - const argument = args.map(output.name); + const argument = context.argument(output.name); if (!argument.initializer) { this._outputs.push(new onnx.Parameter(output.name, [ argument ])); } @@ -741,6 +733,78 @@ onnx.Attribute = class { } }; +onnx.Group = class { + + constructor(name, groups) { + this._type = { name: 'Scope' }; + this._name = name; + this._nodes = []; + for (const entry of groups) { + const key = entry[0]; + if (key === '') { + for (const node of entry[1]) { + this._nodes.push(node); + } + } + else { + this._nodes.push(new onnx.Group(name === '' ? key : name + '/' + key, entry[1])); + } + } + const set = new Set(); + const inputs = new Array(); + const outputs = new Array(); + for (const node of this._nodes) { + if (node instanceof onnx.Group) { + node.freeze(); + } + for (const parameter of node.outputs) { + for (const argument of parameter.arguments) { + if (!argument.initializer) { + outputs.push(argument); + set.add(argument.name); + } + } + } + } + for (const node of this._nodes) { + for (const parameter of node.inputs) { + for (const argument of parameter.arguments) { + if (!set.has(argument.name) && !argument.initializer) { + inputs.push(argument); + } + } + } + } + this._inputs = [ new onnx.Parameter('inputs', inputs) ]; + this._outputs = [ new onnx.Parameter('outputs', outputs) ]; + this._attributes = []; + } + + get name() { + return this._name; + } + + get type() { + return this._type; + } + + get inputs() { + return this._inputs; + } + + get outputs() { + return this._outputs; + } + + get attributes() { + return this._attributes; + } + + get nodes() { + return this._nodes; + } +}; + onnx.Tensor = class { constructor(context, tensor, kind) { @@ -1175,27 +1239,19 @@ onnx.Function = class { this._inputs = []; this._outputs = []; this._attributes = func.attribute.map((attribtue) => { return { name: attribtue }; }); - const tensors = onnx.Utility.createTensors(func.node); - func.input = func.input.map((input) => tensors.map(input)); - func.output = func.output.map((output) => tensors.map(output)); - const args = new Map(); - args.map = function(name) { - if (!this.has(name)) { - const tensor = tensors.map(name); - const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; - this.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description)); - } - return this.get(name); - }; - this._nodes = onnx.Utility.createNodes(context, func.node, func.input, func.output, tensors, args); + context = new onnx.GraphContext(context, func.node); + func.input = func.input.map((input) => context.tensor(input)); + func.output = func.output.map((output) => context.tensor(output)); + context.push(func.node, func.input, func.output); + this._nodes = context.pop(); for (const input of func.input) { - const argument = args.map(input.name); + const argument = context.argument(input.name); if (!argument.initializer) { this._inputs.push(new onnx.Parameter(input.name, [ argument ])); } } for (const output of func.output) { - const argument = args.map(output.name); + const argument = context.argument(output.name); if (!argument.initializer) { this._outputs.push(new onnx.Parameter(output.name, [ argument ])); } @@ -1421,29 +1477,123 @@ onnx.AttributeType = { TYPE_PROTOS: 14 }; -onnx.Context = class { +onnx.ModelContext = class { constructor(metadata, imageFormat) { this._metadata = metadata; this._imageFormat = imageFormat; this._graphs = new Map(); + } + + get metadata() { + return this._metadata; + } + + get imageFormat() { + return this._imageFormat; + } + + graph(value) { + if (!this._graphs.has(value)) { + this._graphs.set(value, new onnx.Graph(this, value)); + } + return this._graphs.get(value); + } +}; + +onnx.GraphContext = class { + + constructor(context, nodes) { + this._context = context; this._decoder = new TextDecoder('utf-8'); this._dataTypes = new Map(Object.entries(onnx.DataType).map((entry) => [ entry[1], entry[0].toLowerCase() ])); this._dataTypes.set(onnx.DataType.UNDEFINED, 'UNDEFINED'); this._dataTypes.set(onnx.DataType.BOOL, 'boolean'); this._dataTypes.set(onnx.DataType.FLOAT, 'float32'); this._dataTypes.set(onnx.DataType.DOUBLE, 'float64'); + this._tensors = new Map(); + this._arguments = new Map(); + this._groups = new Map(); + this._nodes = []; + for (const node of nodes) { + node.input = node.input.map((name) => this.tensor(name)); + node.output = node.output.map((name) => this.tensor(name)); + node.param = {}; + for (const attribute of node.attribute) { + if (attribute.type) { + continue; + } + if (attribute.ints && attribute.ints.length > 0) { + attribute.type = onnx.AttributeType.INTS; + } + else if (attribute.floats && attribute.floats.length > 0) { + attribute.type = onnx.AttributeType.FLOATS; + } + else if (attribute.strings && attribute.strings.length > 0) { + attribute.type = onnx.AttributeType.STRINGS; + } + else if (attribute.graphs && attribute.graphs.length > 0) { + attribute.type = onnx.AttributeType.GRAPHS; + } + else if (attribute.s && attribute.s.length > 0) { + attribute.type = onnx.AttributeType.STRING; + } + else if (Object.prototype.hasOwnProperty.call(attribute, 'f')) { + attribute.type = onnx.AttributeType.FLOAT; + } + else if (Object.prototype.hasOwnProperty.call(attribute, 'i')) { + attribute.type = onnx.AttributeType.INT; + } + else if (Object.prototype.hasOwnProperty.call(attribute, 't')) { + attribute.type = onnx.AttributeType.TENSOR; + } + else if (Object.prototype.hasOwnProperty.call(attribute, 'g')) { + attribute.type = onnx.AttributeType.GRAPH; + } + else if (Object.prototype.hasOwnProperty.call(attribute, 'sparse_tensor')) { + attribute.type =onnx.AttributeType.SPARSE_TENSOR; + } + else { + attribute.type = onnx.AttributeType.UNDEFINED; + } + } + } } get metadata() { - return this._metadata; + return this._context.metadata; } - graph(value) { - if (!this._graphs.has(value)) { - this._graphs.set(value, new onnx.Graph(this, value)); + graph(name) { + return this._context.graph(name); + } + + tensor(name) { + if (!this._tensors.has(name)) { + this._tensors.set(name, { name: name }); } - return this._graphs.get(value); + return this._tensors.get(name); + } + + group(name) { + if (!this._groups.has(name)) { + const path = name.split('/'); + if (path.length > 1) { + path.pop(); + return this.group(path.join('/')); + } + this._groups.set(name, new Map([ [ '', [] ]])); + } + return this._groups.get(name); + } + + argument(name) { + if (!this._arguments.has(name)) { + const tensor = this.tensor(name); + const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; + this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description)); + } + return this._arguments.get(name); } createType(type) { @@ -1456,7 +1606,7 @@ onnx.Context = class { denotation = 'Tensor'; break; case 'IMAGE': - denotation = 'Image' + (this._imageFormat ? '(' + this._imageFormat.join(',') + ')' : ''); + denotation = 'Image' + (this._context.imageFormat ? '(' + this._context.imageFormat.join(',') + ')' : ''); break; case 'AUDIO': denotation = 'Audio'; @@ -1523,65 +1673,8 @@ onnx.Context = class { } return this._decoder.decode(value); } -}; - -onnx.Utility = class { - - static createTensors(nodes) { - const tensors = new Map(); - tensors.map = function(name) { - if (!this.has(name)) { - this.set(name, { name: name }); - } - return this.get(name); - }; - for (const node of nodes) { - node.input = node.input.map((name) => tensors.map(name)); - node.output = node.output.map((name) => tensors.map(name)); - node.param = {}; - for (const attribute of node.attribute) { - if (attribute.type) { - continue; - } - if (attribute.ints && attribute.ints.length > 0) { - attribute.type = onnx.AttributeType.INTS; - } - else if (attribute.floats && attribute.floats.length > 0) { - attribute.type = onnx.AttributeType.FLOATS; - } - else if (attribute.strings && attribute.strings.length > 0) { - attribute.type = onnx.AttributeType.STRINGS; - } - else if (attribute.graphs && attribute.graphs.length > 0) { - attribute.type = onnx.AttributeType.GRAPHS; - } - else if (attribute.s && attribute.s.length > 0) { - attribute.type = onnx.AttributeType.STRING; - } - else if (Object.prototype.hasOwnProperty.call(attribute, 'f')) { - attribute.type = onnx.AttributeType.FLOAT; - } - else if (Object.prototype.hasOwnProperty.call(attribute, 'i')) { - attribute.type = onnx.AttributeType.INT; - } - else if (Object.prototype.hasOwnProperty.call(attribute, 't')) { - attribute.type = onnx.AttributeType.TENSOR; - } - else if (Object.prototype.hasOwnProperty.call(attribute, 'g')) { - attribute.type = onnx.AttributeType.GRAPH; - } - else if (Object.prototype.hasOwnProperty.call(attribute, 'sparse_tensor')) { - attribute.type =onnx.AttributeType.SPARSE_TENSOR; - } - else { - attribute.type = onnx.AttributeType.UNDEFINED; - } - } - } - return tensors; - } - static createNodes(context, nodes, inputs, outputs, tensors, args) { + push(nodes, inputs, outputs) { const inputMap = new Map(); const outputMap = new Map(); for (const node of nodes) { @@ -1598,25 +1691,25 @@ onnx.Utility = class { node.output.length === 1 && node.output[0] && inputMap.get(node.output[0].name) === 1 && outputMap.get(node.output[0].name) === 1; const attribute = constant ? node.attribute[0] : null; if (attribute && attribute.name === 'value' && attribute.type === onnx.AttributeType.TENSOR && attribute.t) { - const tensor = tensors.map(node.output[0].name); - tensor.initializer = new onnx.Tensor(context, attribute.t, 'Constant'); + const tensor = this.tensor(node.output[0].name); + tensor.initializer = new onnx.Tensor(this, attribute.t, 'Constant'); return false; } else if (attribute && attribute.name === 'sparse_value' && attribute.type === onnx.AttributeType.SPARSE_TENSOR && attribute.sparse_tensor) { - const tensor = tensors.map(node.output[0].name); - tensor.initializer = new onnx.Tensor(context, attribute.sparse_tensor, 'Sparse Constant'); + const tensor = this.tensor(node.output[0].name); + tensor.initializer = new onnx.Tensor(this, attribute.sparse_tensor, 'Sparse Constant'); return false; } return true; }); - return nodes.map((node) => { - const schema = context.metadata.type(node.op_type, node.domain); + for (let node of nodes) { + const schema = this._context.metadata.type(node.op_type, node.domain); const inputs = []; node.input = node.input || []; for (let i = 0; i < node.input.length; ) { const input = schema && schema.inputs && i < schema.inputs.length ? schema.inputs[i] : { name: i.toString() }; const count = input.list ? node.input.length - i : 1; - const list = node.input.slice(i, i + count).map((input) => args.map(input.name)); + const list = node.input.slice(i, i + count).map((input) => this.argument(input.name)); inputs.push(new onnx.Parameter(input.name, list)); i += count; } @@ -1625,12 +1718,34 @@ onnx.Utility = class { for (let i = 0; i < node.output.length; ) { const output = schema && schema.outputs && i < schema.outputs.length ? schema.outputs[i] : { name: i.toString() }; const count = output.list ? node.output.length - i : 1; - const list = node.output.slice(i, i + count).map((output) => args.map(output.name)); + const list = node.output.slice(i, i + count).map((output) => this.argument(output.name)); outputs.push(new onnx.Parameter(output.name, list)); i += count; } - return new onnx.Node(context, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs); - }); + node = new onnx.Node(this, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs); + this._nodes.push(node); + + // const path = (node.name || '').split('/'); + // path.pop(); + // this.group(path.join('/')).get('').push(node); + } + } + + pop() { + /* + const nodes = []; + for (const entry of this._groups) { + if (entry[0] === '') { + for (const node of entry[1].get('')) { + nodes.push(node); + } + continue; + } + nodes.push(new onnx.Group(entry[0], entry[1])); + } + return nodes; + */ + return this._nodes; } };