Skip to content

Commit

Permalink
[ENH] Improve performance for polars' pivot_longer (#1402)
Browse files Browse the repository at this point in the history
* shortcut for .value only

* fastpath if others is just a single column and a string dtype

* fix parameters for unpivot

---------

Co-authored-by: samuel.oranyeli <[email protected]>
  • Loading branch information
samukweku and samuel.oranyeli authored Sep 14, 2024
1 parent dabccdb commit 772b7dc
Showing 1 changed file with 110 additions and 35 deletions.
145 changes: 110 additions & 35 deletions janitor/polars/pivot_longer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def pivot_longer_spec(
A declarative interface to pivot a Polars Frame
from wide to long form,
where you describe how the data will be unpivoted,
using a DataFrame. This gives you, the user,
using a DataFrame.
It is modeled after tidyr's `pivot_longer_spec`.
This gives you, the user,
more control over the transformation to long form,
using a *spec* DataFrame that describes exactly
how data stored in the column names
Expand Down Expand Up @@ -108,41 +112,56 @@ def pivot_longer_spec(
corresponding to columns pivoted from the wide format.
Note that these additional columns should not already exist
in the source DataFrame.
If there are additional columns, the combination of these columns
and the `.value` column must be unique.
Raises:
KeyError: If `.name` or `.value` is missing from the spec's columns.
ValueError: If the labels in `spec['.name']` is not unique.
ValueError: If the labels in spec's `.name` column is not unique.
Returns:
A polars DataFrame/LazyFrame.
"""
check("spec", spec, [pl.DataFrame])
if ".name" not in spec.columns:
spec_columns = spec.collect_schema().names()
if ".name" not in spec_columns:
raise KeyError(
"Kindly ensure the spec DataFrame has a `.name` column."
)
if ".value" not in spec.columns:
if ".value" not in spec_columns:
raise KeyError(
"Kindly ensure the spec DataFrame has a `.value` column."
)
if spec.select(pl.col(".name").is_duplicated().any()).item():
if spec.get_column(".name").is_duplicated().any():
raise ValueError("The labels in the `.name` column should be unique.")

exclude = set(df.columns).intersection(spec.columns)
df_columns = df.collect_schema().names()
exclude = set(df_columns).intersection(spec_columns)
if exclude:
raise ValueError(
f"Labels {*exclude, } in the spec dataframe already exist "
"as column labels in the source dataframe. "
"Kindly ensure the spec DataFrame's columns "
"are not present in the source DataFrame."
)

index = [
label for label in df.columns if label not in spec.get_column(".name")
label for label in df_columns if label not in spec.get_column(".name")
]
others = [
label for label in spec.columns if label not in {".name", ".value"}
label for label in spec_columns if label not in {".name", ".value"}
]
variable_name = "".join(df.columns + spec.columns)

if (len(others) == 1) & (spec.get_column(others[0]).dtype == pl.String):
# shortcut that avoids the implode/explode approach - and is faster
# if the requirements are met
# inspired by https://github.com/pola-rs/polars/pull/18519#issue-2500860927
return _pivot_longer_dot_value_string(
df=df,
index=index,
spec=spec,
variable_name=others[0],
)
variable_name = "".join(df_columns + spec_columns)
variable_name = f"{variable_name}_"
if others:
dot_value_only = False
Expand Down Expand Up @@ -219,7 +238,7 @@ def pivot_longer(
│ 5.9 ┆ 3.0 ┆ 5.1 ┆ 1.8 ┆ virginica │
└──────────────┴─────────────┴──────────────┴─────────────┴───────────┘
Replicate polars' [melt](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.melt.html#polars-dataframe-melt):
Replicate polars' [melt](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.unpivot.html#polars-dataframe-melt):
>>> df.pivot_longer(index = 'Species').sort(by=pl.all())
shape: (8, 3)
┌───────────┬──────────────┬───────┐
Expand Down Expand Up @@ -375,8 +394,8 @@ def pivot_longer(
specification as polars' `str.split` method.
names_pattern: Determines how the column name is broken up.
It can be a regular expression containing matching groups.
It takes the same
specification as polars' `str.extract_groups` method.
It takes the same specification as
polars' `str.extract_groups` method.
names_transform: Use this option to change the types of columns that
have been transformed to rows.
This does not applies to the values' columns.
Expand Down Expand Up @@ -440,7 +459,7 @@ def _pivot_longer(
names_pattern=names_pattern,
)

variable_name = "".join(df.columns)
variable_name = "".join(df.collect_schema().names())
variable_name = f"{variable_name}_"
spec = _pivot_longer_create_spec(
column_names=column_names,
Expand All @@ -461,8 +480,25 @@ def _pivot_longer(
variable_name=variable_name,
names_transform=names_transform,
)

if {".name", ".value"}.symmetric_difference(spec.columns):
if {".name", ".value"}.symmetric_difference(spec.collect_schema().names()):
# shortcut that avoids the implode/explode approach - and is faster
# if the requirements are met
# inspired by https://github.com/pola-rs/polars/pull/18519#issue-2500860927
data = spec.get_column(variable_name)
others = data.struct.fields
data = data.struct[others[0]]
if (
(len(others) == 1)
& (data.dtype == pl.String)
& (names_transform is None)
):
spec = spec.unnest(variable_name)
return _pivot_longer_dot_value_string(
df=df,
index=index,
spec=spec,
variable_name=others[0],
)
dot_value_only = False
else:
dot_value_only = True
Expand Down Expand Up @@ -552,7 +588,7 @@ def _pivot_longer_create_spec(
return spec.select(".name", ".value")
_spec = spec.get_column(variable_name)
_spec = _spec.struct.unnest()
fields = _spec.columns
fields = _spec.collect_schema().names()

if len(set(names_to)) == 1:
expression = pl.concat_str(fields).alias(".value")
Expand Down Expand Up @@ -591,7 +627,7 @@ def _pivot_longer_no_dot_value(
# do the operation on a smaller size
# and then blow it up after
# it is usually much faster
# than running on the actual data
# than unpivoting and running the string operations after
outcome = (
df.select(pl.all().implode())
.unpivot(
Expand All @@ -606,11 +642,44 @@ def _pivot_longer_no_dot_value(
outcome = outcome.unnest(variable_name)
if names_transform is not None:
outcome = outcome.with_columns(names_transform)
columns = [name for name in outcome.columns if name not in names_to]
columns = [
name
for name in outcome.collect_schema().names()
if name not in names_to
]
outcome = outcome.explode(columns=columns)
return outcome


def _pivot_longer_dot_value_string(
df: pl.DataFrame | pl.LazyFrame,
spec: pl.DataFrame,
index: ColumnNameOrSelector,
variable_name: str,
) -> pl.DataFrame | pl.LazyFrame:
"""
fastpath for .value - does not require implode/explode approach.
"""
spec = spec.group_by(variable_name)
spec = spec.agg(pl.all())
expressions = []
for names, fields, header in zip(
spec.get_column(".name").to_list(),
spec.get_column(".value").to_list(),
spec.get_column(variable_name).to_list(),
):
expression = pl.struct(names).struct.rename_fields(names=fields)
expression = expression.alias(header)
expressions.append(expression)
expressions = [*index, *expressions]
df = (
df.select(expressions)
.unpivot(index=index, variable_name=variable_name, value_name=".value")
.unnest(".value")
)
return df


def _pivot_longer_dot_value(
df: pl.DataFrame | pl.LazyFrame,
spec: pl.DataFrame,
Expand All @@ -621,7 +690,7 @@ def _pivot_longer_dot_value(
) -> pl.DataFrame | pl.LazyFrame:
"""
flip polars Frame to long form,
if names_sep and .value in names_to.
if .value in names_to.
"""
spec = spec.group_by(variable_name)
spec = spec.agg(pl.all())
Expand All @@ -634,25 +703,31 @@ def _pivot_longer_dot_value(
expressions.append(expression)
expressions = [*index, *expressions]
spec = spec.get_column(variable_name)
if dot_value_only:
outcome = (
df.select(expressions)
.unpivot(
index=index, variable_name=variable_name, value_name=".value"
)
.select(pl.exclude(variable_name))
.unnest(".value")
)
return outcome

outcome = (
df.select(expressions)
.select(pl.all().implode())
.unpivot(index=index, variable_name=variable_name, value_name=".value")
.with_columns(spec)
)

if dot_value_only:
columns = [
label for label in outcome.columns if label != variable_name
]
outcome = outcome.explode(columns).unnest(".value")
outcome = outcome.select(pl.exclude(variable_name))
return outcome
outcome = outcome.unnest(variable_name)
if names_transform is not None:
outcome = outcome.with_columns(names_transform)
columns = [
label for label in outcome.columns if label not in spec.struct.fields
label
for label in outcome.collect_schema().names()
if label not in spec.struct.fields
]
outcome = outcome.explode(columns)
outcome = outcome.unnest(".value")
Expand Down Expand Up @@ -710,17 +785,17 @@ def _data_checks_pivot_longer(
check("values_to", values_to, [str])

if (index is None) and (column_names is None):
column_names = df.columns
column_names = df.collect_schema().names()
index = []
elif (index is None) and (column_names is not None):
column_names = df.select(column_names).columns
index = df.select(pl.exclude(column_names)).columns
column_names = df.select(column_names).collect_schema().names()
index = df.select(pl.exclude(column_names)).collect_schema().names()
elif (index is not None) and (column_names is None):
index = df.select(index).columns
column_names = df.select(pl.exclude(index)).columns
index = df.select(index).collect_schema().names()
column_names = df.select(pl.exclude(index)).collect_schema().names()
else:
index = df.select(index).columns
column_names = df.select(column_names).columns
index = df.select(index).collect_schema().names()
column_names = df.select(column_names).collect_schema().names()

return (
df,
Expand Down

0 comments on commit 772b7dc

Please sign in to comment.