diff --git a/source/ggml.js b/source/ggml.js index fb9fd77674..24a4b3447c 100644 --- a/source/ggml.js +++ b/source/ggml.js @@ -20,19 +20,22 @@ ggml.Model = class { constructor(target) { this.format = target.format; + this.metadata = new Map(); const layers = new Map(); - layers.map = (key) => { + for (const [name, tensor] of target.tensors) { + const [key, param] = name.match(/^(.*)\.(.*?)$/).slice(1); if (!layers.has(key)) { - layers.set(key, { metadata: new Map(), weights: new Map() }); + layers.set(key, { name: key, type: 'weights', metadata: new Map(), weights: new Map() }); } - return layers.get(key); - }; - this.metadata = new Map(); + const layer = layers.get(key); + layer.weights.set(param, tensor); + } const metadata = new Map(); + let architecture = '?'; for (const [name, value] of target.metadata) { switch (name) { case 'general.name': this.name = value; break; - case 'general.architecture': this.runtime = value; break; + case 'general.architecture': architecture = value; break; case 'general.description': this.description = value; break; case 'general.author': this.metadata.set('author', value); break; case 'general.license': this.metadata.set('license', value); break; @@ -44,38 +47,35 @@ ggml.Model = class { break; } } + const tokenizer = { type: 'tokenizer', metadata: new Map(), layers: [] }; + const model = { type: architecture, metadata: new Map(), layers: Array.from(layers.values()) }; for (const [name, value] of metadata) { if (name.startsWith('tokenizer.')) { - const [key, param] = name.match(/^(.*)\.(.*?)$/).slice(1); - const layer = layers.map(key); - layer.type = 'Tokenizer'; - layer.metadata.set(param, value); - } else if (this.runtime && name.startsWith(this.runtime + '.')) { - const layer = layers.map(''); - layer.type = 'Parameters'; - layer.metadata.set(name, value); + const [, param] = name.match(/^(.*)\.(.*?)$/).slice(1); + tokenizer.metadata.set(param, value); + } else if (architecture && name.startsWith(architecture + '.')) { + model.metadata.set(name, value); } else { this.metadata.set(name, value); } } - for (const [name, tensor] of target.tensors) { - const [key, param] = name.match(/^(.*)\.(.*?)$/).slice(1); - const layer = layers.map(key); - layer.type = 'Weights'; - layer.weights.set(param, tensor); + const graph = { layers: [ model ] }; + if (tokenizer.metadata.size > 0) { + graph.layers.push(tokenizer); } - this.graphs = [ new ggml.Graph(target.metadata, layers) ]; + this.graphs = [ new ggml.Graph(graph) ]; } }; ggml.Graph = class { - constructor(metadata, layers) { + constructor(graph) { + this.name = graph.type; this.nodes = []; this.inputs = []; this.outputs = []; - for (const [name, layer] of layers) { - const node = new ggml.Node(name, layer); + for (const layer of graph.layers) { + const node = new ggml.Node(layer); this.nodes.push(node); } } @@ -101,21 +101,25 @@ ggml.Value = class { ggml.Node = class { - constructor(name, layer) { - this.type = { name: layer.type }; - this.name = name; + constructor(layer) { + this.type = Array.isArray(layer.layers) && layer.layers.length > 0 ? new ggml.Graph(layer) : { name: layer.type }; + this.name = layer.name || ''; this.inputs = []; this.outputs = []; this.attributes = []; - for (const [name, weight] of layer.weights) { - const tensor = new ggml.Tensor(weight); - const value = new ggml.Value(weight.name, tensor); - const argument = new ggml.Argument(name, [ value ]); - this.inputs.push(argument); + if (layer.weights) { + for (const [name, weight] of layer.weights) { + const tensor = new ggml.Tensor(weight); + const value = new ggml.Value(weight.name, tensor); + const argument = new ggml.Argument(name, [ value ]); + this.inputs.push(argument); + } } - for (const [name, value] of layer.metadata) { - const attribute = new ggml.Attribute(name, value); - this.attributes.push(attribute); + if (layer.metadata) { + for (const [name, value] of layer.metadata) { + const attribute = new ggml.Attribute(name, value); + this.attributes.push(attribute); + } } } };