Skip to content

Commit

Permalink
[SPARK-36771][PYTHON] Fix pop of Categorical Series
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Fix `pop` of Categorical Series to be consistent with the latest pandas (1.3.2) behavior.

### Why are the changes needed?
As databricks/koalas#2198, pandas API on Spark behaves differently from pandas on `pop` of Categorical Series.

### Does this PR introduce _any_ user-facing change?
Yes, results of `pop` of Categorical Series change.

#### From
```py
>>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
>>> psser
0    a
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(0)
0
>>> psser
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(3)
0
>>> psser
1    b
2    c
dtype: category
Categories (3, object): ['a', 'b', 'c']
```

#### To
```py
>>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
>>> psser
0    a
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(0)
'a'
>>> psser
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(3)
'a'
>>> psser
1    b
2    c
dtype: category
Categories (3, object): ['a', 'b', 'c']

```

### How was this patch tested?
Unit tests.

Closes #34052 from xinrong-databricks/cat_pop.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
xinrong-meng authored and ueshin committed Sep 21, 2021
1 parent b7d99e3 commit 079a9c5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import pandas as pd
from pandas.core.accessor import CachedAccessor
from pandas.io.formats.printing import pprint_thing
from pandas.api.types import is_list_like, is_hashable
from pandas.api.types import is_list_like, is_hashable, CategoricalDtype
from pandas.tseries.frequencies import DateOffset
from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame
from pyspark.sql.types import (
Expand Down Expand Up @@ -4098,7 +4098,11 @@ def pop(self, item: Name) -> Union["Series", Scalar]:
pdf = sdf.limit(2).toPandas()
length = len(pdf)
if length == 1:
return pdf[internal.data_spark_column_names[0]].iloc[0]
val = pdf[internal.data_spark_column_names[0]].iloc[0]
if isinstance(self.dtype, CategoricalDtype):
return self.dtype.categories[val]
else:
return val

item_string = name_like_string(item)
sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, SF.lit(str(item_string)))
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/pandas/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,31 @@ def test_pop(self):
with self.assertRaisesRegex(KeyError, msg):
psser.pop(("lama", "speed", "x"))

pser = pd.Series(["a", "b", "c", "a"], dtype="category")
psser = ps.from_pandas(pser)

if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"):
self.assert_eq(psser.pop(0), pser.pop(0))
self.assert_eq(psser, pser)

self.assert_eq(psser.pop(3), pser.pop(3))
self.assert_eq(psser, pser)
else:
# Before pandas 1.3.0, `pop` modifies the dtype of categorical series wrongly.
self.assert_eq(psser.pop(0), "a")
self.assert_eq(
psser,
pd.Series(
pd.Categorical(["b", "c", "a"], categories=["a", "b", "c"]), index=[1, 2, 3]
),
)

self.assert_eq(psser.pop(3), "a")
self.assert_eq(
psser,
pd.Series(pd.Categorical(["b", "c"], categories=["a", "b", "c"]), index=[1, 2]),
)

def test_replace(self):
pser = pd.Series([10, 20, 15, 30, np.nan], name="x")
psser = ps.Series(pser)
Expand Down

0 comments on commit 079a9c5

Please sign in to comment.