Skip to content

Commit

Permalink
BUG/REF: Use lru_cache instead of NUMBA_FUNC_CACHE (pandas-dev#46086)
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke authored Feb 27, 2022
1 parent 66a5de3 commit 696a8e9
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 260 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ Groupby/resample/rolling
- Bug in :meth:`DataFrame.resample` ignoring ``closed="right"`` on :class:`TimedeltaIndex` (:issue:`45414`)
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``func="size"`` and the input DataFrame has multiple columns (:issue:`27469`)
- Bug in :meth:`.DataFrameGroupBy.size` and :meth:`.DataFrameGroupBy.transform` with ``func="size"`` produced incorrect results when ``axis=1`` (:issue:`45715`)
- Bug in :meth:`.ExponentialMovingWindow.mean` with ``axis=1`` and ``engine='numba'`` when the :class:`DataFrame` has more columns than rows (:issue:`46086`)
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
- Bug in :meth:`.DataFrameGroupby.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)

Reshaping
Expand Down
29 changes: 11 additions & 18 deletions pandas/core/_numba/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
from typing import (
TYPE_CHECKING,
Callable,
Expand All @@ -10,16 +11,13 @@
from pandas._typing import Scalar
from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
get_jit_arguments,
)


@functools.lru_cache(maxsize=None)
def generate_shared_aggregator(
func: Callable[..., Scalar],
engine_kwargs: dict[str, bool] | None,
cache_key_str: str,
nopython: bool,
nogil: bool,
parallel: bool,
):
"""
Generate a Numba function that loops over the columns 2D object and applies
Expand All @@ -29,22 +27,17 @@ def generate_shared_aggregator(
----------
func : function
aggregation function to be applied to each column
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
cache_key_str: str
string to access the compiled function of the form
<caller_type>_<aggregation_type> e.g. rolling_mean, groupby_mean
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, None)

cache_key = (func, cache_key_str)
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

if TYPE_CHECKING:
import numba
else:
Expand Down
60 changes: 17 additions & 43 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class providing the base-class of operations.
from pandas.core.series import Series
from pandas.core.sorting import get_group_index_sorter
from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
get_jit_arguments,
maybe_use_numba,
)

Expand Down Expand Up @@ -1247,11 +1247,7 @@ def _resolve_numeric_only(self, numeric_only: bool | lib.NoDefault) -> bool:
# numba

@final
def _numba_prep(self, func, data):
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
def _numba_prep(self, data):
ids, _, ngroups = self.grouper.group_info
sorted_index = get_group_index_sorter(ids, ngroups)
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)
Expand All @@ -1271,7 +1267,6 @@ def _numba_agg_general(
self,
func: Callable,
engine_kwargs: dict[str, bool] | None,
numba_cache_key_str: str,
*aggregator_args,
):
"""
Expand All @@ -1288,16 +1283,12 @@ def _numba_agg_general(
with self._group_selection_context():
data = self._selected_obj
df = data if data.ndim == 2 else data.to_frame()
starts, ends, sorted_index, sorted_data = self._numba_prep(func, df)
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
aggregator = executor.generate_shared_aggregator(
func, engine_kwargs, numba_cache_key_str
func, **get_jit_arguments(engine_kwargs)
)
result = aggregator(sorted_data, starts, ends, 0, *aggregator_args)

cache_key = (func, numba_cache_key_str)
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = aggregator

index = self.grouper.result_index
if data.ndim == 1:
result_kwargs = {"name": data.name}
Expand All @@ -1315,10 +1306,10 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)

starts, ends, sorted_index, sorted_data = self._numba_prep(data)
numba_.validate_udf(func)
numba_transform_func = numba_.generate_numba_transform_func(
kwargs, func, engine_kwargs
func, **get_jit_arguments(engine_kwargs, kwargs)
)
result = numba_transform_func(
sorted_data,
Expand All @@ -1328,11 +1319,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
len(data.columns),
*args,
)

cache_key = (func, "groupby_transform")
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func

# result values needs to be resorted to their original positions since we
# evaluated the data sorted by group
return result.take(np.argsort(sorted_index), axis=0)
Expand All @@ -1346,9 +1332,11 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)

numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs)
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
numba_.validate_udf(func)
numba_agg_func = numba_.generate_numba_agg_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
)
result = numba_agg_func(
sorted_data,
sorted_index,
Expand All @@ -1357,11 +1345,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
len(data.columns),
*args,
)

cache_key = (func, "groupby_agg")
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func

return result

# -----------------------------------------------------------------
Expand Down Expand Up @@ -1947,7 +1930,7 @@ def mean(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_mean

return self._numba_agg_general(sliding_mean, engine_kwargs, "groupby_mean")
return self._numba_agg_general(sliding_mean, engine_kwargs)
else:
result = self._cython_agg_general(
"mean",
Expand Down Expand Up @@ -2029,9 +2012,7 @@ def std(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_var

return np.sqrt(
self._numba_agg_general(sliding_var, engine_kwargs, "groupby_std", ddof)
)
return np.sqrt(self._numba_agg_general(sliding_var, engine_kwargs, ddof))
else:
return self._get_cythonized_result(
libgroupby.group_var,
Expand Down Expand Up @@ -2085,9 +2066,7 @@ def var(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_var

return self._numba_agg_general(
sliding_var, engine_kwargs, "groupby_var", ddof
)
return self._numba_agg_general(sliding_var, engine_kwargs, ddof)
else:
if ddof == 1:
numeric_only = self._resolve_numeric_only(lib.no_default)
Expand Down Expand Up @@ -2180,7 +2159,6 @@ def sum(
return self._numba_agg_general(
sliding_sum,
engine_kwargs,
"groupby_sum",
)
else:
numeric_only = self._resolve_numeric_only(numeric_only)
Expand Down Expand Up @@ -2221,9 +2199,7 @@ def min(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_min_max

return self._numba_agg_general(
sliding_min_max, engine_kwargs, "groupby_min", False
)
return self._numba_agg_general(sliding_min_max, engine_kwargs, False)
else:
return self._agg_general(
numeric_only=numeric_only,
Expand All @@ -2244,9 +2220,7 @@ def max(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_min_max

return self._numba_agg_general(
sliding_min_max, engine_kwargs, "groupby_max", True
)
return self._numba_agg_general(sliding_min_max, engine_kwargs, True)
else:
return self._agg_general(
numeric_only=numeric_only,
Expand Down
55 changes: 26 additions & 29 deletions pandas/core/groupby/numba_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Common utilities for Numba operations with groupby ops"""
from __future__ import annotations

import functools
import inspect
from typing import (
TYPE_CHECKING,
Expand All @@ -14,9 +15,7 @@
from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
NumbaUtilError,
get_jit_arguments,
jit_user_function,
)

Expand All @@ -43,6 +42,10 @@ def f(values, index, ...):
------
NumbaUtilError
"""
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
udf_signature = list(inspect.signature(func).parameters.keys())
expected_args = ["values", "index"]
min_number_args = len(expected_args)
Expand All @@ -56,10 +59,12 @@ def f(values, index, ...):
)


@functools.lru_cache(maxsize=None)
def generate_numba_agg_func(
kwargs: dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: dict[str, bool] | None,
nopython: bool,
nogil: bool,
parallel: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
"""
Generate a numba jitted agg function specified by values from engine_kwargs.
Expand All @@ -72,24 +77,19 @@ def generate_numba_agg_func(
Parameters
----------
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
function to be applied to each group and will be JITed
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)

validate_udf(func)
cache_key = (func, "groupby_agg")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

numba_func = jit_user_function(func, nopython, nogil, parallel)
if TYPE_CHECKING:
import numba
Expand Down Expand Up @@ -120,10 +120,12 @@ def group_agg(
return group_agg


@functools.lru_cache(maxsize=None)
def generate_numba_transform_func(
kwargs: dict[str, Any],
func: Callable[..., np.ndarray],
engine_kwargs: dict[str, bool] | None,
nopython: bool,
nogil: bool,
parallel: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
"""
Generate a numba jitted transform function specified by values from engine_kwargs.
Expand All @@ -136,24 +138,19 @@ def generate_numba_transform_func(
Parameters
----------
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)

validate_udf(func)
cache_key = (func, "groupby_transform")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

numba_func = jit_user_function(func, nopython, nogil, parallel)
if TYPE_CHECKING:
import numba
Expand Down
7 changes: 3 additions & 4 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pandas.errors import NumbaUtilError

GLOBAL_USE_NUMBA: bool = False
NUMBA_FUNC_CACHE: dict[tuple[Callable, str], Callable] = {}


def maybe_use_numba(engine: str | None) -> bool:
Expand All @@ -30,7 +29,7 @@ def set_use_numba(enable: bool = False) -> None:

def get_jit_arguments(
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
) -> tuple[bool, bool, bool]:
) -> dict[str, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
Expand All @@ -43,7 +42,7 @@ def get_jit_arguments(
Returns
-------
(bool, bool, bool)
dict[str, bool]
nopython, nogil, parallel
Raises
Expand All @@ -61,7 +60,7 @@ def get_jit_arguments(
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
return nopython, nogil, parallel
return {"nopython": nopython, "nogil": nogil, "parallel": parallel}


def jit_user_function(
Expand Down
Loading

0 comments on commit 696a8e9

Please sign in to comment.