Skip to content

Commit

Permalink
Dispatch to Dask if nanquantile is available (#9719)
Browse files Browse the repository at this point in the history
* Dispatch to Dask is nanquantile is available

* Fixup

* Change test
  • Loading branch information
phofl authored Nov 9, 2024
1 parent 2619c0b commit 6df8bd6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
3 changes: 2 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
from xarray.namedarray.utils import module_available
from xarray.util.deprecation_helpers import deprecate_dims

NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
Expand Down Expand Up @@ -1948,7 +1949,7 @@ def _wrapper(npa, **kwargs):
output_core_dims=[["quantile"]],
output_dtypes=[np.float64],
dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
dask="parallelized",
dask="allowed" if module_available("dask", "2024.11.0") else "parallelized",
kwargs=kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _importorskip(
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
has_cftime, requires_cftime = _importorskip("cftime")
has_dask, requires_dask = _importorskip("dask")
has_dask_ge_2024_11_0, requires_dask_ge_2024_11_0 = _importorskip("dask", "2024.11.0")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
14 changes: 11 additions & 3 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
assert_equal,
assert_identical,
assert_no_warnings,
has_dask_ge_2024_11_0,
has_pandas_3,
raise_if_dask_computes,
requires_bottleneck,
Expand Down Expand Up @@ -1871,9 +1872,16 @@ def test_quantile_interpolation_deprecation(self, method) -> None:
def test_quantile_chunked_dim_error(self):
v = Variable(["x", "y"], self.d).chunk({"x": 2})

# this checks for ValueError in dask.array.apply_gufunc
with pytest.raises(ValueError, match=r"consists of multiple chunks"):
v.quantile(0.5, dim="x")
if has_dask_ge_2024_11_0:
# Dask rechunks
np.testing.assert_allclose(
v.compute().quantile(0.5, dim="x"), v.quantile(0.5, dim="x")
)

else:
# this checks for ValueError in dask.array.apply_gufunc
with pytest.raises(ValueError, match=r"consists of multiple chunks"):
v.quantile(0.5, dim="x")

@pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True)
@pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]])
Expand Down

0 comments on commit 6df8bd6

Please sign in to comment.