From aed22ba4539b3e1c7121ff65524789a80f1c0c26 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 6 Oct 2022 20:01:23 -0700 Subject: [PATCH] Update coreml.js (#987) --- source/coreml-metadata.json | 4 +- source/coreml.js | 343 +++++++++++++++--------------------- source/view.js | 15 ++ 3 files changed, 162 insertions(+), 200 deletions(-) diff --git a/source/coreml-metadata.json b/source/coreml-metadata.json index 81d93d68a8..f6b13684da 100644 --- a/source/coreml-metadata.json +++ b/source/coreml-metadata.json @@ -82,7 +82,7 @@ "category": "Tensor", "description": "A layer that concatenates along the channel axis (default) or sequence axis.", "inputs": [ - { "name": "inputs", "option": "variadic" } + { "name": "inputs", "type": "Tensor[]" } ] }, { @@ -127,7 +127,7 @@ { "name": "featureVectorizer", "inputs": [ - { "name": "inputs", "option": "variadic" } + { "name": "inputs", "type": "Tensor[]" } ] }, { diff --git a/source/coreml.js b/source/coreml.js index e6b75c8167..07713ff8e1 100644 --- a/source/coreml.js +++ b/source/coreml.js @@ -66,7 +66,7 @@ coreml.ModelFactory = class { open(context, match) { return context.require('./coreml-proto').then(() => { - return coreml.Metadata.open(context).then((metadata) => { + return context.metadata('coreml-metadata.json').then((metadata) => { const openModel = (stream, context, path, format) => { let model = null; try { @@ -237,19 +237,52 @@ coreml.Graph = class { this._outputs = []; this._nodes = []; + const args = new Map(); + args.input = (name) => { + if (!args.has(name)) { + const argument = new coreml.Argument(name); + args.set(name, { counter: 0, argument: argument }); + } + return args.get(name).argument; + }; + args.output = (name) => { + if (args.has(name)) { + const value = args.get(name); + value.counter++; + const next = name + '\n' + value.counter.toString(); // custom argument id + value.argument = new coreml.Argument(next); + } + else { + const argument = new coreml.Argument(name); + const value = { counter: 0, argument: argument }; + args.set(name, value); + } + return args.get(name).argument; + }; + const update = (argument, description) => { + if (!argument.type) { + argument.type = coreml.Utility.featureType(description.type); + } + if (!argument.description && description.shortDescription) { + argument.description = description.shortDescription; + } + return argument; + }; if (this._description) { this._inputs = this._description.input.map((input) => { - const argument = new coreml.Argument(input.name, coreml.Utility.featureType(input.type), input.shortDescription, null); + const argument = args.output(input.name); + update(argument, input); return new coreml.Parameter(input.name, true, [ argument ]); }); - + } + this._type = this._loadModel(model, args, '', weights); + if (this._description) { this._outputs = this._description.output.map((output) => { - const argument = new coreml.Argument(output.name, coreml.Utility.featureType(output.type), output.shortDescription, null); + const argument = args.input(output.name); + update(argument, output); return new coreml.Parameter(output.name, true, [ argument ]); }); } - - this._type = this._loadModel(model, {}, '', weights); } get name() { @@ -289,7 +322,7 @@ coreml.Graph = class { return newName; } - _updateClassifierOutput(group, classifier) { + _updateClassifierOutput(args, group, classifier) { let labelProbabilityLayerName = classifier.labelProbabilityLayerName; if (!labelProbabilityLayerName && this._nodes.length > 0) { const node = this._nodes.slice(-1).pop(); @@ -308,15 +341,15 @@ coreml.Graph = class { new coreml.Parameter('input', true, [ new coreml.Argument(labelProbabilityInput) ]) ]; const outputs = [ - new coreml.Parameter('probabilities', true, [ new coreml.Argument(predictedProbabilitiesName) ]), - new coreml.Parameter('feature', true, [ new coreml.Argument(predictedFeatureName) ]) + new coreml.Parameter('probabilities', true, [ args.output(predictedProbabilitiesName) ]), + new coreml.Parameter('feature', true, [ args.output(predictedFeatureName) ]) ]; const node = new coreml.Node(this._metadata, this._group, type, null, '', classifier[type], inputs, outputs); this._nodes.push(node); } } - _updatePreprocessing(scope, group, preprocessing) { + _updatePreprocessing(args, group, preprocessing) { if (preprocessing && preprocessing.length > 0) { const preprocessingInput = this._description.input[0].name; const inputNodes = []; @@ -325,19 +358,21 @@ coreml.Graph = class { inputNodes.push(node); } } - let preprocessorOutput = preprocessingInput; + let currentOutput = preprocessingInput; + let preprocessorOutput = null; let preprocessorIndex = 0; for (const p of preprocessing) { - const input = p.featureName ? p.featureName : preprocessorOutput; - preprocessorOutput = preprocessingInput + ':' + preprocessorIndex.toString(); - this._createNode(scope, group, p.preprocessor, null, '', p[p.preprocessor], [ input ], [ preprocessorOutput ]); + const input = p.featureName ? p.featureName : currentOutput; + currentOutput = preprocessingInput + ':' + preprocessorIndex.toString(); + const node = this._createNode(args, group, p.preprocessor, null, '', p[p.preprocessor], [ input ], [ currentOutput ]); + preprocessorOutput = node.outputs[0].arguments[0]; preprocessorIndex++; } for (const node of inputNodes) { for (const input of node.inputs) { - for (const arg of input.arguments) { - if (arg.name === preprocessingInput) { - arg.name = preprocessorOutput; + for (let i = 0; i < input.arguments.length; i++) { + if (input.arguments[i].name === preprocessingInput) { + input.arguments[i] = preprocessorOutput; } } } @@ -345,55 +380,55 @@ coreml.Graph = class { } } - _loadModel(model, scope, group, weights) { + _loadModel(model, args, group, weights) { this._groups = this._groups | (group.length > 0 ? true : false); const description = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : ''; switch (model.Type) { case 'neuralNetworkClassifier': { const neuralNetworkClassifier = model.neuralNetworkClassifier; for (const layer of neuralNetworkClassifier.layers) { - this._createNode(scope, group, layer.layer, layer.name, group === '' ? '' : description, layer[layer.layer], layer.input, layer.output); + this._createNode(args, group, layer.layer, layer.name, group === '' ? '' : description, layer[layer.layer], layer.input, layer.output, layer.inputTensor, layer.outputTensor); } - this._updateClassifierOutput(group, neuralNetworkClassifier); - this._updatePreprocessing(scope, group, neuralNetworkClassifier.preprocessing); + this._updateClassifierOutput(args, group, neuralNetworkClassifier); + this._updatePreprocessing(args, group, neuralNetworkClassifier.preprocessing); return 'Neural Network Classifier'; } case 'neuralNetwork': { const neuralNetwork = model.neuralNetwork; for (const layer of neuralNetwork.layers) { - this._createNode(scope, group, layer.layer, layer.name, group === '' ? '' : description, layer[layer.layer], layer.input, layer.output); + this._createNode(args, group, layer.layer, layer.name, group === '' ? '' : description, layer[layer.layer], layer.input, layer.output, layer.inputTensor, layer.outputTensor); } - this._updatePreprocessing(scope, group, neuralNetwork.preprocessing); + this._updatePreprocessing(args, group, neuralNetwork.preprocessing); return 'Neural Network'; } case 'neuralNetworkRegressor': { const neuralNetworkRegressor = model.neuralNetworkRegressor; for (const layer of neuralNetworkRegressor.layers) { - this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output); + this._createNode(args, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output); } - this._updatePreprocessing(scope, group, neuralNetworkRegressor); + this._updatePreprocessing(args, group, neuralNetworkRegressor); return 'Neural Network Regressor'; } case 'pipeline': { for (let i = 0; i < model.pipeline.models.length; i++) { - this._loadModel(model.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipeline[' + i.toString() + ']'); + this._loadModel(model.pipeline.models[i], args, (group ? (group + '/') : '') + 'pipeline[' + i.toString() + ']'); } return 'Pipeline'; } case 'pipelineClassifier': { for (let i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) { - this._loadModel(model.pipelineClassifier.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineClassifier[' + i.toString() + ']'); + this._loadModel(model.pipelineClassifier.pipeline.models[i], args, (group ? (group + '/') : '') + 'pipelineClassifier[' + i.toString() + ']'); } return 'Pipeline Classifier'; } case 'pipelineRegressor': { for (let i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) { - this._loadModel(model.pipelineRegressor.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineRegressor[' + i.toString() + ']'); + this._loadModel(model.pipelineRegressor.pipeline.models[i], args, (group ? (group + '/') : '') + 'pipelineRegressor[' + i.toString() + ']'); } return 'Pipeline Regressor'; } case 'glmClassifier': { - this._createNode(scope, group, 'glmClassifier', null, description, + this._createNode(args, group, 'glmClassifier', null, description, { classEncoding: model.glmClassifier.classEncoding, offset: model.glmClassifier.offset, @@ -401,33 +436,33 @@ coreml.Graph = class { }, [ model.description.input[0].name ], [ model.description.predictedProbabilitiesName ]); - this._updateClassifierOutput(group, model.glmClassifier); + this._updateClassifierOutput(args, group, model.glmClassifier); return 'Generalized Linear Classifier'; } case 'glmRegressor': { - this._createNode(scope, group, 'glmRegressor', null, description, + this._createNode(args, group, 'glmRegressor', null, description, model.glmRegressor, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Generalized Linear Regressor'; } case 'treeEnsembleClassifier': { - this._createNode(scope, group, 'treeEnsembleClassifier', null, description, + this._createNode(args, group, 'treeEnsembleClassifier', null, description, model.treeEnsembleClassifier.treeEnsemble, [ model.description.input[0].name ], [ model.description.output[0].name ]); - this._updateClassifierOutput(group, model.treeEnsembleClassifier); + this._updateClassifierOutput(args, group, model.treeEnsembleClassifier); return 'Tree Ensemble Classifier'; } case 'treeEnsembleRegressor': { - this._createNode(scope, group, 'treeEnsembleRegressor', null, description, + this._createNode(args, group, 'treeEnsembleRegressor', null, description, model.treeEnsembleRegressor.treeEnsemble, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Tree Ensemble Regressor'; } case 'supportVectorClassifier': { - this._createNode(scope, group, 'supportVectorClassifier', null, description, + this._createNode(args, group, 'supportVectorClassifier', null, description, { coefficients: model.supportVectorClassifier.coefficients, denseSupportVectors: model.supportVectorClassifier.denseSupportVectors, @@ -440,11 +475,11 @@ coreml.Graph = class { }, [ model.description.input[0].name ], [ model.description.output[0].name ]); - this._updateClassifierOutput(group, model.supportVectorClassifier); + this._updateClassifierOutput(args, group, model.supportVectorClassifier); return 'Support Vector Classifier'; } case 'supportVectorRegressor': { - this._createNode(scope, group, 'supportVectorRegressor', null, description, + this._createNode(args, group, 'supportVectorRegressor', null, description, { coefficients: model.supportVectorRegressor.coefficients, kernel: model.supportVectorRegressor.kernel, @@ -459,7 +494,7 @@ coreml.Graph = class { const categoryType = model.oneHotEncoder.CategoryType; const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse }; oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType]; - this._createNode(scope, group, 'oneHotEncoder', null, description, + this._createNode(args, group, 'oneHotEncoder', null, description, oneHotEncoderParams, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -471,49 +506,49 @@ coreml.Graph = class { const imputerParams = {}; imputerParams[imputedValue] = model.imputer[imputedValue]; imputerParams[replaceValue] = model.imputer[replaceValue]; - this._createNode(scope, group, 'oneHotEncoder', null, description, + this._createNode(args, group, 'oneHotEncoder', null, description, imputerParams, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Imputer'; } case 'featureVectorizer': { - this._createNode(scope, group, 'featureVectorizer', null, description, + this._createNode(args, group, 'featureVectorizer', null, description, model.featureVectorizer, coreml.Graph._formatFeatureDescriptionList(model.description.input), [ model.description.output[0].name ]); return 'Feature Vectorizer'; } case 'dictVectorizer': { - this._createNode(scope, group, 'dictVectorizer', null, description, + this._createNode(args, group, 'dictVectorizer', null, description, model.dictVectorizer, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Dictionary Vectorizer'; } case 'scaler': { - this._createNode(scope, group, 'scaler', null, description, + this._createNode(args, group, 'scaler', null, description, model.scaler, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Scaler'; } case 'categoricalMapping': { - this._createNode(scope, group, 'categoricalMapping', null, description, + this._createNode(args, group, 'categoricalMapping', null, description, model.categoricalMapping, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Categorical Mapping'; } case 'normalizer': { - this._createNode(scope, group, 'normalizer', null, description, + this._createNode(args, group, 'normalizer', null, description, model.normalizer, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Normalizer'; } case 'arrayFeatureExtractor': { - this._createNode(scope, group, 'arrayFeatureExtractor', null, description, + this._createNode(args, group, 'arrayFeatureExtractor', null, description, { extractIndex: model.arrayFeatureExtractor.extractIndex }, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -526,7 +561,7 @@ coreml.Graph = class { iouThreshold: model.nonMaximumSuppression.iouThreshold, confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold }; - this._createNode(scope, group, 'nonMaximumSuppression', null, description, + this._createNode(args, group, 'nonMaximumSuppression', null, description, nonMaximumSuppressionParams, [ model.nonMaximumSuppression.confidenceInputFeatureName, @@ -541,7 +576,7 @@ coreml.Graph = class { return 'Non Maximum Suppression'; } case 'wordTagger': { - this._createNode(scope, group, 'wordTagger', null, description, + this._createNode(args, group, 'wordTagger', null, description, model.wordTagger, [ model.description.input[0].name ], [ @@ -553,7 +588,7 @@ coreml.Graph = class { return 'Word Tagger'; } case 'textClassifier': { - this._createNode(scope, group, 'textClassifier', null, description, + this._createNode(args, group, 'textClassifier', null, description, model.textClassifier, [ model.description.input[0].name ], [ model.description.output[0].name ]); @@ -563,29 +598,29 @@ coreml.Graph = class { const visionFeaturePrintParams = { scene: model.visionFeaturePrint.scene }; - this._createNode(scope, group, 'visionFeaturePrint', null, description, + this._createNode(args, group, 'visionFeaturePrint', null, description, visionFeaturePrintParams, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Vision Feature Print'; } case 'soundAnalysisPreprocessing': { - this._createNode(scope, group, 'soundAnalysisPreprocessing', null, description, + this._createNode(args, group, 'soundAnalysisPreprocessing', null, description, model.soundAnalysisPreprocessing, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Sound Analysis Preprocessing'; } case 'kNearestNeighborsClassifier': { - this._createNode(scope, group, 'kNearestNeighborsClassifier', null, description, + this._createNode(args, group, 'kNearestNeighborsClassifier', null, description, model.kNearestNeighborsClassifier, [ model.description.input[0].name ], [ model.description.output[0].name ]); - this._updateClassifierOutput(group, model.kNearestNeighborsClassifier); + this._updateClassifierOutput(args, group, model.kNearestNeighborsClassifier); return 'Nearest Neighbors Classifier'; } case 'itemSimilarityRecommender': { - this._createNode(scope, group, 'itemSimilarityRecommender', null, description, + this._createNode(args, group, 'itemSimilarityRecommender', null, description, { itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector, itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities @@ -595,28 +630,28 @@ coreml.Graph = class { return 'Item Similarity Recommender'; } case 'audioFeaturePrint': { - this._createNode(scope, group, 'audioFeaturePrint', null, description, + this._createNode(args, group, 'audioFeaturePrint', null, description, model.audioFeaturePrint, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Audio Feature Print'; } case 'linkedModel': { - this._createNode(scope, group, 'linkedModel', null, description, + this._createNode(args, group, 'linkedModel', null, description, model.linkedModel.linkedModelFile, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'Linked Model'; } case 'customModel': { - this._createNode(scope, group, 'customModel', null, description, + this._createNode(args, group, 'customModel', null, description, { className: model.customModel.className, parameters: model.customModel.parameters }, [ model.description.input[0].name ], [ model.description.output[0].name ]); return 'customModel'; } case 'mlProgram': { - return this._loadProgram(model.mlProgram, scope, group, weights); + return this._loadProgram(model.mlProgram, args, group, weights); } default: { throw new coreml.Error("Unsupported model type '" + JSON.stringify(Object.keys(model)) + "'."); @@ -624,7 +659,7 @@ coreml.Graph = class { } } - _loadProgram(program, scope, group, weights) { + _loadProgram(program, _, group, weights) { // TODO: need to handle functions other than main? const main = program.functions.main; // TODO: need to handle more than one block specialization? @@ -832,20 +867,26 @@ coreml.Graph = class { return 'ML Program'; } - _createNode(scope, group, type, name, description, data, inputs, outputs, outputTypes) { - inputs = inputs.map((input) => scope[input] ? scope[input].argument : input); - outputs = outputs.map((output) => { - if (scope[output]) { - scope[output].counter++; - const next = output + '\n' + scope[output].counter.toString(); // custom argument id - scope[output].argument = next; - return next; + _createNode(args, group, type, name, description, data, inputs, outputs, inputTensors, outputTensors) { + + inputs = inputs.map((input, index) => { + const argument = args.input(input); + if (!argument.type && inputTensors && index < inputTensors.length) { + const tensor = inputTensors[index]; + const shape = tensor && tensor.dimValue ? new coreml.TensorShape(tensor.dimValue) : null; + argument.type = new coreml.TensorType('?', shape); } - scope[output] = { - argument: output, - counter: 0 - }; - return output; + return argument; + }); + + outputs = outputs.map((output, index) => { + const argument = args.output(output); + if (!argument.type && outputTensors && index < outputTensors.length) { + const tensor = outputTensors[index]; + const shape = tensor && tensor.dimValue ? new coreml.TensorShape(tensor.dimValue) : null; + argument.type = new coreml.TensorType('?', shape); + } + return argument; }); const initializers = []; @@ -859,19 +900,29 @@ coreml.Graph = class { attributes[key] = data[key]; } } - const inputParameters = this._metadata.getInputs(type, inputs).map((input) => { - return new coreml.Parameter(input.name, true, input.arguments.map((argument) => { - return new coreml.Argument(argument.name, argument.type, null, null); - })); - }); - inputParameters.push(...initializers); - const outputParameters = outputs.map((output, index) => { - const name = this._metadata.getOutputName(type, index); - const outputType = outputTypes ? outputTypes[index] : null; - return new coreml.Parameter(name, true, [ new coreml.Argument(output, outputType, null, null) ]); - }); - const node = new coreml.Node(this._metadata, group, type, name, description, attributes, inputParameters, outputParameters); + const metadata = this._metadata.type(type); + const inputParams = []; + for (let i = 0; i < inputs.length; ) { + const input = metadata && metadata.inputs && i < metadata.inputs.length ? metadata.inputs[i] : { name: i === 0 ? 'input' : i.toString() }; + const count = input.type === 'Tensor[]' ? inputs.length - i : 1; + const args = inputs.slice(i, i + count); + inputParams.push(new coreml.Parameter(input.name, true, args)); + i += count; + } + + inputParams.push(...initializers); + + const outputParams = []; + for (let i = 0; i < outputs.length; ) { + const output = metadata && metadata.outputs && i < metadata.outputs.length ? metadata.outputs[i] : { name: i === 0 ? 'output' : i.toString() }; + const count = output.type === 'Tensor[]' ? outputs.length - i : 1; + const args = outputs.slice(i, i + count); + outputParams.push(new coreml.Parameter(output.name, true, args)); + i += count; + } + + const node = new coreml.Node(this._metadata, group, type, name, description, attributes, inputParams, outputParams); this._nodes.push(node); return node; } @@ -903,7 +954,8 @@ coreml.Graph = class { const tensorType = new coreml.TensorType(dataType, new coreml.TensorShape(shape)); const tensor = new coreml.Tensor(kind, tensorType, values, quantization); const argument = new coreml.Argument('', null, null, tensor); - const visible = this._metadata.visible(type, name); + const input = this._metadata.input(type, name); + const visible = input && input.visible === false ? false : true; initializers.push(new coreml.Parameter(name, visible, [ argument ])); } @@ -1070,7 +1122,7 @@ coreml.Argument = class { throw new coreml.Error("Invalid argument identifier '" + JSON.stringify(name) + "'."); } this._name = name; - this._type = type; + this._type = type || null; this._description = description || null; this._initializer = initializer || null; } @@ -1090,10 +1142,18 @@ coreml.Argument = class { return this._type; } + set type(value) { + this._type = value; + } + get description() { return this._description; } + set description(value) { + this._description = value; + } + get quantization() { if (this._initializer) { return this._initializer.quantization; @@ -1505,119 +1565,6 @@ coreml.Utility = class { } }; -coreml.Metadata = class { - - static open(context) { - if (coreml.Metadata._metadata) { - return Promise.resolve(coreml.Metadata._metadata); - } - return context.request('coreml-metadata.json', 'utf-8', null).then((data) => { - coreml.Metadata._metadata = new coreml.Metadata(data); - return coreml.Metadata._metadata; - }).catch(() => { - coreml.Metadata._metadata = new coreml.Metadata(null); - return coreml.Metadata._metadata; - }); - } - - constructor(data) { - this._map = new Map(); - this._attributeCache = new Map(); - this._inputCache = new Map(); - if (data) { - const metadata = JSON.parse(data); - this._map = new Map(metadata.map((item) => [ item.name, item ])); - } - } - - type(name) { - return this._map.get(name); - } - - attribute(type, name) { - const key = type + ':' + name; - if (!this._attributeCache.has(key)) { - this._attributeCache.set(key, null); - const metadata = this.type(type); - if (metadata && Array.isArray(metadata.attributes) && metadata.attributes.length > 0) { - for (const attribute of metadata.attributes) { - this._attributeCache.set(type + ':' + attribute.name, attribute); - } - } - } - return this._attributeCache.get(key); - } - - visible(type, name) { - const key = type + ':' + name; - if (!this._inputCache.has(key)) { - this._inputCache.set(key, null); - const metadata = this.type(type); - if (metadata && Array.isArray(metadata.inputs) && metadata.inputs.length > 0) { - for (const input of metadata.inputs) { - this._inputCache.set(type + ':' + input.name, input); - } - } - } - const input = this._inputCache.get(key); - if (input) { - return input.visible === false ? false : true; - } - return true; - } - - getInputs(type, inputs) { - const results = []; - const schema = this._map.get(type); - let index = 0; - while (index < inputs.length) { - const result = { arguments: [] }; - let count = 1; - let name = null; - if (schema && schema.inputs) { - if (index < schema.inputs.length) { - const input = schema.inputs[index]; - name = input.name; - if (schema.inputs[index].option == 'variadic') { - count = inputs.length - index; - } - } - } - else if (index == 0) { - name = 'input'; - } - result.name = name ? name : '(' + index.toString() + ')'; - const array = inputs.slice(index, index + count); - for (let j = 0; j < array.length; j++) { - result.arguments.push({ name: array[j] }); - } - index += count; - results.push(result); - } - return results; - } - - getOutputName(type, index) { - const schema = this._map.get(type); - if (schema) { - const outputs = schema.outputs; - if (outputs && index < outputs.length) { - const output = outputs[index]; - if (output) { - const name = output.name; - if (name) { - return name; - } - } - } - } - if (index == 0) { - return 'output'; - } - return '(' + index.toString() + ')'; - } -}; - coreml.Error = class extends Error { constructor(message) { super(message); diff --git a/source/view.js b/source/view.js index 12537d1ecc..699c6c1e7d 100644 --- a/source/view.js +++ b/source/view.js @@ -2143,6 +2143,7 @@ view.Metadata = class { constructor(data) { this._types = new Map(); this._attributes = new Map(); + this._inputs = new Map(); if (data) { const metadata = JSON.parse(data); for (const entry of metadata) { @@ -2174,6 +2175,20 @@ view.Metadata = class { } return this._attributes.get(key); } + + input(type, name) { + const key = type + ':' + name; + if (!this._inputs.has(key)) { + this._inputs.set(key, null); + const metadata = this.type(type); + if (metadata && Array.isArray(metadata.inputs)) { + for (const input of metadata.inputs) { + this._inputs.set(type + ':' + input.name, input); + } + } + } + return this._inputs.get(key); + } }; view.Error = class extends Error {