Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 24, 2024
1 parent adef0de commit 8dccbd8
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 31 deletions.
7 changes: 5 additions & 2 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6168,7 +6168,7 @@ python.Execution = class {
kind() {
return this._kind;
}
annotation_str() {
get annotation_str() {
return this._annotation_str;
}
equals(/* rhs */) {
Expand Down Expand Up @@ -6357,7 +6357,10 @@ python.Execution = class {
return torch.AnyType.value;
}
str() {
return 'AnyType';
return 'Any';
}
__str__() {
return 'Any';
}
});
this.registerType('torch.NoneType', class extends torch.Type {
Expand Down
96 changes: 72 additions & 24 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,20 @@ pytorch.Node = class {
module = null;
}
}
const mapTensor = (input) => {
let initializer = null;
let identifier = input.unique().toString();
if (input.value) {
const value = input.value;
const hide = value.__parent__ ? value.__parent__.__hide__ : true;
initializer = hide ? initializers.get(value) : null;
identifier = initializer ? initializer.name : identifier;
}
if (initializer) {
return new pytorch.Value(identifier, null, null, initializer);
}
return values.map(identifier);
};
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
const arg = schema && schema.arguments && i < schema.arguments.length ? schema.arguments[i] : null;
Expand Down Expand Up @@ -480,24 +494,38 @@ pytorch.Node = class {
argument = new pytorch.Argument(name, input.value, type || 'attribute');
} else if (pytorch.Utility.isInstance(input.type(), 'torch.ListType')) {
if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 &&
input.node().inputs().every((value) => pytorch.Utility.isInstance(value, 'torch.Value') || pytorch.Utility.isInstance(value.type(), 'torch.IntType') || pytorch.Utility.isInstance(value.type(), 'torch.FloatType') || pytorch.Utility.isInstance(value.type(), 'torch.StringType') || pytorch.Utility.isInstance(value.type(), 'torch.ComplexType') || pytorch.Utility.isInstance(value.type(), 'torch.TensorType'))) {
input.node().inputs().every((value) => pytorch.Utility.isInstance(value, 'torch.Value') || pytorch.Utility.isInstance(value.type(), 'torch.IntType') || pytorch.Utility.isInstance(value.type(), 'torch.FloatType') || pytorch.Utility.isInstance(value.type(), 'torch.StringType') || pytorch.Utility.isInstance(value.type(), 'torch.ComplexType') || pytorch.Utility.isInstance(value.type(), 'torch.TensorType'))) {
const list = input.node().inputs();
const args = list.map((value) => {
if (pytorch.Utility.isTensor(value.value)) {
return mapTensor(value);
}
if (value.uses().length === 1 && value.node().kind() === 'prim::Constant') {
return getAttribute(value.node(), 'value')[1];
}
if (value.uses().length === 1 && value.node() === input.node() && value.value !== undefined) {
return value.value;
}
const identifier = value.unique().toString();
return values.map(identifier);
});
argument = new pytorch.Argument(name, args, pytorch.Utility.toType(input.type()));
const type = list.every((value) => (pytorch.Utility.isTensor(value.value)) || value.value === null) ? null : pytorch.Utility.toType(input.type());
argument = new pytorch.Argument(name, args, type);
} else {
const identifier = input.unique().toString();
argument = new pytorch.Argument(name, [values.map(identifier)]);
}
} else if (pytorch.Utility.isInstance(input.type(), 'torch.StringType') && typeof input.value === 'string') {
argument = new pytorch.Argument(name, input.value, 'string');
} else if (input.node() && input.uses().length === 1 && input.node().kind() === 'prim::Constant') {
const [type, value] = getAttribute(input.node(), 'value');
let [type, value] = getAttribute(input.node(), 'value');
const valueType = input.node().outputs()[0].type();
if (valueType) {
type = pytorch.Utility.toType(valueType);
if (type === 'boolean') {
value = Boolean(value);
}
}
argument = new pytorch.Argument(name, value, type || 'attribute');
} else {
const identifier = input.unique().toString();
Expand Down Expand Up @@ -2357,6 +2385,14 @@ pytorch.jit.Execution = class extends pytorch.Execution {
}
const torch = this.torch;
switch (expression.type) {
case 'id': {
switch (expression.value) {
case 'True': return this.constant(true);
case 'False': return this.constant(false);
default: break;
}
return super.expression(expression, context);
}
case '=': {
const target = expression.target;
if (target.type === 'id') {
Expand Down Expand Up @@ -2559,15 +2595,11 @@ pytorch.jit.Execution = class extends pytorch.Execution {
node.addInput(item);
output.setType(torch.ListType.get(item.type()));
} else if (Number.isInteger(item)) {
const value = new torch.Value(node);
value.value = item;
value.setType(torch.IntType.get());
const value = this.constant(item);
node.addInput(value);
output.setType(torch.ListType.get(torch.IntType.get()));
} else if (typeof item === 'string') {
const value = new torch.Value(node);
value.value = item;
value.setType(torch.StringType.get());
const value = this.constant(item);
node.addInput(value);
output.setType(torch.ListType.get(torch.StringType.get()));
} else if (pytorch.Utility.isTensor(item)) {
Expand Down Expand Up @@ -2681,16 +2713,27 @@ pytorch.jit.Execution = class extends pytorch.Execution {
return undefined;
}
case 'if': {
if (!this.traceIf) {
return super.statement(statement, context);
}
/*
const test = this.expression(statement.test, context);
const n = this._graph.create('prim::If');
const true_block = n.addBlock();
const false_block = n.addBlock();
*/
return undefined;
if (test instanceof torch.Value) {
const node = this._graph.create('prim::If');
node.addInput(test);
}
if (test === true || test) {
const value = this.block(statement.body.statements, context);
if (value !== undefined) {
return value;
}
return undefined;
} else if (test === false) {
if (statement.orelse) {
const value = this.block(statement.orelse.statements, context);
if (value !== undefined) {
return value;
}
}
return undefined;
}
throw new python.Error("Unsupported condition.");
}
default: {
break;
Expand Down Expand Up @@ -2806,7 +2849,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
const varTypes = new Map();
varTypes.map = function(type) {
if (type.kind() === 'VarType') {
const key = type.annotation_str();
const key = type.annotation_str;
if (!varTypes.has(key)) {
throw new pytorch.Error(`Unknown var type '${key}'.`);
}
Expand Down Expand Up @@ -2895,7 +2938,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
if (match) {
node.addInput(input);
if (type.kind() === 'VarType') {
const key = type.annotation_str();
const key = type.annotation_str;
if (input instanceof torch.Value && input.type()) {
varTypes.set(key, input.type());
} else if (input instanceof torch.Value && Number.isInteger(input.value)) {
Expand All @@ -2904,7 +2947,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
// throw new pytorch.Error("Unknown value type 't'.");
}
if (type instanceof torch.ListType && type.getElementType().kind() === 'VarType') {
const key = type.getElementType().annotation_str();
const key = type.getElementType().annotation_str;
if (input instanceof torch.Value && input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.ListType) {
varTypes.set(key, input.type().getElementType().getElementType());
} else if (input instanceof torch.Value && input.type() instanceof torch.ListType) {
Expand All @@ -2921,7 +2964,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
}
}
if (type instanceof torch.DictType && type.getValueType().kind() === 'VarType') {
const key = type.getValueType().annotation_str();
const key = type.getValueType().annotation_str;
if (input instanceof torch.Value && input.type() instanceof torch.DictType) {
varTypes.set(key, input.type().getValueType());
} else if (input.value && Object.values(input.value).every((item) => pytorch.Utility.isTensor(item))) {
Expand All @@ -2931,7 +2974,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
}
}
if (type instanceof torch.ListType && type.getElementType() instanceof torch.TupleType && type.getElementType().elements().length === 2 && type.getElementType().elements()[1].kind() === 'VarType') {
const key = type.getElementType().elements()[1].annotation_str();
const key = type.getElementType().elements()[1].annotation_str;
if (input instanceof torch.Value && input.type() instanceof torch.ListType && input.type().getElementType() instanceof torch.TupleType) {
const elements = input.type().getElementType().elements();
if (elements.length === 2) {
Expand Down Expand Up @@ -3341,6 +3384,9 @@ pytorch.jit.Execution = class extends pytorch.Execution {
case 'Tensor[]':
return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) ||
(obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType);
case 'Tensor?[]':
return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) ||
(obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.OptionalType && obj.type().getElementType().getElementType() instanceof torch.TensorType);
case 'Scalar':
return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) ||
(pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) ||
Expand Down Expand Up @@ -3424,6 +3470,8 @@ pytorch.jit.Execution = class extends pytorch.Execution {
return false;
case 'complex':
return obj instanceof torch.Value && obj.type() instanceof torch.ComplexType;
case 'Any':
return true;
case 'Any[]':
if (Array.isArray(obj)) {
return true;
Expand Down Expand Up @@ -4450,7 +4498,7 @@ pytorch.Utility = class {
return `boolean`;
}
if (pytorch.Utility.isInstance(type, 'torch.TensorType')) {
return `Tensor`;
return `tensor`;
}
if (pytorch.Utility.isInstance(type, 'torch.TupleType')) {
return `(${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')})`;
Expand Down
10 changes: 5 additions & 5 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -2913,13 +2913,13 @@ view.ArgumentView = class extends view.Control {
if (argument.type === 'attribute') {
this._source = 'attribute';
}
if (argument.type === 'tensor') {
value = [{ type: value.type, initializer: value }];
} else if (argument.type === 'tensor[]') {
if (argument.type === 'tensor' || argument.type === 'tensor?') {
value = [value === null ? value : { type: value.type, initializer: value }];
} else if (argument.type === 'tensor[]' || argument.type === 'tensor?[]') {
value = value.map((value) => value === null ? value : { type: value.type, initializer: value });
}
this._source = typeof type === 'string' && !type.endsWith('*') ? 'attribute' : this._source;
if (this._source === 'attribute' && type !== 'tensor' && type !== 'tensor[]') {
if (this._source === 'attribute' && type !== 'tensor' && type !== 'tensor?' && type !== 'tensor[]' && type !== 'tensor?[]') {
this._source = 'attribute';
const item = new view.PrimitiveView(context, argument);
this._items.push(item);
Expand All @@ -2929,7 +2929,7 @@ view.ArgumentView = class extends view.Control {
} else {
const values = value;
for (const value of values) {
const emit = values.length === 1 && value.initializer;
const emit = values.length === 1 && value && value.initializer;
const target = emit ? argument : value;
if (value === null) {
const item = new view.TextView(this._view, null);
Expand Down

0 comments on commit 8dccbd8

Please sign in to comment.