Skip to content

Commit

Permalink
Fix ONNX graph attribute names (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 15, 2018
1 parent f0a3b01 commit b470ca6
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions src/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ onnx.Model = class {
this._graphs = [];
if (model && model.graph) {
var graphMetadata = new onnx.GraphMetadata(metadata, this._opsetImport);
var graph = new onnx.Graph(graphMetadata, model.graph, 0, this._imageFormat);
var graph = new onnx.Graph(graphMetadata, this._imageFormat, model.graph);
this._graphs.push(graph);
}
}
Expand Down Expand Up @@ -223,18 +223,16 @@ onnx.Model = class {

onnx.Graph = class {

constructor(metadata, graph, index, imageFormat) {
this._metadata = metadata;
constructor(metadata, imageFormat, graph) {
this._node = '';
this._description = '';
this._nodes = [];
this._inputs = [];
this._outputs = [];
this._operators = {};
this._imageFormat = imageFormat;

if (graph) {
this._name = graph.name || ('(' + index.toString() + ')');
this._name = graph.name || null;
this._description = graph.doc_string || '';

var initializers = {};
Expand Down Expand Up @@ -284,7 +282,7 @@ onnx.Graph = class {
nodes.forEach((node) => {
var inputs = [];
if (node.input) {
inputs = this._metadata.getInputs(node.op_type, node.input);
inputs = metadata.getInputs(node.op_type, node.input);
inputs = inputs.map((input) => {
return new onnx.Argument(input.name, input.connections.map((connection) => {
return this._connection(connections, connection.id, null, null, initializers[connection.id]);
Expand All @@ -293,14 +291,14 @@ onnx.Graph = class {
}
var outputs = [];
if (node.output) {
outputs = this._metadata.getOutputs(node.op_type, node.output);
outputs = metadata.getOutputs(node.op_type, node.output);
outputs = outputs.map((output) => {
return new onnx.Argument(output.name, output.connections.map((connection) => {
return this._connection(connections, connection.id, null, null, initializers[connection.id]);
}));
});
}
this._nodes.push(new onnx.Node(this, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs));
this._nodes.push(new onnx.Node(metadata, imageFormat, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs));
});
}
}
Expand Down Expand Up @@ -404,8 +402,8 @@ onnx.Connection = class {

onnx.Node = class {

constructor(graph, operator, domain, name, description, attributes, inputs, outputs) {
this._graph = graph;
constructor(metadata, imageFormat, operator, domain, name, description, attributes, inputs, outputs) {
this._metadata = metadata;
this._operator = operator;
if (domain) {
this._domain = domain;
Expand All @@ -419,7 +417,7 @@ onnx.Node = class {
this._attributes = [];
if (attributes && attributes.length > 0) {
attributes.forEach((attribute) => {
this._attributes.push(new onnx.Attribute(this.graph.metadata, this.operator, attribute));
this._attributes.push(new onnx.Attribute(this._metadata, imageFormat, this.operator, attribute));
});
}
this._inputs = inputs;
Expand All @@ -443,7 +441,7 @@ onnx.Node = class {
}

get documentation() {
var schema = this._graph.metadata.getSchema(this._operator);
var schema = this._metadata.getSchema(this._operator);
if (schema) {
var options = { baseUrl: 'https://github.com/onnx/onnx/blob/master/docs/' };
schema = JSON.parse(JSON.stringify(schema));
Expand Down Expand Up @@ -498,7 +496,7 @@ onnx.Node = class {
}

get category() {
var schema = this._graph.metadata.getSchema(this._operator);
var schema = this._metadata.getSchema(this._operator);
return (schema && schema.category) ? schema.category : null;
}

Expand All @@ -521,15 +519,11 @@ onnx.Node = class {
get dependencies() {
return [];
}

get graph() {
return this._graph;
}
};

onnx.Attribute = class {

constructor(metadata, operator, attribute) {
constructor(metadata, imageFormat, operator, attribute) {
this._name = attribute.name;
this._type = null;
this._value = null;
Expand Down Expand Up @@ -570,7 +564,7 @@ onnx.Attribute = class {
}
}
else if (attribute.graphs && attribute.graphs.length > 0) {
this._value = arg.graphs.map((graph) => new onnx.Graph(metadata, graph));
this._value = arg.graphs.map((graph) => new onnx.Graph(metadata, imageFormat, graph));
this._type = 'graph[]';
}
else if (attribute.s && attribute.s.length > 0) {
Expand All @@ -593,7 +587,7 @@ onnx.Attribute = class {
}
else if (attribute.hasOwnProperty('g')) {
this._type = 'graph';
this._value = new onnx.Graph(metadata, attribute.g);
this._value = new onnx.Graph(metadata, imageFormat, attribute.g);
}

var attributeSchema = metadata.getAttributeSchema(operator, attribute.name);
Expand Down

0 comments on commit b470ca6

Please sign in to comment.