Skip to content

Commit

Permalink
Add Core ML test file (#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 13, 2023
1 parent 7b9a6be commit 31a487d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 45 deletions.
124 changes: 79 additions & 45 deletions source/coreml.js
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ coreml.Model = class {
constructor(metadata, format, model, weights) {
this.format = (format || 'Core ML') + ' v' + model.specificationVersion.toString();
this.metadata = [];
this.graphs = [ new coreml.Graph(metadata, model, weights) ];
const context = new coreml.Context(metadata, model, weights);
const graph = new coreml.Graph(context);
this.graphs = [ graph ];
if (model.description && model.description.metadata) {
const properties = model.description.metadata;
if (properties.versionString) {
Expand All @@ -194,29 +196,34 @@ coreml.Model = class {

coreml.Graph = class {

constructor(metadata, model, weights) {
const transformer = new coreml.Transformer(metadata, weights);
constructor(context) {
this.name = '';
this.type = transformer.transform(model, '');
this.groups = transformer.groups;
for (const value of transformer.values.values()) {
this.type = context.type;
this.groups = context.groups;
for (const value of context.values.values()) {
const name = value.name || '';
const type = value.type || null;
const description = value.description || null;
const initializer = value.initializer || null;
value.obj = new coreml.Value(name, type, description, initializer);
}
this.inputs = transformer.inputs.map((argument) => {
this.inputs = context.inputs.map((argument) => {
const values = argument.value.map((value) => value.obj);
return new coreml.Argument(argument.name, argument.visible, values);
});
this.outputs = transformer.outputs.map((argument) => {
this.outputs = context.outputs.map((argument) => {
const values = argument.value.map((value) => value.obj);
return new coreml.Argument(argument.name, argument.visible, values);
});
this.nodes = transformer.nodes.map((obj) => {
return new coreml.Node(metadata, obj);
});
/*
for (const obj of context.nodes) {
if (obj.type === 'loop') {
obj.attributes.conditionNetwork = new coreml.Graph(obj.attributes.conditionNetwork);
obj.attributes.bodyNetwork = new coreml.Graph(obj.attributes.bodyNetwork);
}
}
*/
this.nodes = context.nodes.map((obj) => new coreml.Node(context, obj));
}
};

Expand Down Expand Up @@ -245,14 +252,14 @@ coreml.Value = class {

coreml.Node = class {

constructor(metadata, obj) {
constructor(context, obj) {
if (!obj.type) {
throw new Error('Undefined node type.');
}
if (obj.group) {
this.group = obj.group || null;
}
this.type = Object.assign({}, metadata.type(obj.type) || { name: obj.type });
this.type = Object.assign({}, context.metadata.type(obj.type) || { name: obj.type });
this.type.name = obj.type.split(':').pop();
this.name = obj.name || '';
this.description = obj.description || '';
Expand All @@ -265,7 +272,8 @@ coreml.Node = class {
return new coreml.Argument(argument.name, argument.visible, values);
});
this.attributes = Object.entries(obj.attributes).map(([name, value]) => {
return new coreml.Attribute(metadata.attribute(obj.type, name), name, value);
const metadata = context.metadata.attribute(obj.type, name);
return new coreml.Attribute(metadata, name, value);
});
}
};
Expand Down Expand Up @@ -296,6 +304,9 @@ coreml.Attribute = class {
}
}
}
if (this.value instanceof coreml.Graph) {
this.type = 'graph';
}
}
};

Expand Down Expand Up @@ -481,25 +492,52 @@ coreml.ImageType = class {
coreml.OptionalType = class {

constructor(type) {
this._type = type;
}

get type() {
return this._type;
this.type = type;
}

toString() {
return 'optional<' + this._type.toString() + '>';
return 'optional<' + this.type.toString() + '>';
}
};

coreml.Transformer = class {
coreml.Context = class {

constructor(metadata, weights) {
constructor(metadata, model, weights, values) {
this.metadata = metadata;
this.weights = weights;
this.values = new Map();
this.values = values || new Map();
this.nodes = [];
this.inputs = [];
this.outputs = [];
if (model) {
const description = model.description;
const inputs = description && Array.isArray(description.input) ? description.input : [];
for (const description of inputs) {
const value = this.output(description.name);
this.update(value, description);
this.inputs.push({ name: description.name, visible: true, value: [ value ] });
}
this.type = this.model(model, '', description);
const outputs = description && Array.isArray(description.output) ? description.output : [];
for (const description of outputs) {
const value = this.input(description.name);
this.update(value, description);
this.outputs.push({ name: description.name, visible: true, value: [ value ] });
}
}
}

context() {
return new coreml.Context(this.metadata, null, this.weights, this.values);
}

network(obj) {
for (const layer of obj.layers) {
const type = layer.layer;
this.node(this.groups, type, layer.name, '', layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
}
this.updatePreprocessing('', obj.preprocessing, null);
this.type = 'Neural Network';
}

input(name) {
Expand All @@ -513,12 +551,12 @@ coreml.Transformer = class {
if (!this.values.has(name)) {
const value = { counter: 0, name: name, to: [], from: [] };
this.values.set(name, value);
const key = name + '\n' + value.counter.toString();
const key = name + '|' + value.counter.toString();
this.values.set(key, value);
} else {
const value = Object.assign({}, this.values.get(name));
value.counter++;
value.name = name + '\n' + value.counter.toString(); // custom argument id
value.name = name + '|' + value.counter.toString(); // custom argument id
this.values.set(name, value);
this.values.set(value.name, value);
}
Expand Down Expand Up @@ -723,6 +761,17 @@ coreml.Transformer = class {
obj.attributes[name] = value;
}
}
/*
if (obj.type === 'loop') {
const network = (context, obj) => {
context = context.context();
context.network(obj);
return context;
};
obj.attributes.bodyNetwork = network(this, obj.attributes.bodyNetwork);
obj.attributes.conditionNetwork = network(this, obj.attributes.conditionNetwork);
}
*/
}
const metadata = this.metadata.type(type);
for (let i = 0; i < inputs.length;) {
Expand All @@ -744,32 +793,15 @@ coreml.Transformer = class {
return obj;
}

transform(model, group) {
const description = model.description;
const inputs = description && Array.isArray(description.input) ? description.input : [];
this.inputs = inputs.map((description) => {
const value = this.output(description.name);
this.update(value, description);
return { name: description.name, visible: true, value: [ value ] };
});
const type = this.model(model, group, description);
const outputs = description && Array.isArray(description.output) ? description.output : [];
this.outputs = outputs.map((output) => {
const value = this.input(output.name);
this.update(value, output);
return { name: output.name, visible: true, value: [ value ] };
});
return type;
}

model(model, group, description) {
this.groups = this.groups | (group.length > 0 ? true : false);
const shortDescription = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : '';
switch (model.Type) {
case 'neuralNetworkClassifier': {
const neuralNetworkClassifier = model.neuralNetworkClassifier;
for (const layer of neuralNetworkClassifier.layers) {
this.node(group, layer.layer, layer.name, group === '' ? '' : shortDescription, layer[layer.layer], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
const type = layer.layer;
this.node(group, type, layer.name, group === '' ? '' : shortDescription, layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
}
this.updateClassifierOutput(group, neuralNetworkClassifier, description);
this.updatePreprocessing(group, neuralNetworkClassifier.preprocessing, description);
Expand Down Expand Up @@ -817,7 +849,7 @@ coreml.Transformer = class {
weights: model.glmClassifier.weights
},
[ model.description.input[0].name ],
[ model.description.predictedProbabilitiesName ]);
[ model.description.output[0].name ]);
this.updateClassifierOutput(group, model.glmClassifier, description);
return 'Generalized Linear Classifier';
}
Expand Down Expand Up @@ -1066,6 +1098,8 @@ coreml.Transformer = class {
}
}
}
this.values.set(labelProbabilityInput, this.values.get(labelProbabilityLayerName));
this.values.delete(labelProbabilityLayerName);
const type = classifier.ClassLabels;
const node = {
// group: this._group,
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,13 @@
"format": "Core ML v2",
"link": "https://github.com/kingreza/quantization"
},
{
"type": "coreml",
"target": "FrameGRUModel.mlpackage.zip",
"source": "https://github.com/lutzroeder/netron/files/13643385/FrameGRUModel.mlpackage.zip",
"format": "Core ML Package v4",
"link": "https://github.com/lutzroeder/netron/issues/193"
},
{
"type": "coreml",
"target": "GoogLeNetPlaces.mlmodel",
Expand Down

0 comments on commit 31a487d

Please sign in to comment.