Skip to content

Commit

Permalink
Fix CoreML tensor shape (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 9, 2018
1 parent 3b751dc commit fd466b3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
17 changes: 10 additions & 7 deletions src/coreml.js
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ coreml.Graph = class {
var result = '';
switch (type.Type) {
case 'multiArrayType':
var shape = new coreml.TensorShape(null);
var shape = new coreml.TensorShape([]);
if (type.multiArrayType.shape && type.multiArrayType.shape.length > 0) {
shape = new coreml.TensorShape(type.multiArrayType.shape);
}
Expand Down Expand Up @@ -630,11 +630,11 @@ coreml.Node = class {
case 'bias':
this._initializers.push(new coreml.Tensor('Weights', 'bias', data.shapeBias, data.bias));
return { 'bias': true };
case 'simpleRecurrentLayer':
this._initializers.push(new coreml.Tensor('Weights', 'weights', null, data.weightMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'recurrent', null, data.recursionMatrix));
case 'simpleRecurrent':
this._initializers.push(new coreml.Tensor('Weights', 'weights', [ data.outputVectorSize, data.inputVectorSize ], data.weightMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'recurrent', [ data.outputVectorSize, data.inputVectorSize ], data.recursionMatrix));
if (data.hasBiasVectors) {
this._initializers.push(new coreml.Tensor('Weights', 'bias', null, data.biasVector));
this._initializers.push(new coreml.Tensor('Weights', 'bias', [ data.outputVectorSize ], data.biasVector));
}
return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
case 'gru':
Expand Down Expand Up @@ -894,7 +894,7 @@ coreml.TensorType = class {

constructor(dataType, shape) {
this._dataType = dataType;
this._shape = shape || new coreml.TensorShape(null);
this._shape = shape || new coreml.TensorShape([]);
}

get dataType() {
Expand All @@ -921,7 +921,10 @@ coreml.TensorShape = class {
}

toString() {
return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
if (!this._dimensions || this._dimensions.length == 0) {
return '';
}
return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
}
};

Expand Down
5 changes: 4 additions & 1 deletion src/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,10 @@ onnx.TensorShape = class {
}

toString() {
return (this._dimensions && this._dimensions.length) ? ('[' + this._dimensions.join(',') + ']') : '';
if (!this._dimensions || this._dimensions.length == 0) {
return '';
}
return '[' + this._dimensions.join(',') + ']';
}
};

Expand Down

0 comments on commit fd466b3

Please sign in to comment.