Skip to content

Commit

Permalink
Update pytorch.js (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jan 9, 2023
1 parent fd65c7d commit 5052781
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 76 deletions.
8 changes: 5 additions & 3 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -3862,7 +3862,9 @@ python.Execution = class {
if (value instanceof torch.nn.modules.module.Module) {
this._modules.set(name, value);
}
this[name] = value;
else {
this[name] = value;
}
}
__getattr__(name) {
if (this._modules.has(name)) {
Expand Down Expand Up @@ -5465,8 +5467,8 @@ python.Execution = class {
if (buffer) {
const debug = this.debug(file);
const code = this._utf8Decoder.decode(buffer);
const reader = new python.Parser(code, file, debug);
const program = reader.parse();
const parser = new python.Parser(code, file, debug);
const program = parser.parse();
if (!program) {
throw new python.Error("Module '" + file + "' parse error.");
}
Expand Down
178 changes: 105 additions & 73 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ pytorch.Execution = class extends python.Execution {
super(sources);
const execution = this;
const torch = this.register('torch');
const pickle = this.register('pickle');
this.register('torch.jit._script');
this.register('torch.jit._trace');
this.registerType('torch.package.PackageImporter', class {
Expand All @@ -1144,7 +1145,7 @@ pytorch.Execution = class extends python.Execution {
const stream = this.zip_reader.getRecord(name);
const loaded_reduces = new Map();
this.storage_context = new torch._C.DeserializationStorageContext();
const unpickler = execution.invoke('pickle.Unpickler', [ stream ]);
const unpickler = new pickle.Unpickler(stream);
unpickler.persistent_load = (saved_id) => {
const typename = saved_id.shift();
switch (typename) {
Expand Down Expand Up @@ -1241,7 +1242,7 @@ pytorch.Execution = class extends python.Execution {
});
this.registerFunction('torch.jit._script.wrap_cpp_module', function(cpp_module) {
const init_fn = (script_module) => {
for (const entry of new torch.ModuleDict(script_module._c)) {
for (const entry of new torch.ModuleDict(script_module._c).items()) {
script_module.__setattr__(entry[0], torch.jit._script.wrap_cpp_module(entry[1]));
}
};
Expand Down Expand Up @@ -1286,6 +1287,9 @@ pytorch.Execution = class extends python.Execution {
hasConstant(/* name */) {
// TODO
}
methods() {
// TODO
}
});
this.registerType('torch.TupleType', class extends torch.Type {
constructor(/* elements, name, schema */) {
Expand All @@ -1300,40 +1304,36 @@ pytorch.Execution = class extends python.Execution {
}
});
this.registerType('torch.ScriptMethod', class {
constructor(owner, fn) {
this._function = fn;
constructor(owner, value) {
this._owner = owner;
this._function = value;
}
get name() {
return this._function.name();
}
get owner() {
return this._owner;
}
__call__(/* args, kwargs */) {
throw new pytorch.Error();
}
graph() {
get graph() {
return this._function.graph();
}
get schema() {
// return this.function().getSchema();
throw new pytorch.Error();
}
get name() {
return this._function.name();
}
get code() {
throw new pytorch.Error();
}
get code_with_constants() {
throw new pytorch.Error();
}
get owner() {
throw new pytorch.Error();
}
});
this.registerType('torch.ScriptObject', class {
constructor(obj) {
if (obj instanceof torch.ClassType) {
this._type = obj;
}
else {
this.data = obj;
}
constructor(type) {
this._type = type;
}
static create(type) {
if (type.is_module()) {
Expand All @@ -1342,21 +1342,19 @@ pytorch.Execution = class extends python.Execution {
return new torch.ScriptObject(type);
}
_type() {
if (!this._type) {
const qualified_name = this.data && this.data.__class__ && this.data.__class__.__module__ && this.data.__class__.__name__ ? this.data.__class__.__module__ + '.' + this.data.__class__.__name__ : '';
return new torch.ClassType(qualified_name);
}
return this._type;
}
_get_method(/* name */) {
throw new pytorch.Error();
_get_method(name) {
for (const method of this._type.methods()) {
if (name == method.name) {
return method;
}
}
return null;
}
_has_method(/* name */) {
throw new pytorch.Error();
}
_method_names(/* name */) {
throw new pytorch.Error();
}
__setattr__(name, value) {
// TODO if (this._type.hasContant(name))
this[name] = value;
Expand All @@ -1372,8 +1370,8 @@ pytorch.Execution = class extends python.Execution {
}
});
this.registerType('torch.ScriptModule', class extends torch.ScriptObject {
constructor(obj) {
super(obj);
constructor(type) {
super(type);
}
get qualified_name() {
return this._type.qualified_name();
Expand Down Expand Up @@ -1494,14 +1492,12 @@ pytorch.Execution = class extends python.Execution {
return this._graph;
}
});
this.registerType('torch.ModuleDict', class extends Map {
this.registerType('torch.ModuleDict', class {
constructor(module) {
super();
for (const entry of Object.entries(module)) {
if (entry[1] instanceof torch.ScriptModule) {
this.set(entry[0], entry[1]);
}
}
this._items = Object.entries(module).filter((entry) => entry[1] instanceof torch.ScriptModule);
}
items() {
return this._items;
}
});
this.registerType('torch.jit.CompilationUnit', class {
Expand Down Expand Up @@ -2778,6 +2774,30 @@ pytorch.jit.Execution = class extends pytorch.Execution {
}
};

pytorch.jit.Source = class {

constructor(text) {
this._text = text;
}
};

pytorch.jit.SourceLoader = class {

constructor(reader, code_prefix) {
this._reader = reader;
this._code_prefix = code_prefix;
}

loadSource(qualifier) {
const path = this._code_prefix + '/' + qualifier + '.py';
if (this._reader.hasRecord(path)) {
const data = this._reader.getRecord(path);
return new pytorch.jit.Source(data);
}
return null;
}
};

pytorch.jit.SourceImporter = class {

constructor(cu, constant_table, source_loader, version) {
Expand All @@ -2790,6 +2810,19 @@ pytorch.jit.SourceImporter = class {
loadType(/* name */) {
// TODO;
}

resolveType(name) {
return this.findNamedType(new pytorch.jit.QualifiedName(name));
}

findNamedType(name) {
// TODO
this.parseSourceIfNeeded(name.prefix());
}

parseSourceIfNeeded(/* qualifier */) {
// TODO
}
};

pytorch.jit.ScriptModuleDeserializer = class {
Expand All @@ -2798,23 +2831,22 @@ pytorch.jit.ScriptModuleDeserializer = class {
this._compilation_unit = cu;
this._reader = reader;
this._storage_context = storage_context;
this._code_prefix = (!pickle_dir_prefix && !tensor_dir_prefix) ? 'code/' : '.data/ts_code/code/';
this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/';
this._pickle_dir_prefix = pickle_dir_prefix || '';
this._tensor_dir_prefix = tensor_dir_prefix || '';
this._source_importer = new pytorch.jit.SourceImporter(
this._compilation_unit, this._constants_table,
null, // (qualifier) => findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier),
new pytorch.jit.SourceLoader(this._reader, this._code_prefix),
reader.version());
}

deserialize() {
const execution = this._compilation_unit.execution;
const reader = this._reader;
const code_prefix = this._code_prefix;
for (const name of reader.getAllRecords()) {
for (const name of this._reader.getAllRecords()) {
if (name.startsWith(code_prefix) && name.endsWith('.py')) {
const file = name.substring(code_prefix.length);
const stream = reader.getRecord(name);
const stream = this._reader.getRecord(name);
const buffer = stream.peek();
execution.add(file, buffer);
}
Expand All @@ -2825,39 +2857,23 @@ pytorch.jit.ScriptModuleDeserializer = class {
execution.builtins.ops = torch.ops;
execution.builtins.inf = torch.inf;
execution.builtins.CONSTANTS = {};
if (reader.hasRecord('constants.pkl')) {
const stream = reader.getRecord('constants.pkl');
const buffer = stream.peek();
const read_record = (name) => {
const tensor_dir = 'constants/';
const stream = reader.getRecord(tensor_dir + name);
return stream.length <= 0x40000 ? stream.peek() : stream;
};
const constants = this._unpickle(buffer, read_record, this._storage_context);
for (let i = 0; i < constants.length; i++) {
execution.builtins.CONSTANTS['c' + i.toString()] = constants[i];
}
}
if (this._reader.hasRecord('model.json')) {
return this.LEGACY_deserialize();
}

const pickle_dir_prefix = this._pickle_dir_prefix || '';
const tensor_dir_prefix = this._tensor_dir_prefix || 'data/';
const stream = reader.getRecord(pickle_dir_prefix + 'data.pkl');
const buffer = stream.peek();
const read_record = (name) => {
const stream = reader.getRecord(tensor_dir_prefix + name);
return stream.length <= 0x40000 ? stream.peek() : stream;
};
const data = this._unpickle(buffer, read_record, this._storage_context);
return execution.invoke('torch.ScriptModule', [ data ]);
const constants = this.readArchive('constants');
for (let i = 0; i < constants.length; i++) {
execution.builtins.CONSTANTS['c' + i.toString()] = constants[i];
}
const module = this.readArchive('data');
const result = new torch.ScriptModule();
result.data = module;
return result;
}

LEGACY_deserialize() {
const execution = this._compilation_unit.execution;
const reader = this._reader;
const stream = reader.getRecord('model.json');
const torch = execution.import('torch');
const stream = this._reader.getRecord('model.json');
const buffer = stream.peek();
const decoder = new TextDecoder('utf-8');
const content = decoder.decode(buffer);
Expand All @@ -2884,7 +2900,7 @@ pytorch.jit.ScriptModuleDeserializer = class {
const offset = parseInt(constant.offset, 10) || 0;
const storage = new storage_type([ size ]);
const itemsize = storage.dtype.itemsize();
const stream = reader.getRecord(key);
const stream = this._reader.getRecord(key);
const buffer = stream.peek();
const length = size * itemsize;
const data = buffer.slice(offset, offset + length);
Expand All @@ -2898,8 +2914,8 @@ pytorch.jit.ScriptModuleDeserializer = class {
execution.builtins.CONSTANTS['c' + i.toString()] = constants[i];
}
const attributes = [];
if (reader.hasRecord('attributes.pkl')) {
const stream = reader.getRecord('attributes.pkl');
if (this._reader.hasRecord('attributes.pkl')) {
const stream = this._reader.getRecord('attributes.pkl');
const buffer = stream.peek();
const unpickler = execution.invoke('pickle.Unpickler', [ buffer ]);
const obj = unpickler.load();
Expand Down Expand Up @@ -2947,18 +2963,34 @@ pytorch.jit.ScriptModuleDeserializer = class {
data.forward = module.forward;
}
}
return execution.invoke('torch.ScriptModule', [ data ]);
const result = new torch.ScriptModule();
result.data = data;
return result;
}

readArchive(archive_name) {
const type_resolver = null;
const obj_loader = null;
return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, obj_loader, this._device, this._reader, null, this._storage_context);
}

_unpickle(data, read_record, storage_context) {
readArchiveAndTensors(archive_name, pickle_prefix, tensor_prefix, type_resolver, obj_loader, device, stream_reader, type_parser, storage_context) {
const picklename = pickle_prefix + archive_name + ".pkl";
const stream = stream_reader.getRecord(picklename);
const buffer = stream.peek();
const tensor_dir_path = tensor_prefix ? tensor_prefix : archive_name + '/';
const read_record = (name) => {
const stream = stream_reader.getRecord(tensor_dir_path + name);
return stream.length <= 0x40000 ? stream.peek() : stream;
};
const execution = this._compilation_unit.execution;
const pickle = execution.__import__('pickle');
const Unpickler = class extends pickle.Unpickler {
find_class(module, name) {
return super.find_class(module, name);
}
};
const unpickler = new Unpickler(data);
const unpickler = new Unpickler(buffer);
unpickler.persistent_load = (saved_id) => {
const typename = saved_id[0];
if (typename !== 'storage') {
Expand Down

0 comments on commit 5052781

Please sign in to comment.