From 01aa174e9d2f186d0dc47f72e98ef182e9bbb1f6 Mon Sep 17 00:00:00 2001 From: David Pearce Date: Sat, 21 Dec 2024 10:23:28 +1300 Subject: [PATCH] fix: failing wcp trace (#481) * Fix evaluation of constant expressions This fixes a problem related to the evaluation of constant of expressions. * Fix for Expr.AsConstant() There was an aliasing bug contained herein. --- pkg/corset/expression.go | 21 ++++++++++++--------- pkg/test/valid_corset_test.go | 20 ++++++++++++++++---- testdata/constant_08.accepts | 1 + testdata/constant_08.lisp | 17 +++++++++++++++++ testdata/constant_09.accepts | 1 + testdata/constant_09.lisp | 21 +++++++++++++++++++++ testdata/if_11.accepts | 1 + testdata/if_11.lisp | 12 ++++++++++++ 8 files changed, 81 insertions(+), 13 deletions(-) create mode 100644 testdata/constant_08.accepts create mode 100644 testdata/constant_08.lisp create mode 100644 testdata/constant_09.accepts create mode 100644 testdata/constant_09.lisp create mode 100644 testdata/if_11.accepts create mode 100644 testdata/if_11.lisp diff --git a/pkg/corset/expression.go b/pkg/corset/expression.go index 5575a41..820c202 100644 --- a/pkg/corset/expression.go +++ b/pkg/corset/expression.go @@ -48,7 +48,7 @@ type Add struct{ Args []Expr } // AsConstant attempts to evaluate this expression as a constant (signed) value. // If this expression is not constant, then nil is returned. func (e *Add) AsConstant() *big.Int { - fn := func(l *big.Int, r *big.Int) *big.Int { l.Add(l, r); return l } + fn := func(l *big.Int, r *big.Int) { l.Add(l, r) } return AsConstantOfExpressions(e.Args, fn) } @@ -566,7 +566,7 @@ type Mul struct{ Args []Expr } // AsConstant attempts to evaluate this expression as a constant (signed) value. // If this expression is not constant, then nil is returned. func (e *Mul) AsConstant() *big.Int { - fn := func(l *big.Int, r *big.Int) *big.Int { l.Mul(l, r); return l } + fn := func(l *big.Int, r *big.Int) { l.Mul(l, r) } return AsConstantOfExpressions(e.Args, fn) } @@ -702,7 +702,7 @@ type Sub struct{ Args []Expr } // AsConstant attempts to evaluate this expression as a constant (signed) value. // If this expression is not constant, then nil is returned. func (e *Sub) AsConstant() *big.Int { - fn := func(l *big.Int, r *big.Int) *big.Int { l.Sub(l, r); return l } + fn := func(l *big.Int, r *big.Int) { l.Sub(l, r) } return AsConstantOfExpressions(e.Args, fn) } @@ -1046,19 +1046,22 @@ func ListOfExpressions(head sexp.SExp, exprs []Expr) *sexp.List { // given operation (e.g. add, subtract, etc) to produce a constant value. If // any of the expressions are not themselves constant, then neither is the // result. -func AsConstantOfExpressions(exprs []Expr, fn func(*big.Int, *big.Int) *big.Int) *big.Int { - var val *big.Int = big.NewInt(0) +func AsConstantOfExpressions(exprs []Expr, fn func(*big.Int, *big.Int)) *big.Int { + var val big.Int // - for _, arg := range exprs { + for i, arg := range exprs { c := arg.AsConstant() if c == nil { return nil + } else if i == 0 { + // Must clone c + val.Set(c) + } else { + fn(&val, c) } - // Evaluate function - val = fn(val, c) } // - return val + return &val } func determineMultiplicity(exprs []Expr) uint { diff --git a/pkg/test/valid_corset_test.go b/pkg/test/valid_corset_test.go index a4d352a..8c964cc 100644 --- a/pkg/test/valid_corset_test.go +++ b/pkg/test/valid_corset_test.go @@ -95,6 +95,14 @@ func Test_Constant_07(t *testing.T) { Check(t, false, "constant_07") } +func Test_Constant_08(t *testing.T) { + Check(t, false, "constant_08") +} + +func Test_Constant_09(t *testing.T) { + Check(t, false, "constant_09") +} + // =================================================================== // Alias Tests // =================================================================== @@ -336,6 +344,10 @@ func Test_If_10(t *testing.T) { Check(t, false, "if_10") } +func Test_If_11(t *testing.T) { + Check(t, false, "if_11") +} + // =================================================================== // Guards // =================================================================== @@ -794,19 +806,19 @@ func Check(t *testing.T, stdlib bool, test string) { t.Fatalf("Error parsing %s: %v\n", filename, errs) } // Check valid traces are accepted - accepts_file := fmt.Sprintf("%s.%s", test, "accepts") + accepts_file := fmt.Sprintf("%s/%s.%s", TestDir, test, "accepts") accepts := ReadTracesFile(accepts_file) CheckTraces(t, accepts_file, true, true, accepts, schema) // Check invalid traces are rejected - rejects_file := fmt.Sprintf("%s.%s", test, "rejects") + rejects_file := fmt.Sprintf("%s/%s.%s", TestDir, test, "rejects") rejects := ReadTracesFile(rejects_file) CheckTraces(t, rejects_file, false, true, rejects, schema) // Check expanded traces are rejected - expands_file := fmt.Sprintf("%s.%s", test, "expanded") + expands_file := fmt.Sprintf("%s/%s.%s", TestDir, test, "expanded") expands := ReadTracesFile(expands_file) CheckTraces(t, expands_file, false, false, expands, schema) // Check auto-generated valid traces (if applicable) - auto_accepts_file := fmt.Sprintf("%s.%s", test, "auto.accepts") + auto_accepts_file := fmt.Sprintf("%s/%s.%s", TestDir, test, "auto.accepts") if auto_accepts := ReadTracesFileIfExists(auto_accepts_file); auto_accepts != nil { CheckTraces(t, auto_accepts_file, true, true, auto_accepts, schema) } diff --git a/testdata/constant_08.accepts b/testdata/constant_08.accepts new file mode 100644 index 0000000..f90e1b0 --- /dev/null +++ b/testdata/constant_08.accepts @@ -0,0 +1 @@ +{"C": [15, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], "N": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "B": [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0], "L": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} diff --git a/testdata/constant_08.lisp b/testdata/constant_08.lisp new file mode 100644 index 0000000..68f903d --- /dev/null +++ b/testdata/constant_08.lisp @@ -0,0 +1,17 @@ +(defpurefun ((eq! :@loob) x y) (- x y)) + +(defcolumns + (C :byte) + (L :binary) + (B :binary) + (N :binary)) + +;; opcode values +(defconst + LLARGE 16 + LLARGEMO (- LLARGE 1)) + +(defconstraint bits-and-negs (:guard L) + (if (eq! C LLARGEMO) + (eq! N + (shift B (- 0 LLARGEMO))))) diff --git a/testdata/constant_09.accepts b/testdata/constant_09.accepts new file mode 100644 index 0000000..93de8cd --- /dev/null +++ b/testdata/constant_09.accepts @@ -0,0 +1 @@ +{"CT": [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], "NEG_1": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "BYTE_3": [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255], "BYTE_1": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "BITS": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0], "IS_SLT": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "NEG_2": [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} diff --git a/testdata/constant_09.lisp b/testdata/constant_09.lisp new file mode 100644 index 0000000..0b4a2a2 --- /dev/null +++ b/testdata/constant_09.lisp @@ -0,0 +1,21 @@ +(defpurefun ((eq! :@loob) x y) (- x y)) +(defpurefun (if-eq x val then) (if (eq! x val) then)) +;; +(defcolumns + (CT :byte) + (IS_SLT :binary@prove) + (BITS :binary@prove) + (NEG_1 :binary@prove) + (NEG_2 :binary@prove) + (BYTE_1 :byte@prove) + (BYTE_3 :byte@prove) + ) + +;; opcode values +(defconst + LLARGE 16 + LLARGEMO (- LLARGE 1)) + +(defconstraint bits-and-negs (:guard IS_SLT) + (if-eq CT LLARGEMO + (eq! NEG_2 (shift BITS (- 0 7))))) diff --git a/testdata/if_11.accepts b/testdata/if_11.accepts new file mode 100644 index 0000000..e593dd0 --- /dev/null +++ b/testdata/if_11.accepts @@ -0,0 +1 @@ +{"CT": [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], "BITS": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0], "IS_SLT": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "NEG_2": [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} diff --git a/testdata/if_11.lisp b/testdata/if_11.lisp new file mode 100644 index 0000000..73c8919 --- /dev/null +++ b/testdata/if_11.lisp @@ -0,0 +1,12 @@ +(defpurefun ((eq! :@loob) x y) (- x y)) +(defpurefun (if-eq x val then) (if (eq! x val) then)) +;; +(defcolumns + (CT :byte) + (IS_SLT :binary@prove) + (BITS :binary@prove) + (NEG_2 :binary@prove)) + +(defconstraint bits-and-negs (:guard IS_SLT) + (if-eq CT 15 + (eq! NEG_2 (shift BITS (- 0 7)))))