diff --git a/source/python.js b/source/python.js index 64f9e0a7ff..3581e2d474 100644 --- a/source/python.js +++ b/source/python.js @@ -1695,44 +1695,57 @@ python.Execution = class { this.registerFunction('builtins.__import__', (name, globals, locals, fromlist, level) => { return execution.__import__(name, globals, locals, fromlist, level); }); - this.registerFunction('builtins.bool', (value) => { - if (value) { - if (value.__bool__) { - return value.__bool__(); - } - if (value.__len__) { - return value.__len__() > 0; + this.registerType('builtins.bool', class extends Boolean { + constructor(value) { + if (value && value.__bool__) { + value = value.__bool__(); + } else if (value && value.__len__) { + value = value.__len__() > 0; + } else { + value = value ? true : false; } + super(value); } - return false; }); - this.registerFunction('builtins.int', (value) => { - if (value) { - if (value.__int__) { - return value.__int__(); - } - if (Number.isInteger(value)) { - return value; + this.registerType('builtins.int', class extends Number { + constructor(value) { + if (value && value.__int__) { + value = value.__int__(); + } else if (!Number.isInteger(value)) { + value = NaN; } + super(value); } - return NaN; }); - this.registerFunction('builtins.float', (value) => { - if (value) { - if (value.__float__) { - return value.__float__(); + this.registerType('builtins.float', class extends Number { + constructor(value) { + if (value && value.__float__) { + value = value.__float__(); + } else if (Number(value) !== value) { + value = NaN; } - if (Number(value) === value) { - return value; + super(value); + } + }); + this.registerType('builtins.long', class extends Number { + constructor(value) { + if (value && value.__int__) { + value = value.__int__(); + } else if (!Number.isInteger(value)) { + value = NaN; } + super(value); } - return NaN; }); - this.registerFunction('builtins.str', (value) => { - if (value && value.__str__) { - return value.__str__(); + this.registerType('builtins.str', class extends String { + constructor(value) { + if (value && value.__str__) { + value = value.__str__(); + } else if (typeof value !== 'string') { + value = JSON.stringify(value); + } + super(value); } - return JSON.stringify(value); }); this.registerType('builtins.complex', class { constructor(real, imaginary) { @@ -1763,7 +1776,6 @@ python.Execution = class { this.registerType('builtins.Exception', class extends builtins.BaseException {}); this.registerType('builtins.AttributeError', class extends builtins.Exception {}); this.registerType('builtins.SyntaxError', class extends builtins.Exception {}); - this.registerFunction('builtins.long', this.builtins.int); this.registerFunction('builtins.print', () => {}); this.registerFunction('builtins.unicode'); builtins.Ellipsis = new builtins.ellipsis(); @@ -3613,8 +3625,7 @@ python.Execution = class { for (const name of ['__builtin__', 'types']) { const module = self.register(name); for (const [name, obj] of Object.entries(module)) { - if (obj.__module__ === 'builtins' && - obj.__class__ === builtins.type) { + if (obj.__module__ === 'builtins' && obj.__class__ === builtins.type) { _dill._reverse_typemap.set(name, obj); } } @@ -4971,8 +4982,8 @@ python.Execution = class { return tensor; }); this.registerFunction('torch.add', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { - return left * right; + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { + return left + right; } if (Array.isArray(left) && Array.isArray(right)) { return left.concat(right); @@ -5039,7 +5050,7 @@ python.Execution = class { if (typeof left === 'string' && typeof right === 'string') { return left === right; } - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { if (isNaN(left) && isNaN(right)) { return true; } @@ -5076,7 +5087,7 @@ python.Execution = class { }).join(''); }); this.registerFunction('torch.gt', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { if (!isNaN(left) && !isNaN(right)) { return left > right; } @@ -5087,7 +5098,7 @@ python.Execution = class { throw new python.Error("Unsupported 'torch.gt' expression type."); }); this.registerFunction('torch.ge', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { if (!isNaN(left) && !isNaN(right)) { return left > right; } @@ -5145,7 +5156,7 @@ python.Execution = class { return NaN; }); this.registerFunction('torch.le', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { if (isNaN(left) || isNaN(right)) { return false; } @@ -5352,25 +5363,25 @@ python.Execution = class { }); this.registerFunction('torch.log10'); this.registerFunction('torch.lt', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left < right; } throw new python.Error("Unsupported 'torch.lt' expression type."); }); this.registerFunction('torch.mul', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left * right; } if (isNaN(left) || isNaN(right)) { return NaN; } - if (Array.isArray(left) && left.every((value) => typeof value === 'number') && typeof right === 'number') { + if (Array.isArray(left) && left.every((value) => typeof value === 'number' || value instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left.map((value) => value * right); } throw new python.Error("Unsupported 'torch.mul' expression type."); }); this.registerFunction('torch.div', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left / right; } if (isNaN(left) || isNaN(right)) { @@ -5379,7 +5390,7 @@ python.Execution = class { throw new python.Error("Unsupported 'torch.div' expression type."); }); this.registerFunction('torch.round', (value) => { - if (typeof value === 'number') { + if (typeof value === 'number' || value instanceof Number) { return Math.round(value); } if (isNaN(value)) { @@ -5388,7 +5399,7 @@ python.Execution = class { throw new python.Error("Unsupported 'torch.round' expression type."); }); this.registerFunction('torch.remainder', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left % right; } if (isNaN(left) || isNaN(right)) { @@ -5400,7 +5411,7 @@ python.Execution = class { if (typeof left === 'boolean' && typeof right === 'boolean') { return left !== right; } - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { if (isNaN(left) || isNaN(right)) { return false; } @@ -5424,7 +5435,7 @@ python.Execution = class { throw new python.Error("Unsupported 'torch.neg' expression type."); }); this.registerFunction('torch.pow', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return Math.pow(left, right); } throw new python.Error("Unsupported 'torch.pow' expression type."); @@ -5474,7 +5485,7 @@ python.Execution = class { return l.slice(start, end); }); this.registerFunction('torch.sub', (left, right) => { - if (typeof left === 'number' && typeof right === 'number') { + if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left - right; } throw new python.Error("Unsupported 'torch.sub' expression type."); @@ -6909,7 +6920,7 @@ python.Execution = class { const func = name ? callTarget[name] : callTarget; if (func.__class__ === this._builtins.type) { if (func.prototype && func.prototype.__class__ === func) { - return Reflect.construct(func, args); + return Reflect.construct(func, callArguments); } const obj = Object.create(func); obj.__class__ = func; diff --git a/source/pytorch.js b/source/pytorch.js index b9fd1623ee..087194d17c 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -3419,32 +3419,34 @@ pytorch.Utility = class { case 'Tensor[]': return Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null); case 'Scalar': - return (obj !== null && obj !== Object(obj)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0); + return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0); case 'boolean': return obj === true || obj === false; case 'string': return obj === null || typeof obj === 'string'; case 'SymInt': case 'int64': - return Number.isInteger(obj) || typeof obj === 'bigint' || (typeof obj === 'number' && isNaN(obj)); + return Number.isInteger(obj) || typeof obj === 'bigint' || + (typeof obj === 'number' && isNaN(obj)) || (obj instanceof Number); case 'SymInt[]': case 'SymInt[2]': case 'SymInt[3]': case 'SymInt[4]': case 'SymInt[5]': case 'SymInt[6]': + return Array.isArray(obj) && obj.every((item) => pytorch.Utility.isType(item, 'SymInt') || item === undefined || (item.__class__ === 'number' && isNaN(item))); case 'int64[]': case 'int64[2]': case 'int64[3]': - return Array.isArray(obj) && obj.every((item) => Number.isInteger(item) || (typeof item === 'number' && isNaN(item)) || item === undefined); + return Array.isArray(obj) && obj.every((item) => pytorch.Utility.isType(item, 'int64') || item === undefined || (item.__class__ === 'number' && isNaN(item))); case 'int64[1]': case 'SymInt[1]': return pytorch.Utility.isType(obj, 'int64') || pytorch.Utility.isType(obj, 'int64[]'); case 'float32': case 'float64': - return obj !== null && obj !== Object(obj); + return obj !== null && (typeof obj === 'number' || obj instanceof Number); case 'float32[]': - return Array.isArray(obj) && obj.every((item) => typeof item === 'number' && !isNaN(item)); + return Array.isArray(obj) && obj.every((item) => (typeof item === 'number' || item instanceof Number) && !isNaN(item)); case 'string[][]': return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string')); case 'Layout': @@ -3452,7 +3454,7 @@ pytorch.Utility = class { case 'MemoryFormat': return Number.isInteger(obj) || obj === null; case 'Dimname': - return obj === null || typeof obj === 'string'; + return obj === null || (typeof obj === 'string' || obj instanceof String); case 'Dimname[]': return Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string'); case 'Device':