Skip to content

Commit

Permalink
feat: enable preprocessing of expressions (#461)
Browse files Browse the repository at this point in the history
This puts through a reasonable change to the compiler pipeline by
introducing a preprocessing stage.  This occurrs are resolution, but
before translation.  The main purpose is to expand invocations,
reductions and for loops.  This greatly simplies their subsequent
translation.
  • Loading branch information
DavePearce authored Dec 19, 2024
1 parent 5a2f264 commit b46f334
Show file tree
Hide file tree
Showing 8 changed files with 471 additions and 267 deletions.
5 changes: 3 additions & 2 deletions pkg/corset/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"math"
"reflect"

"github.com/consensys/go-corset/pkg/sexp"
tr "github.com/consensys/go-corset/pkg/trace"
)

Expand Down Expand Up @@ -128,14 +129,14 @@ func (p *FunctionSignature) SubtypeOf(other *FunctionSignature) bool {

// Apply a set of concreate arguments to this function. This substitutes
// them through the body of the function producing a single expression.
func (p *FunctionSignature) Apply(args []Expr) Expr {
func (p *FunctionSignature) Apply(args []Expr, srcmap *sexp.SourceMaps[Node]) Expr {
mapping := make(map[uint]Expr)
// Setup the mapping
for i, e := range args {
mapping[uint(i)] = e
}
// Substitute through
return p.body.Substitute(mapping)
return Substitute(p.body, mapping, srcmap)
}

// ============================================================================
Expand Down
6 changes: 5 additions & 1 deletion pkg/corset/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,14 @@ func (p *Compiler) Compile() (*hir.Schema, []SyntaxError) {
if len(errs) != 0 {
return nil, errs
}
// Preprocess circuit to remove invocations, reductions, etc.
if errs := PreprocessCircuit(p.debug, p.srcmap, &p.circuit); len(errs) > 0 {
return nil, errs
}
// Convert global scope into an environment by allocating all columns.
environment := scope.ToEnvironment()
// Finally, translate everything and add it to the schema.
return TranslateCircuit(environment, p.debug, p.srcmap, &p.circuit)
return TranslateCircuit(environment, p.srcmap, &p.circuit)
}

func includeStdlib(stdlib bool, srcfiles []*sexp.SourceFile) []*sexp.SourceFile {
Expand Down
200 changes: 83 additions & 117 deletions pkg/corset/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package corset
import (
"fmt"
"math/big"
"reflect"

"github.com/consensys/go-corset/pkg/sexp"
tr "github.com/consensys/go-corset/pkg/trace"
Expand All @@ -26,16 +27,10 @@ type Expr interface {
// lists return one value for each element in the list. Note, every
// expression must return at least one value.
Multiplicity() uint

// Context returns the context for this expression. Observe that the
// expression must have been resolved for this to be defined (i.e. it may
// panic if it has not been resolved yet).
Context() Context

// Substitute all variables (such as for function parameters) arising in
// this expression.
Substitute(mapping map[uint]Expr) Expr

// Return set of columns on which this declaration depends.
Dependencies() []Symbol
}
Expand Down Expand Up @@ -76,12 +71,6 @@ func (e *Add) Lisp() sexp.SExp {
return ListOfExpressions(sexp.NewSymbol("+"), e.Args)
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Add) Substitute(mapping map[uint]Expr) Expr {
return &Add{SubstituteExpressions(e.Args, mapping)}
}

// Dependencies needed to signal declaration.
func (e *Add) Dependencies() []Symbol {
return DependenciesOfExpressions(e.Args)
Expand Down Expand Up @@ -164,12 +153,6 @@ func (e *ArrayAccess) Lisp() sexp.SExp {
})
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *ArrayAccess) Substitute(mapping map[uint]Expr) Expr {
return &ArrayAccess{e.name, e.arg.Substitute(mapping), e.binding}
}

// Resolve this symbol by associating it with the binding associated with
// the definition of the symbol to which this refers.
func (e *ArrayAccess) Resolve(binding Binding) bool {
Expand Down Expand Up @@ -222,12 +205,6 @@ func (e *Constant) Lisp() sexp.SExp {
return sexp.NewSymbol(e.Val.String())
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Constant) Substitute(mapping map[uint]Expr) Expr {
return e
}

// Dependencies needed to signal declaration.
func (e *Constant) Dependencies() []Symbol {
return nil
Expand Down Expand Up @@ -271,12 +248,6 @@ func (e *Debug) Lisp() sexp.SExp {
e.Arg.Lisp()})
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Debug) Substitute(mapping map[uint]Expr) Expr {
return &Debug{e.Arg.Substitute(mapping)}
}

// Dependencies needed to signal declaration.
func (e *Debug) Dependencies() []Symbol {
return e.Arg.Dependencies()
Expand Down Expand Up @@ -327,12 +298,6 @@ func (e *Exp) Lisp() sexp.SExp {
e.Pow.Lisp()})
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Exp) Substitute(mapping map[uint]Expr) Expr {
return &Exp{e.Arg.Substitute(mapping), e.Pow.Substitute(mapping)}
}

// Dependencies needed to signal declaration.
func (e *Exp) Dependencies() []Symbol {
return DependenciesOfExpressions([]Expr{e.Arg, e.Pow})
Expand Down Expand Up @@ -385,13 +350,6 @@ func (e *For) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *For) Substitute(mapping map[uint]Expr) Expr {
body := e.Body.Substitute(mapping)
return &For{e.Binding, e.Start, e.End, body}
}

// Dependencies needed to signal declaration.
func (e *For) Dependencies() []Symbol {
// Remove occurrences of the index variable defined by this expression. In
Expand Down Expand Up @@ -493,15 +451,6 @@ func (e *If) Lisp() sexp.SExp {
e.TrueBranch.Lisp()})
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *If) Substitute(mapping map[uint]Expr) Expr {
return &If{e.kind, e.Condition.Substitute(mapping),
SubstituteOptionalExpression(e.TrueBranch, mapping),
SubstituteOptionalExpression(e.FalseBranch, mapping),
}
}

// Dependencies needed to signal declaration.
func (e *If) Dependencies() []Symbol {
return DependenciesOfExpressions([]Expr{e.Condition, e.TrueBranch, e.FalseBranch})
Expand All @@ -525,7 +474,7 @@ func (e *Invoke) AsConstant() *big.Int {
panic("unresolved invocation")
}
// Unroll body
body := e.signature.Apply(e.args)
body := e.signature.Apply(e.args, nil)
// Attempt to evaluate as constant
return body.AsConstant()
}
Expand Down Expand Up @@ -561,12 +510,6 @@ func (e *Invoke) Finalise(signature *FunctionSignature) {
e.signature = signature
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Invoke) Substitute(mapping map[uint]Expr) Expr {
return &Invoke{e.fn, e.signature, SubstituteExpressions(e.args, mapping)}
}

// Dependencies needed to signal declaration.
func (e *Invoke) Dependencies() []Symbol {
deps := DependenciesOfExpressions(e.args)
Expand Down Expand Up @@ -608,12 +551,6 @@ func (e *List) Lisp() sexp.SExp {
return ListOfExpressions(sexp.NewSymbol("begin"), e.Args)
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *List) Substitute(mapping map[uint]Expr) Expr {
return &List{SubstituteExpressions(e.Args, mapping)}
}

// Dependencies needed to signal declaration.
func (e *List) Dependencies() []Symbol {
return DependenciesOfExpressions(e.Args)
Expand Down Expand Up @@ -652,12 +589,6 @@ func (e *Mul) Lisp() sexp.SExp {
return ListOfExpressions(sexp.NewSymbol("*"), e.Args)
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Mul) Substitute(mapping map[uint]Expr) Expr {
return &Mul{SubstituteExpressions(e.Args, mapping)}
}

// Dependencies needed to signal declaration.
func (e *Mul) Dependencies() []Symbol {
return DependenciesOfExpressions(e.Args)
Expand Down Expand Up @@ -699,12 +630,6 @@ func (e *Normalise) Lisp() sexp.SExp {
e.Arg.Lisp()})
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Normalise) Substitute(mapping map[uint]Expr) Expr {
return &Normalise{e.Arg.Substitute(mapping)}
}

// Dependencies needed to signal declaration.
func (e *Normalise) Dependencies() []Symbol {
return e.Arg.Dependencies()
Expand Down Expand Up @@ -750,16 +675,6 @@ func (e *Reduce) Lisp() sexp.SExp {
e.arg.Lisp()})
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Reduce) Substitute(mapping map[uint]Expr) Expr {
return &Reduce{
e.fn,
e.signature,
e.arg.Substitute(mapping),
}
}

// Finalise the signature for this reduction.
func (e *Reduce) Finalise(signature *FunctionSignature) {
if signature == nil {
Expand Down Expand Up @@ -810,12 +725,6 @@ func (e *Sub) Lisp() sexp.SExp {
return ListOfExpressions(sexp.NewSymbol("-"), e.Args)
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Sub) Substitute(mapping map[uint]Expr) Expr {
return &Sub{SubstituteExpressions(e.Args, mapping)}
}

// Dependencies needed to signal declaration.
func (e *Sub) Dependencies() []Symbol {
return DependenciesOfExpressions(e.Args)
Expand Down Expand Up @@ -867,12 +776,6 @@ func (e *Shift) Lisp() sexp.SExp {
e.Shift.Lisp()})
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Shift) Substitute(mapping map[uint]Expr) Expr {
return &Shift{e.Arg.Substitute(mapping), e.Shift.Substitute(mapping)}
}

// Dependencies needed to signal declaration.
func (e *Shift) Dependencies() []Symbol {
return DependenciesOfExpressions([]Expr{e.Arg, e.Shift})
Expand Down Expand Up @@ -989,18 +892,6 @@ func (e *VariableAccess) Lisp() sexp.SExp {
return sexp.NewSymbol(name)
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *VariableAccess) Substitute(mapping map[uint]Expr) Expr {
if b, ok1 := e.binding.(*LocalVariableBinding); ok1 {
if e, ok2 := mapping[b.index]; ok2 {
return e
}
}
// Nothing to do here
return e
}

// Dependencies needed to signal declaration.
func (e *VariableAccess) Dependencies() []Symbol {
return []Symbol{e}
Expand All @@ -1025,23 +916,98 @@ func ContextOfExpressions(exprs []Expr) Context {
return context
}

// SubstituteExpressions substitutes all variables found in a given set of
// Substitute variables (such as for function parameters) in this expression
// based on a mapping of said variables to expressions. Furthermore, an
// (optional) source map is provided which will be updated, such that the
// freshly created expressions are mapped to their corresponding nodes.
func Substitute(expr Expr, mapping map[uint]Expr, srcmap *sexp.SourceMaps[Node]) Expr {
var nexpr Expr
//
switch e := expr.(type) {
case *ArrayAccess:
arg := Substitute(e.arg, mapping, srcmap)
nexpr = &ArrayAccess{e.name, arg, e.binding}
case *Add:
args := SubstituteAll(e.Args, mapping, srcmap)
nexpr = &Add{args}
case *Constant:
return e
case *Debug:
arg := Substitute(e.Arg, mapping, srcmap)
nexpr = &Debug{arg}
case *Exp:
arg := Substitute(e.Arg, mapping, srcmap)
pow := Substitute(e.Pow, mapping, srcmap)
// Done
nexpr = &Exp{arg, pow}
case *For:
body := Substitute(e.Body, mapping, srcmap)
nexpr = &For{e.Binding, e.Start, e.End, body}
case *If:
condition := Substitute(e.Condition, mapping, srcmap)
trueBranch := SubstituteOptional(e.TrueBranch, mapping, srcmap)
falseBranch := SubstituteOptional(e.FalseBranch, mapping, srcmap)
// Construct appropriate if form
nexpr = &If{e.kind, condition, trueBranch, falseBranch}
case *Invoke:
args := SubstituteAll(e.args, mapping, srcmap)
nexpr = &Invoke{e.fn, e.signature, args}
case *List:
args := SubstituteAll(e.Args, mapping, srcmap)
nexpr = &List{args}
case *Mul:
args := SubstituteAll(e.Args, mapping, srcmap)
nexpr = &Mul{args}
case *Normalise:
arg := Substitute(e.Arg, mapping, srcmap)
nexpr = &Normalise{arg}
case *Reduce:
arg := Substitute(e.arg, mapping, srcmap)
nexpr = &Reduce{e.fn, e.signature, arg}
case *Sub:
args := SubstituteAll(e.Args, mapping, srcmap)
nexpr = &Sub{args}
case *Shift:
arg := Substitute(e.Arg, mapping, srcmap)
nexpr = &Shift{arg, e.Shift}
case *VariableAccess:
//
if b, ok1 := e.binding.(*LocalVariableBinding); !ok1 {
return e
} else if e2, ok2 := mapping[b.index]; !ok2 {
return e
} else {
return e2
}
default:
panic(fmt.Sprintf("unknown expression (%s)", reflect.TypeOf(expr)))
}
//
if srcmap != nil {
// Copy over source information
srcmap.Copy(expr, nexpr)
}
// Done
return nexpr
}

// SubstituteAll substitutes all variables found in a given set of
// expressions.
func SubstituteExpressions(exprs []Expr, mapping map[uint]Expr) []Expr {
func SubstituteAll(exprs []Expr, mapping map[uint]Expr, srcmap *sexp.SourceMaps[Node]) []Expr {
nexprs := make([]Expr, len(exprs))
//
for i := 0; i < len(nexprs); i++ {
nexprs[i] = exprs[i].Substitute(mapping)
nexprs[i] = Substitute(exprs[i], mapping, srcmap)
}
//
return nexprs
}

// SubstituteOptionalExpression substitutes through an expression which is
// SubstituteOptional substitutes through an expression which is
// optional (i.e. might be nil). In such case, nil is returned.
func SubstituteOptionalExpression(expr Expr, mapping map[uint]Expr) Expr {
func SubstituteOptional(expr Expr, mapping map[uint]Expr, srcmap *sexp.SourceMaps[Node]) Expr {
if expr != nil {
expr = expr.Substitute(mapping)
expr = Substitute(expr, mapping, srcmap)
}
//
return expr
Expand Down
Loading

0 comments on commit b46f334

Please sign in to comment.