Skip to content

Commit

Permalink
Add TorchScript test file (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 1, 2024
1 parent f17e406 commit 9096922
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 171 deletions.
198 changes: 184 additions & 14 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6755,10 +6755,10 @@ python.Execution = class {
return this.kind() === rhs.kind();
}
isSubtypeOf(rhs) {
if (rhs.kind() === 'OptionalType') {
if (rhs.kind() === 'OptionalType' && this.kind() !== 'OptionalType') {
return rhs.getElementType().equals(this);
}
return false;
return this.equals(rhs);
}
str() {
if (this._kind === 'VarType' && this._annotation_str) {
Expand Down Expand Up @@ -6814,6 +6814,9 @@ python.Execution = class {
findAttribute(name) {
return this._attributes.get(name);
}
getAttribute(name) {
return this._attributes.get(name);
}
hasConstant(/* name */) {
}
methods() {
Expand Down Expand Up @@ -6848,7 +6851,7 @@ python.Execution = class {
super('ListType');
this._elem = elem;
}
static get(elem) {
static create(elem) {
return new torch.ListType(elem);
}
getElementType() {
Expand Down Expand Up @@ -6981,6 +6984,9 @@ python.Execution = class {
equals(rhs) {
return this.kind() === rhs.kind();
}
isSubtypeOf(/* rhs */) {
return true;
}
str() {
return 'NoneType';
}
Expand Down Expand Up @@ -7144,7 +7150,7 @@ python.Execution = class {
this._key = key;
this._value = value;
}
static get(key, value) {
static create(key, value) {
return new torch.DictType(key, value);
}
getKeyType() {
Expand Down Expand Up @@ -7415,7 +7421,7 @@ python.Execution = class {
const value_type = this.parseType().first;
L.expect(')');
alias_info = this.parseAliasAnnotation();
real_value = torch.DictType.get(key_type, value_type);
real_value = torch.DictType.create(key_type, value_type);
fake_value = real_value;
} else if (L.eat('Union')) {
L.next();
Expand Down Expand Up @@ -7454,8 +7460,8 @@ python.Execution = class {
while (true) {
if (L.kind === '[]') {
L.expect('[]');
fake_value = torch.ListType.get(fake_value);
real_value = torch.ListType.get(real_value);
fake_value = torch.ListType.create(fake_value);
real_value = torch.ListType.create(real_value);
let container = this.parseAliasAnnotation();
if (alias_info) {
if (!container) {
Expand Down Expand Up @@ -7524,8 +7530,8 @@ python.Execution = class {
L.whitespace(0);
let N = null;
if (L.eat('[')) {
fake_type = torch.ListType.get(fake_type);
real_type = torch.ListType.get(real_type);
fake_type = torch.ListType.create(fake_type);
real_type = torch.ListType.create(real_type);
if (L.kind === '#') {
N = Number(L.value);
L.next();
Expand Down Expand Up @@ -7932,8 +7938,28 @@ python.Execution = class {
this._block = new torch.Block(this);
this._insert_before = this.return_node();
}
create(kind) {
return new torch.Node(this, kind);
create(kind, ...args) {
let inputs = null;
let num_outputs = 1;
if (args.length === 2 && Array.isArray(args[0]) && typeof args[1] === 'number') {
[inputs, num_outputs] = args;
} else if (args.length === 1) {
if (typeof args[0] === 'number') {
[num_outputs] = args;
} else if (Array.isArray(args[0])) {
[inputs] = args;
}
}
const n = new torch.Node(this, kind);
if (inputs) {
for (const i of inputs) {
n.addInput(i);
}
}
for (let i = 0; i < num_outputs; i++) {
n.addOutput();
}
return n;
}
inputs() {
return this._block.inputs();
Expand All @@ -7954,7 +7980,123 @@ python.Execution = class {
return this._block.addInput(name);
}
insertNode(node) {
node.insertBefore(this._insert_before);
return node.insertBefore(this._insert_before);
}
insertConstant(val) {
const n = this.create('prim::Constant');
this.insertNode(n);
let type = null;
if (val === null) {
n.ival_('value', val);
type = torch.NoneType.get();
} else if (typeof val === 'string') {
n.s_('value', val);
type = torch.StringType.get();
} else if (Array.isArray(val) && val.every((item) => typeof item === 'string')) {
n.ss_('value', val);
type = torch.ListType.create(torch.StringType.get());
} else if (typeof val === 'boolean') {
// return value;
n.i_('value', val === true ? 1 : 0);
type = torch.BoolType.get();
} else if (Number.isInteger(val)) {
n.i_('value', val);
type = torch.IntType.get();
} else if (typeof val === 'number') {
// return value;
n.f_('value', val);
type = torch.FloatType.get();
} else {
throw new python.Error(`Unsupported value type '${typeof value}'.`);
}
if (type) {
n.output().setType(type);
}
return n.output();
}
createList(contained_type, values) {
const n = this.create('prim::ListConstruct', values);
for (const v of values) {
if (!v.type().isSubtypeOf(contained_type)) {
throw new python.Error('Invalid list item.');
}
}
n.output().setType(torch.ListType.create(contained_type));
return n;
}
createDict(key_type, value_type, keys, values) {
if (keys.length !== values.length) {
throw new python.Error('Invalid dictionary size.');
}
const n = this.create('prim::DictConstruct');
const length = keys.length;
for (let i = 0; i < length; i++) {
if (!keys[i].type().isSubtypeOf(key_type)) {
throw new python.Error('Invalid key.');
}
if (!values[i].type().isSubtypeOf(value_type)) {
throw new python.Error('Invalid value.');
}
n.addInput(keys[i]);
n.addInput(values[i]);
}
n.output().setType(torch.DictType.create(key_type, value_type));
return n;
}
createObject(type) {
const node = this.create('prim::CreateObject');
node.output().setType(type);
return node;
}
createIsInstance(v, types) {
const n = this.create('prim::isinstance', [v], 1);
n.tys_('types', types);
n.output().setType(torch.BoolType.get());
return n;
}
createSetAttr(obj, field, newValue) {
const n = this.create('prim::SetAttr', [obj, newValue], 0);
n.s_('name', field);
return n;
}
createGetAttr(obj, field) {
const n = this.create('prim::GetAttr', [obj]);
n.s_('name', field);
const classType = obj.type();
const outputType = classType.getAttribute(field);
n.output().setType(outputType);
n.output().setDebugName(/^[0-9]+$/.test(field) ? `_${field}` : field);
return n;
}
insertUncheckedCast(v, type) {
const n = this.insertNode(this.create('prim::unchecked_cast', [v]));
n.output().setType(type);
return n.output();
}
insertToList(v, type) {
let dim = 0;
let ptr = type;
while (ptr instanceof torch.ListType) {
ptr = ptr.getElementType();
dim += 1;
}
let elem_ty = 0;
if (ptr instanceof torch.IntType) {
elem_ty = 0;
} else if (ptr instanceof torch.FloatType) {
elem_ty = 1;
} else if (ptr instanceof torch.BoolType) {
elem_ty = 2;
} else if (ptr instanceof torch.ComplexType) {
elem_ty = 3;
} else {
throw new python.Error(`Unsupported list type '${type.kind()}'.`);
}
const dim_val = this.insertConstant(dim);
const elem_ty_val = this.insertConstant(elem_ty);
const n = this.insertNode(this.create('prim::tolist', [v, dim_val, elem_ty_val]));
n.output().setType(type);
return n.output();
}
insertPoint() {
return this._insert_before;
Expand Down Expand Up @@ -7994,8 +8136,8 @@ python.Execution = class {
this.registerType('torch.Block', class {
constructor(graph) {
this._graph = graph;
this._input = graph.create('prim::Param');
this._output = graph.create('prim::Return');
this._input = graph.create('prim::Param', 0);
this._output = graph.create('prim::Return', 0);
this._input.next = this._output;
this._input.prev = this._output;
this._output.next = this._input;
Expand Down Expand Up @@ -8091,6 +8233,12 @@ python.Execution = class {
outputs() {
return this._outputs;
}
output() {
if (this._outputs.length !== 1) {
throw new python.Error('Node has multiple outputs.');
}
return this._outputs[0];
}
blocks() {
return this._blocks;
}
Expand Down Expand Up @@ -8214,6 +8362,12 @@ python.Execution = class {
f(name) {
return this._values.get(name)[0];
}
tys_(name, value) {
this._values.set(name, [value, 'tys']);
}
tys(name) {
return this._values.get(name)[0];
}
ival_(name, value) {
this._values.set(name, [value, 'ival']);
}
Expand Down Expand Up @@ -8927,7 +9081,23 @@ python.Execution = class {
}
}
}
execution.purge = new Set();
const result = this.data.forward.__call__(args);
const queue = Array.from(execution.purge);
const visited = new Set();
while (queue.length > 0) {
const node = queue.shift();
if (visited.has(node)) {
continue;
}
visited.add(node);
if (node.outputs().every((output) => output.uses().length === 0)) {
for (const input of node.inputs()) {
queue.push(input.node());
}
node.destroy();
}
}
if (Array.isArray(result)) {
for (const output of result) {
if (isTensor(output)) {
Expand Down
3 changes: 3 additions & 0 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -6228,6 +6228,9 @@
{
"name": "prim::shape(Tensor self) -> int[]"
},
{
"name": "prim::tolist(...) -> ..."
},
{
"name": "prim::type(Device self) -> str"
},
Expand Down
Loading

0 comments on commit 9096922

Please sign in to comment.