Skip to content

Commit

Permalink
Update view.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 23, 2024
1 parent 68b31e0 commit a2ecc16
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 110 deletions.
31 changes: 23 additions & 8 deletions source/keras.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ keras.ModelFactory = class {
}
const container = tfjs.Container.open(context);
if (container) {
context.type = 'tfjs.json';
context.type = 'tfjs';
context.target = container;
return;
}
Expand Down Expand Up @@ -83,6 +83,9 @@ keras.ModelFactory = class {
if (context.type === 'keras.config.json' && (type === 'keras.model.weights.h5' || type === 'keras.model.weights.npz')) {
return false;
}
if (context.type === 'tfjs' && type === 'tf.tfjs.weights') {
return false;
}
return true;
}

Expand Down Expand Up @@ -460,10 +463,10 @@ keras.ModelFactory = class {
walk(weights_group);
return open_model(format, '', '', null, weights);
}
case 'tfjs.json': {
const container = tfjs.Container.open(context);
await container.open();
return open_model(container.format, container.producer, container.backend, container.config, container.weights);
case 'tfjs': {
const target = context.target;
await target.read();
return open_model(target.format, target.producer, target.backend, target.config, target.weights);
}
case 'keras.pickle': {
const obj = context.target;
Expand Down Expand Up @@ -1332,12 +1335,16 @@ tfjs.Container = class {
return new tfjs.Container(context, '');
}
if (Array.isArray(json) && json.every((item) => item.weights && item.paths)) {
return new tfjs.Container(context, 'weights');
return new tfjs.Container(context, 'weights.json');
}
if (json.tfjsVersion) {
return new tfjs.Container(context, 'metadata');
}
}
const identifier = context.identifier;
if (/^.*group\d+-shard\d+of\d+(\.bin)?$/.test(identifier)) {
return new tfjs.Container(context, 'weights.bin');
}
return null;
}

Expand All @@ -1346,13 +1353,13 @@ tfjs.Container = class {
this.type = type;
}

async open() {
async read() {
switch (this.type) {
case '': {
const obj = this.context.peek('json');
return this._openModelJson(obj);
}
case 'weights': {
case 'weights.json': {
this.format = 'TensorFlow.js Weights';
this.config = null;
const obj = this.context.peek('json');
Expand All @@ -1366,6 +1373,11 @@ tfjs.Container = class {
}
return this._openManifests(manifests);
}
case 'weights.bin': {
const content = await this.context.fetch('model.json');
const obj = content.read('json');
return this._openModelJson(obj);
}
case 'metadata': {
const content = await this.context.fetch('model.json');
const obj = content.read('json');
Expand Down Expand Up @@ -1441,6 +1453,9 @@ tfjs.Container = class {
}

_openModelJson(obj) {
if (!obj || !obj.modelTopology || (obj.format !== 'layers-model' && !obj.modelTopology.model_config && !obj.modelTopology.config)) {
throw new tfjs.Error('File format is not TensorFlow.js layers-model.');
}
const modelTopology = obj.modelTopology;
this.format = `TensorFlow.js ${obj.format ? obj.format : `Keras${modelTopology.keras_version ? (` v${modelTopology.keras_version}`) : ''}`}`;
this.producer = obj.convertedBy || obj.generatedBy || '';
Expand Down
213 changes: 114 additions & 99 deletions source/tf.js
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,23 @@ tf.ModelFactory = class {
const offset = reader.uint64().toNumber();
if (offset < stream.length) {
context.type = 'tf.pb.mmap';
return;
}
}
}
if (/^.*group\d+-shard\d+of\d+(\.bin)?$/.test(identifier)) {
context.type = 'tf.tfjs.weights';
}
}

filter(context, type) {
return context.type !== 'tf.bundle' || type !== 'tf.data';
if (context.type === 'tf.bundle' && type === 'tf.data') {
return false;
}
if ((context.type === 'tf.json' || context.type === 'tf.json.gz') && type === 'tf.tfjs.weights') {
return false;
}
return true;
}

async open(context) {
Expand Down Expand Up @@ -399,121 +409,124 @@ tf.ModelFactory = class {
return openSavedModel(context, saved_model, format, producer);
};
const openJson = async (context, type) => {
try {
const obj = context.peek(type);
const format = `TensorFlow.js ${obj.format || 'graph-model'}`;
const producer = obj.convertedBy || obj.generatedBy || '';
const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
meta_graph.graph_def = tf.JsonReader.decodeGraphDef(obj.modelTopology);
const saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
const nodes = new Map();
for (const node of meta_graph.graph_def.node) {
node.input = node.input || [];
if (node.op === 'Const') {
nodes.set(node.name, node);
const obj = context.peek(type);
if (!obj || !obj.modelTopology || (obj.format !== 'graph-model' && !Array.isArray(obj.modelTopology.node))) {
throw new tf.Error('File format is not TensorFlow.js graph-model.');
}
const format = `TensorFlow.js ${obj.format || 'graph-model'}`;
const producer = obj.convertedBy || obj.generatedBy || '';
const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
meta_graph.graph_def = tf.JsonReader.decodeGraphDef(obj.modelTopology);
const saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
const nodes = new Map();
for (const node of meta_graph.graph_def.node) {
node.input = node.input || [];
if (node.op === 'Const') {
nodes.set(node.name, node);
}
}
const shards = new Map();
const manifests = Array.isArray(obj.weightsManifest) ? obj.weightsManifest : [];
for (const manifest of manifests) {
for (const path of manifest.paths) {
if (!shards.has(path)) {
shards.set(path, context.fetch(path));
}
}
const shards = new Map();
const manifests = Array.isArray(obj.weightsManifest) ? obj.weightsManifest : [];
}
const openShards = (shards) => {
const dtype_size_map = new Map([
['float16', 2], ['float32', 4], ['float64', 8],
['int8', 1], ['int16', 2], ['int32', 4], ['int64', 8],
['uint8', 1], ['uint16', 2], ['uint32', 4], ['uint64', 8],
['bool', 1]
]);
for (const manifest of manifests) {
for (const path of manifest.paths) {
if (!shards.has(path)) {
shards.set(path, context.fetch(path));
let buffer = null;
if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
const list = manifest.paths.map((path) => shards.get(path));
const size = list.reduce((a, b) => a + b.length, 0);
buffer = new Uint8Array(size);
let offset = 0;
for (const item of list) {
buffer.set(item, offset);
offset += item.length;
}
}
}
const openShards = (shards) => {
const dtype_size_map = new Map([
['float16', 2], ['float32', 4], ['float64', 8],
['int8', 1], ['int16', 2], ['int32', 4], ['int64', 8],
['uint8', 1], ['uint16', 2], ['uint32', 4], ['uint64', 8],
['bool', 1]
]);
for (const manifest of manifests) {
let buffer = null;
if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
const list = manifest.paths.map((path) => shards.get(path));
const size = list.reduce((a, b) => a + b.length, 0);
buffer = new Uint8Array(size);
let offset = 0;
for (const item of list) {
buffer.set(item, offset);
offset += item.length;
}
}
let offset = 0;
for (const weight of manifest.weights) {
const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
const size = weight.shape.reduce((a, b) => a * b, 1);
switch (dtype) {
case 'string': {
const data = [];
if (buffer && size > 0) {
const reader = new tf.BinaryReader(buffer.subarray(offset));
for (let i = 0; i < size; i++) {
data[i] = reader.string();
}
offset += reader.position;
}
if (nodes.has(weight.name)) {
const node = nodes.get(weight.name);
node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
node.attr.value.tensor.string_val = data;
let offset = 0;
for (const weight of manifest.weights) {
const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
const size = weight.shape.reduce((a, b) => a * b, 1);
switch (dtype) {
case 'string': {
const data = [];
if (buffer && size > 0) {
const reader = new tf.BinaryReader(buffer.subarray(offset));
for (let i = 0; i < size; i++) {
data[i] = reader.string();
}
break;
offset += reader.position;
}
default: {
if (!dtype_size_map.has(dtype)) {
throw new tf.Error(`Unsupported weight data type size '${dtype}'.`);
}
const itemsize = dtype_size_map.get(dtype);
const length = itemsize * size;
const tensor_content = buffer ? buffer.slice(offset, offset + length) : null;
offset += length;
if (nodes.has(weight.name)) {
const node = nodes.get(weight.name);
node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
node.attr.value.tensor.tensor_content = tensor_content;
}
break;
if (nodes.has(weight.name)) {
const node = nodes.get(weight.name);
node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
node.attr.value.tensor.string_val = data;
}
break;
}
default: {
if (!dtype_size_map.has(dtype)) {
throw new tf.Error(`Unsupported weight data type size '${dtype}'.`);
}
const itemsize = dtype_size_map.get(dtype);
const length = itemsize * size;
const tensor_content = buffer ? buffer.slice(offset, offset + length) : null;
offset += length;
if (nodes.has(weight.name)) {
const node = nodes.get(weight.name);
node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
node.attr.value.tensor.tensor_content = tensor_content;
}
break;
}
}
}
return openSavedModel(context, saved_model, format, producer);
};
try {
const contexts = await Promise.all(shards.values());
for (const key of shards.keys()) {
const context = contexts.shift();
const buffer = context.stream.peek();
shards.set(key, buffer);
}
if (type === 'json.gz') {
try {
for (const key of shards.keys()) {
const stream = shards.get(key);
const archive = zip.Archive.open(stream, 'gzip');
if (archive && archive.entries.size === 1) {
const stream = archive.entries.values().next().value;
const buffer = stream.peek();
shards.set(key, buffer);
}
}
return openSavedModel(context, saved_model, format, producer);
};
try {
const contexts = await Promise.all(shards.values());
for (const key of shards.keys()) {
const context = contexts.shift();
const buffer = context.stream.peek();
shards.set(key, buffer);
}
if (type === 'json.gz') {
try {
for (const key of shards.keys()) {
const stream = shards.get(key);
const archive = zip.Archive.open(stream, 'gzip');
if (archive && archive.entries.size === 1) {
const stream = archive.entries.values().next().value;
const buffer = stream.peek();
shards.set(key, buffer);
}
} catch {
// continue regardless of error
}
} catch {
// continue regardless of error
}
return openShards(shards);
} catch {
shards.clear();
return openShards(shards);
}
} catch (error) {
throw new tf.Error(`File text format is not TensorFlow.js graph-model (${error.message}).`);
return openShards(shards);
} catch {
shards.clear();
return openShards(shards);
}
};
const openJsonWeights = async (context) => {
const content = await context.fetch('model.json');
return openJson(content, 'json');
};
const openTextGraphDef = (context) => {
try {
const reader = context.read('protobuf.text');
Expand Down Expand Up @@ -673,6 +686,8 @@ tf.ModelFactory = class {
return openJson(context, 'json');
case 'tf.json.gz':
return openJson(context, 'json.gz');
case 'tf.tfjs.weights':
return await openJsonWeights(context);
case 'tf.pbtxt.GraphDef':
return openTextGraphDef(context);
case 'tf.pbtxt.MetaGraphDef':
Expand Down
6 changes: 3 additions & 3 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -5322,7 +5322,7 @@ view.Context = class {

async fetch(file) {
const stream = await this._context.request(file, null, this._base);
return new view.Context(this, file, stream);
return new view.Context(this._context, file, stream);
}

async require(id) {
Expand Down Expand Up @@ -5764,9 +5764,9 @@ view.ModelFactoryService = class {
this.register('./caffe', ['.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt']);
this.register('./caffe2', ['.pb', '.pbtxt', '.prototxt']);
this.register('./torch', ['.t7', '.net']);
this.register('./tf', ['.pb', '.meta', '.pbtxt', '.prototxt', '.txt', '.pt', '.json', '.index', '.ckpt', '.graphdef', '.pbmm', /.data-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]$/, /^events.out.tfevents./], ['.zip']);
this.register('./tf', ['.pb', '.meta', '.pbtxt', '.prototxt', '.txt', '.pt', '.json', '.index', '.ckpt', '.graphdef', '.pbmm', /.data-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]$/, /^events.out.tfevents./, /^.*group\d+-shard\d+of\d+(\.bin)?$/], ['.zip']);
this.register('./tensorrt', ['.trt', '.trtmodel', '.engine', '.model', '.txt', '.uff', '.pb', '.tmfile', '.onnx', '.pth', '.dnn', '.plan', '.pt', '.dat', '.bin']);
this.register('./keras', ['.h5', '.hd5', '.hdf5', '.keras', '.json', '.cfg', '.model', '.pb', '.pth', '.weights', '.pkl', '.lite', '.tflite', '.ckpt', '.pb', 'model.weights.npz'], ['.zip']);
this.register('./keras', ['.h5', '.hd5', '.hdf5', '.keras', '.json', '.cfg', '.model', '.pb', '.pth', '.weights', '.pkl', '.lite', '.tflite', '.ckpt', '.pb', 'model.weights.npz', /^.*group\d+-shard\d+of\d+(\.bin)?$/], ['.zip']);
this.register('./numpy', ['.npz', '.npy', '.pkl', '.pickle', '.model', '.model2', '.mge', '.joblib']);
this.register('./lasagne', ['.pkl', '.pickle', '.joblib', '.model', '.pkl.z', '.joblib.z']);
this.register('./lightgbm', ['.txt', '.pkl', '.model']);
Expand Down

0 comments on commit a2ecc16

Please sign in to comment.