You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, the extraction in egglog is rather limited. It does a tree-based extraction (meaning that if a node shows up twice, it will be counted twice) and requires static costs per function.
The first issue, the type of extractor, could be alleviated by using some extractors from extraction gym. The second, having some custom costs per item could be addressed upstream in egglog (egraphs-good/egglog#294) but is not on the immediate roadmap.
Either way, it would also be nice to have fully custom extraction. Being able to iterate through the e-graph and do what you will...
Currently, it's "possible" by serializing the e-graph to JSON. But this is not ideal because then you have to look at JSON with random keys and they might not map to your Python function names and it's not type safe and... Yeah it's just a real pain!
So I think it would make sense to add an interface that allows:
Using custom extractors from extraction-gym without leaving the Python bindings by building that with egglog.
Being able to set custom costs per record before extracting.
Being able to query costs and write your own extractor in Python
Do all of this while keeping static-type safety.
Reduce overhead as much as possible in terms of serialization and wrapping/unwrapping.
Possible Design
Here is a possible API design for the extractors
"""Examples using egglog."""from __future__ importannotationsfromtypingimportLiteral, Protocol, TypeVarfromegglogimportExprEXPR=TypeVar("EXPR", bound=Expr)
classEGraph:
defserialize(self) ->SerializedEGraph:
""" Serializes the e-graph into a format that can be passed to an extractor. """raiseNotImplementedErrordefextract(
self, x: EXPR, /, extractor: Extractor, include_cost: Literal["tree", "dag"] |None
) ->EXPR|tuple[EXPR, int]:
""" Extracts the given expression using the given extractor, optionally including the cost of the extraction. """extract_result=extractor.extract(self.serialize(), [x])
res=extract_result.chosen(x)
ifinclude_costisNone:
returnrescost=extract_result.tree_cost([x]) ifinclude_cost=="tree"elseextract_result.dag_cost([x])
returnres, costclassSerializedEGraph:
defequivalent(self, x: EXPR, /) ->list[EXPR]:
""" Returns all equivalent expressions. i.e. all expressions with the same e-class. """raiseNotImplementedErrordefself_cost(self, x: Expr, /) ->int:
""" Returns the cost of just that function, not including its children. """raiseNotImplementedErrordefset_self_cost(self, x: Expr, cost: int, /) ->None:
""" Sets the cost of just that function, not including its children. """raiseNotImplementedErrorclassExtractor(Protocol):
defextract(self, egraph: SerializedEGraph, roots: list[Expr]) ->ExtractionResult: ...
classExtractionResult:
""" An extraction result is a mapping from an e-class to chosen nodes, paired with the original extracted e-graph. Based off of https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/mod.rs but removed e-classes since they are not present in Python bindings and instead let you pass in any member of that e-class and get out representative nodes. """egraph: SerializedEGraphdef__init__(self, egraph: SerializedEGraph) ->None: ...
defchoose(self, class_: EXPR, chosen_node: EXPR, /) ->None:
""" Choose an expression in the e-graph. """defchosen(self, x: EXPR, /) ->EXPR:
""" Given an expr that is in the e-graph, it recursively returns the chosen expressions in each e-class. """raiseNotImplementedErrordefcheck(self) ->None:
""" Check the extraction result for consistency. """raiseNotImplementedErrordeffind_cycles(self, roots: list[Expr]) ->list[Expr]:
""" Returns all classes that are in a cycle, which is reachable from the roots. """raiseNotImplementedErrordeftree_cost(self, roots: list[Expr]) ->int:
""" Returns the "tree cost" (counting duplicate nodes twice) of the trees rooted at the given roots. """raiseNotImplementedErrordefdag_cost(self, roots: list[Expr]) ->int:
""" Returns the "dag cost" (counting duplicate nodes once) of the dag rooted at the given roots. """raiseNotImplementedError
Using this interface, you could use the default costs form egglog and use a custom extractor, as shown in the helper extract method.
However, you could also set custom costs before serializing, overriding any from egglog:
How would you be able to traverse an expression at runtime and see its children? I think with three small additions, we could be able to do this with our current API:
Primitives: Allow any primitive to be converted to a Python object with, i.e. int(i64(0))
User Defined Constants: Allow bool(eq(x).to(y)) which will resolve to whether the two sides are exactly syntactically equal.
User Defined Functions: Support a new way to get the args in a type safe manner based on a function, i.e. fn_matches(x, f) would return a boolean to say whether the function matches, and then fn_args(x, f) would return a list of the args. They could be typed like this:
class_FnMatchesBuilder(Generic[EXPR]):
deffn(self, y: Callable[[Unpack[EXPRS]], EXPR], /) ->tuple[Unpack[EXPRS]] |None:
""" Returns the list of args or None """raiseNotImplementedErrorEXPRS=TypeVarTuple("EXPRS")
defmatches(x: EXPR) ->_FnMatchesBuilder[EXPR]:
raiseNotImplementedErrorifargs:=matches(x).fn(y):
x, y, z=args
Alternatively, how would you create a custom extractor? We would want to add one more way to traverse expressions... This time not caring about what particular expression they are, just their args and a way to re-ccreate them with different args. Using that, we could write a simple tree based extractor:
defdecompose(x: EXPR, /) ->tuple[ReturnsExpr[EXPR], list[Expr]]:
""" Decomposes an expression into a callable that will reconstruct it based on its args. For all expressions, constants or functions, this should hold: >>> fn, args = decompose(x) >>> assert fn(*args) == x This can be used to change the args of a function and then reconstruct it. If you are looking for a type safe way to deal with a particular constructor, you can use either `eq(x).to(y)` for constants or `match(x).fn(y)` for functions to get their args in a type safe manner. """raiseNotImplementedErrordeftree_based_extractor(serialized: SerializedEGraph, expr: EXPR, /) ->tuple[EXPR, int]:
""" Returns the lowest cost equivalent expression and the cost of that expression, based on a tree based extraction. """min_expr, min_cost=expr_cost(serialized, expr)
foreqinserialized.equivalent(expr):
new_expr, new_cost=expr_cost(serialized, eq)
ifnew_cost<min_cost:
min_cost=new_costmin_expr=new_exprreturnmin_expr, min_costdefexpr_cost(serialized: SerializedEGraph, expr: EXPR, /) ->tuple[EXPR, int]:
""" Returns the cost of the given expression. """cost=serialized.self_cost(expr)
constructor, children=decompose(expr)
best_children= []
forchildinchildren:
best_child, child_cost=tree_based_extractor(serialized, child)
cost+=child_costbest_children.append(best_child)
returnconstructor(*best_children), cost
The text was updated successfully, but these errors were encountered:
I think we would also need a way to get all parent nodes from a node in the serialized format for doing custom cost traversal... For example you might set some kind of length of a vec, and then want to look that up when computing costs.
The example expands x ** 4 to x * x * x * x, such that AST size cost will not work. The custom cost-model will penalize the Pow a lot so it will select the Mul variant.
How would you be able to traverse an expression at runtime and see its children? I think with three small additions, we could be able to do this with our current API:
This is currently hard to do and therefore omitted in my PoC. I would want to know that it is a Pow(x, 4) and compute cost knowing the 4.
For short term, is there a way to associate node in the serialized json back to the egglog-python Expr object?
Alternatively, how would you create a custom extractor?
I'm very interested in the decompose() and constructor(). Our workflow is compiler IR -> egglog -> compiler IR. We need the extracted result to be translated back to the IR nodes.
Currently, the extraction in egglog is rather limited. It does a tree-based extraction (meaning that if a node shows up twice, it will be counted twice) and requires static costs per function.
The first issue, the type of extractor, could be alleviated by using some extractors from extraction gym. The second, having some custom costs per item could be addressed upstream in egglog (egraphs-good/egglog#294) but is not on the immediate roadmap.
Either way, it would also be nice to have fully custom extraction. Being able to iterate through the e-graph and do what you will...
Currently, it's "possible" by serializing the e-graph to JSON. But this is not ideal because then you have to look at JSON with random keys and they might not map to your Python function names and it's not type safe and... Yeah it's just a real pain!
So I think it would make sense to add an interface that allows:
Possible Design
Here is a possible API design for the extractors
Using this interface, you could use the default costs form egglog and use a custom extractor, as shown in the helper
extract
method.However, you could also set custom costs before serializing, overriding any from egglog:
How would you be able to traverse an expression at runtime and see its children? I think with three small additions, we could be able to do this with our current API:
int(i64(0))
bool(eq(x).to(y))
which will resolve to whether the two sides are exactly syntactically equal.fn_matches(x, f)
would return a boolean to say whether the function matches, and thenfn_args(x, f)
would return a list of the args. They could be typed like this:Alternatively, how would you create a custom extractor? We would want to add one more way to traverse expressions... This time not caring about what particular expression they are, just their args and a way to re-ccreate them with different args. Using that, we could write a simple tree based extractor:
The text was updated successfully, but these errors were encountered: