From 63bb6bae5988298b81c17988c82885111fa419c0 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 8 Jun 2019 00:30:29 -0700 Subject: [PATCH] Core ML nonMaximumSuppression support (#193) --- src/coreml-metadata.json | 19 +++++++++++++++++++ src/coreml.js | 41 ++++++++++++++++++++++++++++++++-------- test/models.json | 14 ++++++++++++++ 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/src/coreml-metadata.json b/src/coreml-metadata.json index 66bebf5230e..3169dd4248f 100644 --- a/src/coreml-metadata.json +++ b/src/coreml-metadata.json @@ -441,5 +441,24 @@ { "name": "revision", "visible": false } ] } + }, + { + "name": "nonMaximumSuppression", + "schema": { + "attributes": [ + { "name": "iouThreshold" }, + { "name": "confidenceThreshold" } + ], + "inputs": [ + { "name": "confidence" }, + { "name": "coordinates" }, + { "name": "iouThreshold" }, + { "name": "confidenceThreshold" } + ], + "outputs": [ + { "name": "confidence" }, + { "name": "coordinates" } + ] + } } ] diff --git a/src/coreml.js b/src/coreml.js index 9aaf530a5dd..cb41c1c7f58 100644 --- a/src/coreml.js +++ b/src/coreml.js @@ -33,7 +33,7 @@ coreml.ModelFactory = class { host.exception(error, false); var message = error && error.message ? error.message : error.toString(); message = message.endsWith('.') ? message.substring(0, message.length - 1) : message; - return new coreml.Error(message + " in '" + identifier + "'."); + throw new coreml.Error(message + " in '" + identifier + "'."); } }); }); @@ -187,6 +187,7 @@ coreml.Graph = class { } _loadModel(model, scope, group) { + var i; this._groups = this._groups | (group.length > 0 ? true : false); var layer; if (model.neuralNetworkClassifier) { @@ -215,20 +216,20 @@ coreml.Graph = class { return 'Neural Network Regressor'; } else if (model.pipeline) { - for (layer of model.pipeline.models) { - this._loadModel(layer, scope, (group ? (group + '/') : '') + 'pipeline'); + for (i = 0; i < model.pipeline.models.length; i++) { + this._loadModel(model.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipeline[' + i.toString() + ']'); } return 'Pipeline'; } else if (model.pipelineClassifier) { - for (layer of model.pipelineClassifier.pipeline.models) { - this._loadModel(layer, scope, (group ? (group + '/') : '') + 'pipelineClassifier'); + for (i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) { + this._loadModel(model.pipelineClassifier.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineClassifier[' + i.toString() + ']'); } return 'Pipeline Classifier'; } else if (model.pipelineRegressor) { - for (layer of model.pipelineRegressor.pipeline.models) { - this._loadModel(layer, scope, (group ? (group + '/') : '') + 'pipelineRegressor'); + for (i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) { + this._loadModel(model.pipelineRegressor.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineRegressor[' + i.toString() + ']'); } return 'Pipeline Regressor'; } @@ -361,7 +362,28 @@ coreml.Graph = class { [ model.description.output[0].name ]); return 'Text Classifier'; } - return 'Unknown'; + else if (model.nonMaximumSuppression) { + var nonMaximumSuppressionParams = { + pickTop: model.nonMaximumSuppression.pickTop, + stringClassLabels: model.nonMaximumSuppression.stringClassLabels, + iouThreshold: model.nonMaximumSuppression.iouThreshold, + confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold + }; + this._createNode(scope, group, 'nonMaximumSuppression', null, + nonMaximumSuppressionParams, + [ + model.nonMaximumSuppression.confidenceInputFeatureName, + model.nonMaximumSuppression.coordinatesInputFeatureName, + model.nonMaximumSuppression.iouThresholdInputFeatureName, + model.nonMaximumSuppression.confidenceThresholdInputFeatureName, + ], + [ + model.nonMaximumSuppression.confidenceOutputFeatureName, + model.nonMaximumSuppression.coordinatesOutputFeatureName + ]); + return 'Non Maximum Suppression'; + } + throw new coreml.Error("Unknown model type '" + Object.keys(model).filter(k => k != 'specificationVersion' && k != 'description').join(',') + "'."); } _createNode(scope, group, operator, name, data, inputs, outputs) { @@ -696,6 +718,9 @@ coreml.Node = class { data.modelParameterData = Array.from(data.modelParameterData); data.stringClassLabels = this._convertVector(data.stringClassLabels); return {}; + case 'nonMaximumSuppression': + data.stringClassLabels = this._convertVector(data.stringClassLabels); + return {}; } return {}; } diff --git a/test/models.json b/test/models.json index c8827131e44..54507e6becd 100644 --- a/test/models.json +++ b/test/models.json @@ -1738,6 +1738,20 @@ "format": "Core ML v1", "link": "https://github.com/UnusualWolf/coreML" }, + { + "type": "coreml", + "target": "YOLOv3.mlmodel", + "source": "https://docs-assets.developer.apple.com/coreml/models/Image/ObjectDetection/YOLOv3/YOLOv3.mlmodel", + "format": "Core ML v3", + "link": "https://developer.apple.com/machine-learning/models" + }, + { + "type": "coreml", + "target": "YOLOv3Tiny.mlmodel", + "source": "https://docs-assets.developer.apple.com/coreml/models/Image/ObjectDetection/YOLOv3Tiny/YOLOv3Tiny.mlmodel", + "format": "Core ML v3", + "link": "https://developer.apple.com/machine-learning/models" + }, { "type": "darknet", "target": "alexnet.cfg,alexnet.weights",