From 519d2c650e395132d6cbd7988564145b0048091e Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Mon, 23 Sep 2024 13:59:39 +0200 Subject: [PATCH] cast `numpy` scalars to arrays in `as_compatible_data` (#9403) * also call `np.asarray` on numpy scalars * check that numpy scalars are properly casted to arrays * don't allow `numpy.ndarray` subclasses * comment on the purpose of the explicit isinstance and `np.asarray` --- xarray/core/variable.py | 6 ++++-- xarray/tests/test_variable.py | 7 ++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9b9239cc042..d8cf0fe7550 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -320,12 +320,14 @@ def convert_non_numpy_type(data): else: data = np.asarray(data) - if not isinstance(data, np.ndarray) and ( + # immediately return array-like types except `numpy.ndarray` subclasses and `numpy` scalars + if not isinstance(data, np.ndarray | np.generic) and ( hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") ): return cast("T_DuckArray", data) - # validate whether the data is valid data types. + # validate whether the data is valid data types. Also, explicitly cast `numpy` + # subclasses and `numpy` scalars to `numpy.ndarray` data = np.asarray(data) if data.dtype.kind in "OMm": diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 0a55b42f228..1d430b6b27e 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2585,7 +2585,12 @@ def test_unchanged_types(self): assert source_ndarray(x) is source_ndarray(as_compatible_data(x)) def test_converted_types(self): - for input_array in [[[0, 1, 2]], pd.DataFrame([[0, 1, 2]])]: + for input_array in [ + [[0, 1, 2]], + pd.DataFrame([[0, 1, 2]]), + np.float64(1.4), + np.str_("abc"), + ]: actual = as_compatible_data(input_array) assert_array_equal(np.asarray(input_array), actual) assert np.ndarray is type(actual)