Skip to content

Commit

Permalink
perf: Generalize the arg_sort fast path onto Column.
Browse files Browse the repository at this point in the history
This makes the implementation of `arg_sort` take the fast path in `Column`
before ever going into a specific SeriesTrait implementor. This now makes that
every type has the same fast path that gets taken when sorted. It also now
allows taking the fast path when `maintain_order=True`.

This PR needed to adjust some tests because they wrongly assumed things about
the output order of the sorting when `maintain_order=False`.
  • Loading branch information
coastalwhite committed Dec 24, 2024
1 parent 93ceacc commit 399a2c1
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 38 deletions.
156 changes: 153 additions & 3 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -894,8 +894,155 @@ impl Column {
}

pub fn arg_sort(&self, options: SortOptions) -> IdxCa {
if self.is_empty() {
return IdxCa::from_vec(self.name().clone(), Vec::new());
}

if self.null_count() == self.len() {
// We might need to maintain order so just respect the descending parameter.
let values = if options.descending {
(0..self.len() as IdxSize).rev().collect()
} else {
(0..self.len() as IdxSize).collect()
};

return IdxCa::from_vec(self.name().clone(), values);
}

let is_sorted = Some(self.is_sorted_flag());
let Some(is_sorted) = is_sorted.filter(|v| !matches!(v, IsSorted::Not)) else {
return self.as_materialized_series().arg_sort(options);
};

// Fast path: the data is sorted.
let is_sorted_dsc = matches!(is_sorted, IsSorted::Descending);
let invert = options.descending != is_sorted_dsc;

let mut values = Vec::with_capacity(self.len());

#[inline(never)]
fn extend(
start: IdxSize,
end: IdxSize,
slf: &Column,
values: &mut Vec<IdxSize>,
is_only_nulls: bool,
invert: bool,
maintain_order: bool,
) {
debug_assert!(start <= end);
debug_assert!(start as usize <= slf.len());
debug_assert!(end as usize <= slf.len());

if !invert || is_only_nulls {
values.extend(start..end);
return;
}

// If we don't have to maintain order but we have to invert. Just flip it around.
if !maintain_order {
values.extend((start..end).rev());
return;
}

// If we want to maintain order but we also needs to invert, we need to invert
// per group of items.
//
// @NOTE: Since the column is sorted, arg_unique can also take a fast path and
// just do a single traversal.
let arg_unique = slf
.slice(start as i64, (end - start) as usize)
.arg_unique()
.unwrap();

assert!(!arg_unique.has_nulls());

let num_unique = arg_unique.len();

// Fast path: all items are unique.
if num_unique == (end - start) as usize {
values.extend((start..end).rev());
return;
}

if num_unique == 1 {
values.extend(start..end);
return;
}

let mut prev_idx = end - start;
for chunk in arg_unique.downcast_iter() {
for &idx in chunk.values().as_slice().iter().rev() {
values.extend(start + idx..start + prev_idx);
prev_idx = idx;
}
}
}
macro_rules! extend {
($start:expr, $end:expr) => {
extend!($start, $end, is_only_nulls = false);
};
($start:expr, $end:expr, is_only_nulls = $is_only_nulls:expr) => {
extend(
$start,
$end,
self,
&mut values,
$is_only_nulls,
invert,
options.maintain_order,
);
};
}

let length = self.len() as IdxSize;
let null_count = self.null_count() as IdxSize;

if null_count == 0 {
extend!(0, length);
} else {
let has_nulls_last = self.get(self.len() - 1).unwrap().is_null();
match (options.nulls_last, has_nulls_last) {
(true, true) => {
// Current: Nulls last, Wanted: Nulls last
extend!(0, length - null_count);
extend!(length - null_count, length, is_only_nulls = true);
},
(true, false) => {
// Current: Nulls first, Wanted: Nulls last
extend!(null_count, length);
extend!(0, null_count, is_only_nulls = true);
},
(false, true) => {
// Current: Nulls last, Wanted: Nulls first
extend!(length - null_count, length, is_only_nulls = true);
extend!(0, length - null_count);
},
(false, false) => {
// Current: Nulls first, Wanted: Nulls first
extend!(0, null_count, is_only_nulls = true);
extend!(null_count, length);
},
}
}

IdxCa::from_vec(self.name().clone(), values)
}

pub fn arg_sort_multiple(
&self,
by: &[Column],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
// @scalar-opt
self.as_materialized_series().arg_sort(options)
self.as_materialized_series().arg_sort_multiple(by, options)
}

pub fn arg_unique(&self) -> PolarsResult<IdxCa> {
match self {
Column::Scalar(s) => Ok(IdxCa::new_vec(s.name().clone(), vec![0])),
_ => self.as_materialized_series().arg_unique(),
}
}

pub fn bit_repr(&self) -> Option<BitRepr> {
Expand Down Expand Up @@ -986,8 +1133,11 @@ impl Column {
}

pub fn is_sorted_flag(&self) -> IsSorted {
// @scalar-opt
self.as_materialized_series().is_sorted_flag()
match self {
Column::Series(s) => s.is_sorted_flag(),
Column::Partitioned(s) => s.partitions().is_sorted_flag(),
Column::Scalar(_) => IsSorted::Ascending,
}
}

pub fn unique(&self) -> PolarsResult<Column> {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_interpolate_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_interpolate_by_leading_nulls() -> None:
result = (
df.sort("times", descending=True)
.with_columns(pl.col("values").interpolate_by("times"))
.sort("times")
.sort("times", maintain_order=True)
.drop("times")
)
assert_frame_equal(result, expected)
Expand Down
74 changes: 40 additions & 34 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_sort_by(
assert out["a"].to_list() == expected[2]

# by can also be a single column
out = df.select(pl.col("a").sort_by("b", descending=[False]))
out = df.select(pl.col("a").sort_by("b", descending=[False], maintain_order=True))
assert out["a"].to_list() == expected[3]


Expand All @@ -143,17 +143,21 @@ def test_expr_sort_by_nulls_last(
df = sort_function(df)

# nulls last
expected = pl.DataFrame({"a": [1, 2, 5, None, None], "b": [None, 1, None, 1, 2]})
out = df.select(pl.all().sort_by("a", nulls_last=True))
assert_frame_equal(out, expected)
assert out["a"].to_list() == [1, 2, 5, None, None]
# We don't maintain order so there are two possibilities
assert out["b"].to_list()[:3] == [None, 1, None]
assert out["b"].to_list()[3:] in [[1, 2], [2, 1]]

# nulls first (default)
expected = pl.DataFrame({"a": [None, None, 1, 2, 5], "b": [1, 2, None, 1, None]})
for out in (
df.select(pl.all().sort_by("a", nulls_last=False)),
df.select(pl.all().sort_by("a")),
):
assert_frame_equal(out, expected)
assert out["a"].to_list() == [None, None, 1, 2, 5]
# We don't maintain order so there are two possibilities
assert out["b"].to_list()[2:] == [None, 1, None]
assert out["b"].to_list()[:2] in [[1, 2], [2, 1]]


def test_expr_sort_by_multi_nulls_last() -> None:
Expand Down Expand Up @@ -423,21 +427,23 @@ def test_sorted_fast_paths() -> None:
(
pl.DataFrame({"Idx": [0, 1, 2, 3, 4, 5, 6], "Val": [0, 1, 2, 3, 4, 5, 6]}),
(
[0, 1, 2, 3, 4, 5, 6],
[6, 5, 4, 3, 2, 1, 0],
[0, 1, 2, 3, 4, 5, 6],
[6, 5, 4, 3, 2, 1, 0],
[[0, 1, 2, 3, 4, 5, 6]],
[[6, 5, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 5, 6]],
[[6, 5, 4, 3, 2, 1, 0]],
),
),
(
pl.DataFrame(
{"Idx": [0, 1, 2, 3, 4, 5, 6], "Val": [0, 1, None, 3, None, 5, 6]}
),
# We don't use maintain order here, so it might as well do anything
# with the None elements.
(
[0, 1, 3, 5, 6, 2, 4],
[6, 5, 3, 1, 0, 2, 4],
[2, 4, 0, 1, 3, 5, 6],
[2, 4, 6, 5, 3, 1, 0],
[[0, 1, 3, 5, 6, 2, 4], [0, 1, 3, 5, 6, 4, 2]],
[[6, 5, 3, 1, 0, 2, 4], [6, 5, 3, 1, 0, 4, 2]],
[[2, 4, 0, 1, 3, 5, 6], [4, 2, 0, 1, 3, 5, 6]],
[[2, 4, 6, 5, 3, 1, 0], [4, 2, 6, 5, 3, 1, 0]],
),
),
],
Expand All @@ -454,7 +460,7 @@ def test_sorted_fast_paths() -> None:
def test_sorted_arg_sort_fast_paths(
sort_function: Callable[[pl.DataFrame], pl.DataFrame],
df: pl.DataFrame,
expected: tuple[list[int], list[int], list[int], list[int]],
expected: tuple[list[list[int]], list[list[int]], list[list[int]], list[list[int]]],
) -> None:
# Test that an already sorted df is correctly sorted (by a single column)
# In certain cases below we will not go through fast path; this test
Expand All @@ -467,34 +473,34 @@ def test_sorted_arg_sort_fast_paths(
# Test dataframe.sort
assert (
df.sort("Val", descending=False, nulls_last=True)["Idx"].to_list()
== expected[0]
in expected[0]
)
assert (
df.sort("Val", descending=True, nulls_last=True)["Idx"].to_list() == expected[1]
df.sort("Val", descending=True, nulls_last=True)["Idx"].to_list() in expected[1]
)
assert (
df.sort("Val", descending=False, nulls_last=False)["Idx"].to_list()
== expected[2]
in expected[2]
)
assert (
df.sort("Val", descending=True, nulls_last=False)["Idx"].to_list()
== expected[3]
in expected[3]
)
# Test series.arg_sort
assert (
df["Idx"][s.arg_sort(descending=False, nulls_last=True)].to_list()
== expected[0]
in expected[0]
)
assert (
df["Idx"][s.arg_sort(descending=True, nulls_last=True)].to_list() == expected[1]
df["Idx"][s.arg_sort(descending=True, nulls_last=True)].to_list() in expected[1]
)
assert (
df["Idx"][s.arg_sort(descending=False, nulls_last=False)].to_list()
== expected[2]
in expected[2]
)
assert (
df["Idx"][s.arg_sort(descending=True, nulls_last=False)].to_list()
== expected[3]
in expected[3]
)


Expand Down Expand Up @@ -896,30 +902,30 @@ def test_sort_with_null_12139(
}
)
df = sort_function(df)
assert df.sort("bool", descending=False, nulls_last=False).to_dict(
as_series=False
) == {
assert df.sort(
"bool", descending=False, nulls_last=False, maintain_order=True
).to_dict(as_series=False) == {
"bool": [None, False, False, True, True],
"float": [3.0, 2.0, 5.0, 1.0, 4.0],
}

assert df.sort("bool", descending=False, nulls_last=True).to_dict(
as_series=False
) == {
assert df.sort(
"bool", descending=False, nulls_last=True, maintain_order=True
).to_dict(as_series=False) == {
"bool": [False, False, True, True, None],
"float": [2.0, 5.0, 1.0, 4.0, 3.0],
}

assert df.sort("bool", descending=True, nulls_last=True).to_dict(
as_series=False
) == {
assert df.sort(
"bool", descending=True, nulls_last=True, maintain_order=True
).to_dict(as_series=False) == {
"bool": [True, True, False, False, None],
"float": [1.0, 4.0, 2.0, 5.0, 3.0],
}

assert df.sort("bool", descending=True, nulls_last=False).to_dict(
as_series=False
) == {
assert df.sort(
"bool", descending=True, nulls_last=False, maintain_order=True
).to_dict(as_series=False) == {
"bool": [None, True, True, False, False],
"float": [3.0, 1.0, 4.0, 2.0, 5.0],
}
Expand Down

0 comments on commit 399a2c1

Please sign in to comment.