Skip to content

Commit

Permalink
Add TorchScript test file (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 1, 2024
1 parent f17e406 commit 741f491
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 229 deletions.
244 changes: 228 additions & 16 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6755,10 +6755,16 @@ python.Execution = class {
return this.kind() === rhs.kind();
}
isSubtypeOf(rhs) {
if (rhs.kind() === 'OptionalType') {
if (rhs.kind() === 'OptionalType' && this.kind() !== 'OptionalType') {
return rhs.getElementType().equals(this);
}
return false;
return this.equals(rhs);
}
expect(type) {
if (this instanceof type === false) {
throw new python.Error(`Expected '${type.kind()}' but got '${this.kind()}'.`);
}
return this;
}
str() {
if (this._kind === 'VarType' && this._annotation_str) {
Expand Down Expand Up @@ -6814,6 +6820,9 @@ python.Execution = class {
findAttribute(name) {
return this._attributes.get(name);
}
getAttribute(name) {
return this._attributes.get(name);
}
hasConstant(/* name */) {
}
methods() {
Expand Down Expand Up @@ -6848,7 +6857,7 @@ python.Execution = class {
super('ListType');
this._elem = elem;
}
static get(elem) {
static create(elem) {
return new torch.ListType(elem);
}
getElementType() {
Expand Down Expand Up @@ -6924,7 +6933,7 @@ python.Execution = class {
this._elements = elements;
this._schema = schema;
}
static get(elements) {
static create(elements) {
return new torch.TupleType(elements);
}
static createNamed(qualified_name, field_names, field_types /*, field_defaults */) {
Expand Down Expand Up @@ -6981,6 +6990,9 @@ python.Execution = class {
equals(rhs) {
return this.kind() === rhs.kind();
}
isSubtypeOf(/* rhs */) {
return true;
}
str() {
return 'NoneType';
}
Expand Down Expand Up @@ -7144,7 +7156,7 @@ python.Execution = class {
this._key = key;
this._value = value;
}
static get(key, value) {
static create(key, value) {
return new torch.DictType(key, value);
}
getKeyType() {
Expand Down Expand Up @@ -7372,7 +7384,7 @@ python.Execution = class {
L.eat(',');
L.whitespace(0);
}
real_value = torch.TupleType.get(types);
real_value = torch.TupleType.create(types);
fake_value = real_value;
} else if (L.value === 'Future') {
L.next();
Expand Down Expand Up @@ -7415,7 +7427,7 @@ python.Execution = class {
const value_type = this.parseType().first;
L.expect(')');
alias_info = this.parseAliasAnnotation();
real_value = torch.DictType.get(key_type, value_type);
real_value = torch.DictType.create(key_type, value_type);
fake_value = real_value;
} else if (L.eat('Union')) {
L.next();
Expand Down Expand Up @@ -7454,8 +7466,8 @@ python.Execution = class {
while (true) {
if (L.kind === '[]') {
L.expect('[]');
fake_value = torch.ListType.get(fake_value);
real_value = torch.ListType.get(real_value);
fake_value = torch.ListType.create(fake_value);
real_value = torch.ListType.create(real_value);
let container = this.parseAliasAnnotation();
if (alias_info) {
if (!container) {
Expand Down Expand Up @@ -7524,8 +7536,8 @@ python.Execution = class {
L.whitespace(0);
let N = null;
if (L.eat('[')) {
fake_type = torch.ListType.get(fake_type);
real_type = torch.ListType.get(real_type);
fake_type = torch.ListType.create(fake_type);
real_type = torch.ListType.create(real_type);
if (L.kind === '#') {
N = Number(L.value);
L.next();
Expand Down Expand Up @@ -7932,8 +7944,118 @@ python.Execution = class {
this._block = new torch.Block(this);
this._insert_before = this.return_node();
}
create(kind) {
return new torch.Node(this, kind);
create(kind, ...args) {
let inputs = null;
let num_outputs = 1;
if (args.length === 2 && Array.isArray(args[0]) && typeof args[1] === 'number') {
[inputs, num_outputs] = args;
} else if (args.length === 1) {
if (typeof args[0] === 'number') {
[num_outputs] = args;
} else if (Array.isArray(args[0])) {
[inputs] = args;
}
}
const n = new torch.Node(this, kind);
if (inputs) {
for (const i of inputs) {
n.addInput(i);
}
}
for (let i = 0; i < num_outputs; i++) {
n.addOutput();
}
return n;
}
createUninitialized(typ) {
const n = this.create('prim::Uninitialized');
n.output().setType(typ);
return n;
}
createList(contained_type, values) {
const n = this.create('prim::ListConstruct', values);
for (const v of values) {
if (!v.type().isSubtypeOf(contained_type)) {
throw new python.Error('Invalid list item.');
}
}
n.output().setType(torch.ListType.create(contained_type));
return n;
}
createListUnpack(v, size) {
const list_type = v.type().expect(torch.ListType);
const elem_type = list_type.getElementType();
const n = this.create('prim::ListUnpack', [v], 0);
for (let i = 0; i < size; i++) {
n.addOutput().setType(elem_type);
}
return n;
}
createTuple(values, tuple_type) {
if (!tuple_type) {
const types = values.map((v) => v.type());
tuple_type = torch.TupleType.create(types);
}
const n = this.create('prim::TupleConstruct', values);
n.output().setType(tuple_type);
return n;
}
createTupleUnpack(v) {
const tt = v.type().expect(torch.TupleType);
const n = this.create('prim::TupleUnpack', [v], 0);
for (const element of tt.elements()) {
n.addOutput().setType(element);
}
return n;
}
createTupleIndex(tup, idx, output_type) {
const n = this.create('prim::TupleIndex', [tup, idx]);
n.output().setType(output_type);
return n;
}
createDict(key_type, value_type, keys, values) {
if (keys.length !== values.length) {
throw new python.Error('Invalid dictionary size.');
}
const n = this.create('prim::DictConstruct');
const length = keys.length;
for (let i = 0; i < length; i++) {
if (!keys[i].type().isSubtypeOf(key_type)) {
throw new python.Error('Invalid key.');
}
if (!values[i].type().isSubtypeOf(value_type)) {
throw new python.Error('Invalid value.');
}
n.addInput(keys[i]);
n.addInput(values[i]);
}
n.output().setType(torch.DictType.create(key_type, value_type));
return n;
}
createObject(type) {
const node = this.create('prim::CreateObject');
node.output().setType(type);
return node;
}
createIsInstance(v, types) {
const n = this.create('prim::isinstance', [v], 1);
n.tys_('types', types);
n.output().setType(torch.BoolType.get());
return n;
}
createSetAttr(obj, field, newValue) {
const n = this.create('prim::SetAttr', [obj, newValue], 0);
n.s_('name', field);
return n;
}
createGetAttr(obj, field) {
const n = this.create('prim::GetAttr', [obj]);
n.s_('name', field);
const classType = obj.type();
const outputType = classType.getAttribute(field);
n.output().setType(outputType);
n.output().setDebugName(/^[0-9]+$/.test(field) ? `_${field}` : field);
return n;
}
inputs() {
return this._block.inputs();
Expand All @@ -7954,7 +8076,69 @@ python.Execution = class {
return this._block.addInput(name);
}
insertNode(node) {
node.insertBefore(this._insert_before);
return node.insertBefore(this._insert_before);
}
insertConstant(val) {
const n = this.create('prim::Constant');
this.insertNode(n);
let type = null;
if (val === null) {
n.ival_('value', val);
type = torch.NoneType.get();
} else if (typeof val === 'string') {
n.s_('value', val);
type = torch.StringType.get();
} else if (Array.isArray(val) && val.every((item) => typeof item === 'string')) {
n.ss_('value', val);
type = torch.ListType.create(torch.StringType.get());
} else if (typeof val === 'boolean') {
// return value;
n.i_('value', val === true ? 1 : 0);
type = torch.BoolType.get();
} else if (Number.isInteger(val)) {
n.i_('value', val);
type = torch.IntType.get();
} else if (typeof val === 'number') {
// return value;
n.f_('value', val);
type = torch.FloatType.get();
} else {
throw new python.Error(`Unsupported value type '${typeof value}'.`);
}
if (type) {
n.output().setType(type);
}
return n.output();
}
insertUncheckedCast(v, type) {
const n = this.insertNode(this.create('prim::unchecked_cast', [v]));
n.output().setType(type);
return n.output();
}
insertToList(v, type) {
let dim = 0;
let ptr = type;
while (ptr instanceof torch.ListType) {
ptr = ptr.getElementType();
dim += 1;
}
let elem_ty = 0;
if (ptr instanceof torch.IntType) {
elem_ty = 0;
} else if (ptr instanceof torch.FloatType) {
elem_ty = 1;
} else if (ptr instanceof torch.BoolType) {
elem_ty = 2;
} else if (ptr instanceof torch.ComplexType) {
elem_ty = 3;
} else {
throw new python.Error(`Unsupported list type '${type.kind()}'.`);
}
const dim_val = this.insertConstant(dim);
const elem_ty_val = this.insertConstant(elem_ty);
const n = this.insertNode(this.create('prim::tolist', [v, dim_val, elem_ty_val]));
n.output().setType(type);
return n.output();
}
insertPoint() {
return this._insert_before;
Expand Down Expand Up @@ -7994,8 +8178,8 @@ python.Execution = class {
this.registerType('torch.Block', class {
constructor(graph) {
this._graph = graph;
this._input = graph.create('prim::Param');
this._output = graph.create('prim::Return');
this._input = graph.create('prim::Param', 0);
this._output = graph.create('prim::Return', 0);
this._input.next = this._output;
this._input.prev = this._output;
this._output.next = this._input;
Expand Down Expand Up @@ -8091,6 +8275,12 @@ python.Execution = class {
outputs() {
return this._outputs;
}
output() {
if (this._outputs.length !== 1) {
throw new python.Error('Node has multiple outputs.');
}
return this._outputs[0];
}
blocks() {
return this._blocks;
}
Expand Down Expand Up @@ -8214,6 +8404,12 @@ python.Execution = class {
f(name) {
return this._values.get(name)[0];
}
tys_(name, value) {
this._values.set(name, [value, 'tys']);
}
tys(name) {
return this._values.get(name)[0];
}
ival_(name, value) {
this._values.set(name, [value, 'ival']);
}
Expand Down Expand Up @@ -8927,7 +9123,23 @@ python.Execution = class {
}
}
}
execution.purge = new Set();
const result = this.data.forward.__call__(args);
const queue = Array.from(execution.purge);
const visited = new Set();
while (queue.length > 0) {
const node = queue.shift();
if (visited.has(node)) {
continue;
}
visited.add(node);
if (node.outputs().every((output) => output.uses().length === 0)) {
for (const input of node.inputs()) {
queue.push(input.node());
}
node.destroy();
}
}
if (Array.isArray(result)) {
for (const output of result) {
if (isTensor(output)) {
Expand Down
3 changes: 3 additions & 0 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -6228,6 +6228,9 @@
{
"name": "prim::shape(Tensor self) -> int[]"
},
{
"name": "prim::tolist(...) -> ..."
},
{
"name": "prim::type(Device self) -> str"
},
Expand Down
Loading

0 comments on commit 741f491

Please sign in to comment.