diff --git a/src/onnx.js b/src/onnx.js index b5134fe304..0a76472b9b 100644 --- a/src/onnx.js +++ b/src/onnx.js @@ -136,7 +136,7 @@ onnx.Model = class { this._graphs = []; if (model && model.graph) { var graphMetadata = new onnx.GraphMetadata(metadata, this._opsetImport); - var graph = new onnx.Graph(graphMetadata, model.graph, 0, this._imageFormat); + var graph = new onnx.Graph(graphMetadata, this._imageFormat, model.graph); this._graphs.push(graph); } } @@ -223,18 +223,16 @@ onnx.Model = class { onnx.Graph = class { - constructor(metadata, graph, index, imageFormat) { - this._metadata = metadata; + constructor(metadata, imageFormat, graph) { this._node = ''; this._description = ''; this._nodes = []; this._inputs = []; this._outputs = []; this._operators = {}; - this._imageFormat = imageFormat; if (graph) { - this._name = graph.name || ('(' + index.toString() + ')'); + this._name = graph.name || null; this._description = graph.doc_string || ''; var initializers = {}; @@ -284,7 +282,7 @@ onnx.Graph = class { nodes.forEach((node) => { var inputs = []; if (node.input) { - inputs = this._metadata.getInputs(node.op_type, node.input); + inputs = metadata.getInputs(node.op_type, node.input); inputs = inputs.map((input) => { return new onnx.Argument(input.name, input.connections.map((connection) => { return this._connection(connections, connection.id, null, null, initializers[connection.id]); @@ -293,14 +291,14 @@ onnx.Graph = class { } var outputs = []; if (node.output) { - outputs = this._metadata.getOutputs(node.op_type, node.output); + outputs = metadata.getOutputs(node.op_type, node.output); outputs = outputs.map((output) => { return new onnx.Argument(output.name, output.connections.map((connection) => { return this._connection(connections, connection.id, null, null, initializers[connection.id]); })); }); } - this._nodes.push(new onnx.Node(this, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs)); + this._nodes.push(new onnx.Node(metadata, imageFormat, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs)); }); } } @@ -404,8 +402,8 @@ onnx.Connection = class { onnx.Node = class { - constructor(graph, operator, domain, name, description, attributes, inputs, outputs) { - this._graph = graph; + constructor(metadata, imageFormat, operator, domain, name, description, attributes, inputs, outputs) { + this._metadata = metadata; this._operator = operator; if (domain) { this._domain = domain; @@ -419,7 +417,7 @@ onnx.Node = class { this._attributes = []; if (attributes && attributes.length > 0) { attributes.forEach((attribute) => { - this._attributes.push(new onnx.Attribute(this.graph.metadata, this.operator, attribute)); + this._attributes.push(new onnx.Attribute(this._metadata, imageFormat, this.operator, attribute)); }); } this._inputs = inputs; @@ -443,7 +441,7 @@ onnx.Node = class { } get documentation() { - var schema = this._graph.metadata.getSchema(this._operator); + var schema = this._metadata.getSchema(this._operator); if (schema) { var options = { baseUrl: 'https://github.com/onnx/onnx/blob/master/docs/' }; schema = JSON.parse(JSON.stringify(schema)); @@ -498,7 +496,7 @@ onnx.Node = class { } get category() { - var schema = this._graph.metadata.getSchema(this._operator); + var schema = this._metadata.getSchema(this._operator); return (schema && schema.category) ? schema.category : null; } @@ -521,15 +519,11 @@ onnx.Node = class { get dependencies() { return []; } - - get graph() { - return this._graph; - } }; onnx.Attribute = class { - constructor(metadata, operator, attribute) { + constructor(metadata, imageFormat, operator, attribute) { this._name = attribute.name; this._type = null; this._value = null; @@ -570,7 +564,7 @@ onnx.Attribute = class { } } else if (attribute.graphs && attribute.graphs.length > 0) { - this._value = arg.graphs.map((graph) => new onnx.Graph(metadata, graph)); + this._value = arg.graphs.map((graph) => new onnx.Graph(metadata, imageFormat, graph)); this._type = 'graph[]'; } else if (attribute.s && attribute.s.length > 0) { @@ -593,7 +587,7 @@ onnx.Attribute = class { } else if (attribute.hasOwnProperty('g')) { this._type = 'graph'; - this._value = new onnx.Graph(metadata, attribute.g); + this._value = new onnx.Graph(metadata, imageFormat, attribute.g); } var attributeSchema = metadata.getAttributeSchema(operator, attribute.name);