From c2403706bb8f389c893c426dae37ecae3d401e56 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 17 Jun 2018 01:45:39 -0700 Subject: [PATCH] ONNX loader towards shared edge objects (#71) --- src/onnx-model.js | 398 ++++++++++++++++++++++------------------------ 1 file changed, 191 insertions(+), 207 deletions(-) diff --git a/src/onnx-model.js b/src/onnx-model.js index e7febd9dd72..54f9861bb90 100644 --- a/src/onnx-model.js +++ b/src/onnx-model.js @@ -43,30 +43,26 @@ class OnnxModel { constructor(model) { this._model = model; - this._graphs = []; - this._activeGraph = null; - if (this._model.graph) { - var metadata = new OnnxGraphOperatorMetadata(this._model); - var graph = new OnnxGraph(this, metadata, this._model.graph, 0); - this._graphs.push(graph); - this._activeGraph = graph; - } + this._irVersion = model.irVersion; + this._opsetImport = model.opsetImport; + this._producerName = model.producerName; + this._producerVersion = model.producerVersion; + this._domain = model.domain; + this._modelVersion = model.modelVersion; + this._docString = model.docString; + this._metadataProps = model.metadataProps; } get properties() { var results = []; var format = 'ONNX'; - if (this._model.irVersion) { - format = format + ' v' + this._model.irVersion.toString(); - // var major = (this._model.irVersion >> 16) & 0x0f; - // var minor = (this._model.irVersion >> 8) & 0x0f; - // var revision = (this._model.irVersion) & 0x0f; - // format = format + ' v' + major.toString() + '.' + minor.toString() + '.' + revision.toString(); + if (this._irVersion) { + format = format + ' v' + this._irVersion.toString(); } results.push({ name: 'Format', value: format }); - if (this._model.opsetImport && this._model.opsetImport.length > 0) { + if (this._opsetImport && this._opsetImport.length > 0) { var opsetImports = []; - this._model.opsetImport.forEach((opsetImport) => { + this._opsetImport.forEach((opsetImport) => { var domain = opsetImport.domain ? opsetImport.domain : 'ai.onnx'; var result = domain + ' v' + opsetImport.version; if (!opsetImports.includes(result)) { @@ -76,28 +72,28 @@ class OnnxModel { results.push({ name: 'Imports', value: opsetImports.join(', ') }); } var producer = []; - if (this._model.producerName) { - producer.push(this._model.producerName); + if (this._producerName) { + producer.push(this._producerName); } - if (this._model.producerVersion && this._model.producerVersion.length > 0) { - producer.push(this._model.producerVersion); + if (this._producerVersion && this._producerVersion.length > 0) { + producer.push(this._producerVersion); } if (producer.length > 0) { results.push({ 'name': 'Producer', 'value': producer.join(' ') }); } - if (this._model.domain) { - results.push({ name: 'Domain', value: this._model.domain }); + if (this._domain) { + results.push({ name: 'Domain', value: this._domain }); } - if (this._model.modelVersion) { - results.push({ name: 'Version', value: this._model.modelVersion }); + if (this._modelVersion) { + results.push({ name: 'Version', value: this._modelVersion }); } if (this._model.docString) { - results.push({ name: 'Description', value: this._model.docString }); + results.push({ name: 'Description', value: this._docString }); } var metadata = {}; - if (this._model.metadataProps) + if (this._metadataProps) { - this._model.metadataProps.forEach((metadataProp) => { + this._metadataProps.forEach((metadataProp) => { metadata[metadataProp.key] = metadataProp.value; }); } @@ -125,68 +121,130 @@ class OnnxModel { } get graphs() { + if (this._model) { + this._graphs = []; + var metadata = new OnnxGraphOperatorMetadata(this._opsetImport); + this._graphs.push(new OnnxGraph(metadata, this._model.graph, 0)); + delete this._model; + } return this._graphs; } } class OnnxGraph { - constructor(model, metadata, graph, index) { - this._model = model; + constructor(metadata, graph, index) { this._metadata = metadata; - this._graph = graph; + this._node = ''; + this._description = ''; this._nodes = []; - this._initializerMap = []; - this._valueInfoMap = []; - this._outputMap = {}; - if (this._graph) { - this._name = this._graph.name ? this._graph.name : ('(' + index.toString() + ')'); + if (graph) { + var initializerMap = []; + var valueInfoMap = []; + + this._name = graph.name || ('(' + index.toString() + ')'); + this._description = graph.docString || ''; - this._graph.node.forEach((node) => { + var nodes = []; + var outputCountMap = {}; + graph.node.forEach((node) => { node.output.forEach((output) => { - this._outputMap[output] = (this._outputMap[output] || 0) + 1; + outputCountMap[output] = (outputCountMap[output] || 0) + 1; }); }); - - this._graph.initializer.forEach((tensor) => { - this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer'); - }); - this._graph.node.forEach((node) => { - var add = true; - if (node.opType == 'Constant' && node.output && node.output.length == 1 && this._outputMap[node.output[0]] == 1) { - node.attribute.forEach((attribute) => { - if (attribute.name == 'value' && attribute.t) { - var name = node.output[0]; - this._initializerMap[name] = new OnnxTensor(attribute.t, name, 'Constant'); - add = false; + graph.node.forEach((node) => { + var initializerNode = false; + if (node.opType == 'Constant' && node.output && node.output.length == 1) { + var name = node.output[0]; + if (outputCountMap[name] == 1) { + var attribute = node.attribute.find((attribute) => { return attribute.name == 'value' && attribute.t; }); + if (attribute) { + initializerMap[name] = new OnnxTensor(attribute.t, name, 'Constant'); + initializerNode = true; } - }); + } } - if (add) { - this._nodes.push(new OnnxNode(this, node)); + if (!initializerNode) { + nodes.push(node); } }); - this._graph.valueInfo.forEach((valueInfo) => { - this._valueInfoMap[valueInfo.name] = valueInfo; + graph.initializer.forEach((tensor) => { + initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer'); + }); + graph.valueInfo.forEach((valueInfo) => { + valueInfoMap[valueInfo.name] = valueInfo; }); - } - } - get model() { - return this._model; + this._inputs = []; + graph.input.forEach((valueInfo) => { + if (!initializerMap[valueInfo.name]) { + this._inputs.push({ + id: valueInfo.name, + name: valueInfo.name, + description: valueInfo.docString, + type: OnnxTensor.formatType(valueInfo.type) + }); + valueInfoMap[valueInfo.name] = valueInfo; + } + }); + + this._outputs = graph.output.map((valueInfo) => { + valueInfoMap[valueInfo.name] = valueInfo; + return { + id: valueInfo.name, + name: valueInfo.name, + description: valueInfo.docString, + type: OnnxTensor.formatType(valueInfo.type) + }; + }); + + nodes.forEach((node) => { + var inputs = []; + if (node.input) { + inputs = this._metadata.getInputs(node.opType, node.input); + inputs.forEach((input) => { + input.connections.forEach((connection) => { + var initializer = initializerMap[connection.id]; + if (initializer) { + connection.initializer = initializer; + connection.type = initializer.type; + } + else { + var valueInfo = valueInfoMap[connection.id]; + if (valueInfo) { + connection.type = OnnxTensor.formatType(valueInfo.type); + connection.description = valueInfo.docString; + } + } + }); + }); + } + var outputs = []; + if (node.output) { + outputs = this._metadata.getOutputs(node.opType, node.output); + outputs.forEach((output) => { + output.connections.forEach((connection) => { + var valueInfo = valueInfoMap[connection.id]; + if (valueInfo) { + connection.type = OnnxTensor.formatType(valueInfo.type); + connection.description = valueInfo.docString; + } + }); + }); + } + this._nodes.push(new OnnxNode(this, node.opType, node.domain, node.name, node.docString, node.attribute, inputs, outputs)); + }); + } } get name() { - return this._name || ''; + return this._name; } get description() { - if (this._graph && this._graph.docString) { - return this._graph.docString; - } - return ''; + return this._description; } get groups() { @@ -194,44 +252,10 @@ class OnnxGraph { } get inputs() { - if (!this._inputs) { - this._inputs = []; - if (this._graph) { - var initializerMap = {}; - this._graph.initializer.forEach((tensor) => { - initializerMap[tensor.name] = true; - }); - this._graph.input.forEach((valueInfo) => { - if (!initializerMap[valueInfo.name]) { - this._inputs.push({ - id: valueInfo.name, - name: valueInfo.name, - description: valueInfo.docString, - type: OnnxTensor.formatType(valueInfo.type) - }); - this._valueInfoMap[valueInfo.name] = valueInfo; - } - }); - } - } return this._inputs; } get outputs() { - if (!this._outputs) { - this._outputs = []; - if (this._graph) { - this._outputs = this._graph.output.map((valueInfo) => { - this._valueInfoMap[valueInfo.name] = valueInfo; - return { - id: valueInfo.name, - name: valueInfo.name, - description: valueInfo.docString, - type: OnnxTensor.formatType(valueInfo.type) - }; - }); - } - } return this._outputs; } @@ -239,16 +263,6 @@ class OnnxGraph { return this._nodes; } - getInitializer(input) { - var initializer = this._initializerMap[input]; - return initializer ? initializer : null; - } - - getValueInfo(input) { - var valueInfo = this._valueInfoMap[input]; - return valueInfo ? valueInfo : null; - } - get metadata() { return this._metadata; } @@ -256,21 +270,38 @@ class OnnxGraph { class OnnxNode { - constructor(graph, node) { + constructor(graph, operator, domain, name, description, attributes, inputs, outputs) { this._graph = graph; - this._node = node; + this._operator = operator; + if (domain) { + this._domain = domain; + } + if (name) { + this._name = name; + } + if (description) { + this._description = description; + } + this._attributes = []; + if (attributes && attributes.length > 0) { + attributes.forEach((attribute) => { + this._attributes.push(new OnnxAttribute(this, attribute)); + }); + } + this._inputs = inputs; + this._outputs = outputs; } get operator() { - return this._node.opType; + return this._operator; } get name() { - return this._node.name ? this._node.name : null; + return this._name || null; } get description() { - return this._node.docString ? this._node.docString : null; + return this._description || null; } get primitive() { @@ -278,86 +309,44 @@ class OnnxNode { } get documentation() { - return this._graph.metadata.getOperatorDocumentation(this); + return this._graph.metadata.getOperatorDocumentation(this._operator); } get domain() { - return this._node.domain ? this._node.domain : null; + return this._domain || null; } get category() { - return this._graph.metadata.getOperatorCategory(this); + return this._graph.metadata.getOperatorCategory(this._operator); } get group() { return null; } + get attributes() { + return this._attributes; + } + get inputs() { - if (this._node.input) { - var inputs = this._graph.metadata.getInputs(this); - inputs.forEach((input) => { - input.connections.forEach((connection) => { - var initializer = this._graph.getInitializer(connection.id); - if (initializer) { - connection.initializer = initializer; - connection.type = initializer.type; - } - else { - var valueInfo = this._graph.getValueInfo(connection.id); - if (valueInfo) { - connection.type = OnnxTensor.formatType(valueInfo.type); - } - } - }); - }); - return inputs; - } - return []; + return this._inputs; } get outputs() { - if (this._node.output) { - var outputs = this._graph.metadata.getOutputs(this); - outputs.forEach((output) => { - output.connections.forEach((connection) => { - var valueInfo = this._graph.getValueInfo(connection.id); - if (valueInfo) { - connection.type = OnnxTensor.formatType(valueInfo.type); - } - }); - }); - return outputs; - } - return []; + return this._outputs; } get dependencies() { return []; } - get attributes() { - var result = null; - var node = this._node; - if (node.attribute && node.attribute.length > 0) { - result = []; - node.attribute.forEach((attribute) => { - result.push(new OnnxAttribute(this, attribute)); - }); - } - return result; - } - get graph() { return this._graph; } - - get data() { - return this._node; - } } class OnnxAttribute { + constructor(node, attribute) { this._node = node; this._attribute = attribute; @@ -379,7 +368,7 @@ class OnnxAttribute { else if (this._attribute.hasOwnProperty('t')) { return OnnxTensor.formatTensorType(this._attribute.t); } - return this._node.graph.metadata.getAttributeType(this._node, this._attribute.name); + return this._node.graph.metadata.getAttributeType(this._node.operator, this._attribute.name); } get value() { @@ -430,7 +419,7 @@ class OnnxAttribute { } get visible() { - return this._node.graph.metadata.getAttributeVisible(this._node, this); + return this._node.graph.metadata.getAttributeVisible(this._node.operator, this); } get tensor() { @@ -443,9 +432,7 @@ class OnnxTensor { constructor(tensor, id, kind) { this._tensor = tensor; this._id = id; - if (kind) { - this._kind = kind; - } + this._kind = kind || null; } get id() { @@ -457,7 +444,7 @@ class OnnxTensor { } get kind() { - return this._kind ? this._kind : null; + return this._kind; } get type() { @@ -686,11 +673,11 @@ class OnnxTensor { class OnnxGraphOperatorMetadata { - constructor(model) { + constructor(opsetImport) { this._cache = {}; this._imports = {}; - if (model.opsetImport) { - model.opsetImport.forEach((opsetImport) => { + if (opsetImport) { + opsetImport.forEach((opsetImport) => { var domain = opsetImport.domain || ''; if (domain == 'ai.onnx') { domain = ''; @@ -706,8 +693,7 @@ class OnnxGraphOperatorMetadata { } } - getSchema(node) { - var operator = node.operator; + getSchema(operator) { var schema = this._cache[operator]; if (!schema) { schema = OnnxOperatorMetadata.operatorMetadata.getSchema(operator, this._imports); @@ -718,8 +704,8 @@ class OnnxGraphOperatorMetadata { return schema; } - getAttributeSchema(node, name) { - var schema = this.getSchema(node); + getAttributeSchema(operator, name) { + var schema = this.getSchema(operator); if (schema) { var attributeMap = schema.attributeMap; if (!attributeMap) { @@ -739,32 +725,31 @@ class OnnxGraphOperatorMetadata { return null; } - getInputs(node) { - var inputs = []; + getInputs(operator, inputs) { + var results = []; var index = 0; - var schema = this.getSchema(node); - var data = node.data; + var schema = this.getSchema(operator); if (schema && schema.inputs) { schema.inputs.forEach((inputDef) => { - if (index < data.input.length || inputDef.option != 'optional') { + if (index < inputs.length || inputDef.option != 'optional') { var input = {}; input.name = inputDef.name; input.type = inputDef.type; - var count = (inputDef.option == 'variadic') ? (data.input.length - index) : 1; + var count = (inputDef.option == 'variadic') ? (inputs.length - index) : 1; input.connections = []; - data.input.slice(index, index + count).forEach((id) => { + inputs.slice(index, index + count).forEach((id) => { if (id != '' || inputDef.option != 'optional') { input.connections.push({ id: id}); } }); index += count; - inputs.push(input); + results.push(input); } }); } else { - data.input.slice(index).forEach((input) => { - inputs.push({ + inputs.slice(index).forEach((input) => { + results.push({ name: '(' + index.toString() + ')', connections: [ { id: input } ] }); @@ -772,31 +757,30 @@ class OnnxGraphOperatorMetadata { }); } - return inputs; + return results; } - getOutputs(node) { - var outputs = []; + getOutputs(operator, outputs) { + var results = []; var index = 0; - var schema = this.getSchema(node); - var data = node.data; + var schema = this.getSchema(operator); if (schema && schema.outputs) { schema.outputs.forEach((outputDef) => { - if (index < data.output.length || outputDef.option != 'optional') { + if (index < outputs.length || outputDef.option != 'optional') { var output = {}; output.name = outputDef.name; var count = (outputDef.option == 'variadic') ? (data.output.length - index) : 1; - output.connections = data.output.slice(index, index + count).map((id) => { + output.connections = outputs.slice(index, index + count).map((id) => { return { id: id }; }); index += count; - outputs.push(output); + results.push(output); } }); } else { - data.output.slice(index).forEach((output) => { - outputs.push({ + outputs.slice(index).forEach((output) => { + results.push({ name: '(' + index.toString() + ')', connections: [ { id: output } ] }); @@ -804,19 +788,19 @@ class OnnxGraphOperatorMetadata { }); } - return outputs; + return results; } - getAttributeType(node, name) { - var schema = this.getAttributeSchema(node, name); + getAttributeType(operator, name) { + var schema = this.getAttributeSchema(operator, name); if (schema && schema.type) { return schema.type; } return ''; } - getAttributeVisible(node, attribute) { - var schema = this.getAttributeSchema(node, attribute.name); + getAttributeVisible(operator, attribute) { + var schema = this.getAttributeSchema(operator, attribute.name); if (schema && schema.hasOwnProperty('default') && schema.default) { if (attribute.value == schema.default.toString()) { return false; @@ -825,19 +809,19 @@ class OnnxGraphOperatorMetadata { return true; } - getOperatorCategory(node) { - var schema = this.getSchema(node); + getOperatorCategory(operator) { + var schema = this.getSchema(operator); if (schema && schema.category) { return schema.category; } return null; } - getOperatorDocumentation(node) { - var schema = this.getSchema(node); + getOperatorDocumentation(operator) { + var schema = this.getSchema(operator); if (schema) { schema = JSON.parse(JSON.stringify(schema)); - schema.name = node.operator; + schema.name = operator; if (schema.description) { var input = schema.description.split('\n'); var output = [];