Skip to content

Commit

Permalink
Fix bugs (#508)
Browse files Browse the repository at this point in the history
* Fix bugs

* update

* fix edge cases
  • Loading branch information
goodwanghan authored Aug 20, 2023
1 parent 81276ef commit 8e1cf90
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 36 deletions.
7 changes: 5 additions & 2 deletions fugue/dataframe/pandas_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ def __init__( # noqa: C901
apply_schema = True
if df is None:
schema = _input_schema(schema).assert_not_empty()
df = []
if isinstance(df, PandasDataFrame):
pdf = schema.create_empty_pandas_df(
use_extension_types=True, use_arrow_dtype=False
)
apply_schema = False
elif isinstance(df, PandasDataFrame):
# TODO: This is useless if in this way and wrong
pdf = df.native
schema = None
Expand Down
22 changes: 16 additions & 6 deletions fugue_dask/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,22 @@ def join(
on: Optional[List[str]] = None,
) -> DataFrame:
key_schema, output_schema = get_join_schemas(df1, df2, how=how, on=on)
d = self.pl_utils.join(
self.to_df(df1).native,
self.to_df(df2).native,
join_type=how,
on=key_schema.names,
)
# Dask joins on different types such as int64 vs Int64 can occasionally fail
# so we need to cast to the same type
ndf1 = self.to_df(df1).native
ntp1 = ndf1.dtypes[key_schema.names].to_dict()
ndf2 = self.to_df(df2).native
ntp2 = ndf2.dtypes[key_schema.names].to_dict()
if ntp1 != ntp2:
ntp = key_schema.to_pandas_dtype(
use_extension_types=True, use_arrow_dtype=FUGUE_DASK_USE_ARROW
)
if ntp1 != ntp:
ndf1 = ndf1.astype(ntp)
if ntp2 != ntp:
ndf2 = ndf2.astype(ntp)

d = self.pl_utils.join(ndf1, ndf2, join_type=how, on=key_schema.names)
return DaskDataFrame(d, output_schema, type_safe=False)

def union(
Expand Down
19 changes: 18 additions & 1 deletion fugue_spark/_utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import pyspark.sql.types as pt
from packaging import version
from pyarrow.types import is_list, is_struct, is_timestamp
from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
from pyspark.sql.pandas.types import (
from_arrow_schema,
from_arrow_type,
to_arrow_schema,
to_arrow_type,
)
from triad.collections import Schema
from triad.utils.assertion import assert_arg_not_none, assert_or_throw
from triad.utils.pyarrow import TRIAD_DEFAULT_TIMESTAMP
Expand All @@ -28,6 +33,18 @@
_PYSPARK_ARROW_FRIENDLY = version.parse(pyspark.__version__) >= version.parse("3.3")


def pandas_udf_can_accept(schema: Schema, is_input: bool) -> bool:
try:
# pyspark's own from_arrow_schema to_arrow_schema
# can validate if a type can be supported by pandas udf
if not is_input and any(pa.types.is_struct(t) for t in schema.types):
return False
to_arrow_schema(from_arrow_schema(schema.pa_schema))
return True
except Exception:
return False


def to_spark_schema(obj: Any) -> pt.StructType:
assert_arg_not_none(obj, "schema")
if isinstance(obj, pt.StructType):
Expand Down
30 changes: 17 additions & 13 deletions fugue_spark/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@

from ._constants import FUGUE_SPARK_CONF_USE_PANDAS_UDF, FUGUE_SPARK_DEFAULT_CONF
from ._utils.convert import (
_PYSPARK_ARROW_FRIENDLY,
pandas_udf_can_accept,
to_schema,
to_spark_df,
to_spark_schema,
to_type_safe_input,
_PYSPARK_ARROW_FRIENDLY,
)
from ._utils.io import SparkIO
from ._utils.misc import is_spark_connect as _is_spark_connect
Expand Down Expand Up @@ -108,24 +109,27 @@ def is_spark_connect(self) -> bool:
"""Whether the spark session is created by spark connect"""
return self.execution_engine.is_spark_connect # type:ignore

def _should_use_pandas_udf(self, schema: Schema) -> bool:
def _should_use_pandas_udf(
self, input_schema: Schema, output_schema: Schema
) -> bool:
if self.is_spark_connect: # pragma: no cover
return True
possible = hasattr(ps.DataFrame, "mapInPandas") # must be new version of Spark
# else: # this condition seems to be unnecessary
# possible &= self.execution_engine.conf.get(
# "spark.sql.execution.arrow.pyspark.enabled", False
# )
compatible = pandas_udf_can_accept(
input_schema, is_input=True
) and pandas_udf_can_accept(output_schema, is_input=False)
enabled = self.execution_engine.conf.get_or_throw(
FUGUE_SPARK_CONF_USE_PANDAS_UDF, bool
)
if not possible or any(pa.types.is_nested(t) for t in schema.types):
if enabled and not possible: # pragma: no cover
self.log.warning(
f"{FUGUE_SPARK_CONF_USE_PANDAS_UDF}"
" is enabled but the current PySpark session"
"did not enable Pandas UDF support"
)
if enabled and not compatible: # pragma: no cover
self.log.warning(
f"{FUGUE_SPARK_CONF_USE_PANDAS_UDF}"
f" is enabled but {input_schema} or {output_schema}"
" is not compatible with pandas udf, using RDD instead"
)
return False
return enabled

Expand All @@ -139,7 +143,7 @@ def map_dataframe(
map_func_format_hint: Optional[str] = None,
) -> DataFrame:
output_schema = Schema(output_schema)
if self._should_use_pandas_udf(output_schema):
if self._should_use_pandas_udf(df.schema, output_schema):
if len(partition_spec.partition_by) > 0:
if partition_spec.algo in ["coarse", "even"]:
return self._map_by_pandas_udf(
Expand Down Expand Up @@ -210,7 +214,7 @@ def _group_map_by_pandas_udf(
def _udf_pandas(pdf: Any) -> pd.DataFrame: # pragma: no cover
if pdf.shape[0] == 0:
return _to_safe_spark_worker_pandas(
PandasDataFrame([], output_schema).as_pandas()
PandasDataFrame(schema=output_schema).as_pandas()
)
if len(partition_spec.presort) > 0:
pdf = pdf.sort_values(presort_keys, ascending=presort_asc)
Expand Down Expand Up @@ -270,7 +274,7 @@ def get_dfs() -> Iterable[LocalDataFrame]:
input_df = IterablePandasDataFrame(get_dfs(), input_schema)
if input_df.empty:
yield _to_safe_spark_worker_pandas(
PandasDataFrame([], output_schema).as_pandas()
PandasDataFrame(schema=output_schema).as_pandas()
)
return
if on_init_once is not None:
Expand Down
24 changes: 24 additions & 0 deletions fugue_test/execution_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,30 @@ def test_map_with_dict_col(self):
)
df_eq(c, o, no_pandas=True, check_order=True, throw=True)

# input has dict, output doesn't
def mp2(cursor, data):
return data[["a"]]

c = e.map_engine.map_dataframe(
o, mp2, "a:datetime", PartitionSpec(by=["a"])
)
df_eq(
c,
PandasDataFrame([[dt]], "a:datetime"),
no_pandas=True,
check_order=True,
throw=True,
)

# input doesn't have dict, output has
def mp3(cursor, data):
return PandasDataFrame([[dt, dict(a=1)]], "a:datetime,b:{a:long}")

c = e.map_engine.map_dataframe(
c, mp3, "a:datetime,b:{a:long}", PartitionSpec(by=["a"])
)
df_eq(c, o, no_pandas=True, check_order=True, throw=True)

def test_map_with_binary(self):
e = self.engine
o = ArrayDataFrame(
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def get_version() -> str:
license="Apache-2.0",
author="The Fugue Development Team",
author_email="[email protected]",
keywords="distributed spark dask sql dsl domain specific language",
keywords="distributed spark dask ray sql dsl domain specific language",
url="http://github.com/fugue-project/fugue",
install_requires=[
"triad==0.9.2.dev2",
"triad==0.9.2.dev3",
"adagio>=0.2.4",
# sql dependencies
"qpd>=0.4.4",
Expand Down Expand Up @@ -67,6 +67,7 @@ def get_version() -> str:
"fugue-sql-antlr[cpp]>=0.1.6",
"pyspark>=3.1.1",
"dask[distributed,dataframe]>=2023.5.0",
"dask-sql",
"ray[data]>=2.1.0",
"notebook",
"jupyterlab",
Expand Down
6 changes: 3 additions & 3 deletions tests/fugue/dataframe/test_function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ def test_iterable_pandas_dataframes():
df = PandasDataFrame(pdf)
data = list(p.to_input_data(df, ctx=None))
assert 1 == len(data)
assert data[0] is pdf # this is to guarantee no copy in any wrapping logic
assert data[0] is df.native # this is to guarantee no copy in any wrapping logic
assert data[0].values.tolist() == [[0, "x"]]

dfs = IterablePandasDataFrame([df, df])
data = list(p.to_input_data(dfs, ctx=None))
assert 2 == len(data)
assert data[0] is pdf
assert data[1] is pdf
assert data[0] is df.native
assert data[1] is df.native

def get_pdfs():
yield pdf
Expand Down
17 changes: 13 additions & 4 deletions tests/fugue/dataframe/test_pandas_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,19 @@ def test_init():
assert [["a", "1"], ["b", "2"]] == df.native.values.tolist()
df = PandasDataFrame(pdf, "a:str,b:int")
assert [["a", 1], ["b", 2]] == df.native.values.tolist()
assert pdf is not df.native
df = PandasDataFrame(pdf, "a:str,b:double")
assert [["a", 1.0], ["b", 2.0]] == df.native.values.tolist()

# no copy is important for performance
df = PandasDataFrame(pdf, "a:str,b:long")
assert pdf is df.native
df = PandasDataFrame(pdf, pandas_df_wrapper=True)
assert pdf is df.native
assert df.schema == "a:str,b:long"
df = PandasDataFrame(pdf, "a:str,b:int", pandas_df_wrapper=True)
assert pdf is df.native

pdf = pd.DataFrame([["a", 1], ["b", 2]], columns=["a", "b"])["b"]
assert isinstance(pdf, pd.Series)
df = PandasDataFrame(pdf, "b:str")
Expand Down Expand Up @@ -101,10 +111,9 @@ def test_nested():
# a = df.as_array(type_safe=True)
# assert [[dict(a="1", b=[3, 4])], [dict(a=None, b=[30, 40])]] == a

data = [[[json.dumps(dict(b=[30, "40"]))]]]
df = PandasDataFrame(data, "a:[{a:str,b:[int]}]")
a = df.as_array(type_safe=True)
assert [[[dict(a=None, b=[30, 40])]]] == a
data = [[[dict(b=[30, 40])]]]
df = PandasDataFrame(data, "a:[{b:[int]}]")
assert df.as_array(type_safe=True) == data


def test_rename():
Expand Down
14 changes: 14 additions & 0 deletions tests/fugue_dask/test_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,20 @@ def assert_(df: pd.DataFrame, rc: int, n: int, check_ordered: bool) -> None:
dag.run(self.engine)


def test_join_keys_unification(fugue_dask_client):
df1 = DaskDataFrame(
pd.DataFrame([[10, 1], [11, 3]], columns=["a", "b"]).convert_dtypes(),
"a:long,b:long",
)
df2 = PandasDataFrame(
pd.DataFrame([[10, [2]]], columns=["a", "c"]),
"a:long,c:[long]",
)
with fa.engine_context(fugue_dask_client) as engine:
assert fa.as_array(fa.inner_join(df1, df2)) == [[10, 1, [2]]]
assert fa.as_array(fa.inner_join(df2, df1)) == [[10, [2], 1]]


def test_transform(fugue_dask_client):
class CB:
def __init__(self):
Expand Down
22 changes: 18 additions & 4 deletions tests/fugue_spark/test_execution_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any, Iterable, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_get_parallelism(self):

def test_not_using_pandas_udf(self):
assert not self.engine.create_default_map_engine()._should_use_pandas_udf(
Schema("a:int")
Schema("a:int"), Schema("a:int")
)

def test__join_outer_pandas_incompatible(self):
Expand Down Expand Up @@ -144,10 +144,13 @@ def test__join_outer_pandas_incompatible(self):

def test_using_pandas_udf(self):
assert self.engine.map_engine._should_use_pandas_udf( # type: ignore
Schema("a:int")
Schema("a:int"), Schema("a:int")
)
assert not self.engine.map_engine._should_use_pandas_udf( # type: ignore
Schema("a:{x:int}")
Schema("a:int"), Schema("a:{x:int}")
)
assert self.engine.map_engine._should_use_pandas_udf( # type: ignore
Schema("a:{x:int}"), Schema("a:int")
)

def test_sample_n(self):
Expand Down Expand Up @@ -398,6 +401,17 @@ def make_engine(self):
return e


def test_rdd_pd_extension_types_handling(spark_session):
def tr(df: List[List[Any]]) -> Iterable[pd.DataFrame]:
assert isinstance(df[0][0], float)
yield pd.DataFrame([[0.1, [1]]], columns=["a", "b"]).convert_dtypes()

with fa.engine_context(spark_session, {"fugue.spark.use_pandas_udf": False}):
df = pd.DataFrame([[0.1]], columns=["a"]).convert_dtypes()
res = fa.as_array(fa.transform(df, tr, schema="a:double,b:[long]"))
assert res == [[0.1, [1]]]


@transformer("ct:long")
def count_partition(df: List[List[Any]]) -> List[List[Any]]:
return [[len(df)]]
Expand Down
15 changes: 14 additions & 1 deletion tests/fugue_spark/utils/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
import pyarrow as pa
from pytest import raises
from triad import Schema

from fugue_spark._utils.convert import (
pandas_udf_can_accept,
to_cast_expression,
to_schema,
to_select_expression,
to_spark_schema,
)
from pytest import raises


def test_pandas_udf_can_accept():
for is_input in [True, False]:
assert pandas_udf_can_accept(Schema("a:int,b:str"), is_input)
assert pandas_udf_can_accept(Schema("a:int,b:[str],c:[float]"), is_input)
assert not pandas_udf_can_accept(Schema("a:int,b:[datetime]"), is_input)
assert pandas_udf_can_accept(Schema("a:int,b:{a:int}"), True)
assert not pandas_udf_can_accept(Schema("a:int,b:{a:int}"), False)


def test_schema_conversion(spark_session):
Expand Down

0 comments on commit 8e1cf90

Please sign in to comment.