Skip to content

Commit

Permalink
Fix dask bugs (#507)
Browse files Browse the repository at this point in the history
* Fix dask bugs

* update
  • Loading branch information
goodwanghan authored Aug 17, 2023
1 parent ae269ee commit 81276ef
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
7 changes: 5 additions & 2 deletions fugue_dask/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def hash_repartition(df: dd.DataFrame, num: int, cols: List[Any]) -> dd.DataFram
return df
if num == 1:
return df.repartition(1)
df = df.reset_index(drop=True).clear_divisions()
idf, ct = _add_hash_index(df, num, cols)
return _postprocess(idf, ct, num)

Expand All @@ -63,9 +64,10 @@ def even_repartition(df: dd.DataFrame, num: int, cols: List[Any]) -> dd.DataFram
"""
if num == 1:
return df.repartition(1)
if len(cols) == 0 and num <= 0:
return df
df = df.reset_index(drop=True).clear_divisions()
if len(cols) == 0:
if num <= 0:
return df
idf, ct = _add_continuous_index(df)
else:
idf, ct = _add_group_index(df, cols, shuffle=False)
Expand Down Expand Up @@ -97,6 +99,7 @@ def rand_repartition(
return df
if num == 1:
return df.repartition(1)
df = df.reset_index(drop=True).clear_divisions()
if len(cols) == 0:
idf, ct = _add_random_index(df, num=num, seed=seed)
else:
Expand Down
2 changes: 1 addition & 1 deletion fugue_ray/_utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _save_csv(
if "header" in kw:
kw["include_header"] = kw.pop("header")

def _fn() -> Dict[str, Any]:
def _fn() -> Dict[str, Any]: # pragma: no cover
return dict(write_options=pacsv.WriteOptions(**kw))

df.native.write_csv(
Expand Down
45 changes: 43 additions & 2 deletions tests/fugue_dask/test_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from threading import RLock
from typing import Any, List, Optional

import dask
import dask.dataframe as dd
import numpy as np
import pandas as pd
import pytest
from dask.distributed import Client
Expand All @@ -25,7 +25,6 @@
from fugue_test.builtin_suite import BuiltInTests
from fugue_test.execution_suite import ExecutionEngineTests


_CONF = {
"fugue.rpc.server": "fugue.rpc.flask.FlaskRPCServer",
"fugue.rpc.flask_server.host": "127.0.0.1",
Expand Down Expand Up @@ -321,6 +320,48 @@ def tr(df: List[List[Any]], add: Optional[callable]) -> List[List[Any]]:
assert 5 == cb.n


def test_multiple_transforms(fugue_dask_client):
def t1(df: pd.DataFrame) -> pd.DataFrame:
return pd.concat([df, df])

def t2(df: pd.DataFrame) -> pd.DataFrame:
return (
df.groupby(["a", "b"], as_index=False, dropna=False)
.apply(lambda x: x.head(1))
.reset_index(drop=True)
)

def compute(df: pd.DataFrame, engine) -> pd.DataFrame:
with fa.engine_context(engine):
ddf = fa.as_fugue_df(df)
ddf1 = fa.transform(ddf, t1, schema="*", partition=dict(algo="hash"))
ddf2 = fa.transform(
ddf1,
t2,
schema="*",
partition=dict(by=["a", "b"], presort="c", algo="coarse", num=2),
)
return (
ddf2.as_pandas()
.astype("float64")
.fillna(float("nan"))
.sort_values(["a", "b"])
)

np.random.seed(0)
df = pd.DataFrame(
dict(
a=np.random.randint(1, 5, 1000),
b=np.random.choice([1, 2, 3, None], 1000),
c=np.random.rand(1000),
)
)

actual = compute(df, fugue_dask_client)
expected = compute(df, None)
assert np.allclose(actual, expected, equal_nan=True)


@transformer("ct:long")
def count_partition(df: List[List[Any]]) -> List[List[Any]]:
return [[len(df)]]
Expand Down

0 comments on commit 81276ef

Please sign in to comment.