Skip to content

Commit

Permalink
feat: Always resolve dynamic types in schema (#20406)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 23, 2024
1 parent 62ebbe5 commit f242871
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 121 deletions.
31 changes: 20 additions & 11 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,30 +253,39 @@ impl DataType {

/// Materialize this datatype if it is unknown. All other datatypes
/// are left unchanged.
pub fn materialize_unknown(&self) -> PolarsResult<DataType> {
pub fn materialize_unknown(self, allow_unknown: bool) -> PolarsResult<DataType> {
match self {
DataType::Unknown(u) => u
.materialize()
.ok_or_else(|| polars_err!(SchemaMismatch: "failed to materialize unknown type")),
DataType::List(inner) => Ok(DataType::List(Box::new(inner.materialize_unknown()?))),
DataType::Unknown(u) => match u.materialize() {
Some(known) => Ok(known),
None => {
if allow_unknown {
Ok(DataType::Unknown(u))
} else {
polars_bail!(SchemaMismatch: "failed to materialize unknown type")
}
},
},
DataType::List(inner) => Ok(DataType::List(Box::new(
inner.materialize_unknown(allow_unknown)?,
))),
#[cfg(feature = "dtype-array")]
DataType::Array(inner, size) => Ok(DataType::Array(
Box::new(inner.materialize_unknown()?),
*size,
Box::new(inner.materialize_unknown(allow_unknown)?),
size,
)),
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => Ok(DataType::Struct(
fields
.iter()
.into_iter()
.map(|f| {
PolarsResult::Ok(Field::new(
f.name().clone(),
f.dtype().materialize_unknown()?,
f.name,
f.dtype.materialize_unknown(allow_unknown)?,
))
})
.try_collect_vec()?,
)),
_ => Ok(self.clone()),
_ => Ok(self),
}
}

Expand Down
9 changes: 0 additions & 9 deletions crates/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ pub trait SchemaExt {
fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_;

fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool>;

fn materialize_unknown_dtypes(&self) -> PolarsResult<Schema>;
}

impl SchemaExt for Schema {
Expand Down Expand Up @@ -90,13 +88,6 @@ impl SchemaExt for Schema {
}
Ok(changed)
}

/// Materialize all unknown dtypes in this schema.
fn materialize_unknown_dtypes(&self) -> PolarsResult<Schema> {
self.iter()
.map(|(name, dtype)| Ok((name.clone(), dtype.materialize_unknown()?)))
.collect()
}
}

pub trait SchemaNamesAndDtypes {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-expr/src/reduce/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn into_reduction(
expr_arena
.get(node)
.to_dtype(schema, Context::Default, expr_arena)?
.materialize_unknown()
.materialize_unknown(false)
};
let out = match expr_arena.get(node) {
AExpr::Agg(agg) => match agg {
Expand Down
13 changes: 12 additions & 1 deletion crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,17 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
convert_utils::convert_st_union(&mut inputs, ctxt.lp_arena, ctxt.expr_arena)
.map_err(|e| e.context(failed_here!(vertical concat)))?;
}

let first = *inputs.first().ok_or_else(
|| polars_err!(InvalidOperation: "expected at least one input in 'union'/'concat'"),
)?;
let schema = ctxt.lp_arena.get(first).schema(ctxt.lp_arena);
for n in &inputs[1..] {
let schema_i = ctxt.lp_arena.get(*n).schema(ctxt.lp_arena);
polars_ensure!(schema == schema_i, InvalidOperation: "'union'/'concat' inputs should all have the same schema,\
got\n{:?} and \n{:?}", schema, schema_i)
}

let options = args.into();
IR::Union { inputs, options }
},
Expand Down Expand Up @@ -976,7 +987,7 @@ fn resolve_with_columns(
);
polars_bail!(ComputeError: msg)
}
new_schema.with_column(field.name().clone(), field.dtype().clone());
new_schema.with_column(field.name, field.dtype.materialize_unknown(true)?);
arena.clear();
}

Expand Down
8 changes: 7 additions & 1 deletion crates/polars-plan/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,12 @@ pub fn expressions_to_schema(
) -> PolarsResult<Schema> {
let mut expr_arena = Arena::with_capacity(4 * expr.len());
expr.iter()
.map(|expr| expr.to_field_amortized(schema, ctxt, &mut expr_arena))
.map(|expr| {
let mut field = expr.to_field_amortized(schema, ctxt, &mut expr_arena)?;

field.dtype = field.dtype.materialize_unknown(true)?;
Ok(field)
})
.collect()
}

Expand Down Expand Up @@ -336,6 +341,7 @@ pub(crate) fn expr_irs_to_schema<I: IntoIterator<Item = K>, K: AsRef<ExprIR>>(
if let Some(name) = e.get_alias() {
field.name = name.clone()
}
field.dtype = field.dtype.materialize_unknown(true).unwrap();
field
})
.collect()
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,11 @@ pub fn compute_output_schema(
.iter()
.map(|e| {
let name = e.output_name().clone();
let dtype = e.dtype(input_schema, Context::Default, expr_arena)?.clone();
let dtype = e
.dtype(input_schema, Context::Default, expr_arena)?
.clone()
.materialize_unknown(true)
.unwrap();
PolarsResult::Ok(Field::new(name, dtype))
})
.try_collect()?;
Expand Down
17 changes: 7 additions & 10 deletions crates/polars-stream/src/physical_plan/to_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use parking_lot::Mutex;
use polars_core::prelude::PlRandomState;
use polars_core::schema::{Schema, SchemaExt};
use polars_core::schema::Schema;
use polars_error::PolarsResult;
use polars_expr::groups::new_hash_grouper;
use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState};
Expand Down Expand Up @@ -416,9 +416,8 @@ fn to_graph_rec<'a>(
let input_key = to_graph_rec(*input, ctx)?;

let input_schema = &ctx.phys_sm[*input].output_schema;
let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)?
.materialize_unknown_dtypes()?;
let grouper = new_hash_grouper(Arc::new(key_schema));
let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)?;
let grouper = new_hash_grouper(key_schema);

let key_selectors = key
.iter()
Expand Down Expand Up @@ -521,11 +520,9 @@ fn to_graph_rec<'a>(
let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone();

let left_key_schema =
compute_output_schema(&left_input_schema, left_on, ctx.expr_arena)?
.materialize_unknown_dtypes()?;
compute_output_schema(&left_input_schema, left_on, ctx.expr_arena)?;
let right_key_schema =
compute_output_schema(&right_input_schema, right_on, ctx.expr_arena)?
.materialize_unknown_dtypes()?;
compute_output_schema(&right_input_schema, right_on, ctx.expr_arena)?;

let left_key_selectors = left_on
.iter()
Expand All @@ -540,8 +537,8 @@ fn to_graph_rec<'a>(
nodes::joins::equi_join::EquiJoinNode::new(
left_input_schema,
right_input_schema,
Arc::new(left_key_schema),
Arc::new(right_key_schema),
left_key_schema,
right_key_schema,
left_key_selectors,
right_key_selectors,
args,
Expand Down
34 changes: 0 additions & 34 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,34 +706,6 @@ def test_multiple_columns_drop() -> None:
assert out.columns == ["a"]


def test_concat() -> None:
df1 = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
df2 = pl.concat([df1, df1], rechunk=True)

assert df2.shape == (6, 3)
assert df2.n_chunks() == 1
assert df2.rows() == df1.rows() + df1.rows()
assert pl.concat([df1, df1], rechunk=False).n_chunks() == 2

# concat from generator of frames
df3 = pl.concat(items=(df1 for _ in range(2)))
assert_frame_equal(df2, df3)

# check that df4 is not modified following concat of itself
df4 = pl.from_records(((1, 2), (1, 2)))
_ = pl.concat([df4, df4, df4])

assert df4.shape == (2, 2)
assert df4.rows() == [(1, 1), (2, 2)]

# misc error conditions
with pytest.raises(ValueError):
_ = pl.concat([])

with pytest.raises(ValueError):
pl.concat([df1, df1], how="rubbish") # type: ignore[arg-type]


def test_arg_where() -> None:
s = pl.Series([True, False, True, False])
assert_series_equal(
Expand Down Expand Up @@ -2262,12 +2234,6 @@ def test_list_of_list_of_struct() -> None:
assert df.to_dicts() == [] # type: ignore[union-attr]


def test_concat_to_empty() -> None:
assert pl.concat([pl.DataFrame([]), pl.DataFrame({"a": [1]})]).to_dict(
as_series=False
) == {"a": [1]}


def test_fill_null_limits() -> None:
assert pl.DataFrame(
{
Expand Down
45 changes: 1 addition & 44 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime, time, timezone
from decimal import Decimal
from itertools import chain
from typing import IO, TYPE_CHECKING, Any, Callable, Literal, cast
from typing import TYPE_CHECKING, Any, Callable, Literal, cast

import fsspec
import numpy as np
Expand Down Expand Up @@ -1896,49 +1896,6 @@ def test_row_index_projection_pushdown_18463(
)


def test_concat_multiple_inmem() -> None:
f = io.BytesIO()
g = io.BytesIO()

df1 = pl.DataFrame(
{
"a": [1, 2, 3],
"b": ["xyz", "abc", "wow"],
}
)
df2 = pl.DataFrame(
{
"a": [5, 6, 7],
"b": ["a", "few", "entries"],
}
)

dfs = pl.concat([df1, df2])

df1.write_parquet(f)
df2.write_parquet(g)

f.seek(0)
g.seek(0)

items: list[IO[bytes]] = [f, g]
assert_frame_equal(pl.read_parquet(items), dfs)

f.seek(0)
g.seek(0)

assert_frame_equal(pl.read_parquet(items, use_pyarrow=True), dfs)

f.seek(0)
g.seek(0)

fb = f.read()
gb = g.read()

assert_frame_equal(pl.read_parquet([fb, gb]), dfs)
assert_frame_equal(pl.read_parquet([fb, gb], use_pyarrow=True), dfs)


@pytest.mark.write_disk
def test_write_binary_open_file(tmp_path: Path) -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
Expand Down
99 changes: 99 additions & 0 deletions py-polars/tests/unit/operations/test_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import io
from typing import IO

import pytest

import polars as pl
from polars.testing import assert_frame_equal


def test_concat_invalid_schema_err_20355() -> None:
lf1 = pl.LazyFrame({"x": [1], "y": [None]})
lf2 = pl.LazyFrame({"y": [1]})
with pytest.raises(pl.exceptions.InvalidOperationError):
pl.concat([lf1, lf2]).collect(streaming=True)


def test_concat_df() -> None:
df1 = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
df2 = pl.concat([df1, df1], rechunk=True)

assert df2.shape == (6, 3)
assert df2.n_chunks() == 1
assert df2.rows() == df1.rows() + df1.rows()
assert pl.concat([df1, df1], rechunk=False).n_chunks() == 2

# concat from generator of frames
df3 = pl.concat(items=(df1 for _ in range(2)))
assert_frame_equal(df2, df3)

# check that df4 is not modified following concat of itself
df4 = pl.from_records(((1, 2), (1, 2)))
_ = pl.concat([df4, df4, df4])

assert df4.shape == (2, 2)
assert df4.rows() == [(1, 1), (2, 2)]

# misc error conditions
with pytest.raises(ValueError):
_ = pl.concat([])

with pytest.raises(ValueError):
pl.concat([df1, df1], how="rubbish") # type: ignore[arg-type]


def test_concat_to_empty() -> None:
assert pl.concat([pl.DataFrame([]), pl.DataFrame({"a": [1]})]).to_dict(
as_series=False
) == {"a": [1]}


def test_concat_multiple_parquet_inmem() -> None:
f = io.BytesIO()
g = io.BytesIO()

df1 = pl.DataFrame(
{
"a": [1, 2, 3],
"b": ["xyz", "abc", "wow"],
}
)
df2 = pl.DataFrame(
{
"a": [5, 6, 7],
"b": ["a", "few", "entries"],
}
)

dfs = pl.concat([df1, df2])

df1.write_parquet(f)
df2.write_parquet(g)

f.seek(0)
g.seek(0)

items: list[IO[bytes]] = [f, g]
assert_frame_equal(pl.read_parquet(items), dfs)

f.seek(0)
g.seek(0)

assert_frame_equal(pl.read_parquet(items, use_pyarrow=True), dfs)

f.seek(0)
g.seek(0)

fb = f.read()
gb = g.read()

assert_frame_equal(pl.read_parquet([fb, gb]), dfs)
assert_frame_equal(pl.read_parquet([fb, gb], use_pyarrow=True), dfs)


def test_concat_series() -> None:
s = pl.Series("a", [2, 1, 3])

assert pl.concat([s, s]).len() == 6
# check if s remains unchanged
assert s.len() == 3
Loading

0 comments on commit f242871

Please sign in to comment.