From 8ee6172c7f075882cbaeef76e1ad4ca5707938bb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 06:43:06 +0200 Subject: [PATCH] subclasses needs to be passed down --- .../_array_api/_elementwise_functions.py | 130 +++++++++--------- 1 file changed, 65 insertions(+), 65 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 4585c668bc1..4158a595bab 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -22,21 +22,21 @@ def abs(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.abs(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def acos(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.acos(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def acosh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.acosh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def add(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: @@ -44,28 +44,28 @@ def add(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.add(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def asin(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.asin(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def asinh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.asinh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def atan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.atan(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def atan2( @@ -75,14 +75,14 @@ def atan2( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.atan2(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def atanh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.atanh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def bitwise_and( @@ -92,14 +92,14 @@ def bitwise_and( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_and(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def bitwise_invert(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.bitwise_invert(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def bitwise_left_shift( @@ -109,7 +109,7 @@ def bitwise_left_shift( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_left_shift(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def bitwise_or( @@ -119,7 +119,7 @@ def bitwise_or( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_or(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def bitwise_right_shift( @@ -129,7 +129,7 @@ def bitwise_right_shift( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_right_shift(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def bitwise_xor( @@ -139,14 +139,14 @@ def bitwise_xor( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_xor(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def ceil(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.ceil(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def clip( @@ -158,14 +158,14 @@ def clip( xp = _get_data_namespace(x) _dims = x.dims _data = xp.clip(x._data, min=min, max=max) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def conj(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.conj(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def copysign( @@ -175,21 +175,21 @@ def copysign( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.copysign(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def cos(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.cos(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def cosh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.cosh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def divide( @@ -199,21 +199,21 @@ def divide( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.divide(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def exp(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.exp(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def expm1(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.expm1(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def equal( @@ -223,14 +223,14 @@ def equal( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.equal(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def floor(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.floor(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def floor_divide( @@ -240,7 +240,7 @@ def floor_divide( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.floor_divide(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def greater( @@ -250,7 +250,7 @@ def greater( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.greater(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def greater_equal( @@ -260,7 +260,7 @@ def greater_equal( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.greater_equal(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def hypot( @@ -270,7 +270,7 @@ def hypot( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.hypot(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def imag( @@ -303,28 +303,28 @@ def imag( xp = _get_data_namespace(x) _dims = x.dims _data = xp.imag(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def isfinite(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isfinite(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def isinf(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isinf(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def isnan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isnan(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def less(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: @@ -332,7 +332,7 @@ def less(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[An x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.less(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def less_equal( @@ -342,35 +342,35 @@ def less_equal( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.less_equal(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def log(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def log1p(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log1p(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def log2(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log2(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def log10(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log10(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def logaddexp( @@ -380,7 +380,7 @@ def logaddexp( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.logaddexp(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def logical_and( @@ -390,14 +390,14 @@ def logical_and( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.logical_and(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def logical_not(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.logical_not(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def logical_or( @@ -407,7 +407,7 @@ def logical_or( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.logical_or(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def logical_xor( @@ -417,7 +417,7 @@ def logical_xor( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.logical_xor(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def maximum( @@ -427,7 +427,7 @@ def maximum( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.maximum(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def minimum( @@ -437,7 +437,7 @@ def minimum( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.minimum(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def multiply( @@ -447,14 +447,14 @@ def multiply( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.multiply(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def negative(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.negative(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def not_equal( @@ -464,14 +464,14 @@ def not_equal( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.not_equal(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def positive(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.positive(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def pow(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: @@ -479,7 +479,7 @@ def pow(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.pow(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def real( @@ -512,7 +512,7 @@ def real( xp = _get_data_namespace(x) _dims = x.dims _data = xp.real(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def remainder( @@ -522,56 +522,56 @@ def remainder( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.remainder(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def round(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.round(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def sign(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sign(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def signbit(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.signbit(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def sin(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sin(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def sinh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sinh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def sqrt(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sqrt(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def square(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.square(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def subtract( @@ -581,25 +581,25 @@ def subtract( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.subtract(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def tan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.tan(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def tanh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.tanh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def trunc(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.trunc(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data)