Skip to content

Commit

Permalink
Seems successfull do it on the rest
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 25, 2024
1 parent 7c278dc commit c59fb54
Showing 1 changed file with 18 additions and 30 deletions.
48 changes: 18 additions & 30 deletions xarray/namedarray/_array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from xarray.namedarray._array_api._utils import (
_dims_to_axis,
_get_data_namespace,
_get_remaining_dims,
_reduce_dims,
)
from xarray.namedarray._typing import (
Expand Down Expand Up @@ -84,10 +83,9 @@ def max(
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dims, axis)
_data = xp.max(x._data, axis=_axis, keepdims=False)
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
return x._new(dims=dims_, data=data_)
_data = xp.max(x._data, axis=_axis, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)


def mean(
Expand Down Expand Up @@ -166,11 +164,9 @@ def min(
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dims, axis)
_data = xp.min(x._data, axis=_axis, keepdims=False)
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out
_data = xp.min(x._data, axis=_axis, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)


def prod(
Expand All @@ -184,11 +180,9 @@ def prod(
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dims, axis)
_data = xp.prod(x._data, axis=_axis, keepdims=False)
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out
_data = xp.prod(x._data, axis=_axis, dtype=dtype, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)


def std(
Expand All @@ -202,11 +196,9 @@ def std(
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dims, axis)
_data = xp.std(x._data, axis=_axis, correction=correction, keepdims=False)
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out
_data = xp.std(x._data, axis=_axis, correction=correction, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)


def sum(
Expand All @@ -220,11 +212,9 @@ def sum(
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dims, axis)
_data = xp.sum(x._data, axis=_axis, keepdims=False)
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out
_data = xp.sum(x._data, axis=_axis, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)


def var(
Expand All @@ -238,8 +228,6 @@ def var(
) -> NamedArray[Any, _DType]:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dims, axis)
_data = xp.var(x._data, axis=_axis, correction=correction, keepdims=False)
# TODO: Why do we need to do the keepdims ourselves?
dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out
_data = xp.var(x._data, axis=_axis, correction=correction, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)

0 comments on commit c59fb54

Please sign in to comment.