Skip to content

Commit

Permalink
Update onnx.js (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 17, 2024
1 parent cee0b8c commit ac10c28
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 224 deletions.
245 changes: 28 additions & 217 deletions source/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,11 @@ onnx.ProtoReader = class {
return new onnx.ProtoReader(context, 'text', 'model');
}
}
const obj = context.peek('json');
if (obj && (obj.ir_version === undefined && obj.producer_name === undefined && !Array.isArray(obj.opset_import) && !Array.isArray(obj.metadata_props)) &&
(obj.irVersion !== undefined || obj.producerName !== undefined || Array.isArray(obj.opsetImport) || Array.isArray(obj.metadataProps) || (Array.isArray(obj.graph) && Array.isArray(obj.graph.node)))) {
return new onnx.ProtoReader(context, 'json', 'model');
}
return undefined;
}

Expand All @@ -1588,6 +1593,17 @@ onnx.ProtoReader = class {
}
break;
}
case 'json': {
try {
const obj = this.context.read('json');
this.model = onnx.proto.ModelProto.decodeJson(obj);
this.format = `ONNX${this.model.ir_version ? ` v${this.model.ir_version}` : ''}`;
} catch (error) {
const message = error && error.message ? error.message : error.toString();
throw new onnx.Error(`File JSON format is not onnx.ModelProto (${message.replace(/\.$/, '')}).`);
}
break;
}
case 'binary': {
switch (this.type) {
case 'tensor': {
Expand Down Expand Up @@ -1877,7 +1893,8 @@ onnx.JsonReader = class {

static open(context) {
const obj = context.peek('json');
if (obj && (obj.irVersion !== undefined || obj.ir_Version !== undefined || (obj.graph && Array.isArray(obj.graph.node)))) {
if (obj && obj.framework === undefined && obj.graph &&
(obj.ir_version !== undefined || obj.producer_name !== undefined || Array.isArray(obj.opset_import))) {
return new onnx.JsonReader(obj);
}
return null;
Expand All @@ -1886,145 +1903,20 @@ onnx.JsonReader = class {
constructor(obj) {
this.name = 'onnx.json';
this.model = obj;
this._attributeTypes = new Map(Object.entries(onnx.AttributeType));
}

async read() {
this.model = this._model(this.model);
this.format = `ONNX JSON${this.model.ir_version ? ` v${this.model.ir_version}` : ''}`;
}

_tensor_shape(value) {
if (Array.isArray(value.dim)) {
for (const dimension of value.dim) {
if (dimension.dimValue !== undefined) {
dimension.dim_value = parseInt(dimension.dimValue, 10);
delete dimension.dimValue;
} else if (dimension.dimParam !== undefined) {
dimension.dim_param = dimension.dimParam;
delete dimension.dimParam;
}
}
}
return value;
}

_tensor_type(value) {
if (value.elemType !== undefined) {
value.elem_type = value.elemType;
delete value.elemType;
}
if (value.shape) {
value.shape = this._tensor_shape(value.shape);
}
return value;
}

_optional_type(value) {
if (value.elemType !== undefined) {
value.elem_type = this._type(value.elemType);
delete value.elemType;
}
return value;
}

_sequence_type(value) {
if (value.elemType !== undefined) {
value.elem_type = this._type(value.elemType);
delete value.elemType;
}
return value;
}

_map_type(value) {
if (value.keyType !== undefined) {
value.key_type = value.keyType;
delete value.keyType;
}
if (value.valueType !== undefined) {
value.value_type = this._type(value.valueType);
delete value.valueType;
}
return value;
}

_sparse_tensor_type(value) {
if (value.elemType !== undefined) {
value.elem_type = value.elemType;
delete value.elemType;
}
if (value.shape) {
value.shape = this._tensor_shape(value.shape);
}
return value;
}

_type(value) {
if (value.tensorType) {
value.tensor_type = this._tensor_type(value.tensorType);
delete value.tensorType;
} else if (value.tensor_type) {
value.tensor_type = this._tensor_type(value.tensor_type);
} else if (value.sequenceType) {
value.sequence_type = this._sequence_type(value.sequenceType);
delete value.sequenceType;
} else if (value.sequence_type) {
value.sequence_type = this._sequence_type(value.sequenceType);
} else if (value.optionalType !== undefined) {
value.optional_type = this._optional_type(value.optionalType);
delete value.optionalType;
} else if (value.optional_type) {
value.optional_type = this._optional_type(value.optionalType);
} else if (value.mapType) {
value.map_type = this._map_type(value.mapType);
delete value.mapType;
} else if (value.map_type) {
value.map_type = this._map_type(value.mapType);
} else if (value.sparseTensorType) {
value.sparse_tensor_type = this._sparse_tensor_type(value.sparseTensorType);
delete value.sparseTensorType;
} else if (value.sparse_tensor_type) {
value.sparse_tensor_type = this._sparse_tensor_type(value.sparseTensorType);
} else if (Object.keys(value).length > 0) {
throw new onnx.Error(`Unsupported ONNX JSON type '${JSON.stringify(Object.keys(value))}'.`);
}
return value;
}

_tensor(value) {
if (value.dataType !== undefined) {
value.data_type = value.dataType;
delete value.dataType;
}
value.dims = Array.isArray(value.dims) ? value.dims.map((dim) => parseInt(dim, 10)) : [];
value.dims = Array.isArray(value.dims) ? value.dims : [];
if (value.raw_data !== undefined) {
if (value.raw_data && value.raw_data instanceof Uint8Array === false &&
value.raw_data.type === 'Buffer' && Array.isArray(value.raw_data.data)) {
if (value.raw_data && value.raw_data instanceof Uint8Array === false && value.raw_data.type === 'Buffer' && Array.isArray(value.raw_data.data)) {
value.data_location = onnx.DataLocation.DEFAULT;
value.raw_data = new Uint8Array(value.raw_data.data);
}
} else if (value.rawData !== undefined) {
value.data_location = onnx.DataLocation.DEFAULT;
const data = atob(value.rawData);
const length = data.length;
const array = new Uint8Array(length);
for (let i = 0; i < length; i++) {
array[i] = data[i].charCodeAt(0);
}
value.raw_data = array;
delete value.rawData;
} else if (Array.isArray(value.floatData)) {
value.data_location = onnx.DataLocation.DEFAULT;
value.float_data = value.floatData;
delete value.floatData;
} else if (Array.isArray(value.int32Data)) {
value.data_location = onnx.DataLocation.DEFAULT;
value.int32_data = value.int32Data;
delete value.int32Data;
} else if (Array.isArray(value.int64Data)) {
value.data_location = onnx.DataLocation.DEFAULT;
value.int64_data = value.int64Data.map((value) => parseInt(value, 10));
delete value.int64Data;
} else if ((Array.isArray(value.float_data) && value.float_data.length > 0) ||
(Array.isArray(value.int32_data) && value.int32_data.length > 0) ||
(Array.isArray(value.int64_data) && value.int64_data.length > 0)) {
Expand All @@ -2042,13 +1934,7 @@ onnx.JsonReader = class {
}

_attribute(value) {
if (value.type && this._attributeTypes.has(value.type)) {
value.type = this._attributeTypes.get(value.type);
}
if (value.refAttrName) {
value.ref_attr_name = value.refAttrName;
delete value.refAttrName;
} else if (value.ref_attr_name) {
if (value.ref_attr_name) {
value.ref_attr_name = value.ref_attr_name.toString();
} else if (value.type === onnx.AttributeType.FLOATS || (Array.isArray(value.floats) && value.floats.length > 0)) {
value.floats = value.floats.map((value) => parseFloat(value));
Expand All @@ -2060,24 +1946,18 @@ onnx.JsonReader = class {
value.tensors = value.tensors.map((value) => this._tensor(value));
} else if (value.type === onnx.AttributeType.GRAPHS || (Array.isArray(value.graphs) && value.graphs.length > 0)) {
value.graphs = value.graphs.map((value) => this._graph(value));
} else if (value.type === onnx.AttributeType.SPARSE_TENSORS || (Array.isArray(value.sparseTensors) && value.sparseTensors.length > 0)) {
value.sparse_tensors = value.sparseTensors.map((item) => this._sparse_tensor(item));
delete value.sparseTensors;
} else if (value.type === onnx.AttributeType.SPARSE_TENSORS || (Array.isArray(value.sparse_tensors) && value.sparse_tensors.length > 0)) {
value.sparse_tensors = value.sparse_tensors.map((item) => this._sparse_tensor(item));
} else if (value.type === onnx.AttributeType.FLOAT || value.f !== undefined) {
value.f = parseFloat(value.f);
// continue
} else if (value.type === onnx.AttributeType.INT || value.i !== undefined) {
value.i = parseInt(value.i, 10);
// continue
} else if (value.type === onnx.AttributeType.STRING || value.s !== undefined) {
value.s = atob(value.s);
} else if (value.type === onnx.AttributeType.TENSOR || value.t !== undefined) {
value.t = this._tensor(value.t);
} else if (value.type === onnx.AttributeType.GRAPH || value.g !== undefined) {
value.g = this._graph(value.g);
} else if (value.type === onnx.AttributeType.SPARSE_TENSOR || value.sparseTensor !== undefined) {
value.sparse_tensor = this._sparse_tensor(value.sparseTensor);
delete value.sparseTensor;
} else if (value.type === onnx.AttributeType.SPARSE_TENSOR || value.sparse_tensor !== undefined) {
value.sparse_tensor = this._sparse_tensor(value.sparse_tensor);
} else {
Expand All @@ -2087,43 +1967,18 @@ onnx.JsonReader = class {
}

_node(value) {
if (value.opType !== undefined) {
value.op_type = value.opType;
delete value.opType;
}
value.input = Array.isArray(value.input) ? value.input : [];
value.output = Array.isArray(value.output) ? value.output : [];
value.attribute = Array.isArray(value.attribute) ? value.attribute.map((value) => this._attribute(value)) : [];
return value;
}

_value_info(value) {
value.type = this._type(value.type);
return value;
}

_operator_set(value) {
value.version = parseInt(value.version, 10);
return value;
}

_graph(value) {
value.node = value.node.map((value) => this._node(value));
value.initializer = Array.isArray(value.initializer) ? value.initializer.map((value) => this._tensor(value)) : [];
if (Array.isArray(value.sparseInitializer) && value.sparseInitializer.length > 0) {
value.sparse_initializer = value.sparseInitializer.map((item) => this._sparse_tensor(item));
delete value.sparseInitializer;
} else if (Array.isArray(value.sparse_initializer) && value.sparse_initializer.length > 0) {
value.sparse_initializer = value.sparseInitializer.map((item) => this._sparse_tensor(item));
}
if (Array.isArray(value.valueInfo) && value.valueInfo.length > 0) {
value.value_info = value.valueInfo.map((item) => this._value_info(item));
delete value.valueInfo;
} else if (Array.isArray(value.value_info) && value.value_info.length > 0) {
value.value_info = value.value_info.map((item) => this._value_info(item));
}
value.input = Array.isArray(value.input) ? value.input.map((value) => this._value_info(value)) : [];
value.output = Array.isArray(value.output) ? value.output.map((value) => this._value_info(value)) : [];
value.sparse_initializer = Array.isArray(value.sparse_initializer) ? value.sparse_initializer.map((item) => this._sparse_tensor(item)) : [];
value.input = Array.isArray(value.input) ? value.input : [];
value.output = Array.isArray(value.output) ? value.output : [];
return value;
}

Expand All @@ -2132,57 +1987,13 @@ onnx.JsonReader = class {
value.input = Array.isArray(value.input) ? value.input : [];
value.output = Array.isArray(value.output) ? value.output : [];
value.attribute = Array.isArray(value.attribute) ? value.attribute : [];
if (Array.isArray(value.attributeProto) && value.attributeProto.length > 0) {
value.attribute_proto = value.attributeProto.map((value) => this._attribute(value));
delete value.attributeProto;
} else if (Array.isArray(value.attribute_proto) && value.attribute_proto.length > 0) {
value.attribute_proto = value.attribute_proto.map((value) => this._attribute(value));
}
if (value.docString) {
value.doc_string = value.docString;
delete value.docString;
}
value.attribute_proto = Array.isArray(value.attribute_proto) ? value.attribute_proto.map((value) => this._attribute(value)) : [];
return value;
}

_model(value) {
if (value.irVersion !== undefined) {
value.ir_version = parseInt(value.irVersion, 10);
delete value.irVersion;
}
if (value.version !== undefined) {
value.version = parseInt(value.version, 10);
}
if (value.producerName) {
value.producer_name = value.producerName;
delete value.producerName;
}
if (value.producerVersion) {
value.producer_version = value.producerVersion;
delete value.producerVersion;
}
if (value.modelVersion) {
value.model_version = parseInt(value.modelVersion, 10);
delete value.modelVersion;
}
if (value.docString) {
value.doc_string = value.docString;
delete value.docString;
}
value.graph = this._graph(value.graph);
if (Array.isArray(value.opsetImport) && value.opsetImport.length > 0) {
value.opset_import = value.opsetImport.map((item) => this._operator_set(item));
delete value.opsetImport;
} else if (Array.isArray(value.opset_import) && value.opset_import.length > 0) {
value.opset_import = value.opset_import.map((item) => this._operator_set(item));
}
if (Array.isArray(value.metadataProps)) {
value.metadata_props = value.metadataProps;
delete value.metadataProps;
}
if (Array.isArray(value.functions)) {
value.functions = value.functions.map((item) => this._function(item));
}
value.functions = Array.isArray(value.functions) ? value.functions.map((item) => this._function(item)) : [];
return value;
}
};
Expand Down
14 changes: 7 additions & 7 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4074,7 +4074,7 @@
"type": "onnx",
"target": "candy.json.zip",
"source": "https://github.com/lutzroeder/netron/files/12329067/candy.json.zip",
"format": "ONNX JSON v3",
"format": "ONNX v3",
"assert": "model.graphs[0].nodes[2].attributes[1].visible == false",
"tags": "validation",
"link": "https://github.com/lutzroeder/netron/issues/6"
Expand Down Expand Up @@ -4254,7 +4254,7 @@
"type": "onnx",
"target": "gather.json",
"source": "https://github.com/lutzroeder/netron/files/12306625/gather.json.zip[gather.json]",
"format": "ONNX JSON v6",
"format": "ONNX v6",
"link": "https://github.com/lutzroeder/netron/issues/6"
},
{
Expand Down Expand Up @@ -4374,7 +4374,7 @@
"type": "onnx",
"target": "issue_1138.json",
"source": "https://github.com/lutzroeder/netron/files/12343742/issue_1138.json.zip[issue_1138.json]",
"format": "ONNX JSON v9",
"format": "ONNX v9",
"link": "https://github.com/lutzroeder/netron/issues/1138"
},
{
Expand Down Expand Up @@ -4484,7 +4484,7 @@
"type": "onnx",
"target": "nms_base_component.json",
"source": "https://github.com/lutzroeder/netron/files/12306626/nms_base_component.json.zip[nms_base_component.json]",
"format": "ONNX JSON v8",
"format": "ONNX v8",
"link": "https://github.com/lutzroeder/netron/issues/6"
},
{
Expand All @@ -4507,7 +4507,7 @@
"type": "onnx",
"target": "optional_type.json",
"source": "https://github.com/lutzroeder/netron/files/12329086/optional_type.json.zip[optional_type.json]",
"format": "ONNX JSON v8",
"format": "ONNX v8",
"link": "https://github.com/lutzroeder/netron/issues/6"
},
{
Expand Down Expand Up @@ -4600,7 +4600,7 @@
"type": "onnx",
"target": "sparse_initializer_as_output.json",
"source": "https://github.com/lutzroeder/netron/files/12444489/sparse_initializer_as_output.json.zip[sparse_initializer_as_output.json]",
"format": "ONNX JSON v7",
"format": "ONNX v7",
"assert": "model.graphs[0].outputs[0].value[0].type.layout == 'sparse'",
"tags": "validation",
"link": "https://github.com/lutzroeder/netron/issues/741"
Expand Down Expand Up @@ -4769,7 +4769,7 @@
"type": "onnx",
"target": "zipmap_int64float.json",
"source": "https://github.com/lutzroeder/netron/files/12329104/zipmap_int64float.json.zip[zipmap_int64float.json]",
"format": "ONNX JSON v3",
"format": "ONNX v3",
"link": "https://github.com/lutzroeder/netron/issues/6"
},
{
Expand Down

0 comments on commit ac10c28

Please sign in to comment.