Skip to content

Commit

Permalink
fix: failing wcp trace (#481)
Browse files Browse the repository at this point in the history
* Fix evaluation of constant expressions

This fixes a problem related to the evaluation of constant of
expressions.

* Fix for Expr.AsConstant()

There was an aliasing bug contained herein.
  • Loading branch information
DavePearce authored Dec 20, 2024
1 parent 2441ddd commit 01aa174
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 13 deletions.
21 changes: 12 additions & 9 deletions pkg/corset/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type Add struct{ Args []Expr }
// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *Add) AsConstant() *big.Int {
fn := func(l *big.Int, r *big.Int) *big.Int { l.Add(l, r); return l }
fn := func(l *big.Int, r *big.Int) { l.Add(l, r) }
return AsConstantOfExpressions(e.Args, fn)
}

Expand Down Expand Up @@ -566,7 +566,7 @@ type Mul struct{ Args []Expr }
// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *Mul) AsConstant() *big.Int {
fn := func(l *big.Int, r *big.Int) *big.Int { l.Mul(l, r); return l }
fn := func(l *big.Int, r *big.Int) { l.Mul(l, r) }
return AsConstantOfExpressions(e.Args, fn)
}

Expand Down Expand Up @@ -702,7 +702,7 @@ type Sub struct{ Args []Expr }
// AsConstant attempts to evaluate this expression as a constant (signed) value.
// If this expression is not constant, then nil is returned.
func (e *Sub) AsConstant() *big.Int {
fn := func(l *big.Int, r *big.Int) *big.Int { l.Sub(l, r); return l }
fn := func(l *big.Int, r *big.Int) { l.Sub(l, r) }
return AsConstantOfExpressions(e.Args, fn)
}

Expand Down Expand Up @@ -1046,19 +1046,22 @@ func ListOfExpressions(head sexp.SExp, exprs []Expr) *sexp.List {
// given operation (e.g. add, subtract, etc) to produce a constant value. If
// any of the expressions are not themselves constant, then neither is the
// result.
func AsConstantOfExpressions(exprs []Expr, fn func(*big.Int, *big.Int) *big.Int) *big.Int {
var val *big.Int = big.NewInt(0)
func AsConstantOfExpressions(exprs []Expr, fn func(*big.Int, *big.Int)) *big.Int {
var val big.Int
//
for _, arg := range exprs {
for i, arg := range exprs {
c := arg.AsConstant()
if c == nil {
return nil
} else if i == 0 {
// Must clone c
val.Set(c)
} else {
fn(&val, c)
}
// Evaluate function
val = fn(val, c)
}
//
return val
return &val
}

func determineMultiplicity(exprs []Expr) uint {
Expand Down
20 changes: 16 additions & 4 deletions pkg/test/valid_corset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ func Test_Constant_07(t *testing.T) {
Check(t, false, "constant_07")
}

func Test_Constant_08(t *testing.T) {
Check(t, false, "constant_08")
}

func Test_Constant_09(t *testing.T) {
Check(t, false, "constant_09")
}

// ===================================================================
// Alias Tests
// ===================================================================
Expand Down Expand Up @@ -336,6 +344,10 @@ func Test_If_10(t *testing.T) {
Check(t, false, "if_10")
}

func Test_If_11(t *testing.T) {
Check(t, false, "if_11")
}

// ===================================================================
// Guards
// ===================================================================
Expand Down Expand Up @@ -794,19 +806,19 @@ func Check(t *testing.T, stdlib bool, test string) {
t.Fatalf("Error parsing %s: %v\n", filename, errs)
}
// Check valid traces are accepted
accepts_file := fmt.Sprintf("%s.%s", test, "accepts")
accepts_file := fmt.Sprintf("%s/%s.%s", TestDir, test, "accepts")
accepts := ReadTracesFile(accepts_file)
CheckTraces(t, accepts_file, true, true, accepts, schema)
// Check invalid traces are rejected
rejects_file := fmt.Sprintf("%s.%s", test, "rejects")
rejects_file := fmt.Sprintf("%s/%s.%s", TestDir, test, "rejects")
rejects := ReadTracesFile(rejects_file)
CheckTraces(t, rejects_file, false, true, rejects, schema)
// Check expanded traces are rejected
expands_file := fmt.Sprintf("%s.%s", test, "expanded")
expands_file := fmt.Sprintf("%s/%s.%s", TestDir, test, "expanded")
expands := ReadTracesFile(expands_file)
CheckTraces(t, expands_file, false, false, expands, schema)
// Check auto-generated valid traces (if applicable)
auto_accepts_file := fmt.Sprintf("%s.%s", test, "auto.accepts")
auto_accepts_file := fmt.Sprintf("%s/%s.%s", TestDir, test, "auto.accepts")
if auto_accepts := ReadTracesFileIfExists(auto_accepts_file); auto_accepts != nil {
CheckTraces(t, auto_accepts_file, true, true, auto_accepts, schema)
}
Expand Down
1 change: 1 addition & 0 deletions testdata/constant_08.accepts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"C": [15, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], "N": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "B": [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0], "L": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
17 changes: 17 additions & 0 deletions testdata/constant_08.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
(defpurefun ((eq! :@loob) x y) (- x y))

(defcolumns
(C :byte)
(L :binary)
(B :binary)
(N :binary))

;; opcode values
(defconst
LLARGE 16
LLARGEMO (- LLARGE 1))

(defconstraint bits-and-negs (:guard L)
(if (eq! C LLARGEMO)
(eq! N
(shift B (- 0 LLARGEMO)))))
1 change: 1 addition & 0 deletions testdata/constant_09.accepts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"CT": [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], "NEG_1": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "BYTE_3": [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255], "BYTE_1": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "BITS": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0], "IS_SLT": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "NEG_2": [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
21 changes: 21 additions & 0 deletions testdata/constant_09.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
(defpurefun ((eq! :@loob) x y) (- x y))
(defpurefun (if-eq x val then) (if (eq! x val) then))
;;
(defcolumns
(CT :byte)
(IS_SLT :binary@prove)
(BITS :binary@prove)
(NEG_1 :binary@prove)
(NEG_2 :binary@prove)
(BYTE_1 :byte@prove)
(BYTE_3 :byte@prove)
)

;; opcode values
(defconst
LLARGE 16
LLARGEMO (- LLARGE 1))

(defconstraint bits-and-negs (:guard IS_SLT)
(if-eq CT LLARGEMO
(eq! NEG_2 (shift BITS (- 0 7)))))
1 change: 1 addition & 0 deletions testdata/if_11.accepts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"CT": [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], "BITS": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0], "IS_SLT": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "NEG_2": [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
12 changes: 12 additions & 0 deletions testdata/if_11.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(defpurefun ((eq! :@loob) x y) (- x y))
(defpurefun (if-eq x val then) (if (eq! x val) then))
;;
(defcolumns
(CT :byte)
(IS_SLT :binary@prove)
(BITS :binary@prove)
(NEG_2 :binary@prove))

(defconstraint bits-and-negs (:guard IS_SLT)
(if-eq CT 15
(eq! NEG_2 (shift BITS (- 0 7)))))

0 comments on commit 01aa174

Please sign in to comment.