Skip to content

Commit

Permalink
Update coreml.js (#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 13, 2023
1 parent 060aa4b commit 0b978dc
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions source/coreml.js
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,13 @@ coreml.Graph = class {
this.type = context.type;
this.groups = context.groups;
for (const value of context.values.values()) {
const name = value.name || '';
const type = value.type || null;
const description = value.description || null;
const initializer = value.initializer || null;
value.obj = new coreml.Value(name, type, description, initializer);
const name = value.name;
const type = value.type;
const description = value.description;
const initializer = value.initializer;
if (!value.obj) {
value.obj = new coreml.Value(name, type, description, initializer);
}
}
this.inputs = context.inputs.map((argument) => {
const values = argument.value.map((value) => value.obj);
Expand All @@ -237,14 +239,21 @@ coreml.Graph = class {
const values = argument.value.map((value) => value.obj);
return new coreml.Argument(argument.name, argument.visible, values);
});
/*
for (const obj of context.nodes) {
if (obj.type === 'loop') {
obj.attributes.conditionNetwork = new coreml.Graph(obj.attributes.conditionNetwork);
obj.attributes.bodyNetwork = new coreml.Graph(obj.attributes.bodyNetwork);
const attributes = obj.attributes;
switch (obj.type) {
case 'loop':
attributes.conditionNetwork = new coreml.Graph(attributes.conditionNetwork);
attributes.bodyNetwork = new coreml.Graph(attributes.bodyNetwork);
break;
case 'branch':
attributes.ifBranch = new coreml.Graph(attributes.ifBranch);
attributes.elseBranch = new coreml.Graph(attributes.elseBranch);
break;
default:
break;
}
}
*/
this.nodes = context.nodes.map((obj) => new coreml.Node(context, obj));
}
};
Expand Down Expand Up @@ -554,12 +563,14 @@ coreml.Context = class {
}

network(obj) {
const context = this.context();
for (const layer of obj.layers) {
const type = layer.layer;
this.node(this.groups, type, layer.name, '', layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
context.node(context.groups, type, layer.name, '', layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
}
this.updatePreprocessing('', obj.preprocessing, null);
this.type = 'Neural Network';
context.updatePreprocessing('', obj.preprocessing, null);
context.type = 'Neural Network';
return context;
}

input(name) {
Expand Down Expand Up @@ -777,23 +788,25 @@ coreml.Context = class {
}
};
if (data) {
const attributes = obj.attributes;
const map = weights(type, data, initializers);
for (const [name, value] of Object.entries(data)) {
if (!map[name]) {
obj.attributes[name] = value;
attributes[name] = value;
}
}
/*
if (obj.type === 'loop') {
const network = (context, obj) => {
context = context.context();
context.network(obj);
return context;
};
obj.attributes.bodyNetwork = network(this, obj.attributes.bodyNetwork);
obj.attributes.conditionNetwork = network(this, obj.attributes.conditionNetwork);
switch (obj.type) {
case 'loop':
attributes.bodyNetwork = this.network(attributes.bodyNetwork);
attributes.conditionNetwork = this.network(attributes.conditionNetwork);
break;
case 'branch':
attributes.ifBranch = this.network(attributes.ifBranch);
attributes.elseBranch = this.network(attributes.elseBranch);
break;
default:
break;
}
*/
}
const metadata = this.metadata.type(type);
for (let i = 0; i < inputs.length;) {
Expand Down

0 comments on commit 0b978dc

Please sign in to comment.