Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust): Refactor compute kernels in polars-arrow to avoid using gather #19669

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions crates/polars-arrow/src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ pub use binview_to::utf8view_to_utf8;
pub use boolean_to::*;
pub use decimal_to::*;
use dictionary_to::*;
use growable::make_growable;
use polars_error::{polars_bail, polars_ensure, polars_err, PolarsResult};
use polars_utils::IdxSize;
pub use primitive_to::*;
pub use utf8_to::*;

Expand Down Expand Up @@ -170,9 +170,13 @@ fn cast_fixed_size_list_to_list<O: Offset>(
fn cast_list_to_fixed_size_list<O: Offset>(
list: &ListArray<O>,
inner: &Field,
size: usize,
size: usize, // width
options: CastOptionsImpl,
) -> PolarsResult<FixedSizeListArray> {
) -> PolarsResult<FixedSizeListArray>
where
ListArray<O>: crate::array::StaticArray
+ ArrayFromIter<std::option::Option<Box<dyn crate::array::Array>>>,
{
let null_cnt = list.null_count();
let new_values = if null_cnt == 0 {
let start_offset = list.offsets().first().to_usize();
Expand All @@ -190,7 +194,8 @@ fn cast_list_to_fixed_size_list<O: Offset>(
.sliced(start_offset, list.offsets().range().to_usize());
cast(sliced_values.as_ref(), inner.dtype(), options)?
} else {
let offsets = list.offsets().as_slice();
let offsets = list.offsets();

// Check the lengths of each list are equal to the fixed size.
// SAFETY: we know the index is in bound.
let mut expected_offset = unsafe { *offsets.get_unchecked(0) } + O::from_as_usize(size);
Expand All @@ -206,27 +211,28 @@ fn cast_list_to_fixed_size_list<O: Offset>(
}
}

// Build take indices for the values. This is used to fill in the null slots.
let mut indices =
MutablePrimitiveArray::<IdxSize>::with_capacity(list.values().len() + null_cnt * size);
for i in 0..list.len() {
if list.is_null(i) {
indices.extend_constant(size, None)
} else {
// SAFETY: we know the index is in bound.
let current_offset = unsafe { *offsets.get_unchecked(i) };
for j in 0..size {
indices.push(Some(
(current_offset + O::from_as_usize(j)).to_usize() as IdxSize
));
let list_validity = list.validity().unwrap();
let mut growable = make_growable(&[list.values().as_ref()], true, list.len());

for (outer_idx, x) in offsets.windows(2).enumerate() {
let [i, j] = x else { unreachable!() };
let i = i.to_usize();
let j = j.to_usize();

unsafe {
let outer_is_valid = list_validity.get_bit_unchecked(outer_idx);

if outer_is_valid {
growable.extend(0, i, j - i);
} else {
growable.extend_validity(size)
}
}
};
}
let take_values = unsafe {
crate::compute::take::take_unchecked(list.values().as_ref(), &indices.freeze())
Copy link
Collaborator Author

@nameexhaustion nameexhaustion Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting nullable List -> FixedSizeList, used a gather to ensure the width of the null slots - have updated this to use Growable instead.

};

cast(take_values.as_ref(), inner.dtype(), options)?
let values = growable.as_box();

cast(values.as_ref(), inner.dtype(), options)?
};

FixedSizeListArray::try_new(
Expand Down
193 changes: 141 additions & 52 deletions crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,162 @@
use polars_error::{polars_bail, PolarsResult};
use polars_utils::index::NullCount;
use polars_utils::IdxSize;

use crate::array::{ArrayRef, FixedSizeListArray, PrimitiveArray};
use crate::compute::take::take_unchecked;
use crate::legacy::prelude::*;
use crate::legacy::utils::CustomIterTools;

fn sub_fixed_size_list_get_indexes_literal(width: usize, len: usize, index: i64) -> IdxArr {
(0..len)
.map(|i| {
if index >= width as i64 {
return None;
}

index
.negative_to_usize(width)
.map(|idx| (idx + i * width) as IdxSize)
})
.collect_trusted()
}
use polars_error::{polars_bail, PolarsError, PolarsResult};

fn sub_fixed_size_list_get_indexes(width: usize, index: &PrimitiveArray<i64>) -> IdxArr {
index
.iter()
.enumerate()
.map(|(i, idx)| {
if let Some(idx) = idx {
if *idx >= width as i64 {
return None;
}

idx.negative_to_usize(width)
.map(|idx| (idx + i * width) as IdxSize)
} else {
None
}
})
.collect_trusted()
}
use crate::array::growable::make_growable;
use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray};
use crate::bitmap::BitmapBuilder;
use crate::compute::utils::combine_validities_and3;
use crate::datatypes::ArrowDataType;

pub fn sub_fixed_size_list_get_literal(
arr: &FixedSizeListArray,
index: i64,
null_on_oob: bool,
) -> PolarsResult<ArrayRef> {
let take_by = sub_fixed_size_list_get_indexes_literal(arr.size(), arr.len(), index);
if !null_on_oob && take_by.null_count() > 0 {
polars_bail!(ComputeError: "get index is out of bounds");
let ArrowDataType::FixedSizeList(_, width) = arr.dtype() else {
unreachable!();
};

let width = *width;

let orig_index = index;

let index = if index < 0 {
if index.unsigned_abs() as usize > width {
width
} else {
(width as i64 + index) as usize
}
} else {
usize::try_from(index).unwrap()
};

let index_is_oob = index >= width;

if !null_on_oob && index >= width {
polars_bail!(
ComputeError:
"get index {} is out of bounds for array(width={})",
orig_index,
width
);
}

let values = arr.values();
// SAFETY:
// the indices we generate are in bounds
unsafe { Ok(take_unchecked(&**values, &take_by)) }

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list.get() / array.get() were building selection indices and then calling gather with them - I've re-written them to use loops instead.

let mut growable = make_growable(
&[values.as_ref()],
values.validity().is_some() | arr.validity().is_some() | index_is_oob,
arr.len(),
);

if index_is_oob {
unsafe { growable.extend_validity(arr.len()) }
let out = growable.as_box();
return Ok(out);
}

if let Some(arr_validity) = arr.validity() {
for i in 0..arr.len() {
unsafe {
if arr_validity.get_bit_unchecked(i) {
growable.extend(0, i * width + index, 1)
} else {
growable.extend_validity(1)
}
}
}
} else {
for i in 0..arr.len() {
unsafe { growable.extend(0, i * width + index, 1) }
}
}

Ok(growable.as_box())
}

pub fn sub_fixed_size_list_get(
arr: &FixedSizeListArray,
index: &PrimitiveArray<i64>,
null_on_oob: bool,
) -> PolarsResult<ArrayRef> {
let take_by = sub_fixed_size_list_get_indexes(arr.size(), index);
if !null_on_oob && take_by.null_count() > 0 {
polars_bail!(ComputeError: "get index is out of bounds");
assert_eq!(arr.len(), index.len());

fn idx_oob_err(index: i64, width: usize) -> PolarsError {
PolarsError::ComputeError(
format!(
"get index {} is out of bounds for array(width={})",
index, width
)
.into(),
)
}

let ArrowDataType::FixedSizeList(_, width) = arr.dtype() else {
unreachable!();
};

let width = *width;

if arr.is_empty() {
let values = arr.values();
assert!(values.is_empty());
return Ok(values.clone());
}

if !null_on_oob && width == 0 {
if let Some(i) = index.non_null_values_iter().next() {
return Err(idx_oob_err(i, width));
}
}

// Array is non-empty and has non-zero width at this point
let values = arr.values();
// SAFETY:
// the indices we generate are in bounds
unsafe { Ok(take_unchecked(&**values, &take_by)) }

let mut growable = make_growable(&[values.as_ref()], values.validity().is_some(), arr.len());
let mut idx_oob_validity = BitmapBuilder::with_capacity(arr.len());
let opt_index_validity = index.validity();
let mut exceeded_width_idx = 0;
let mut current_index_i64 = 0;

for i in 0..arr.len() {
let index = index.value(i);
current_index_i64 = index;

let idx = if index < 0 {
if index.unsigned_abs() as usize > width {
width
} else {
(width as i64 + index) as usize
}
} else {
usize::try_from(index).unwrap()
};

let idx_is_oob = idx >= width;
let idx_is_valid = opt_index_validity.map_or(true, |x| unsafe { x.get_bit_unchecked(i) });

if idx_is_oob && idx_is_valid && exceeded_width_idx < width {
exceeded_width_idx = idx;
}

let idx = if idx_is_oob { 0 } else { idx };

unsafe {
growable.extend(0, i * width + idx, 1);
let output_is_valid = idx_is_valid & !idx_is_oob;
idx_oob_validity.push_unchecked(output_is_valid);
}
}

if !null_on_oob && exceeded_width_idx >= width {
return Err(idx_oob_err(current_index_i64, width));
}

let output = growable.as_box();
let output_validity = combine_validities_and3(
output.validity(), // inner validity
Some(&idx_oob_validity.freeze()), // validity for OOB idx
arr.validity(), // outer validity
);

Ok(output.with_validity(output_validity))
}
Loading
Loading