diff --git a/src/coreml.js b/src/coreml.js index 0c8c8ec5ca..17810f356a 100644 --- a/src/coreml.js +++ b/src/coreml.js @@ -165,7 +165,7 @@ coreml.Graph = class { predictedProbabilitiesName = predictedProbabilitiesName ? predictedProbabilitiesName : '?'; let labelProbabilityInput = this._updateOutput(labelProbabilityLayerName, labelProbabilityLayerName + ':labelProbabilityLayerName'); let operator = classifier.ClassLabels; - this._nodes.push(new coreml.Node(this._metadata, this._group, operator, null, classifier[operator], [ labelProbabilityInput ], [ predictedProbabilitiesName, predictedFeatureName ])); + this._nodes.push(new coreml.Node(this._metadata, this._group, operator, null, '', classifier[operator], [ labelProbabilityInput ], [ predictedProbabilitiesName, predictedFeatureName ])); } } @@ -183,7 +183,7 @@ coreml.Graph = class { for (const p of preprocessing) { let input = p.featureName ? p.featureName : preprocessorOutput; preprocessorOutput = preprocessingInput + ':' + preprocessorIndex.toString(); - this._createNode(scope, group, p.preprocessor, null, p[p.preprocessor], [ input ], [ preprocessorOutput ]); + this._createNode(scope, group, p.preprocessor, null, '', p[p.preprocessor], [ input ], [ preprocessorOutput ]); preprocessorIndex++; } for (const node of inputNodes) { @@ -200,10 +200,11 @@ coreml.Graph = class { _loadModel(model, scope, group) { this._groups = this._groups | (group.length > 0 ? true : false); + const description = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : ''; if (model.neuralNetworkClassifier) { const neuralNetworkClassifier = model.neuralNetworkClassifier; for (const layer of neuralNetworkClassifier.layers) { - this._createNode(scope, group, layer.layer, layer.name, layer[layer.layer], layer.input, layer.output); + this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output); } this._updateClassifierOutput(group, neuralNetworkClassifier); this._updatePreprocessing(scope, group, neuralNetworkClassifier.preprocessing); @@ -212,7 +213,7 @@ coreml.Graph = class { else if (model.neuralNetwork) { const neuralNetwork = model.neuralNetwork; for (const layer of neuralNetwork.layers) { - this._createNode(scope, group, layer.layer, layer.name, layer[layer.layer], layer.input, layer.output); + this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output); } this._updatePreprocessing(scope, group, neuralNetwork.preprocessing); return 'Neural Network'; @@ -220,7 +221,7 @@ coreml.Graph = class { else if (model.neuralNetworkRegressor) { const neuralNetworkRegressor = model.neuralNetworkRegressor; for (const layer of neuralNetworkRegressor.layers) { - this._createNode(scope, group, layer.layer, layer.name, layer[layer.layer], layer.input, layer.output); + this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output); } this._updatePreprocessing(scope, group, neuralNetworkRegressor); return 'Neural Network Regressor'; @@ -244,7 +245,7 @@ coreml.Graph = class { return 'Pipeline Regressor'; } else if (model.glmClassifier) { - this._createNode(scope, group, 'glmClassifier', null, + this._createNode(scope, group, 'glmClassifier', null, description, { classEncoding: model.glmClassifier.classEncoding, offset: model.glmClassifier.offset, @@ -256,39 +257,43 @@ coreml.Graph = class { return 'Generalized Linear Classifier'; } else if (model.glmRegressor) { - this._createNode(scope, group, 'glmRegressor', null, + this._createNode(scope, group, 'glmRegressor', null, description, model.glmRegressor, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Generalized Linear Regressor'; } else if (model.dictVectorizer) { - this._createNode(scope, group, 'dictVectorizer', null, model.dictVectorizer, + this._createNode(scope, group, 'dictVectorizer', null, description, + model.dictVectorizer, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Dictionary Vectorizer'; } else if (model.featureVectorizer) { - this._createNode(scope, group, 'featureVectorizer', null, model.featureVectorizer, + this._createNode(scope, group, 'featureVectorizer', null, description, + model.featureVectorizer, coreml.Graph._formatFeatureDescriptionList(model.description.input), [ model.description.output[0].name ]); return 'Feature Vectorizer'; } else if (model.treeEnsembleClassifier) { - this._createNode(scope, group, 'treeEnsembleClassifier', null, model.treeEnsembleClassifier.treeEnsemble, + this._createNode(scope, group, 'treeEnsembleClassifier', null, description, + model.treeEnsembleClassifier.treeEnsemble, [ model.description.input[0].name ], [ model.description.output[0].name ]); this._updateClassifierOutput(group, model.treeEnsembleClassifier); return 'Tree Ensemble Classifier'; } else if (model.treeEnsembleRegressor) { - this._createNode(scope, group, 'treeEnsembleRegressor', null, model.treeEnsembleRegressor.treeEnsemble, + this._createNode(scope, group, 'treeEnsembleRegressor', null, description, + model.treeEnsembleRegressor.treeEnsemble, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Tree Ensemble Regressor'; } else if (model.supportVectorClassifier) { - this._createNode(scope, group, 'supportVectorClassifier', null, + this._createNode(scope, group, 'supportVectorClassifier', null, description, { coefficients: model.supportVectorClassifier.coefficients, denseSupportVectors: model.supportVectorClassifier.denseSupportVectors, @@ -305,7 +310,7 @@ coreml.Graph = class { return 'Support Vector Classifier'; } else if (model.supportVectorRegressor) { - this._createNode(scope, group, 'supportVectorRegressor', null, + this._createNode(scope, group, 'supportVectorRegressor', null, description, { coefficients: model.supportVectorRegressor.coefficients, kernel: model.supportVectorRegressor.kernel, @@ -317,7 +322,7 @@ coreml.Graph = class { return 'Support Vector Regressor'; } else if (model.arrayFeatureExtractor) { - this._createNode(scope, group, 'arrayFeatureExtractor', null, + this._createNode(scope, group, 'arrayFeatureExtractor', null, description, { extractIndex: model.arrayFeatureExtractor.extractIndex }, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -327,7 +332,7 @@ coreml.Graph = class { const categoryType = model.oneHotEncoder.CategoryType; const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse }; oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType]; - this._createNode(scope, group, 'oneHotEncoder', null, + this._createNode(scope, group, 'oneHotEncoder', null, description, oneHotEncoderParams, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -339,7 +344,7 @@ coreml.Graph = class { let imputerParams = {}; imputerParams[imputedValue] = model.imputer[imputedValue]; imputerParams[replaceValue] = model.imputer[replaceValue]; - this._createNode(scope, group, 'oneHotEncoder', null, + this._createNode(scope, group, 'oneHotEncoder', null, description, imputerParams, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -347,14 +352,14 @@ coreml.Graph = class { } else if (model.normalizer) { - this._createNode(scope, group, 'normalizer', null, + this._createNode(scope, group, 'normalizer', null, description, model.normalizer, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Normalizer'; } else if (model.wordTagger) { - this._createNode(scope, group, 'wordTagger', null, + this._createNode(scope, group, 'wordTagger', null, description, model.wordTagger, [ model.description.input[0].name ], [ @@ -366,7 +371,7 @@ coreml.Graph = class { return 'Word Tagger'; } else if (model.textClassifier) { - this._createNode(scope, group, 'textClassifier', null, + this._createNode(scope, group, 'textClassifier', null, description, model.textClassifier, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -379,7 +384,7 @@ coreml.Graph = class { iouThreshold: model.nonMaximumSuppression.iouThreshold, confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold }; - this._createNode(scope, group, 'nonMaximumSuppression', null, + this._createNode(scope, group, 'nonMaximumSuppression', null, description, nonMaximumSuppressionParams, [ model.nonMaximumSuppression.confidenceInputFeatureName, @@ -397,21 +402,21 @@ coreml.Graph = class { const visionFeaturePrintParams = { scene: model.visionFeaturePrint.scene } - this._createNode(scope, group, 'visionFeaturePrint', null, + this._createNode(scope, group, 'visionFeaturePrint', null, description, visionFeaturePrintParams, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Vision Feature Print'; } else if (model.soundAnalysisPreprocessing) { - this._createNode(scope, group, 'soundAnalysisPreprocessing', null, + this._createNode(scope, group, 'soundAnalysisPreprocessing', null, description, model.soundAnalysisPreprocessing, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Sound Analysis Preprocessing'; } else if (model.kNearestNeighborsClassifier) { - this._createNode(scope, group, 'kNearestNeighborsClassifier', null, + this._createNode(scope, group, 'kNearestNeighborsClassifier', null, description, model.kNearestNeighborsClassifier, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -419,18 +424,24 @@ coreml.Graph = class { return 'kNearestNeighborsClassifier'; } else if (model.itemSimilarityRecommender) { - const itemSimilarityRecommenderParams = { - itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector, - itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities - } - this._createNode(scope, group, 'itemSimilarityRecommender', null, - itemSimilarityRecommenderParams, + this._createNode(scope, group, 'itemSimilarityRecommender', null, description, + { + itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector, + itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities + }, model.description.input.map((feature) => feature.name), model.description.output.map((feature) => feature.name)); return 'itemSimilarityRecommender' } + else if (model.linkedModel) { + this._createNode(scope, group, 'linkedModel', null, description, + model.linkedModel.linkedModelFile, + [ model.description.input[0].name ], + [ model.description.output[0].name ]); + return 'customModel'; + } else if (model.customModel) { - this._createNode(scope, group, 'customModel', null, + this._createNode(scope, group, 'customModel', null, description, { className: model.customModel.className, parameters: model.customModel.parameters }, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -439,7 +450,7 @@ coreml.Graph = class { throw new coreml.Error("Unknown model type '" + JSON.stringify(Object.keys(model)) + "'."); } - _createNode(scope, group, operator, name, data, inputs, outputs) { + _createNode(scope, group, operator, name, description, data, inputs, outputs) { inputs = inputs.map((input) => scope[input] ? scope[input].argument : input); outputs = outputs.map((output) => { if (scope[output]) { @@ -455,7 +466,7 @@ coreml.Graph = class { return output; }); - const node = new coreml.Node(this._metadata, group, operator, name, data, inputs, outputs); + const node = new coreml.Node(this._metadata, group, operator, name, description, data, inputs, outputs); this._nodes.push(node); return node; } @@ -583,13 +594,14 @@ coreml.Argument = class { coreml.Node = class { - constructor(metadata, group, operator, name, data, inputs, outputs) { + constructor(metadata, group, operator, name, description, data, inputs, outputs) { this._metadata = metadata; if (group) { this._group = group; } this._operator = operator; this._name = name || ''; + this._description = description || ''; this._attributes = []; let initializers = []; if (data) { @@ -620,6 +632,10 @@ coreml.Node = class { return this._name; } + get description() { + return this._description; + } + get metadata() { return this._metadata.type(this.operator); } diff --git a/test/models.json b/test/models.json index 488ee1eeee..42c1778f19 100644 --- a/test/models.json +++ b/test/models.json @@ -1459,6 +1459,13 @@ "format": "Core ML v1", "link": "https://github.com/gavi/Iris" }, + { + "type": "coreml", + "target": "LinkedUpdatableTinyDrawingClassifier.mlmodel", + "source": "https://github.com/lutzroeder/netron/files/4500539/LinkedUpdatableTinyDrawingClassifier.zip[LinkedUpdatableTinyDrawingClassifier.mlmodel]", + "format": "Core ML v4", + "link": "https://github.com/lutzroeder/netron/issues/193" + }, { "type": "coreml", "target": "MessageClassifier.mlmodel",