diff --git a/source/app.js b/source/app.js index fc411d11337..1901383060e 100644 --- a/source/app.js +++ b/source/app.js @@ -62,7 +62,7 @@ class Application { const paths = data.paths.filter((path) => { if (fs.existsSync(path)) { const stat = fs.statSync(path); - return stat.isFile() /* || stat.isDirectory() */; + return stat.isFile() || stat.isDirectory(); } return false; }); @@ -109,7 +109,7 @@ class Application { const extension = arg.split('.').pop().toLowerCase(); if (extension != '' && extension != 'js' && fs.existsSync(arg)) { const stat = fs.statSync(arg); - if (stat.isFile() /* || stat.isDirectory() */) { + if (stat.isFile() || stat.isDirectory()) { this._openPath(arg); open = true; } @@ -191,7 +191,7 @@ class Application { } if (path && path.length > 0 && fs.existsSync(path)) { const stat = fs.statSync(path); - if (stat.isFile() /* || stat.isDirectory() */) { + if (stat.isFile() || stat.isDirectory()) { // find existing view for this file let view = this._views.find(path); // find empty welcome window @@ -390,7 +390,7 @@ class Application { const path = recent.path; if (fs.existsSync(path)) { const stat = fs.statSync(path); - if (stat.isFile() /* || stat.isDirectory() */) { + if (stat.isFile() || stat.isDirectory()) { return true; } } diff --git a/source/dl4j.js b/source/dl4j.js index acd31f3ee15..2de9f2a3788 100644 --- a/source/dl4j.js +++ b/source/dl4j.js @@ -26,14 +26,14 @@ dl4j.ModelFactory = class { } static _openContainer(entries) { - const configurationEntries = entries.filter((entry) => entry.name === 'configuration.json'); - const coefficientsEntries = entries.filter((entry) => entry.name === 'coefficients.bin'); - if (configurationEntries.length === 1 && coefficientsEntries.length <= 1) { + const configurationStream = entries.get('configuration.json'); + const coefficientsStream = entries.get('coefficients.bin'); + if (configurationStream) { try { - const reader = json.TextReader.create(configurationEntries[0].data); + const reader = json.TextReader.create(configurationStream.peek()); const configuration = reader.read(); if (configuration && (configuration.confs || configuration.vertices)) { - const coefficients = coefficientsEntries.length == 1 ? coefficientsEntries[0].data : []; + const coefficients = coefficientsStream ? coefficientsStream.peek() : []; return { configuration: configuration, coefficients: coefficients }; } } diff --git a/source/dlc.js b/source/dlc.js index f0f485a477d..b1febce2d71 100644 --- a/source/dlc.js +++ b/source/dlc.js @@ -6,7 +6,7 @@ dlc.ModelFactory = class { match(context) { const entries = context.entries('zip'); - if (entries.find((entry) => entry.name === 'model')) { + if (entries.has('model')) { return true; } return false; diff --git a/source/electron.js b/source/electron.js index ac750d4260d..4636c7dc429 100644 --- a/source/electron.js +++ b/source/electron.js @@ -56,7 +56,7 @@ host.ElectronHost = class { initialize(view) { this._view = view; electron.ipcRenderer.on('open', (_, data) => { - this._openFile(data.path); + this._openPath(data.path); }); return new Promise((resolve /*, reject */) => { const accept = () => { @@ -117,8 +117,8 @@ host.ElectronHost = class { const queue = this._queue; delete this._queue; if (queue.length > 0) { - const file = queue.pop(); - this._openFile(file); + const path = queue.pop(); + this._openPath(path); } } @@ -281,17 +281,17 @@ host.ElectronHost = class { request(file, encoding, base) { return new Promise((resolve, reject) => { const pathname = path.join(base || __dirname, file); - fs.stat(pathname, (err, stats) => { + fs.stat(pathname, (err, stat) => { if (err && err.code === 'ENOENT') { reject(new Error("The file '" + file + "' does not exist.")); } else if (err) { reject(err); } - else if (!stats.isFile()) { + else if (!stat.isFile()) { reject(new Error("The path '" + file + "' is not a file.")); } - else if (stats && stats.size < 0x7ffff000) { + else if (stat && stat.size < 0x7ffff000) { fs.readFile(pathname, encoding, (err, data) => { if (err) { reject(err); @@ -302,10 +302,10 @@ host.ElectronHost = class { }); } else if (encoding) { - reject(new Error("The file '" + file + "' size (" + stats.size.toString() + ") for encoding '" + encoding + "' is greater than 2 GB.")); + reject(new Error("The file '" + file + "' size (" + stat.size.toString() + ") for encoding '" + encoding + "' is greater than 2 GB.")); } else { - resolve(new host.ElectronHost.FileStream(pathname, 0, stats.size, stats.mtimeMs)); + resolve(new host.ElectronHost.FileStream(pathname, 0, stat.size, stat.mtimeMs)); } }); }); @@ -366,21 +366,49 @@ host.ElectronHost = class { } } - _openFile(file) { + _context(location) { + const basename = path.basename(location); + const stat = fs.statSync(location); + if (stat.isFile()) { + const dirname = path.dirname(location); + return this.request(basename, null, dirname).then((stream) => { + return new host.ElectronHost.ElectonContext(this, dirname, basename, stream); + }); + } + else if (stat.isDirectory()) { + const entries = new Map(); + const walk = (dir) => { + for (const item of fs.readdirSync(dir)) { + const pathname = path.join(dir, item); + const stat = fs.statSync(pathname); + if (stat.isDirectory()) { + walk(pathname); + } + else if (stat.isFile()) { + const stream = new host.ElectronHost.FileStream(pathname, 0, stat.size, stat.mtimeMs); + const name = pathname.split(path.sep).join(path.posix.sep); + entries.set(name, stream); + } + } + }; + walk(location); + return Promise.resolve(new host.ElectronHost.ElectonContext(this, location, basename, null, entries)); + } + throw new Error("Unsupported path stat '" + JSON.stringify(stat) + "'."); + } + + _openPath(path) { if (this._queue) { - this._queue.push(file); + this._queue.push(path); return; } - if (file && this._view.accept(file)) { + if (path && this._view.accept(path)) { this._view.show('welcome spinner'); - const dirname = path.dirname(file); - const basename = path.basename(file); - this.request(basename, null, dirname).then((stream) => { - const context = new host.ElectronHost.ElectonContext(this, dirname, basename, stream); + this._context(path).then((context) => { this._view.open(context).then((model) => { this._view.show(null); if (model) { - this._update('path', file); + this._update('path', path); } this._update('show-attributes', this._view.showAttributes); this._update('show-initializers', this._view.showInitializers); @@ -683,11 +711,12 @@ host.ElectronHost.FileStream = class { host.ElectronHost.ElectonContext = class { - constructor(host, folder, identifier, stream) { + constructor(host, folder, identifier, stream, entries) { this._host = host; this._folder = folder; this._identifier = identifier; this._stream = stream; + this._entries = entries || new Map(); } get identifier() { @@ -698,6 +727,10 @@ host.ElectronHost.ElectonContext = class { return this._stream; } + get entries() { + return this._entries; + } + request(file, encoding, base) { return this._host.request(file, encoding, base === undefined ? this._folder : base); } diff --git a/source/gzip.js b/source/gzip.js index d36a9df9199..793826f009a 100644 --- a/source/gzip.js +++ b/source/gzip.js @@ -74,10 +74,6 @@ gzip.Entry = class { get stream() { return this._stream; } - - get data() { - return this.stream.peek(); - } }; gzip.InflaterStream = class { diff --git a/source/mlnet.js b/source/mlnet.js index 30cc5197a03..63ec5d1c9ca 100644 --- a/source/mlnet.js +++ b/source/mlnet.js @@ -9,9 +9,9 @@ mlnet.ModelFactory = class { match(context) { const entries = context.entries('zip'); - if (entries.length > 0) { + if (entries.size > 0) { const root = new Set([ 'TransformerChain', 'Predictor']); - if (entries.some((e) => root.has(e.name.split('\\').shift().split('/').shift()))) { + if (Array.from(entries.keys()).some((name) => root.has(name.split('\\').shift().split('/').shift()))) { return true; } } @@ -20,7 +20,8 @@ mlnet.ModelFactory = class { open(context) { return mlnet.Metadata.open(context).then((metadata) => { - const reader = new mlnet.ModelReader(context.entries('zip')); + const entries = context.entries('zip'); + const reader = new mlnet.ModelReader(entries); return new mlnet.Model(metadata, reader); }); } @@ -614,10 +615,11 @@ mlnet.ModelHeader = class { open(name) { const dir = this._directory.length > 0 ? this._directory + '/' : this._directory; name = dir + name; - const entryName = name + '/Model.key'; - const entry = this._entries.find((entry) => entry.name == entryName || entry.name == entryName.replace(/\//g, '\\')); - if (entry) { - const context = new mlnet.ModelHeader(this._catalog, this._entries, name, entry.data); + const key = name + '/Model.key'; + const stream = this._entries.get(key) || this._entries.get(key.replace(/\//g, '\\')); + if (stream) { + const buffer = stream.peek(); + const context = new mlnet.ModelHeader(this._catalog, this._entries, name, buffer); const value = this._catalog.create(context.loaderSignature, context); value.__type__ = value.__type__ || context.loaderSignature; value.__name__ = name; @@ -629,16 +631,22 @@ mlnet.ModelHeader = class { openBinary(name) { const dir = this._directory.length > 0 ? this._directory + '/' : this._directory; name = dir + name; - const entry = this._entries.find((entry) => entry.name == name || entry.name == name.replace(/\//g, '\\')); - return entry ? new mlnet.Reader(entry.data) : null; + const stream = this._entries.get(name) || this._entries.get(name.replace(/\//g, '\\')); + if (stream) { + const buffer = stream.peek(); + return new mlnet.Reader(buffer); + } + return null; } openText(name) { const dir = this._directory.length > 0 ? this._directory + '/' : this._directory; name = dir + name; - const entry = this._entries.find((entry) => entry.name.split('\\').join('/') == name); - if (entry) { - return new TextDecoder().decode(entry.data); + const stream = this._entries.get(name) || this._entries.get(name.replace(/\//g, '\\')); + if (stream) { + const buffer = stream.peek(); + const decoder = new TextDecoder(); + return decoder.decode(buffer); } return null; } diff --git a/source/npz.js b/source/npz.js index cd05607d68f..6f2f56834d5 100644 --- a/source/npz.js +++ b/source/npz.js @@ -52,11 +52,12 @@ npz.ModelFactory = class { case 'npz': { format = 'NumPy Zip'; const execution = new python.Execution(null); - for (const entry of context.entries('zip')) { - if (!entry.name.endsWith('.npy')) { + const entries = context.entries('zip'); + for (const entry of entries) { + if (!entry[0].endsWith('.npy')) { throw new npz.Error("Invalid file name '" + entry.name + "'."); } - const name = entry.name.replace(/\.npy$/, ''); + const name = entry[0].replace(/\.npy$/, ''); const parts = name.split('/'); const parameterName = parts.pop(); const groupName = parts.join('/'); @@ -64,8 +65,9 @@ npz.ModelFactory = class { groups.set(groupName, { name: groupName, parameters: [] }); } const group = groups.get(groupName); - const data = entry.data; - let array = new numpy.Array(data); + const stream = entry[1]; + const buffer = stream.peek(); + let array = new numpy.Array(buffer); if (array.byteOrder === '|' && array.dataType !== 'u1' && array.dataType !== 'i1') { if (array.dataType !== 'O') { throw new npz.Error("Invalid data type '" + array.dataType + "'."); @@ -494,7 +496,7 @@ npz.Utility = class { return 'npy'; } const entries = context.entries('zip'); - if (entries.length > 0 && entries.every((entry) => entry.name.endsWith('.npy'))) { + if (entries.size > 0 && Array.from(entries.keys()).every((name) => name.endsWith('.npy'))) { return 'npz'; } const obj = context.open('pkl'); diff --git a/source/paddle.js b/source/paddle.js index 1a7edaa39e9..7302eb279b2 100644 --- a/source/paddle.js +++ b/source/paddle.js @@ -728,10 +728,10 @@ paddle.Utility = class { paddle.Container = class { static open(context) { - const extension = [ 'zip', 'tar' ].find((extension) => context.entries(extension).length > 0); + const extension = [ 'zip', 'tar' ].find((extension) => context.entries(extension).size > 0); if (extension) { - const entries = context.entries(extension).filter((entry) => !entry.name.endsWith('/') && !entry.name.split('/').pop().startsWith('.')).slice(); - if (entries.length > 2 && entries.every((entry) => entry.name.split('_').length > 0 && entry.data.slice(0, 16).every((value) => value === 0x00))) { + const entries = new Map(Array.from(context.entries(extension)).filter((entry) => !entry[0].endsWith('/') && !entry[0].split('/').pop().startsWith('.')).slice()); + if (entries.size > 2 && Array.from(entries).every((entry) => entry[0].split('_').length > 0 && entry[1].peek(16).every((value) => value === 0x00))) { return new paddle.Container('entries', entries); } } @@ -773,7 +773,7 @@ paddle.Container = class { case 'entries': { let rootFolder = null; for (const entry of this._data) { - const name = entry.name; + const name = entry[0]; if (name.startsWith('.') && !name.startsWith('./')) { continue; } @@ -783,9 +783,11 @@ paddle.Container = class { } this._weights = new Map(); for (const entry of this._data) { - if (entry.name.startsWith(rootFolder)) { - const name = entry.name.substring(rootFolder.length); - this._weights.set(name, new paddle.Tensor(null, entry.stream)); + if (entry[0].startsWith(rootFolder)) { + const name = entry[0].substring(rootFolder.length); + const stream = entry[1]; + const tensor = new paddle.Tensor(null, stream); + this._weights.set(name, tensor); } } break; diff --git a/source/pytorch.js b/source/pytorch.js index 34f8e2963f4..87beeef3b3d 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1911,8 +1911,9 @@ pytorch.Container = class { if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) { return new pytorch.Container.Pickle(stream, exception); } - if (context.entries('tar').some((entry) => entry.name == 'pickle')) { - return new pytorch.Container.Tar(context.entries('tar'), exception); + const entries = context.entries('tar'); + if (entries.has('pickle')) { + return new pytorch.Container.Tar(entries, exception); } return null; } @@ -1957,10 +1958,10 @@ pytorch.Container.Tar = class { const entries = {}; for (const entry of this._entries) { switch (entry.name) { - case 'sys_info': entries.sys_info = entry.data; break; - case 'pickle': entries.pickle = entry.data; break; - case 'storages': entries.storages = entry.data; break; - case 'tensors': entries.tensors = entry.data; break; + case 'sys_info': entries.sys_info = entry.stream.peek(); break; + case 'pickle': entries.pickle = entry.stream.peek(); break; + case 'storages': entries.storages = entry.stream.peek(); break; + case 'tensors': entries.tensors = entry.stream.peek(); break; } } @@ -2164,15 +2165,17 @@ pytorch.Container.Pickle = class { pytorch.Container.Zip = class { static open(entries, metadata, exception) { - const entry = entries.find((entry) => entry.name == 'model.json' || entry.name == 'data.pkl' || entry.name.endsWith('/model.json') || entry.name.endsWith('/data.pkl')); - if (!entry) { + const name = Array.from(entries.keys()).find((name) => name == 'model.json' || name == 'data.pkl' || name.endsWith('/model.json') || name.endsWith('/data.pkl')); + if (!name) { return null; } let model = null; - if (entry.name.endsWith('.json')) { + if (name.endsWith('.json')) { try { + const stream = entries.get(name); + const buffer = stream.peek(); const decoder = new TextDecoder('utf-8'); - const text = decoder.decode(entry.data); + const text = decoder.decode(buffer); model = JSON.parse(text); if (!model.mainModule) { return null; @@ -2182,17 +2185,17 @@ pytorch.Container.Zip = class { return null; } } - return new pytorch.Container.Zip(entries, entry, model, metadata, exception); + return new pytorch.Container.Zip(entries, name, model, metadata, exception); } - constructor(entries, entry, model, metadata, exception) { + constructor(entries, name, model, metadata, exception) { this._entries = entries; this._metadata = metadata; this._exceptionCallback = exception; // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md this._model = model; - const lastIndex = entry.name.lastIndexOf('/'); - this._prefix = lastIndex === -1 ? '' : entry.name.substring(0, lastIndex + 1); + const lastIndex = name.lastIndexOf('/'); + this._prefix = lastIndex === -1 ? '' : name.substring(0, lastIndex + 1); } get format() { @@ -2201,9 +2204,9 @@ pytorch.Container.Zip = class { this._format = this._entry('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0'; } else if (this._entry('data.pkl')) { - const versionEntry = this._entry('version'); + const stream = this._entry('version'); const decoder = new TextDecoder('utf-8'); - const versionNumber = versionEntry ? decoder.decode(versionEntry.data).split('\n').shift() : ''; + const versionNumber = stream ? decoder.decode(stream.peek()).split('\n').shift() : ''; // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h // kProducedFileFormatVersion const versionTable = { @@ -2249,9 +2252,10 @@ pytorch.Container.Zip = class { get constants() { if (this._constants === undefined) { this._constants = []; - const entry = this._entry('constants.pkl'); - if (entry && entry.data) { - this._constants = this._unpickle(entry.data, this._storage('constants')); + const stream = this._entry('constants.pkl'); + if (stream) { + const buffer = stream.peek(); + this._constants = this._unpickle(buffer, this._storage('constants')); for (let i = 0; i < this._constants.length; i++) { const constant = this._constants[i]; const variable = 'CONSTANTS.c' + i.toString(); @@ -2286,12 +2290,15 @@ pytorch.Container.Zip = class { if (this._execution === undefined) { const sources = new Map(); for (const entry of this._entries) { - if (entry.name.startsWith(this._prefix + 'code')) { - const file = entry.name.substring(this._prefix.length); + const name = entry[0]; + if (name.startsWith(this._prefix + 'code')) { + const file = name.substring(this._prefix.length); if (sources.has(file)) { throw new pytorch.Error("Duplicate source file '" + file + "'."); } - sources.set(file, entry.data); + const stream = entry[1]; + const buffer = stream.peek(); + sources.set(file, buffer); } } this._execution = new pytorch.Container.Zip.Execution(sources, this._exceptionCallback, this._metadata); @@ -2305,15 +2312,16 @@ pytorch.Container.Zip = class { } _entry(name) { - return this._entries.find((entry) => entry.name == this._prefix + name); + return this._entries.get(this._prefix + name); } _load() { if (this._data === undefined) { this._data = null; - const dataEntry = this._entry('data.pkl'); - if (dataEntry && dataEntry.data) { - this._data = this._unpickle(dataEntry.data, this._storage('data')); + const stream = this._entry('data.pkl'); + if (stream) { + const buffer = stream.peek(); + this._data = this._unpickle(buffer, this._storage('data')); } else { if (this._model) { @@ -2326,7 +2334,10 @@ pytorch.Container.Zip = class { const queue = [ this._data ]; const entries = new Map(); for (const entry of this._entries) { - entries.set(entry.name, entry.data); + const name = entry[0]; + const stream = entry[1]; + const buffer = stream.peek(); + entries.set(name, buffer); } const tensorTypeMap = new Map([ [ 'FLOAT', 'Float' ], @@ -2358,9 +2369,11 @@ pytorch.Container.Zip = class { return tensor; }); this._attributes = []; - const attributesEntry = this._entry('attributes.pkl'); - if (attributesEntry && attributesEntry.data) { - this._attributes.push(...new python.Unpickler(attributesEntry.data).load((name, args) => this.execution.invoke(name, args))); + const stream = this._entry('attributes.pkl'); + if (stream) { + const buffer = stream.peek(); + const unpickler = new python.Unpickler(buffer); + this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args))); } while (queue.length > 0) { const module = queue.shift(); @@ -2477,9 +2490,10 @@ pytorch.Container.Zip = class { const map = new Map(); const prefix = this._prefix + dirname + '/'; for (const entry of this._entries) { - if (entry.name.startsWith(prefix)) { - const key = entry.name.substring(prefix.length); - map.set(key, entry.data); + if (entry[0].startsWith(prefix)) { + const key = entry[0].substring(prefix.length); + const buffer = entry[1].peek(); + map.set(key, buffer); } } return map; diff --git a/source/tar.js b/source/tar.js index fcd4c14f0cc..ba51969c286 100644 --- a/source/tar.js +++ b/source/tar.js @@ -88,10 +88,6 @@ tar.Entry = class { get stream() { return this._stream; } - - get data() { - return this.stream.peek(); - } }; tar.BinaryReader = class { diff --git a/source/view.js b/source/view.js index c4fef80e4a8..2f5aec89c85 100644 --- a/source/view.js +++ b/source/view.js @@ -419,7 +419,7 @@ view.View = class { } open(context) { - this._host.event('Model', 'Open', 'Size', context.stream.length); + this._host.event('Model', 'Open', 'Size', context.stream ? context.stream.length : 0); this._sidebar.close(); return this._timeout(2).then(() => { return this._modelFactoryService.open(context).then((model) => { @@ -1186,11 +1186,11 @@ view.Edge = class extends grapher.Edge { view.ModelContext = class { - constructor(context, entries) { + constructor(context, formats) { this._context = context; this._tags = new Map(); this._content = new Map(); - this._entries = entries || new Map(); + this._formats = formats || new Map(); } get identifier() { @@ -1214,7 +1214,7 @@ view.ModelContext = class { } entries(format) { - return this._entries.get(format) || []; + return this._formats.get(format) || new Map(); } open(type) { @@ -1289,73 +1289,75 @@ view.ModelContext = class { [ 0x50, 0x4b ] ]; const stream = this.stream; - if (!signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value))) { - try { - switch (type) { - case 'pbtxt': { - reset = true; - const buffer = stream.peek(); - const decoder = base.TextDecoder.create(buffer); - let count = 0; - for (let i = 0; i < 0x100; i++) { - const c = decoder.decode(); - switch (c) { - case '\n': case '\r': case '\t': case '\0': break; - case undefined: i = 0x100; break; - default: count += c < ' ' ? 1 : 0; break; - } - } - if (count < 4) { + if (stream) { + if (!signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value))) { + try { + switch (type) { + case 'pbtxt': { + reset = true; const buffer = stream.peek(); - const reader = protobuf.TextReader.create(buffer); - reader.start(false); - while (!reader.end(false)) { - const tag = reader.tag(); - tags.set(tag, true); - if (reader.token() === '{') { - reader.start(); - while (!reader.end()) { - const subtag = reader.tag(); - tags.set(tag + '.' + subtag, true); + const decoder = base.TextDecoder.create(buffer); + let count = 0; + for (let i = 0; i < 0x100; i++) { + const c = decoder.decode(); + switch (c) { + case '\n': case '\r': case '\t': case '\0': break; + case undefined: i = 0x100; break; + default: count += c < ' ' ? 1 : 0; break; + } + } + if (count < 4) { + const buffer = stream.peek(); + const reader = protobuf.TextReader.create(buffer); + reader.start(false); + while (!reader.end(false)) { + const tag = reader.tag(); + tags.set(tag, true); + if (reader.token() === '{') { + reader.start(); + while (!reader.end()) { + const subtag = reader.tag(); + tags.set(tag + '.' + subtag, true); + reader.skip(); + reader.match(','); + } + } + else { reader.skip(); - reader.match(','); } } - else { - reader.skip(); - } } + break; } - break; - } - case 'pb': { - reset = true; - const buffer = stream.peek(); - const reader = protobuf.Reader.create(buffer); - const length = reader.length; - while (reader.position < length) { - const tag = reader.uint32(); - const number = tag >>> 3; - const type = tag & 7; - if (type > 5 || number === 0) { - tags = new Map(); - break; - } - tags.set(number, type); - try { - reader.skipType(type); - } - catch (err) { - tags = new Map(); - break; + case 'pb': { + reset = true; + const buffer = stream.peek(); + const reader = protobuf.Reader.create(buffer); + const length = reader.length; + while (reader.position < length) { + const tag = reader.uint32(); + const number = tag >>> 3; + const type = tag & 7; + if (type > 5 || number === 0) { + tags = new Map(); + break; + } + tags.set(number, type); + try { + reader.skipType(type); + } + catch (err) { + tags = new Map(); + break; + } } + break; } - break; } } - } - catch (error) { - tags = new Map(); + catch (error) { + tags = new Map(); + } } } if (reset) { @@ -1371,12 +1373,12 @@ view.ArchiveContext = class { constructor(host, entries, rootFolder, identifier, stream) { this._host = host; - this._entries = {}; + this._entries = new Map(); if (entries) { for (const entry of entries) { - if (entry.name.startsWith(rootFolder)) { - const name = entry.name.substring(rootFolder.length); - this._entries[name] = entry; + if (entry[0].startsWith(rootFolder)) { + const name = entry[0].substring(rootFolder.length); + this._entries.set(name, entry[1]); } } } @@ -1394,11 +1396,17 @@ view.ArchiveContext = class { request(file, encoding, base) { if (base === undefined) { - const entry = this._entries[file]; - if (!entry) { + const stream = this._entries.get(file); + if (!stream) { return Promise.reject(new Error('File not found.')); } - return Promise.resolve(encoding ? new TextDecoder(encoding).decode(entry.data) : entry.stream); + if (encoding) { + const decoder = new TextDecoder(encoding); + const buffer = stream.peek(); + const value = decoder.decode(buffer); + return Promise.resolve(value); + } + return Promise.resolve(stream); } return this._host.request(file, encoding, base); } @@ -1428,7 +1436,7 @@ view.ModelFactoryService = class { this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pyt', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel' ]); this.register('./onnx', [ '.onnx', '.onn', '.pb', '.pbtxt', '.prototxt', '.model', '.pt', '.pth', '.pkl' ]); this.register('./mxnet', [ '.json', '.params' ]); - this.register('./coreml', [ '.mlmodel', 'manifest.json', 'metadata.json', 'featuredescriptions.json' ]); + this.register('./coreml', [ '.mlmodel', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.mlpackage' ]); this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt' ]); this.register('./caffe2', [ '.pb', '.pbtxt', '.prototxt' ]); this.register('./torch', [ '.t7' ]); @@ -1475,14 +1483,56 @@ view.ModelFactoryService = class { open(context) { return this._openSignature(context).then((context) => { - const entries = this._openArchive(context); - const modelContext = new view.ModelContext(context, entries); + const containers = new Map(); + let stream = context.stream; + const entries = context.entries; + if (!stream && entries && entries.size > 0) { + containers.set('', entries); + } + else { + const identifier = context.identifier; + try { + const archive = gzip.Archive.open(stream); + if (archive) { + const entries = new Map(archive.entries.map((entry) => [ entry.name, entry.stream ])); + containers.set('gzip', entries); + if (archive.entries.length === 1) { + const entry = archive.entries[0]; + stream = entry.stream; + } + } + } + catch (error) { + const message = error && error.message ? error.message : error.toString(); + throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'."); + } + try { + const formats = new Map([ [ 'zip', zip ], [ 'tar', tar ] ]); + for (const pair of formats) { + const format = pair[0]; + const module = pair[1]; + const archive = module.Archive.open(stream); + if (archive) { + const entries = new Map(archive.entries.map((entry) => [ entry.name, entry.stream ])); + containers.set(format, entries); + containers.delete('gzip'); + break; + } + } + } + catch (error) { + const message = error && error.message ? error.message : error.toString(); + throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'."); + } + } + + const modelContext = new view.ModelContext(context, containers); return this._openContext(modelContext).then((model) => { if (model) { return model; } - if (entries.size > 0) { - return this._openEntries(entries.values().next().value).then((context) => { + if (containers.size > 0) { + return this._openEntries(containers.values().next().value).then((context) => { if (context) { return this._openContext(context); } @@ -1597,44 +1647,6 @@ view.ModelFactoryService = class { throw new view.Error("Unsupported file content " + content + " for extension '." + extension + "' in '" + identifier + "'.", !skip); } - _openArchive(context) { - const entries = new Map(); - let stream = context.stream; - const identifier = context.identifier; - try { - const archive = gzip.Archive.open(stream); - if (archive) { - entries.set('gzip', archive.entries); - if (archive.entries.length === 1) { - const entry = archive.entries[0]; - stream = entry.stream; - } - } - } - catch (error) { - const message = error && error.message ? error.message : error.toString(); - throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'."); - } - try { - const formats = new Map([ [ 'zip', zip ], [ 'tar', tar ] ]); - for (const pair of formats) { - const format = pair[0]; - const module = pair[1]; - const archive = module.Archive.open(stream); - if (archive) { - entries.set(format, archive.entries); - entries.delete('gzip'); - break; - } - } - } - catch (error) { - const message = error && error.message ? error.message : error.toString(); - throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'."); - } - return entries; - } - _openContext(context) { const modules = this._filter(context).filter((module) => module && module.length > 0); const errors = []; @@ -1750,7 +1762,10 @@ view.ModelFactoryService = class { }; return nextEntry(); }; - const files = entries.filter((entry) => { + const list = Array.from(entries).map((entry) => { + return { name: entry[0], stream: entry[1] }; + }); + const files = list.filter((entry) => { if (entry.name.endsWith('/')) { return false; } @@ -1811,45 +1826,47 @@ view.ModelFactoryService = class { _openSignature(context) { const stream = context.stream; - let empty = true; - let position = 0; - while (empty && position < stream.length) { - const buffer = stream.read(Math.min(4096, stream.length - position)); - position += buffer.length; - if (!buffer.every((value) => value === 0x00)) { - empty = false; - break; + if (stream) { + let empty = true; + let position = 0; + while (empty && position < stream.length) { + const buffer = stream.read(Math.min(4096, stream.length - position)); + position += buffer.length; + if (!buffer.every((value) => value === 0x00)) { + empty = false; + break; + } } - } - stream.seek(0); - if (empty) { - return Promise.reject(new view.Error('File has no content.', true)); - } - /* eslint-disable no-control-regex */ - const entries = [ - { name: 'ELF executable', value: /^\x7FELF/ }, - { name: 'PNG image', value: /^\x89PNG/ }, - { name: 'Git LFS header', value: /^version https:\/\/git-lfs.github.com/ }, - { name: 'Git LFS header', value: /^\s*oid sha256:/ }, - { name: 'HTML markup', value: /^\s*/ }, - { name: 'HTML markup', value: /^\s*/ }, - { name: 'HTML markup', value: /^\s*/ }, - { name: 'HTML markup', value: /^\s*/ }, - { name: 'HTML markup', value: /^\s*/ }, + { name: 'HTML markup', value: /^\s*/ }, + { name: 'HTML markup', value: /^\s*/ }, + { name: 'HTML markup', value: /^\s*/ }, + { name: 'HTML markup', value: /^\s* entry.name == file)[0]; - if (!entry) { - throw new Error("Entry not found '" + file + '. Archive contains entries: ' + JSON.stringify(archive.entries.map((entry) => entry.name)) + " ."); + process.stdout.write(' write ' + item + '\n'); + if (item !== '.') { + const entry = archive.entries.filter((entry) => entry.name == item)[0]; + if (!entry) { + throw new Error("Entry not found '" + item + '. Archive contains entries: ' + JSON.stringify(archive.entries.map((entry) => entry.name)) + " ."); + } + const target = targets.shift(); + const buffer = entry.stream.peek(); + const file = path.join(folder, target); + fs.writeFileSync(file, buffer, null); + } + else { + const target = targets.shift(); + const dir = path.join(folder, target); + if (!fs.existsSync(dir)) { + fs.mkdirSync(dir); + } } - const target = targets.shift(); - fs.writeFileSync(folder + '/' + target, entry.data, null); } } else { @@ -547,15 +568,35 @@ function loadModel(target, item) { host.on('exception', (_, data) => { exceptions.push(data.exception); }); - const folder = path.dirname(target); const identifier = path.basename(target); - const size = fs.statSync(target).size; - const buffer = new Uint8Array(size); - const fd = fs.openSync(target, 'r'); - fs.readSync(fd, buffer, 0, size, 0); - fs.closeSync(fd); - const reader = new TestBinaryStream(buffer); - const context = new TestContext(host, folder, identifier, reader); + const stat = fs.statSync(target); + let context = null; + if (stat.isFile()) { + const buffer = fs.readFileSync(target, null); + const reader = new TestBinaryStream(buffer); + const dirname = path.dirname(target); + context = new TestContext(host, dirname, identifier, reader); + } + else if (stat.isDirectory()) { + const entries = new Map(); + const walk = (dir) => { + for (const item of fs.readdirSync(dir)) { + const pathname = path.join(dir, item); + const stat = fs.statSync(pathname); + if (stat.isDirectory()) { + walk(pathname); + } + else if (stat.isFile()) { + const buffer = fs.readFileSync(pathname, null); + const stream = new TestBinaryStream(buffer); + const name = pathname.split(path.sep).join(path.posix.sep); + entries.set(name, stream); + } + } + }; + walk(target); + context = new TestContext(host, target, identifier, null, entries); + } const modelFactoryService = new view.ModelFactoryService(host); let opened = false; return modelFactoryService.open(context).then((model) => { diff --git a/test/models.json b/test/models.json index 41046934042..7c4891d0d22 100644 --- a/test/models.json +++ b/test/models.json @@ -1100,6 +1100,13 @@ "format": "Core ML v3", "link": "https://developer.apple.com/machine-learning/models" }, + { + "type": "coreml", + "target": "EfficientNetB0.mlpackage,EfficientNetB0.mlpackage/Manifest.json,EfficientNetB0.mlpackage/Data/com.apple.CoreML/model.mlmodel,EfficientNetB0.mlpackage/Data/com.apple.CoreML/weights/weight.bin", + "source": "https://github.com/lutzroeder/netron/files/6636195/EfficientNetB0.mlpackage.zip[.,EfficientNetB0.mlpackage/Manifest.json,EfficientNetB0.mlpackage/Data/com.apple.CoreML/model.mlmodel,EfficientNetB0.mlpackage/Data/com.apple.CoreML/weights/weight.bin]", + "format": "Core ML Package v6", + "link": "https://github.com/lutzroeder/netron/issues/751" + }, { "type": "coreml", "target": "EfficientNetB0.mlpackage.zip",