Skip to content

Commit

Permalink
Core ML nonMaximumSuppression support (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 8, 2019
1 parent d18d4eb commit 63bb6ba
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 8 deletions.
19 changes: 19 additions & 0 deletions src/coreml-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -441,5 +441,24 @@
{ "name": "revision", "visible": false }
]
}
},
{
"name": "nonMaximumSuppression",
"schema": {
"attributes": [
{ "name": "iouThreshold" },
{ "name": "confidenceThreshold" }
],
"inputs": [
{ "name": "confidence" },
{ "name": "coordinates" },
{ "name": "iouThreshold" },
{ "name": "confidenceThreshold" }
],
"outputs": [
{ "name": "confidence" },
{ "name": "coordinates" }
]
}
}
]
41 changes: 33 additions & 8 deletions src/coreml.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ coreml.ModelFactory = class {
host.exception(error, false);
var message = error && error.message ? error.message : error.toString();
message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
return new coreml.Error(message + " in '" + identifier + "'.");
throw new coreml.Error(message + " in '" + identifier + "'.");
}
});
});
Expand Down Expand Up @@ -187,6 +187,7 @@ coreml.Graph = class {
}

_loadModel(model, scope, group) {
var i;
this._groups = this._groups | (group.length > 0 ? true : false);
var layer;
if (model.neuralNetworkClassifier) {
Expand Down Expand Up @@ -215,20 +216,20 @@ coreml.Graph = class {
return 'Neural Network Regressor';
}
else if (model.pipeline) {
for (layer of model.pipeline.models) {
this._loadModel(layer, scope, (group ? (group + '/') : '') + 'pipeline');
for (i = 0; i < model.pipeline.models.length; i++) {
this._loadModel(model.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipeline[' + i.toString() + ']');
}
return 'Pipeline';
}
else if (model.pipelineClassifier) {
for (layer of model.pipelineClassifier.pipeline.models) {
this._loadModel(layer, scope, (group ? (group + '/') : '') + 'pipelineClassifier');
for (i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) {
this._loadModel(model.pipelineClassifier.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineClassifier[' + i.toString() + ']');
}
return 'Pipeline Classifier';
}
else if (model.pipelineRegressor) {
for (layer of model.pipelineRegressor.pipeline.models) {
this._loadModel(layer, scope, (group ? (group + '/') : '') + 'pipelineRegressor');
for (i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) {
this._loadModel(model.pipelineRegressor.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineRegressor[' + i.toString() + ']');
}
return 'Pipeline Regressor';
}
Expand Down Expand Up @@ -361,7 +362,28 @@ coreml.Graph = class {
[ model.description.output[0].name ]);
return 'Text Classifier';
}
return 'Unknown';
else if (model.nonMaximumSuppression) {
var nonMaximumSuppressionParams = {
pickTop: model.nonMaximumSuppression.pickTop,
stringClassLabels: model.nonMaximumSuppression.stringClassLabels,
iouThreshold: model.nonMaximumSuppression.iouThreshold,
confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold
};
this._createNode(scope, group, 'nonMaximumSuppression', null,
nonMaximumSuppressionParams,
[
model.nonMaximumSuppression.confidenceInputFeatureName,
model.nonMaximumSuppression.coordinatesInputFeatureName,
model.nonMaximumSuppression.iouThresholdInputFeatureName,
model.nonMaximumSuppression.confidenceThresholdInputFeatureName,
],
[
model.nonMaximumSuppression.confidenceOutputFeatureName,
model.nonMaximumSuppression.coordinatesOutputFeatureName
]);
return 'Non Maximum Suppression';
}
throw new coreml.Error("Unknown model type '" + Object.keys(model).filter(k => k != 'specificationVersion' && k != 'description').join(',') + "'.");
}

_createNode(scope, group, operator, name, data, inputs, outputs) {
Expand Down Expand Up @@ -696,6 +718,9 @@ coreml.Node = class {
data.modelParameterData = Array.from(data.modelParameterData);
data.stringClassLabels = this._convertVector(data.stringClassLabels);
return {};
case 'nonMaximumSuppression':
data.stringClassLabels = this._convertVector(data.stringClassLabels);
return {};
}
return {};
}
Expand Down
14 changes: 14 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -1738,6 +1738,20 @@
"format": "Core ML v1",
"link": "https://github.com/UnusualWolf/coreML"
},
{
"type": "coreml",
"target": "YOLOv3.mlmodel",
"source": "https://docs-assets.developer.apple.com/coreml/models/Image/ObjectDetection/YOLOv3/YOLOv3.mlmodel",
"format": "Core ML v3",
"link": "https://developer.apple.com/machine-learning/models"
},
{
"type": "coreml",
"target": "YOLOv3Tiny.mlmodel",
"source": "https://docs-assets.developer.apple.com/coreml/models/Image/ObjectDetection/YOLOv3Tiny/YOLOv3Tiny.mlmodel",
"format": "Core ML v3",
"link": "https://developer.apple.com/machine-learning/models"
},
{
"type": "darknet",
"target": "alexnet.cfg,alexnet.weights",
Expand Down

0 comments on commit 63bb6ba

Please sign in to comment.