diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index 527c4215d22ca..2072dda2965d7 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -383,6 +383,8 @@ Groupby/resample/rolling - Bug in :meth:`DataFrame.resample` ignoring ``closed="right"`` on :class:`TimedeltaIndex` (:issue:`45414`) - Bug in :meth:`.DataFrameGroupBy.transform` fails when ``func="size"`` and the input DataFrame has multiple columns (:issue:`27469`) - Bug in :meth:`.DataFrameGroupBy.size` and :meth:`.DataFrameGroupBy.transform` with ``func="size"`` produced incorrect results when ``axis=1`` (:issue:`45715`) +- Bug in :meth:`.ExponentialMovingWindow.mean` with ``axis=1`` and ``engine='numba'`` when the :class:`DataFrame` has more columns than rows (:issue:`46086`) +- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`) - Bug in :meth:`.DataFrameGroupby.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`) Reshaping diff --git a/pandas/core/_numba/executor.py b/pandas/core/_numba/executor.py index 0b59d0717a476..13d8b52bae39c 100644 --- a/pandas/core/_numba/executor.py +++ b/pandas/core/_numba/executor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools from typing import ( TYPE_CHECKING, Callable, @@ -10,16 +11,13 @@ from pandas._typing import Scalar from pandas.compat._optional import import_optional_dependency -from pandas.core.util.numba_ import ( - NUMBA_FUNC_CACHE, - get_jit_arguments, -) - +@functools.lru_cache(maxsize=None) def generate_shared_aggregator( func: Callable[..., Scalar], - engine_kwargs: dict[str, bool] | None, - cache_key_str: str, + nopython: bool, + nogil: bool, + parallel: bool, ): """ Generate a Numba function that loops over the columns 2D object and applies @@ -29,22 +27,17 @@ def generate_shared_aggregator( ---------- func : function aggregation function to be applied to each column - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit - cache_key_str: str - string to access the compiled function of the form - _ e.g. rolling_mean, groupby_mean + nopython : bool + nopython to be passed into numba.jit + nogil : bool + nogil to be passed into numba.jit + parallel : bool + parallel to be passed into numba.jit Returns ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs, None) - - cache_key = (func, cache_key_str) - if cache_key in NUMBA_FUNC_CACHE: - return NUMBA_FUNC_CACHE[cache_key] - if TYPE_CHECKING: import numba else: diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f8106edeb5d62..ad2ec2a88af82 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -114,7 +114,7 @@ class providing the base-class of operations. from pandas.core.series import Series from pandas.core.sorting import get_group_index_sorter from pandas.core.util.numba_ import ( - NUMBA_FUNC_CACHE, + get_jit_arguments, maybe_use_numba, ) @@ -1247,11 +1247,7 @@ def _resolve_numeric_only(self, numeric_only: bool | lib.NoDefault) -> bool: # numba @final - def _numba_prep(self, func, data): - if not callable(func): - raise NotImplementedError( - "Numba engine can only be used with a single function." - ) + def _numba_prep(self, data): ids, _, ngroups = self.grouper.group_info sorted_index = get_group_index_sorter(ids, ngroups) sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False) @@ -1271,7 +1267,6 @@ def _numba_agg_general( self, func: Callable, engine_kwargs: dict[str, bool] | None, - numba_cache_key_str: str, *aggregator_args, ): """ @@ -1288,16 +1283,12 @@ def _numba_agg_general( with self._group_selection_context(): data = self._selected_obj df = data if data.ndim == 2 else data.to_frame() - starts, ends, sorted_index, sorted_data = self._numba_prep(func, df) + starts, ends, sorted_index, sorted_data = self._numba_prep(df) aggregator = executor.generate_shared_aggregator( - func, engine_kwargs, numba_cache_key_str + func, **get_jit_arguments(engine_kwargs) ) result = aggregator(sorted_data, starts, ends, 0, *aggregator_args) - cache_key = (func, numba_cache_key_str) - if cache_key not in NUMBA_FUNC_CACHE: - NUMBA_FUNC_CACHE[cache_key] = aggregator - index = self.grouper.result_index if data.ndim == 1: result_kwargs = {"name": data.name} @@ -1315,10 +1306,10 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) to generate the indices of each group in the sorted data and then passes the data and indices into a Numba jitted function. """ - starts, ends, sorted_index, sorted_data = self._numba_prep(func, data) - + starts, ends, sorted_index, sorted_data = self._numba_prep(data) + numba_.validate_udf(func) numba_transform_func = numba_.generate_numba_transform_func( - kwargs, func, engine_kwargs + func, **get_jit_arguments(engine_kwargs, kwargs) ) result = numba_transform_func( sorted_data, @@ -1328,11 +1319,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) len(data.columns), *args, ) - - cache_key = (func, "groupby_transform") - if cache_key not in NUMBA_FUNC_CACHE: - NUMBA_FUNC_CACHE[cache_key] = numba_transform_func - # result values needs to be resorted to their original positions since we # evaluated the data sorted by group return result.take(np.argsort(sorted_index), axis=0) @@ -1346,9 +1332,11 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) to generate the indices of each group in the sorted data and then passes the data and indices into a Numba jitted function. """ - starts, ends, sorted_index, sorted_data = self._numba_prep(func, data) - - numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs) + starts, ends, sorted_index, sorted_data = self._numba_prep(data) + numba_.validate_udf(func) + numba_agg_func = numba_.generate_numba_agg_func( + func, **get_jit_arguments(engine_kwargs, kwargs) + ) result = numba_agg_func( sorted_data, sorted_index, @@ -1357,11 +1345,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) len(data.columns), *args, ) - - cache_key = (func, "groupby_agg") - if cache_key not in NUMBA_FUNC_CACHE: - NUMBA_FUNC_CACHE[cache_key] = numba_agg_func - return result # ----------------------------------------------------------------- @@ -1947,7 +1930,7 @@ def mean( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_mean - return self._numba_agg_general(sliding_mean, engine_kwargs, "groupby_mean") + return self._numba_agg_general(sliding_mean, engine_kwargs) else: result = self._cython_agg_general( "mean", @@ -2029,9 +2012,7 @@ def std( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_var - return np.sqrt( - self._numba_agg_general(sliding_var, engine_kwargs, "groupby_std", ddof) - ) + return np.sqrt(self._numba_agg_general(sliding_var, engine_kwargs, ddof)) else: return self._get_cythonized_result( libgroupby.group_var, @@ -2085,9 +2066,7 @@ def var( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_var - return self._numba_agg_general( - sliding_var, engine_kwargs, "groupby_var", ddof - ) + return self._numba_agg_general(sliding_var, engine_kwargs, ddof) else: if ddof == 1: numeric_only = self._resolve_numeric_only(lib.no_default) @@ -2180,7 +2159,6 @@ def sum( return self._numba_agg_general( sliding_sum, engine_kwargs, - "groupby_sum", ) else: numeric_only = self._resolve_numeric_only(numeric_only) @@ -2221,9 +2199,7 @@ def min( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_min_max - return self._numba_agg_general( - sliding_min_max, engine_kwargs, "groupby_min", False - ) + return self._numba_agg_general(sliding_min_max, engine_kwargs, False) else: return self._agg_general( numeric_only=numeric_only, @@ -2244,9 +2220,7 @@ def max( if maybe_use_numba(engine): from pandas.core._numba.kernels import sliding_min_max - return self._numba_agg_general( - sliding_min_max, engine_kwargs, "groupby_max", True - ) + return self._numba_agg_general(sliding_min_max, engine_kwargs, True) else: return self._agg_general( numeric_only=numeric_only, diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index 24d66725caa70..acfc690ab3fdb 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -1,6 +1,7 @@ """Common utilities for Numba operations with groupby ops""" from __future__ import annotations +import functools import inspect from typing import ( TYPE_CHECKING, @@ -14,9 +15,7 @@ from pandas.compat._optional import import_optional_dependency from pandas.core.util.numba_ import ( - NUMBA_FUNC_CACHE, NumbaUtilError, - get_jit_arguments, jit_user_function, ) @@ -43,6 +42,10 @@ def f(values, index, ...): ------ NumbaUtilError """ + if not callable(func): + raise NotImplementedError( + "Numba engine can only be used with a single function." + ) udf_signature = list(inspect.signature(func).parameters.keys()) expected_args = ["values", "index"] min_number_args = len(expected_args) @@ -56,10 +59,12 @@ def f(values, index, ...): ) +@functools.lru_cache(maxsize=None) def generate_numba_agg_func( - kwargs: dict[str, Any], func: Callable[..., Scalar], - engine_kwargs: dict[str, bool] | None, + nopython: bool, + nogil: bool, + parallel: bool, ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: """ Generate a numba jitted agg function specified by values from engine_kwargs. @@ -72,24 +77,19 @@ def generate_numba_agg_func( Parameters ---------- - kwargs : dict - **kwargs to be passed into the function func : function - function to be applied to each window and will be JITed - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit + function to be applied to each group and will be JITed + nopython : bool + nopython to be passed into numba.jit + nogil : bool + nogil to be passed into numba.jit + parallel : bool + parallel to be passed into numba.jit Returns ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) - - validate_udf(func) - cache_key = (func, "groupby_agg") - if cache_key in NUMBA_FUNC_CACHE: - return NUMBA_FUNC_CACHE[cache_key] - numba_func = jit_user_function(func, nopython, nogil, parallel) if TYPE_CHECKING: import numba @@ -120,10 +120,12 @@ def group_agg( return group_agg +@functools.lru_cache(maxsize=None) def generate_numba_transform_func( - kwargs: dict[str, Any], func: Callable[..., np.ndarray], - engine_kwargs: dict[str, bool] | None, + nopython: bool, + nogil: bool, + parallel: bool, ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]: """ Generate a numba jitted transform function specified by values from engine_kwargs. @@ -136,24 +138,19 @@ def generate_numba_transform_func( Parameters ---------- - kwargs : dict - **kwargs to be passed into the function func : function function to be applied to each window and will be JITed - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit + nopython : bool + nopython to be passed into numba.jit + nogil : bool + nogil to be passed into numba.jit + parallel : bool + parallel to be passed into numba.jit Returns ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) - - validate_udf(func) - cache_key = (func, "groupby_transform") - if cache_key in NUMBA_FUNC_CACHE: - return NUMBA_FUNC_CACHE[cache_key] - numba_func = jit_user_function(func, nopython, nogil, parallel) if TYPE_CHECKING: import numba diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index 06630989444bb..be798e022ac6e 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -13,7 +13,6 @@ from pandas.errors import NumbaUtilError GLOBAL_USE_NUMBA: bool = False -NUMBA_FUNC_CACHE: dict[tuple[Callable, str], Callable] = {} def maybe_use_numba(engine: str | None) -> bool: @@ -30,7 +29,7 @@ def set_use_numba(enable: bool = False) -> None: def get_jit_arguments( engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None -) -> tuple[bool, bool, bool]: +) -> dict[str, bool]: """ Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. @@ -43,7 +42,7 @@ def get_jit_arguments( Returns ------- - (bool, bool, bool) + dict[str, bool] nopython, nogil, parallel Raises @@ -61,7 +60,7 @@ def get_jit_arguments( ) nogil = engine_kwargs.get("nogil", False) parallel = engine_kwargs.get("parallel", False) - return nopython, nogil, parallel + return {"nopython": nopython, "nogil": nogil, "parallel": parallel} def jit_user_function( diff --git a/pandas/core/window/ewm.py b/pandas/core/window/ewm.py index 4bebc56273805..2d633ba1a2bcd 100644 --- a/pandas/core/window/ewm.py +++ b/pandas/core/window/ewm.py @@ -32,7 +32,10 @@ ExponentialMovingWindowIndexer, GroupbyIndexer, ) -from pandas.core.util.numba_ import maybe_use_numba +from pandas.core.util.numba_ import ( + get_jit_arguments, + maybe_use_numba, +) from pandas.core.window.common import zsqrt from pandas.core.window.doc import ( _shared_docs, @@ -406,7 +409,9 @@ def __init__( "times is not None." ) # Without times, points are equally spaced - self._deltas = np.ones(max(len(self.obj) - 1, 0), dtype=np.float64) + self._deltas = np.ones( + max(self.obj.shape[self.axis] - 1, 0), dtype=np.float64 + ) self._com = get_center_of_mass( # error: Argument 3 to "get_center_of_mass" has incompatible type # "Union[float, Any, None, timedelta64, signedinteger[_64Bit]]"; @@ -527,22 +532,17 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): if self.method == "single": func = generate_numba_ewm_func - numba_cache_key = (lambda x: x, "ewm_mean") else: func = generate_numba_ewm_table_func - numba_cache_key = (lambda x: x, "ewm_mean_table") ewm_func = func( - engine_kwargs=engine_kwargs, + **get_jit_arguments(engine_kwargs), com=self._com, adjust=self.adjust, ignore_na=self.ignore_na, - deltas=self._deltas, + deltas=tuple(self._deltas), normalize=True, ) - return self._apply( - ewm_func, - numba_cache_key=numba_cache_key, - ) + return self._apply(ewm_func) elif engine in ("cython", None): if engine_kwargs is not None: raise ValueError("cython engine does not accept engine_kwargs") @@ -583,22 +583,17 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): if self.method == "single": func = generate_numba_ewm_func - numba_cache_key = (lambda x: x, "ewm_sum") else: func = generate_numba_ewm_table_func - numba_cache_key = (lambda x: x, "ewm_sum_table") ewm_func = func( - engine_kwargs=engine_kwargs, + **get_jit_arguments(engine_kwargs), com=self._com, adjust=self.adjust, ignore_na=self.ignore_na, - deltas=self._deltas, + deltas=tuple(self._deltas), normalize=False, ) - return self._apply( - ewm_func, - numba_cache_key=numba_cache_key, - ) + return self._apply(ewm_func) elif engine in ("cython", None): if engine_kwargs is not None: raise ValueError("cython engine does not accept engine_kwargs") @@ -1011,7 +1006,9 @@ def mean(self, *args, update=None, update_times=None, **kwargs): else: result_kwargs["name"] = self._selected_obj.name np_array = self._selected_obj.astype(np.float64).to_numpy() - ewma_func = generate_online_numba_ewma_func(self.engine_kwargs) + ewma_func = generate_online_numba_ewma_func( + **get_jit_arguments(self.engine_kwargs) + ) result = self._mean.run_ewm( np_array if is_frame else np_array[:, np.newaxis], update_deltas, diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index 0e8eea3ec671e..3b14f0d14ecab 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -12,18 +12,15 @@ from pandas._typing import Scalar from pandas.compat._optional import import_optional_dependency -from pandas.core.util.numba_ import ( - NUMBA_FUNC_CACHE, - get_jit_arguments, - jit_user_function, -) +from pandas.core.util.numba_ import jit_user_function +@functools.lru_cache(maxsize=None) def generate_numba_apply_func( - kwargs: dict[str, Any], func: Callable[..., Scalar], - engine_kwargs: dict[str, bool] | None, - name: str, + nopython: bool, + nogil: bool, + parallel: bool, ): """ Generate a numba jitted apply function specified by values from engine_kwargs. @@ -36,25 +33,19 @@ def generate_numba_apply_func( Parameters ---------- - kwargs : dict - **kwargs to be passed into the function func : function function to be applied to each window and will be JITed - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit - name: str - name of the caller (Rolling/Expanding) + nopython : bool + nopython to be passed into numba.jit + nogil : bool + nogil to be passed into numba.jit + parallel : bool + parallel to be passed into numba.jit Returns ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) - - cache_key = (func, f"{name}_apply_single") - if cache_key in NUMBA_FUNC_CACHE: - return NUMBA_FUNC_CACHE[cache_key] - numba_func = jit_user_function(func, nopython, nogil, parallel) if TYPE_CHECKING: import numba @@ -84,12 +75,15 @@ def roll_apply( return roll_apply +@functools.lru_cache(maxsize=None) def generate_numba_ewm_func( - engine_kwargs: dict[str, bool] | None, + nopython: bool, + nogil: bool, + parallel: bool, com: float, adjust: bool, ignore_na: bool, - deltas: np.ndarray, + deltas: tuple, normalize: bool, ): """ @@ -98,25 +92,22 @@ def generate_numba_ewm_func( Parameters ---------- - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit + nopython : bool + nopython to be passed into numba.jit + nogil : bool + nogil to be passed into numba.jit + parallel : bool + parallel to be passed into numba.jit com : float adjust : bool ignore_na : bool - deltas : numpy.ndarray + deltas : tuple normalize : bool Returns ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - - str_key = "ewm_mean" if normalize else "ewm_sum" - cache_key = (lambda x: x, str_key) - if cache_key in NUMBA_FUNC_CACHE: - return NUMBA_FUNC_CACHE[cache_key] - if TYPE_CHECKING: import numba else: @@ -183,11 +174,12 @@ def ewm( return ewm +@functools.lru_cache(maxsize=None) def generate_numba_table_func( - kwargs: dict[str, Any], func: Callable[..., np.ndarray], - engine_kwargs: dict[str, bool] | None, - name: str, + nopython: bool, + nogil: bool, + parallel: bool, ): """ Generate a numba jitted function to apply window calculations table-wise. @@ -201,25 +193,19 @@ def generate_numba_table_func( Parameters ---------- - kwargs : dict - **kwargs to be passed into the function func : function function to be applied to each window and will be JITed - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit - name : str - caller (Rolling/Expanding) and original method name for numba cache key + nopython : bool + nopython to be passed into numba.jit + nogil : bool + nogil to be passed into numba.jit + parallel : bool + parallel to be passed into numba.jit Returns ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) - - cache_key = (func, f"{name}_table") - if cache_key in NUMBA_FUNC_CACHE: - return NUMBA_FUNC_CACHE[cache_key] - numba_func = jit_user_function(func, nopython, nogil, parallel) if TYPE_CHECKING: import numba @@ -272,12 +258,15 @@ def nan_agg_with_axis(table): return nan_agg_with_axis +@functools.lru_cache(maxsize=None) def generate_numba_ewm_table_func( - engine_kwargs: dict[str, bool] | None, + nopython: bool, + nogil: bool, + parallel: bool, com: float, adjust: bool, ignore_na: bool, - deltas: np.ndarray, + deltas: tuple, normalize: bool, ): """ @@ -286,25 +275,22 @@ def generate_numba_ewm_table_func( Parameters ---------- - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit + nopython : bool + nopython to be passed into numba.jit + nogil : bool + nogil to be passed into numba.jit + parallel : bool + parallel to be passed into numba.jit com : float adjust : bool ignore_na : bool - deltas : numpy.ndarray + deltas : tuple normalize: bool Returns ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - - str_key = "ewm_mean_table" if normalize else "ewm_sum_table" - cache_key = (lambda x: x, str_key) - if cache_key in NUMBA_FUNC_CACHE: - return NUMBA_FUNC_CACHE[cache_key] - if TYPE_CHECKING: import numba else: diff --git a/pandas/core/window/online.py b/pandas/core/window/online.py index 8ef4aee154db4..94408e6df2504 100644 --- a/pandas/core/window/online.py +++ b/pandas/core/window/online.py @@ -1,37 +1,32 @@ -from typing import ( - TYPE_CHECKING, - Dict, - Optional, -) +from typing import TYPE_CHECKING import numpy as np from pandas.compat._optional import import_optional_dependency -from pandas.core.util.numba_ import ( - NUMBA_FUNC_CACHE, - get_jit_arguments, -) - -def generate_online_numba_ewma_func(engine_kwargs: Optional[Dict[str, bool]]): +def generate_online_numba_ewma_func( + nopython: bool, + nogil: bool, + parallel: bool, +): """ Generate a numba jitted groupby ewma function specified by values from engine_kwargs. + Parameters ---------- - engine_kwargs : dict - dictionary of arguments to be passed into numba.jit + nopython : bool + nopython to be passed into numba.jit + nogil : bool + nogil to be passed into numba.jit + parallel : bool + parallel to be passed into numba.jit + Returns ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - - cache_key = (lambda x: x, "online_ewma") - if cache_key in NUMBA_FUNC_CACHE: - return NUMBA_FUNC_CACHE[cache_key] - if TYPE_CHECKING: import numba else: diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 72ebbcbc65e5e..0f7e624967e9b 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -73,7 +73,7 @@ ) from pandas.core.reshape.concat import concat from pandas.core.util.numba_ import ( - NUMBA_FUNC_CACHE, + get_jit_arguments, maybe_use_numba, ) from pandas.core.window.common import ( @@ -530,7 +530,6 @@ def _apply( self, func: Callable[..., Any], name: str | None = None, - numba_cache_key: tuple[Callable, str] | None = None, numba_args: tuple[Any, ...] = (), **kwargs, ): @@ -543,8 +542,6 @@ def _apply( ---------- func : callable function to apply name : str, - numba_cache_key : tuple - caching key to be used to store a compiled numba func numba_args : tuple args to be passed when func is a numba func **kwargs @@ -581,9 +578,6 @@ def calc(x): with np.errstate(all="ignore"): result = calc(values) - if numba_cache_key is not None: - NUMBA_FUNC_CACHE[numba_cache_key] = func - return result if self.method == "single": @@ -594,7 +588,6 @@ def calc(x): def _numba_apply( self, func: Callable[..., Any], - numba_cache_key_str: str, engine_kwargs: dict[str, bool] | None = None, *func_args, ): @@ -618,10 +611,9 @@ def _numba_apply( ) self._check_window_bounds(start, end, len(values)) aggregator = executor.generate_shared_aggregator( - func, engine_kwargs, numba_cache_key_str + func, **get_jit_arguments(engine_kwargs) ) result = aggregator(values, start, end, min_periods, *func_args) - NUMBA_FUNC_CACHE[(func, numba_cache_key_str)] = aggregator result = result.T if self.axis == 1 else result if obj.ndim == 1: result = result.squeeze() @@ -673,14 +665,12 @@ def _apply( self, func: Callable[..., Any], name: str | None = None, - numba_cache_key: tuple[Callable, str] | None = None, numba_args: tuple[Any, ...] = (), **kwargs, ) -> DataFrame | Series: result = super()._apply( func, name, - numba_cache_key, numba_args, **kwargs, ) @@ -1101,7 +1091,6 @@ def _apply( self, func: Callable[[np.ndarray, int, int], np.ndarray], name: str | None = None, - numba_cache_key: tuple[Callable, str] | None = None, numba_args: tuple[Any, ...] = (), **kwargs, ): @@ -1114,8 +1103,6 @@ def _apply( ---------- func : callable function to apply name : str, - use_numba_cache : tuple - unused numba_args : tuple unused **kwargs @@ -1294,23 +1281,19 @@ def apply( if not is_bool(raw): raise ValueError("raw parameter must be `True` or `False`") - numba_cache_key = None numba_args: tuple[Any, ...] = () if maybe_use_numba(engine): if raw is False: raise ValueError("raw must be `True` when using the numba engine") - caller_name = type(self).__name__ numba_args = args if self.method == "single": apply_func = generate_numba_apply_func( - kwargs, func, engine_kwargs, caller_name + func, **get_jit_arguments(engine_kwargs, kwargs) ) - numba_cache_key = (func, f"{caller_name}_apply_single") else: apply_func = generate_numba_table_func( - kwargs, func, engine_kwargs, f"{caller_name}_apply" + func, **get_jit_arguments(engine_kwargs, kwargs) ) - numba_cache_key = (func, f"{caller_name}_apply_table") elif engine in ("cython", None): if engine_kwargs is not None: raise ValueError("cython engine does not accept engine_kwargs") @@ -1320,7 +1303,6 @@ def apply( return self._apply( apply_func, - numba_cache_key=numba_cache_key, numba_args=numba_args, ) @@ -1369,7 +1351,7 @@ def sum( else: from pandas.core._numba.kernels import sliding_sum - return self._numba_apply(sliding_sum, "rolling_sum", engine_kwargs) + return self._numba_apply(sliding_sum, engine_kwargs) window_func = window_aggregations.roll_sum return self._apply(window_func, name="sum", **kwargs) @@ -1393,9 +1375,7 @@ def max( else: from pandas.core._numba.kernels import sliding_min_max - return self._numba_apply( - sliding_min_max, "rolling_max", engine_kwargs, True - ) + return self._numba_apply(sliding_min_max, engine_kwargs, True) window_func = window_aggregations.roll_max return self._apply(window_func, name="max", **kwargs) @@ -1419,9 +1399,7 @@ def min( else: from pandas.core._numba.kernels import sliding_min_max - return self._numba_apply( - sliding_min_max, "rolling_min", engine_kwargs, False - ) + return self._numba_apply(sliding_min_max, engine_kwargs, False) window_func = window_aggregations.roll_min return self._apply(window_func, name="min", **kwargs) @@ -1445,7 +1423,7 @@ def mean( else: from pandas.core._numba.kernels import sliding_mean - return self._numba_apply(sliding_mean, "rolling_mean", engine_kwargs) + return self._numba_apply(sliding_mean, engine_kwargs) window_func = window_aggregations.roll_mean return self._apply(window_func, name="mean", **kwargs) @@ -1485,9 +1463,7 @@ def std( else: from pandas.core._numba.kernels import sliding_var - return zsqrt( - self._numba_apply(sliding_var, "rolling_std", engine_kwargs, ddof) - ) + return zsqrt(self._numba_apply(sliding_var, engine_kwargs, ddof)) window_func = window_aggregations.roll_var def zsqrt_func(values, begin, end, min_periods): @@ -1514,9 +1490,7 @@ def var( else: from pandas.core._numba.kernels import sliding_var - return self._numba_apply( - sliding_var, "rolling_var", engine_kwargs, ddof - ) + return self._numba_apply(sliding_var, engine_kwargs, ddof) window_func = partial(window_aggregations.roll_var, ddof=ddof) return self._apply( window_func, diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index e7fa2e0690066..ba58ac27284b8 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -12,7 +12,6 @@ option_context, ) import pandas._testing as tm -from pandas.core.util.numba_ import NUMBA_FUNC_CACHE @td.skip_if_no("numba") @@ -33,8 +32,8 @@ def incorrect_function(x): @td.skip_if_no("numba") def test_check_nopython_kwargs(): - def incorrect_function(x, **kwargs): - return sum(x) * 2.7 + def incorrect_function(values, index): + return sum(values) * 2.7 data = DataFrame( {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, @@ -106,14 +105,11 @@ def func_2(values, index): result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython") tm.assert_equal(result, expected) - # func_1 should be in the cache now - assert (func_1, "groupby_agg") in NUMBA_FUNC_CACHE # Add func_2 to the cache result = grouped.agg(func_2, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython") tm.assert_equal(result, expected) - assert (func_2, "groupby_agg") in NUMBA_FUNC_CACHE # Retest func_1 which should use the cache result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs) @@ -187,3 +183,31 @@ def f(values, index): [-1.5, -3.0], columns=["v"], index=Index(["A", "B"], name="group") ) tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +def test_engine_kwargs_not_cached(): + # If the user passes a different set of engine_kwargs don't return the same + # jitted function + nogil = True + parallel = False + nopython = True + + def func_kwargs(values, index): + return nogil + parallel + nopython + + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + df = DataFrame({"value": [0, 0, 0]}) + result = df.groupby(level=0).aggregate( + func_kwargs, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + + nogil = False + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.groupby(level=0).aggregate( + func_kwargs, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [1.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 4e1b777296d5b..a404e0b9304cc 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -9,7 +9,6 @@ option_context, ) import pandas._testing as tm -from pandas.core.util.numba_ import NUMBA_FUNC_CACHE @td.skip_if_no("numba") @@ -30,8 +29,8 @@ def incorrect_function(x): @td.skip_if_no("numba") def test_check_nopython_kwargs(): - def incorrect_function(x, **kwargs): - return x + 1 + def incorrect_function(values, index): + return values + 1 data = DataFrame( {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, @@ -103,14 +102,10 @@ def func_2(values, index): result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_equal(result, expected) - # func_1 should be in the cache now - assert (func_1, "groupby_transform") in NUMBA_FUNC_CACHE - # Add func_2 to the cache result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x * 5, engine="cython") tm.assert_equal(result, expected) - assert (func_2, "groupby_transform") in NUMBA_FUNC_CACHE # Retest func_1 which should use the cache result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) @@ -176,3 +171,31 @@ def f(values, index): result = df.groupby("group").transform(f, engine="numba") expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3]) tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +def test_engine_kwargs_not_cached(): + # If the user passes a different set of engine_kwargs don't return the same + # jitted function + nogil = True + parallel = False + nopython = True + + def func_kwargs(values, index): + return nogil + parallel + nopython + + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + df = DataFrame({"value": [0, 0, 0]}) + result = df.groupby(level=0).transform( + func_kwargs, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + + nogil = False + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.groupby(level=0).transform( + func_kwargs, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [1.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index b514891ae5c92..6fd45606ae98d 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -16,7 +16,6 @@ to_datetime, ) import pandas._testing as tm -from pandas.core.util.numba_ import NUMBA_FUNC_CACHE # TODO(GH#44584): Mark these as pytest.mark.single_cpu pytestmark = pytest.mark.skipif( @@ -95,14 +94,6 @@ def test_numba_vs_cython_rolling_methods( engine="numba", engine_kwargs=engine_kwargs, **kwargs ) expected = getattr(roll, method)(engine="cython", **kwargs) - - # Check the cache - if method not in ("mean", "sum", "var", "std", "max", "min"): - assert ( - getattr(np, f"nan{method}"), - "Rolling_apply_single", - ) in NUMBA_FUNC_CACHE - tm.assert_equal(result, expected) @pytest.mark.parametrize( @@ -122,14 +113,6 @@ def test_numba_vs_cython_expanding_methods( engine="numba", engine_kwargs=engine_kwargs, **kwargs ) expected = getattr(expand, method)(engine="cython", **kwargs) - - # Check the cache - if method not in ("mean", "sum", "var", "std", "max", "min"): - assert ( - getattr(np, f"nan{method}"), - "Expanding_apply_single", - ) in NUMBA_FUNC_CACHE - tm.assert_equal(result, expected) @pytest.mark.parametrize("jit", [True, False]) @@ -156,9 +139,6 @@ def func_2(x): expected = roll.apply(func_1, engine="cython", raw=True) tm.assert_series_equal(result, expected) - # func_1 should be in the cache now - assert (func_1, "Rolling_apply_single") in NUMBA_FUNC_CACHE - result = roll.apply( func_2, engine="numba", engine_kwargs=engine_kwargs, raw=True ) @@ -200,6 +180,32 @@ def add(values, x): expected = DataFrame({"value": [2.0, 2.0, 2.0]}) tm.assert_frame_equal(result, expected) + def test_dont_cache_engine_kwargs(self): + # If the user passes a different set of engine_kwargs don't return the same + # jitted function + nogil = False + parallel = True + nopython = True + + def func(x): + return nogil + parallel + nopython + + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + df = DataFrame({"value": [0, 0, 0]}) + result = df.rolling(1).apply( + func, raw=True, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + + parallel = False + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.rolling(1).apply( + func, raw=True, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [1.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) + @td.skip_if_no("numba") class TestEWM: