Skip to content

Commit

Permalink
Replace type strings with tensor types (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Aug 19, 2018
1 parent 9dc7b85 commit badd7e6
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 98 deletions.
8 changes: 4 additions & 4 deletions src/caffe-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class CaffeGraph {
this._inputs.push({
id: input,
name: input,
type: 'T'
type: null
});
});
}
Expand Down Expand Up @@ -155,15 +155,15 @@ class CaffeGraph {
this._outputs.push({
id: keys[0],
name: keys[0],
type: 'T'
type: null
});
}
else if (outputs.length == 1) {
outputs[0]._outputs = [ 'output' ];
this._outputs.push({
id: 'output',
name: 'output',
type: 'T'
type: null
});
}
}
Expand Down Expand Up @@ -294,7 +294,7 @@ class CaffeNode {
input.connections.forEach((connection) => {
if (connection.id instanceof CaffeTensor) {
connection.initializer = connection.id;
connection.type = connection.initializer.type.toString();
connection.type = connection.initializer.type;
connection.id = '';
}
});
Expand Down
6 changes: 3 additions & 3 deletions src/caffe2-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class Caffe2Graph {
this._inputs.push({
id: input,
name: input,
type: 'T'
type: null
});
}
});
Expand All @@ -127,7 +127,7 @@ class Caffe2Graph {
this._outputs.push({
id: output,
name: output,
type: 'T'
type: null
});
});
}
Expand Down Expand Up @@ -214,7 +214,7 @@ class Caffe2Node {
var initializer = this._initializers[connection.id];
if (initializer) {
connection.initializer = initializer;
connection.type = initializer.type.toString();
connection.type = initializer.type;
}
});
});
Expand Down
6 changes: 3 additions & 3 deletions src/coreml-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,13 @@ class CoreMLGraph {
result = 'image(' + CoreMLGraph.formatColorSpace(type.imageType.colorSpace) + ',' + type.imageType.width.toString() + 'x' + type.imageType.height.toString() + ')';
break;
case 'dictionaryType':
result = 'map<' + type.dictionaryType.KeyType.replace('KeyType', '') + ',double>';
result = 'map<' + type.dictionaryType.KeyType.replace('KeyType', '') + ',float64>';
break;
case 'stringType':
result = 'string';
break;
case 'doubleType':
result = 'double';
result = 'float64';
break;
case 'int64Type':
result = 'int64';
Expand Down Expand Up @@ -477,7 +477,7 @@ class CoreMLNode {
name: initializer.name,
connections: [ {
id: '',
type: initializer.type.toString(),
type: initializer.type,
initializer: initializer, } ]
};
if (!CoreMLOperatorMetadata.operatorMetadata.getInputVisible(this._operator, initializer.name)) {
Expand Down
17 changes: 9 additions & 8 deletions src/keras-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class KerasGraph {
if (addGraphOutput) {
this._outputs.push({
id: inputName,
type: 'T',
type: null,
name: inputName
});
}
Expand Down Expand Up @@ -340,7 +340,7 @@ class KerasGraph {
if (connection) {
this._outputs.push({
id: connection,
type: 'T',
type: null,
name: connection
});
}
Expand All @@ -363,19 +363,20 @@ class KerasGraph {
}

_loadInput(layer, input) {
input.type = '';
input.type = null;
if (layer && layer.config) {
var dataType = '?';
var shape = [];
var config = layer.config;
if (config.dtype) {
input.type = config.dtype;
dataType = config.dtype;
delete config.dtype;
}
if (config.batch_input_shape) {
var shape = config.batch_input_shape;
shape = shape.map(s => s == null ? '?' : s).join(',');
input.type = input.type + '[' + shape + ']';
shape = config.batch_input_shape.map(s => s == null ? '?' : s);
delete config.batch_input_shape;
}
input.type = new KerasTensorType(dataType, shape);
}
}
}
Expand Down Expand Up @@ -471,7 +472,7 @@ class KerasNode {
input.connections.forEach((connection) => {
var initializer = this._initializers[connection.id];
if (initializer) {
connection.type = initializer.type.toString();
connection.type = initializer.type;
connection.initializer = initializer;
}
});
Expand Down
12 changes: 6 additions & 6 deletions src/mxnet-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ class MXNetGraph {
var output = {};
output.id = MXNetGraph._updateOutput(nodes, head);
output.name = nodes[output.id[0]] ? nodes[output.id[0]].name : ('output' + ((index == 0) ? '' : (index + 1).toString()));
output.type = 'T';
output.type = null;
var outputSignature = outputs[output.name];
if (outputSignature && outputSignature.data_shape) {
output.type = '?' + '[' + outputSignature.data_shape.toString() + ']';
output.type = new MXNetTensorType(null, outputSignature.data_shape);
}
this._outputs.push(output);
});
Expand All @@ -315,10 +315,10 @@ class MXNetGraph {
var input = {};
input.id = argument.outputs[0];
input.name = argument.name;
input.type = 'T';
input.type = null;
var inputSignature = inputs[input.name];
if (inputSignature && inputSignature.data_shape) {
input.type = '?' + '[' + inputSignature.data_shape.toString() + ']';
input.type = new MXNetTensorType(null, inputSignature.data_shape);
}
this._inputs.push(input);
}
Expand Down Expand Up @@ -464,7 +464,7 @@ class MXNetNode {
var initializer = this._initializers[connection.id];
if (initializer) {
connection.id = initializer.name || connection.id;
connection.type = initializer.type.toString();
connection.type = initializer.type;
connection.initializer = initializer;
}
});
Expand Down Expand Up @@ -674,7 +674,7 @@ class MXNetTensorType {
}

toString() {
return this.dataType + (this._shape ? ('[' + this._shape.map((dimension) => dimension.toString()).join(',') + ']') : '');
return (this.dataType || '?') + (this._shape ? ('[' + this._shape.map((dimension) => dimension.toString()).join(',') + ']') : '');
}
}

Expand Down
115 changes: 72 additions & 43 deletions src/onnx-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class OnnxGraph {
var initializer = this._initializerMap[connection.id];
if (initializer) {
connection.initializer = initializer;
connection.type = connection.type || initializer.type.toString();
connection.type = connection.type || initializer.type;
}
return connection;
});
Expand Down Expand Up @@ -499,6 +499,7 @@ class OnnxTensor {
this._tensor = tensor;
this._id = id;
this._kind = kind || null;
this._type = new OnnxTensorType(this._tensor.dataType, this._tensor.dims.map((dim) => dim), null);
}

get id() {
Expand All @@ -514,7 +515,7 @@ class OnnxTensor {
}

get type() {
return new OnnxTensorType(this._tensor);
return this._type;
}

get value() {
Expand Down Expand Up @@ -744,37 +745,8 @@ class OnnxTensor {

static _formatType(type, imageFormat) {
if (!type) {
return { value: '?' };
return null;
}
var value = {};
switch (type.value) {
case 'tensorType':
var tensorType = type.tensorType;
var text = OnnxTensor._formatElementType(tensorType.elemType);
if (tensorType.shape && tensorType.shape.dim) {
text += '[' + tensorType.shape.dim.map((dimension) => {
if (dimension.dimParam) {
return dimension.dimParam;
}
return dimension.dimValue.toString();
}).join(',') + ']';
}
value = text;
break;
case 'mapType':
var keyType = OnnxTensor._formatElementType(type.mapType.keyType);
var valueType = OnnxTensor._formatType(type.mapType.valueType);
value = 'map<' + keyType + ',' + valueType.value + '>';
break;
case 'sequenceType':
var elemType = OnnxTensor._formatType(type.sequenceType.elemType);
value = 'sequence<' + elemType.value + '>';
break;
default:
// debugger
value = '?';
break;
}
var denotation = '';
switch (type.denotation) {
case 'TENSOR':
Expand All @@ -790,21 +762,30 @@ class OnnxTensor {
denotation = 'Text';
break;
}
return { value: value, denotation: denotation };
switch (type.value) {
case 'tensorType':
var shape = [];
if (type.tensorType.shape && type.tensorType.shape.dim) {
shape = type.tensorType.shape.dim.map((dim) => {
return dim.dimParam ? dim.dimParam : dim.dimValue;
});
}
return new OnnxTensorType(type.tensorType.elemType, shape, denotation);
case 'mapType':
return new OnnxMapType(type.mapType.keyType, OnnxTensor._formatType(type.mapType.valueType, imageFormat), denotation);
case 'sequenceType':
return new OnnxSequenceType(OnnxTensor._formatType(type.sequenceType.elemType, imageFormat), denotation);
}
return null;
}
}

class OnnxTensorType {

constructor(tensor) {
this._dataType = '?';
if (tensor.hasOwnProperty('dataType')) {
this._dataType = OnnxTensor._formatElementType(tensor.dataType);
}
this._shape = [];
if (tensor.hasOwnProperty('dims')) {
this._shape = tensor.dims.map((dimension) => dimension);
}
constructor(dataType, shape, denotation) {
this._dataType = OnnxTensor._formatElementType(dataType);
this._shape = shape;
this._denotation = denotation || null;
}

get dataType() {
Expand All @@ -815,10 +796,58 @@ class OnnxTensorType {
return this._shape;
}

get denotation() {
return this._denotation;
}

toString() {
return this.dataType + (this._shape ? ('[' + this._shape.map((dimension) => dimension.toString()).join(',') + ']') : '');
return this.dataType + ((this._shape && this._shape.length) ? ('[' + this._shape.join(',') + ']') : '');
}
}

class OnnxSequenceType {

constructor(elementType, denotation) {
this._elementType = elementType;
this._denotation = denotation;
}

get elementType() {
return this._elementType;
}

get dennotation() {
return this._dennotation;
}

toString() {
return 'sequence<' + this._elementType.toString() + '>';
}
}

class OnnxMapType {

constructor(keyType, valueType, denotation) {
this._keyType = OnnxTensor._formatElementType(keyType);
this._valueType = valueType;
this._denotation = denotation;
}

get keyType() {
return this._keyType;
}

get valueType() {
return this._valueType;
}

get denotation() {
return this._denotation;
}

toString() {
return 'map<' + this._keyType + ',' + this._valueType.toString() + '>';
}
}

class OnnxGraphOperatorMetadata {
Expand Down
8 changes: 3 additions & 5 deletions src/tf-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class TensorFlowNode {
input.connections.forEach((connection) => {
var initializer = this._graph._getInitializer(connection.id);
if (initializer) {
connection.type = initializer.type.toString();
connection.type = initializer.type;
connection.initializer = initializer;
}
});
Expand Down Expand Up @@ -503,9 +503,6 @@ class TensorFlowAttribute {
return TensorFlowTensor.formatTensorShape(value.shape);
}
else if (value.hasOwnProperty('s')) {
if (value.s.length == 0) {
return '';
}
if (value.s.filter(c => c <= 32 && c >= 128).length == 0) {
return '"' + TensorFlowOperatorMetadata.textDecoder.decode(value.s) + '"';
}
Expand Down Expand Up @@ -581,6 +578,7 @@ class TensorFlowTensor {
if (kind) {
this._kind = kind;
}
this._type = new TensorFlowTensorType(this._tensor.dtype, this._tensor.tensorShape);
}

get id() {
Expand All @@ -592,7 +590,7 @@ class TensorFlowTensor {
}

get type() {
return new TensorFlowTensorType(this._tensor.dtype, this._tensor.tensorShape);
return this._type;
}

get kind() {
Expand Down
Loading

0 comments on commit badd7e6

Please sign in to comment.