Skip to content

Commit

Permalink
Folder support (#751)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 13, 2021
1 parent d2d86bf commit 43ebb5f
Show file tree
Hide file tree
Showing 15 changed files with 380 additions and 271 deletions.
8 changes: 4 additions & 4 deletions source/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Application {
const paths = data.paths.filter((path) => {
if (fs.existsSync(path)) {
const stat = fs.statSync(path);
return stat.isFile() /* || stat.isDirectory() */;
return stat.isFile() || stat.isDirectory();
}
return false;
});
Expand Down Expand Up @@ -109,7 +109,7 @@ class Application {
const extension = arg.split('.').pop().toLowerCase();
if (extension != '' && extension != 'js' && fs.existsSync(arg)) {
const stat = fs.statSync(arg);
if (stat.isFile() /* || stat.isDirectory() */) {
if (stat.isFile() || stat.isDirectory()) {
this._openPath(arg);
open = true;
}
Expand Down Expand Up @@ -191,7 +191,7 @@ class Application {
}
if (path && path.length > 0 && fs.existsSync(path)) {
const stat = fs.statSync(path);
if (stat.isFile() /* || stat.isDirectory() */) {
if (stat.isFile() || stat.isDirectory()) {
// find existing view for this file
let view = this._views.find(path);
// find empty welcome window
Expand Down Expand Up @@ -390,7 +390,7 @@ class Application {
const path = recent.path;
if (fs.existsSync(path)) {
const stat = fs.statSync(path);
if (stat.isFile() /* || stat.isDirectory() */) {
if (stat.isFile() || stat.isDirectory()) {
return true;
}
}
Expand Down
10 changes: 5 additions & 5 deletions source/dl4j.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ dl4j.ModelFactory = class {
}

static _openContainer(entries) {
const configurationEntries = entries.filter((entry) => entry.name === 'configuration.json');
const coefficientsEntries = entries.filter((entry) => entry.name === 'coefficients.bin');
if (configurationEntries.length === 1 && coefficientsEntries.length <= 1) {
const configurationStream = entries.get('configuration.json');
const coefficientsStream = entries.get('coefficients.bin');
if (configurationStream) {
try {
const reader = json.TextReader.create(configurationEntries[0].data);
const reader = json.TextReader.create(configurationStream.peek());
const configuration = reader.read();
if (configuration && (configuration.confs || configuration.vertices)) {
const coefficients = coefficientsEntries.length == 1 ? coefficientsEntries[0].data : [];
const coefficients = coefficientsStream ? coefficientsStream.peek() : [];
return { configuration: configuration, coefficients: coefficients };
}
}
Expand Down
2 changes: 1 addition & 1 deletion source/dlc.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dlc.ModelFactory = class {

match(context) {
const entries = context.entries('zip');
if (entries.find((entry) => entry.name === 'model')) {
if (entries.has('model')) {
return true;
}
return false;
Expand Down
67 changes: 50 additions & 17 deletions source/electron.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ host.ElectronHost = class {
initialize(view) {
this._view = view;
electron.ipcRenderer.on('open', (_, data) => {
this._openFile(data.path);
this._openPath(data.path);
});
return new Promise((resolve /*, reject */) => {
const accept = () => {
Expand Down Expand Up @@ -117,8 +117,8 @@ host.ElectronHost = class {
const queue = this._queue;
delete this._queue;
if (queue.length > 0) {
const file = queue.pop();
this._openFile(file);
const path = queue.pop();
this._openPath(path);
}
}

Expand Down Expand Up @@ -281,17 +281,17 @@ host.ElectronHost = class {
request(file, encoding, base) {
return new Promise((resolve, reject) => {
const pathname = path.join(base || __dirname, file);
fs.stat(pathname, (err, stats) => {
fs.stat(pathname, (err, stat) => {
if (err && err.code === 'ENOENT') {
reject(new Error("The file '" + file + "' does not exist."));
}
else if (err) {
reject(err);
}
else if (!stats.isFile()) {
else if (!stat.isFile()) {
reject(new Error("The path '" + file + "' is not a file."));
}
else if (stats && stats.size < 0x7ffff000) {
else if (stat && stat.size < 0x7ffff000) {
fs.readFile(pathname, encoding, (err, data) => {
if (err) {
reject(err);
Expand All @@ -302,10 +302,10 @@ host.ElectronHost = class {
});
}
else if (encoding) {
reject(new Error("The file '" + file + "' size (" + stats.size.toString() + ") for encoding '" + encoding + "' is greater than 2 GB."));
reject(new Error("The file '" + file + "' size (" + stat.size.toString() + ") for encoding '" + encoding + "' is greater than 2 GB."));
}
else {
resolve(new host.ElectronHost.FileStream(pathname, 0, stats.size, stats.mtimeMs));
resolve(new host.ElectronHost.FileStream(pathname, 0, stat.size, stat.mtimeMs));
}
});
});
Expand Down Expand Up @@ -366,21 +366,49 @@ host.ElectronHost = class {
}
}

_openFile(file) {
_context(location) {
const basename = path.basename(location);
const stat = fs.statSync(location);
if (stat.isFile()) {
const dirname = path.dirname(location);
return this.request(basename, null, dirname).then((stream) => {
return new host.ElectronHost.ElectonContext(this, dirname, basename, stream);
});
}
else if (stat.isDirectory()) {
const entries = new Map();
const walk = (dir) => {
for (const item of fs.readdirSync(dir)) {
const pathname = path.join(dir, item);
const stat = fs.statSync(pathname);
if (stat.isDirectory()) {
walk(pathname);
}
else if (stat.isFile()) {
const stream = new host.ElectronHost.FileStream(pathname, 0, stat.size, stat.mtimeMs);
const name = pathname.split(path.sep).join(path.posix.sep);
entries.set(name, stream);
}
}
};
walk(location);
return Promise.resolve(new host.ElectronHost.ElectonContext(this, location, basename, null, entries));
}
throw new Error("Unsupported path stat '" + JSON.stringify(stat) + "'.");
}

_openPath(path) {
if (this._queue) {
this._queue.push(file);
this._queue.push(path);
return;
}
if (file && this._view.accept(file)) {
if (path && this._view.accept(path)) {
this._view.show('welcome spinner');
const dirname = path.dirname(file);
const basename = path.basename(file);
this.request(basename, null, dirname).then((stream) => {
const context = new host.ElectronHost.ElectonContext(this, dirname, basename, stream);
this._context(path).then((context) => {
this._view.open(context).then((model) => {
this._view.show(null);
if (model) {
this._update('path', file);
this._update('path', path);
}
this._update('show-attributes', this._view.showAttributes);
this._update('show-initializers', this._view.showInitializers);
Expand Down Expand Up @@ -683,11 +711,12 @@ host.ElectronHost.FileStream = class {

host.ElectronHost.ElectonContext = class {

constructor(host, folder, identifier, stream) {
constructor(host, folder, identifier, stream, entries) {
this._host = host;
this._folder = folder;
this._identifier = identifier;
this._stream = stream;
this._entries = entries || new Map();
}

get identifier() {
Expand All @@ -698,6 +727,10 @@ host.ElectronHost.ElectonContext = class {
return this._stream;
}

get entries() {
return this._entries;
}

request(file, encoding, base) {
return this._host.request(file, encoding, base === undefined ? this._folder : base);
}
Expand Down
4 changes: 0 additions & 4 deletions source/gzip.js
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ gzip.Entry = class {
get stream() {
return this._stream;
}

get data() {
return this.stream.peek();
}
};

gzip.InflaterStream = class {
Expand Down
32 changes: 20 additions & 12 deletions source/mlnet.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ mlnet.ModelFactory = class {

match(context) {
const entries = context.entries('zip');
if (entries.length > 0) {
if (entries.size > 0) {
const root = new Set([ 'TransformerChain', 'Predictor']);
if (entries.some((e) => root.has(e.name.split('\\').shift().split('/').shift()))) {
if (Array.from(entries.keys()).some((name) => root.has(name.split('\\').shift().split('/').shift()))) {
return true;
}
}
Expand All @@ -20,7 +20,8 @@ mlnet.ModelFactory = class {

open(context) {
return mlnet.Metadata.open(context).then((metadata) => {
const reader = new mlnet.ModelReader(context.entries('zip'));
const entries = context.entries('zip');
const reader = new mlnet.ModelReader(entries);
return new mlnet.Model(metadata, reader);
});
}
Expand Down Expand Up @@ -614,10 +615,11 @@ mlnet.ModelHeader = class {
open(name) {
const dir = this._directory.length > 0 ? this._directory + '/' : this._directory;
name = dir + name;
const entryName = name + '/Model.key';
const entry = this._entries.find((entry) => entry.name == entryName || entry.name == entryName.replace(/\//g, '\\'));
if (entry) {
const context = new mlnet.ModelHeader(this._catalog, this._entries, name, entry.data);
const key = name + '/Model.key';
const stream = this._entries.get(key) || this._entries.get(key.replace(/\//g, '\\'));
if (stream) {
const buffer = stream.peek();
const context = new mlnet.ModelHeader(this._catalog, this._entries, name, buffer);
const value = this._catalog.create(context.loaderSignature, context);
value.__type__ = value.__type__ || context.loaderSignature;
value.__name__ = name;
Expand All @@ -629,16 +631,22 @@ mlnet.ModelHeader = class {
openBinary(name) {
const dir = this._directory.length > 0 ? this._directory + '/' : this._directory;
name = dir + name;
const entry = this._entries.find((entry) => entry.name == name || entry.name == name.replace(/\//g, '\\'));
return entry ? new mlnet.Reader(entry.data) : null;
const stream = this._entries.get(name) || this._entries.get(name.replace(/\//g, '\\'));
if (stream) {
const buffer = stream.peek();
return new mlnet.Reader(buffer);
}
return null;
}

openText(name) {
const dir = this._directory.length > 0 ? this._directory + '/' : this._directory;
name = dir + name;
const entry = this._entries.find((entry) => entry.name.split('\\').join('/') == name);
if (entry) {
return new TextDecoder().decode(entry.data);
const stream = this._entries.get(name) || this._entries.get(name.replace(/\//g, '\\'));
if (stream) {
const buffer = stream.peek();
const decoder = new TextDecoder();
return decoder.decode(buffer);
}
return null;
}
Expand Down
14 changes: 8 additions & 6 deletions source/npz.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,22 @@ npz.ModelFactory = class {
case 'npz': {
format = 'NumPy Zip';
const execution = new python.Execution(null);
for (const entry of context.entries('zip')) {
if (!entry.name.endsWith('.npy')) {
const entries = context.entries('zip');
for (const entry of entries) {
if (!entry[0].endsWith('.npy')) {
throw new npz.Error("Invalid file name '" + entry.name + "'.");
}
const name = entry.name.replace(/\.npy$/, '');
const name = entry[0].replace(/\.npy$/, '');
const parts = name.split('/');
const parameterName = parts.pop();
const groupName = parts.join('/');
if (!groups.has(groupName)) {
groups.set(groupName, { name: groupName, parameters: [] });
}
const group = groups.get(groupName);
const data = entry.data;
let array = new numpy.Array(data);
const stream = entry[1];
const buffer = stream.peek();
let array = new numpy.Array(buffer);
if (array.byteOrder === '|' && array.dataType !== 'u1' && array.dataType !== 'i1') {
if (array.dataType !== 'O') {
throw new npz.Error("Invalid data type '" + array.dataType + "'.");
Expand Down Expand Up @@ -494,7 +496,7 @@ npz.Utility = class {
return 'npy';
}
const entries = context.entries('zip');
if (entries.length > 0 && entries.every((entry) => entry.name.endsWith('.npy'))) {
if (entries.size > 0 && Array.from(entries.keys()).every((name) => name.endsWith('.npy'))) {
return 'npz';
}
const obj = context.open('pkl');
Expand Down
16 changes: 9 additions & 7 deletions source/paddle.js
Original file line number Diff line number Diff line change
Expand Up @@ -728,10 +728,10 @@ paddle.Utility = class {
paddle.Container = class {

static open(context) {
const extension = [ 'zip', 'tar' ].find((extension) => context.entries(extension).length > 0);
const extension = [ 'zip', 'tar' ].find((extension) => context.entries(extension).size > 0);
if (extension) {
const entries = context.entries(extension).filter((entry) => !entry.name.endsWith('/') && !entry.name.split('/').pop().startsWith('.')).slice();
if (entries.length > 2 && entries.every((entry) => entry.name.split('_').length > 0 && entry.data.slice(0, 16).every((value) => value === 0x00))) {
const entries = new Map(Array.from(context.entries(extension)).filter((entry) => !entry[0].endsWith('/') && !entry[0].split('/').pop().startsWith('.')).slice());
if (entries.size > 2 && Array.from(entries).every((entry) => entry[0].split('_').length > 0 && entry[1].peek(16).every((value) => value === 0x00))) {
return new paddle.Container('entries', entries);
}
}
Expand Down Expand Up @@ -773,7 +773,7 @@ paddle.Container = class {
case 'entries': {
let rootFolder = null;
for (const entry of this._data) {
const name = entry.name;
const name = entry[0];
if (name.startsWith('.') && !name.startsWith('./')) {
continue;
}
Expand All @@ -783,9 +783,11 @@ paddle.Container = class {
}
this._weights = new Map();
for (const entry of this._data) {
if (entry.name.startsWith(rootFolder)) {
const name = entry.name.substring(rootFolder.length);
this._weights.set(name, new paddle.Tensor(null, entry.stream));
if (entry[0].startsWith(rootFolder)) {
const name = entry[0].substring(rootFolder.length);
const stream = entry[1];
const tensor = new paddle.Tensor(null, stream);
this._weights.set(name, tensor);
}
}
break;
Expand Down
Loading

0 comments on commit 43ebb5f

Please sign in to comment.