From 6b20f08a5675e08f93a62b02e471efa6b6f9247f Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 28 Oct 2018 03:20:19 -0700 Subject: [PATCH] Caffe2 ONNXWhile fixes (#168) --- src/caffe2-metadata.json | 25 ++++++++++ src/caffe2-model.js | 104 ++++++++++++++++++++++++++++++--------- 2 files changed, 105 insertions(+), 24 deletions(-) diff --git a/src/caffe2-metadata.json b/src/caffe2-metadata.json index 4f5b1f5171..fee7f899a0 100644 --- a/src/caffe2-metadata.json +++ b/src/caffe2-metadata.json @@ -951,5 +951,30 @@ ], "support_level": "default" } + }, + { + "name": "ONNXWhile", + "schema": { + "attributes": [ + ], + "inputs": [ + { + "name": "max_trip_count" + }, + { + "name": "condition" + }, + { + "name": "initial", + "option": "variadic" + } + ], + "outputs": [ + { + "name": "final_and_scan_outputs", + "option": "variadic" + } + ] + } } ] diff --git a/src/caffe2-model.js b/src/caffe2-model.js index 12bb6ab6ac..c233a19953 100644 --- a/src/caffe2-model.js +++ b/src/caffe2-model.js @@ -105,9 +105,36 @@ class Caffe2Graph { }); if (init) { init.op.forEach((op) => { - if (op.type == 'GivenTensorFill' && op.output && op.output.length == 1) { + if (op.output && op.output.length == 1) { var name = op.output[0]; - initializers[name] = op; + var dataType = null; + switch (op.type) { + case 'GivenTensorFill': + dataType = 'float32'; + break; + case 'GivenTensorBoolFill': + dataType = 'boolean'; + break; + case 'GivenTensorByteStringToUInt8Fill': + dataType = 'uint8'; + break; + case 'GivenTensorIntFill': + dataType = 'int32'; + break; + case 'GivenTensorInt64Fill': + dataType = 'int64'; + break; + case 'GivenTensorStringFill': + dataType = 'string'; + break; + default: + debugger; + break; + } + if (dataType) { + op.dataType = dataType; + initializers[name] = op; + } } }); } @@ -298,7 +325,10 @@ class Caffe2Attribute { this._value = arg.ints; } else if (arg.nets && arg.nets.length > 0) { - this._value = () => '...'; + this._value = () => '{ NefDef[] }'; + } + else if (arg.n) { + this._value = () => '{ NefDef }'; } else if (arg.i != 0) { this._value = arg.i; @@ -345,18 +375,14 @@ class Caffe2Tensor { args[arg.name] = arg; }); } + var shape = null; if (args.shape && args.shape.ints) { - this._shape = args.shape.ints; + shape = args.shape.ints; } if (args.values) { this._values = args.values; - if (this._values.floats || this._values.floats == -1) { - this._dataType = 'float32'; - } - else { - debugger; - } } + this._type = new Caffe2TensorType(tensor.dataType, shape ? new Caffe2TensorShape(shape) : null); } get name() { @@ -364,7 +390,7 @@ class Caffe2Tensor { } get type() { - return new Caffe2TensorType(this._dataType, this._shape); + return this._type; } get kind() { @@ -403,33 +429,48 @@ class Caffe2Tensor { context.state = 'Tensor data is empty.'; return context; } - if (!this._dataType) { - context.state = 'Unknown data type.'; - return context; - } if (this._values.floats == -1) { context.state = 'Tensor data is too large to load in Chrome.'; return context; } - if (this._values.floats) { - context.data = this._values.floats; - } - else { - context.state = 'Unknown data format.'; + switch (this._type.dataType) { + case 'float32': + context.data = this._values.floats; + break; + case 'boolean': + context.data = this._values.ints; + break; + default: + context.state = 'Unknown data type.'; + debugger; + return context; } + context.shape = this._type.shape.dimensions; + context.dataType = this._type.dataType; return context; } _decode(context, dimension) { var results = []; - var size = this._shape[dimension]; - if (dimension == this._shape.length - 1) { + var size = context.shape[dimension]; + if (dimension == context.shape.length - 1) { for (var i = 0; i < size; i++) { if (context.count > context.limit) { results.push('...'); return results; } - results.push(context.data[context.index]); + switch (context.dataType) { + case 'float32': + results.push(context.data[context.index]); + break; + case 'boolean': + results.push(context.data[context.index] == 0 ? false : true); + break; + default: + context.state = 'Unknown data type.'; + debugger; + break; + } context.index++; context.count++; } @@ -463,7 +504,22 @@ class Caffe2TensorType { } toString() { - return this.dataType + (this._shape ? ('[' + this._shape.map((dimension) => dimension.toString()).join(',') + ']') : ''); + return this.dataType + this._shape.toString(); + } +} + +class Caffe2TensorShape { + + constructor(dimensions) { + this._dimensions = dimensions; + } + + get dimensions() { + return this._dimensions; + } + + toString() { + return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : ''; } }