Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 27, 2024
1 parent b9d7b8a commit ee6fce1
Show file tree
Hide file tree
Showing 3 changed files with 449 additions and 50 deletions.
170 changes: 157 additions & 13 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -2444,9 +2444,7 @@ python.Execution = class {
}
flatten() {
const size = this.shape.reduce((a, b) => a * b, 1);
const value = execution.invoke('numpy.ndarray', [
[size], this.dtype, this.data, this.offset, this.strides, this.order
]);
const value = new numpy.ndarray([size], this.dtype, this.data, this.offset, this.strides, this.order);
value.flags = this.flags;
return value;
}
Expand Down Expand Up @@ -5535,6 +5533,9 @@ python.Execution = class {
throw new python.Error("Unsupported 'torch.__isnot__' expression type.");
});
this.registerFunction('torch.__not__', (value) => {
if (Number.isInteger(value)) {
value = Boolean(value);
}
if (typeof value === 'boolean') {
return !value;
}
Expand Down Expand Up @@ -7311,11 +7312,14 @@ python.Execution = class {
this.registerType('torch.Graph', class {
constructor() {
this._unique = 1;
this._nodes = [];
this._block = execution.invoke('torch.Block', [this]);
this._all_nodes = [];
this._all_values = [];
this._all_blocks = [];
this._block = new torch.Block(this);
this._insert_before = this.return_node();
}
create(kind) {
return execution.invoke('torch.Node', [this, kind]);
return new torch.Node(this, kind);
}
inputs() {
return this._block.inputs();
Expand All @@ -7324,8 +7328,7 @@ python.Execution = class {
return this._block.outputs();
}
nodes() {
return this._nodes;
// return this._block.nodes();
return this._block.nodes();
}
param_node() {
return this._block.param_node();
Expand All @@ -7336,13 +7339,51 @@ python.Execution = class {
addInput(name) {
return this._block.addInput(name);
}
insertNode(node) {
node.insertBefore(this._insert_before);
}
insertPoint() {
return this._insert_before;
}
setInsertPoint(node) {
if (node instanceof torch.Block) {
node = node.return_node();
}
this._insert_before = node;
}
get all_nodes() {
return this._all_nodes;
}
freeNode(n) {
const index = this._all_nodes.indexOf(n);
if (index !== -1) {
this._all_nodes.splice(index, 1);
}
}
freeValue(v) {
v.setDebugName('');
const index = this._all_values.indexOf(v);
if (index !== -1) {
this._all_values.splice(index, 1);
}
}
freeBlock(b) {
const index = this._all_blocks.indexOf(b);
if (index !== -1) {
this._all_blocks.splice(index, 1);
}
}
});
this.registerType('torch.Block', class {
constructor(graph) {
this._unique = 1;
this._graph = graph;
this._input = graph.create('prim::Param');
this._output = graph.create('prim::Return');
this._input.next = this._output;
this._input.prev = this._output;
this._output.next = this._input;
this._output.prev = this._input;
}
param_node() {
return this._input;
Expand All @@ -7356,6 +7397,15 @@ python.Execution = class {
outputs() {
return this._output.inputs();
}
nodes() {
const nodes = [];
let current = this._input;
do {
nodes.push(current);
current = current.next;
} while (current !== this._input);
return nodes;
}
addInput(name) {
const value = this._input.addOutput();
value.setDebugName(name || '');
Expand All @@ -7365,16 +7415,30 @@ python.Execution = class {
this._output.addInput(value);
return this.outputs().length - 1;
}
destroy() {
this._output.removeAllInputs();
for (const n of this.nodes()) {
n.destroy();
}
this._output.destroy();
this._input.destroy();
this._graph.freeBlock(this);
}
});
this.registerType('torch.Node', class {
constructor(graph, kind) {
this._graph = graph;
this._graph._nodes.push(this);
this._kind = kind;
this._values = new Map();
this._inputs = [];
this._outputs = [];
this._blocks = [];
this._graph.all_nodes.push(this);
this._prev = null;
this._next = null;
}
owningGraph() {
return this._graph;
}
kind() {
return this._kind;
Expand Down Expand Up @@ -7421,6 +7485,86 @@ python.Execution = class {
this._blocks.push(block);
return block;
}
get prev() {
return this._prev;
}
set prev(value) {
this._prev = value;
}
get next() {
return this._next;
}
set next(value) {
this._next = value;
}
insertBefore(n) {
this.insertAfter(n.prev);
return this;
}
insertAfter(n) {
// this.owning_block_ = n->owningBlock();
const next = n.next;
n.next = this;
this.prev = n;
this.next = next;
next.prev = this;
// assignTopoPosition();
}
dropInput(i) {
const input = this._inputs[i];
const uses = this._inputs[i].uses();
for (let i = uses.length - 1; i >= 0; i--) {
const use = uses[i];
if (use.user === this) {
uses.splice(i, 1);
}
}
this._inputs[i] = null;
return input;
}
eraseOutput(i) {
this._op = null;
const v = this._outputs[i];
this._outputs.splice(i, 1);
this.owningGraph().freeValue(v);
}
eraseBlock(i) {
this._op = null;
const n = this._blocks[i];
this._blocks.splice(i, 1);
n.destroy();
}
removeAllInputs() {
for (let i = this._inputs.length - 1; i >= 0; i--) {
this.dropInput(i);
}
this._inputs.splice(0, this._inputs.length);
}
inBlockList() {
return this.next !== null;
}
removeFromList() {
this._owning_block = null;
const next = this.next;
const prev = this.prev;
prev.next = next;
next.prev = prev;
this.next = null;
this.prev = null;
}
destroy() {
while (this.outputs().length > 0) {
this.eraseOutput(this.outputs().length - 1);
}
while (this.blocks().length > 0) {
this.eraseBlock(this.blocks().length - 1);
}
this.removeAllInputs();
if (this.inBlockList()) {
this.removeFromList();
}
this._graph.freeNode(this);
}
s_(name, value) {
this._values.set(name, [value, 's']);
}
Expand Down Expand Up @@ -7622,7 +7766,7 @@ python.Execution = class {
const data = buffer.slice(offset, offset + length);
storage._set_cdata(data);
}
const tensor = execution.invoke('torch._utils._rebuild_tensor', [storage, 0, shape, strides]);
const tensor = torch._utils._rebuild_tensor(storage, 0, shape, strides);
tensor.name = constant.data.key;
return tensor;
});
Expand All @@ -7634,7 +7778,7 @@ python.Execution = class {
if (this._reader.has_record('attributes.pkl')) {
const stream = this._reader.get_record('attributes.pkl');
const buffer = stream.peek();
const unpickler = execution.invoke('pickle.Unpickler', [buffer]);
const unpickler = new pickle.Unpickler(buffer);
const obj = unpickler.load();
attributes.push(...obj);
}
Expand Down Expand Up @@ -9360,7 +9504,7 @@ python.Execution = class {
this.registerFunction('torch._inductor.compile_fx.compile_fx');
this.registerFunction('torch_utils.persistence._reconstruct_persistent_obj', (meta) => {
const name = `_imported_module_${Math.floor(Math.random() * 10000)}`;
const module = execution.invoke('types.ModuleType', [name]);
const module = new types.ModuleType(name);
execution.register('sys').modules.set(name, module);
const context = new python.Execution.Context(module, null);
execution.exec(meta.get('module_src'), context);
Expand Down Expand Up @@ -9903,7 +10047,7 @@ python.Execution = class {
this.registerType('fastai.basic_train.Learner', class {});
this.registerType('fastai.basic_train.Recorder', class {});
this.registerFunction('fastai.torch_core._fa_rebuild_tensor', (cls, ...args) => {
const tensor = self.invoke('torch._utils._rebuild_tensor_v2', args);
const tensor = torch._utils._rebuild_tensor_v2(...args);
return self.invoke(cls, tensor);
});
this.registerFunction('fastai.torch_core.trainable_params');
Expand Down
27 changes: 18 additions & 9 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,16 @@
"name": "aten::_dim_arange(Tensor like, int dim) -> Tensor"
},
{
"name": "aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.) -> Tensor"
"name": "aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.) -> Tensor",
"category": "Quantization"
},
{
"name": "aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1., *, Tensor(a!) out) -> Tensor(a!)"
"name": "aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1., *, Tensor(a!) out) -> Tensor(a!)",
"category": "Quantization"
},
{
"name": "aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.) -> (Tensor, Tensor, Tensor)"
"name": "aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.) -> (Tensor, Tensor, Tensor)",
"category": "Quantization"
},
{
"name": "aten::_make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor"
Expand Down Expand Up @@ -2159,22 +2162,28 @@
"name": "aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)"
},
{
"name": "aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor"
"name": "aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor",
"category": "Quantization"
},
{
"name": "aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"
"name": "aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor",
"category": "Quantization"
},
{
"name": "aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"
"name": "aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor",
"category": "Quantization"
},
{
"name": "aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)"
"name": "aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)",
"category": "Quantization"
},
{
"name": "aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))"
"name": "aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))",
"category": "Quantization"
},
{
"name": "aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor"
"name": "aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor",
"category": "Quantization"
},
{
"name": "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
Expand Down
Loading

0 comments on commit ee6fce1

Please sign in to comment.