Skip to content

Commit

Permalink
primitives: deprecate wrap_in_cse
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Nov 9, 2024
1 parent 1352ad1 commit 670bbbb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 54 deletions.
9 changes: 6 additions & 3 deletions pymbolic/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def get_cse(self, expr, key=None):
try:
return self.canonical_subexprs[key]
except KeyError:
new_expr = prim.wrap_in_cse(
getattr(IdentityMapper, expr.mapper_method)(self, expr))
new_expr = prim.make_common_subexpression(
getattr(IdentityMapper, expr.mapper_method)(self, expr)
)
self.canonical_subexprs[key] = new_expr
return new_expr

Expand All @@ -113,7 +114,9 @@ def map_sum(self, expr):
def map_common_subexpression(self, expr):
# Avoid creating CSE(CSE(...))
if type(expr) is prim.CommonSubexpression:
return prim.wrap_in_cse(self.rec(expr.child), expr.prefix)
return prim.make_common_subexpression(
self.rec(expr.child), expr.prefix, expr.scope
)
else:
# expr is of a derived CSE type
result = self.rec(expr.child)
Expand Down
97 changes: 46 additions & 51 deletions pymbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,29 +1926,18 @@ def is_zero(value: object) -> bool:
def wrap_in_cse(expr: ExpressionT,
prefix: str | None = None,
scope: str | None = None) -> ExpressionT:
if isinstance(expr, Variable | Subscript):
return expr

if scope is None:
scope = cse_scope.EVALUATION
warn("'wrap_in_cse' is deprecated and will be removed in 2025. Use "
"'make_common_subexpression' with the `wrap_vars=False` flag instead.",
DeprecationWarning, stacklevel=2)

if isinstance(expr, CommonSubexpression):
if prefix is None:
return expr

if expr.prefix is None and type(expr) is CommonSubexpression:
return CommonSubexpression(expr.child, prefix, scope)

# existing prefix wins
return expr

else:
return CommonSubexpression(expr, prefix, scope)
return make_common_subexpression(expr, prefix, scope, wrap_vars=False)


def make_common_subexpression(expr: ExpressionT,
prefix: str | None = None,
scope: str | None = None) -> ExpressionT:
scope: str | None = None,
*,
wrap_vars: bool = True) -> ExpressionT:
"""Wrap *expr* in a :class:`CommonSubexpression` with *prefix*.
If *expr* is a :mod:`numpy` object array, each individual entry is instead
Expand All @@ -1958,68 +1947,74 @@ def make_common_subexpression(expr: ExpressionT,
See :class:`CommonSubexpression` for the meaning of *prefix* and *scope*. The
scope defaults to :attr:`cse_scope.EVALUATION`.
:arg wrap_vars: If *True*, this also wraps a :class:`~pymbolic.primitives.Variable`
and its subscripts. Otherwise, these expressions are returned as is.
"""

if scope is None:
scope = cse_scope.EVALUATION
if is_constant(expr):
return expr

if (isinstance(expr, CommonSubexpression)
and (scope == cse_scope.EVALUATION or expr.scope == scope)):
# Don't re-wrap
if not wrap_vars and isinstance(expr, Variable | Subscript):
return expr

try:
import numpy
# handle CSE re-wrapping
if scope is None:
scope = cse_scope.EVALUATION

if isinstance(expr, numpy.ndarray) and expr.dtype.char == "O":
is_obj_array = True
logical_shape = expr.shape
else:
is_obj_array = False
logical_shape = ()
except ImportError:
is_obj_array = False
logical_shape = ()
if isinstance(expr, CommonSubexpression):
if scope == cse_scope.EVALUATION or expr.scope == scope:
return expr

# handle MultiVector
from pymbolic.geometric_algebra import MultiVector

if isinstance(expr, MultiVector):
new_data = {}
for bits, coeff in expr.data.items():
if prefix is not None:
blade_str = expr.space.blade_bits_to_str(bits, "")
component_prefix = prefix+"_"+blade_str
component_prefix = f"{prefix}_{blade_str}"
else:
component_prefix = None

new_data[bits] = make_common_subexpression(
coeff, component_prefix, scope)
coeff, component_prefix, scope, wrap_vars=wrap_vars)

return MultiVector(new_data, expr.space)

elif is_obj_array and logical_shape != ():
assert isinstance(expr, numpy.ndarray)
# handle numpy object arrays
try:
import numpy as np

result = numpy.zeros(logical_shape, dtype=object)
for i in numpy.ndindex(logical_shape):
if isinstance(expr, np.ndarray) and expr.dtype.char == "O":
is_obj_array = True
logical_shape = expr.shape
else:
is_obj_array = False
logical_shape = ()
except ImportError:
is_obj_array = False
logical_shape = ()

if is_obj_array and logical_shape != ():
assert isinstance(expr, np.ndarray)

result = np.zeros(logical_shape, dtype=object)
for i in np.ndindex(logical_shape):
if prefix is not None:
component_prefix = prefix+"_".join(str(i_i) for i_i in i)
bits = "_".join(str(i_i) for i_i in i)
component_prefix = f"{prefix}_{bits}"
else:
component_prefix = None

if is_constant(expr[i]):
result[i] = expr[i]
else:
result[i] = make_common_subexpression(
expr[i], component_prefix, scope)
result[i] = make_common_subexpression(
expr[i], component_prefix, scope, wrap_vars=wrap_vars)

return result

else:
if is_constant(expr):
return expr
else:
return CommonSubexpression(expr, prefix, scope)
# everything else gets re-wrapped
return CommonSubexpression(expr, prefix, scope)


def make_sym_vector(name, components, var_factory=Variable):
Expand Down

0 comments on commit 670bbbb

Please sign in to comment.