Skip to content

Commit

Permalink
(fix): fix extension array + dataarray roundtrip
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Sep 19, 2024
1 parent 7750c00 commit d18a74c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
9 changes: 9 additions & 0 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,12 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
roundtripped.columns.name = "cols" # why?
pd.testing.assert_frame_equal(df, roundtripped)
xr.testing.assert_identical(dataset, roundtripped.to_xarray())


def test_roundtrip_1d_pandas_extension_array() -> None:
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
arr = xr.Dataset.from_dataframe(df)["cat"]
roundtripped = arr.to_pandas()
assert (df["cat"] == roundtripped).all()
assert df["cat"].dtype == roundtripped.dtype
xr.testing.assert_identical(arr, roundtripped.to_xarray())
5 changes: 4 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3866,7 +3866,10 @@ def to_pandas(self) -> Self | pd.Series | pd.DataFrame:
"pandas objects. Requires 2 or fewer dimensions."
) from err
indexes = [self.get_index(dim) for dim in self.dims]
return constructor(self.values, *indexes) # type: ignore[operator]
pandas_object = constructor(self.values, *indexes)
if isinstance(pandas_object, pd.Series):
pandas_object.name = self.name
return pandas_object

def to_dataframe(
self, name: Hashable | None = None, dim_order: Sequence[Hashable] | None = None
Expand Down
11 changes: 9 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,13 @@ def as_compatible_data(
if isinstance(data, DataArray):
return cast("T_DuckArray", data._variable._data)

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
def convert_non_numpy_type(data):
data = _possibly_convert_datetime_or_timedelta_index(data)
return cast("T_DuckArray", _maybe_wrap_data(data))

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
return convert_non_numpy_type(data)

if isinstance(data, tuple):
data = utils.to_0d_object_array(data)

Expand All @@ -303,7 +306,9 @@ def as_compatible_data(

# we don't want nested self-described arrays
if isinstance(data, pd.Series | pd.DataFrame):
data = data.values # type: ignore[assignment]
data = data.values
if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
return convert_non_numpy_type(data)

if isinstance(data, np.ma.MaskedArray):
mask = np.ma.getmaskarray(data)
Expand Down Expand Up @@ -542,6 +547,8 @@ def _dask_finalize(self, results, array_func, *args, **kwargs):
@property
def values(self):
"""The variable's data as a numpy.ndarray"""
if isinstance(self._data, PandasExtensionArray):
return self._data.array
return _as_array_or_item(self._data)

@values.setter
Expand Down

0 comments on commit d18a74c

Please sign in to comment.