Skip to content

Commit

Permalink
fix(python): Respect schema_overrides in batched csv reader (#19755)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Nov 14, 2024
1 parent 97c82d0 commit 5f11dd9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
11 changes: 5 additions & 6 deletions crates/polars-python/src/batched_csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl PyBatchedCsv {
#[staticmethod]
#[pyo3(signature = (
infer_schema_length, chunk_size, has_header, ignore_errors, n_rows, skip_rows,
projection, separator, rechunk, columns, encoding, n_threads, path, overwrite_dtype,
projection, separator, rechunk, columns, encoding, n_threads, path, schema_overrides,
overwrite_dtype_slice, low_memory, comment_prefix, quote_char, null_values,
missing_utf8_is_empty_string, try_parse_dates, skip_rows_after_header, row_index,
eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma)
Expand All @@ -42,7 +42,7 @@ impl PyBatchedCsv {
encoding: Wrap<CsvEncoding>,
n_threads: Option<usize>,
path: PathBuf,
overwrite_dtype: Option<Vec<(PyBackedStr, Wrap<DataType>)>>,
schema_overrides: Option<Vec<(PyBackedStr, Wrap<DataType>)>>,
overwrite_dtype_slice: Option<Vec<Wrap<DataType>>>,
low_memory: bool,
comment_prefix: Option<&str>,
Expand Down Expand Up @@ -73,7 +73,7 @@ impl PyBatchedCsv {
None
};

let overwrite_dtype = overwrite_dtype.map(|overwrite_dtype| {
let schema_overrides = schema_overrides.map(|overwrite_dtype| {
overwrite_dtype
.iter()
.map(|(name, dtype)| {
Expand Down Expand Up @@ -105,6 +105,7 @@ impl PyBatchedCsv {
.with_n_threads(n_threads)
.with_dtype_overwrite(overwrite_dtype_slice.map(Arc::new))
.with_low_memory(low_memory)
.with_schema_overwrite(schema_overrides.map(Arc::new))
.with_skip_rows_after_header(skip_rows_after_header)
.with_row_index(row_index)
.with_raise_if_empty(raise_if_empty)
Expand All @@ -123,9 +124,7 @@ impl PyBatchedCsv {
)
.into_reader_with_file_handle(reader);

let reader = reader
.batched(overwrite_dtype.map(Arc::new))
.map_err(PyPolarsErr::from)?;
let reader = reader.batched(None).map_err(PyPolarsErr::from)?;

Ok(PyBatchedCsv {
reader: Mutex::new(reader),
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/csv/batched_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
encoding=encoding,
n_threads=n_threads,
path=path,
overwrite_dtype=dtype_list,
schema_overrides=dtype_list,
overwrite_dtype_slice=dtype_slice,
low_memory=low_memory,
comment_prefix=comment_prefix,
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2346,3 +2346,13 @@ def test_csv_read_time_dtype_overwrite(tmp_path: Path) -> None:
),
df,
)


def test_batched_csv_schema_overrides(io_files_path: Path) -> None:
foods = io_files_path / "foods1.csv"
batched = pl.read_csv_batched(foods, schema_overrides={"calories": pl.String})
res = batched.next_batches(1)
assert res is not None
b = res[0]
assert b["calories"].dtype == pl.String
assert b.width == 4

0 comments on commit 5f11dd9

Please sign in to comment.