From 972cc85c88921fbd839eb38e1ac5927323dc77fe Mon Sep 17 00:00:00 2001 From: Bogdan Ursu Date: Fri, 20 Dec 2024 17:26:26 +0100 Subject: [PATCH] =?UTF-8?q?Prover/feat/small=20field=20exploratory?= =?UTF-8?q?=E2=80=94Smartvectors=20that=20support=20both=20base=20field=20?= =?UTF-8?q?elements=20and=20field=20extensions.=20=20(#469)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * degree-2 extension poc * refactor * refactor * templating smartvectors * smartvectors wip * helper functions, wip * utilfunctions wip * remove templating from RandVect * REVERT templating in smartvectors * api changes in existing smartvector types * filling gaps in the base API * intermediary Get function * bug fixes * mempool for fext * fext vectors, wip * code for ext types of smartvector * bug fix in isZero * arithmetic_test for extended smartvectors, wip * rotated ext test * window ext test * setInt and setUint touch only one coordinate * removed field amd files * another intemediary function * linter errors, wip * prunning fext, attempting to fix linter errors * cosmetic change * Update vectorext.go * changing build flags * build flags + removed fieldElement interface --------- Co-authored-by: AlexandreBelling --- prover/maths/common/mempoolext/debug_pool.go | 146 +++ .../common/mempoolext/debug_pool_test.go | 51 ++ .../maths/common/mempoolext/from_sync_pool.go | 72 ++ prover/maths/common/mempoolext/mempool.go | 60 ++ prover/maths/common/mempoolext/slice_arena.go | 51 ++ prover/maths/common/polyext/poly.go | 124 +++ .../common/smartvectors/arithmetic_gen.go | 14 +- .../common/smartvectors/circular_interval.go | 148 +-- .../smartvectors/circular_interval_test.go | 610 ++++++------ prover/maths/common/smartvectors/constant.go | 42 +- prover/maths/common/smartvectors/fft.go | 4 +- prover/maths/common/smartvectors/fuzzing.go | 12 +- .../maths/common/smartvectors/polynomial.go | 2 +- prover/maths/common/smartvectors/regular.go | 42 +- prover/maths/common/smartvectors/rotated.go | 52 +- .../common/smartvectors/smartvector_test.go | 2 +- .../maths/common/smartvectors/smartvectors.go | 70 +- .../common/smartvectors/vectorext/vector.go | 334 +++++++ .../smartvectors/vectorext/vector_ops.go | 58 ++ .../smartvectors/vectorext/vectorext.go | 258 ++++++ prover/maths/common/smartvectors/windowed.go | 98 +- .../smartvectorsext/arithmetic_basic_ext.go | 296 ++++++ .../smartvectorsext/arithmetic_basic_test.go | 125 +++ .../smartvectorsext/arithmetic_ext_test.go | 337 +++++++ .../common/smartvectorsext/arithmetic_gen.go | 175 ++++ .../common/smartvectorsext/arithmetic_op.go | 304 ++++++ .../common/smartvectorsext/constant_ext.go | 97 ++ .../common/smartvectorsext/fuzzing_ext.go | 335 +++++++ .../common/smartvectorsext/fuzzing_heavy.go | 5 + .../common/smartvectorsext/regular_ext.go | 199 ++++ .../common/smartvectorsext/rotated_ext.go | 210 +++++ .../smartvectorsext/rotated_ext_test.go | 96 ++ .../common/smartvectorsext/smartvectors_op.go | 158 ++++ .../common/smartvectorsext/temp_parameters.go | 13 + .../maths/common/smartvectorsext/vecutil.go | 39 + .../common/smartvectorsext/windowed_ext.go | 351 +++++++ .../smartvectorsext/windowed_ext_test.go | 67 ++ prover/maths/field/fext/additional.go | 30 + prover/maths/field/fext/e12.go | 866 ++++++++++++++++++ prover/maths/field/fext/e12_pairing.go | 107 +++ prover/maths/field/fext/e12_test.go | 569 ++++++++++++ prover/maths/field/fext/e2.go | 305 ++++++ prover/maths/field/fext/e2_bls377.go | 117 +++ prover/maths/field/fext/e2_fallback.go | 37 + prover/maths/field/fext/e2_fallback_new.go | 37 + prover/maths/field/fext/e2_test.go | 531 +++++++++++ prover/maths/field/fext/e2new.go | 317 +++++++ prover/maths/field/fext/e2new_bls377.go | 110 +++ prover/maths/field/fext/e2new_test.go | 600 ++++++++++++ prover/maths/field/fext/e6.go | 343 +++++++ prover/maths/field/fext/e6_test.go | 363 ++++++++ prover/maths/field/fext/element.go | 353 +++++++ prover/maths/field/fext/frobenius.go | 227 +++++ prover/maths/field/fext/generators_test.go | 50 + prover/maths/field/fext/parameters.go | 33 + prover/maths/field/fext/temp_functionality.go | 11 + prover/maths/field/fext/unexportedFr.go | 107 +++ 57 files changed, 9729 insertions(+), 441 deletions(-) create mode 100644 prover/maths/common/mempoolext/debug_pool.go create mode 100644 prover/maths/common/mempoolext/debug_pool_test.go create mode 100644 prover/maths/common/mempoolext/from_sync_pool.go create mode 100644 prover/maths/common/mempoolext/mempool.go create mode 100644 prover/maths/common/mempoolext/slice_arena.go create mode 100644 prover/maths/common/polyext/poly.go create mode 100644 prover/maths/common/smartvectors/vectorext/vector.go create mode 100644 prover/maths/common/smartvectors/vectorext/vector_ops.go create mode 100644 prover/maths/common/smartvectors/vectorext/vectorext.go create mode 100644 prover/maths/common/smartvectorsext/arithmetic_basic_ext.go create mode 100644 prover/maths/common/smartvectorsext/arithmetic_basic_test.go create mode 100644 prover/maths/common/smartvectorsext/arithmetic_ext_test.go create mode 100644 prover/maths/common/smartvectorsext/arithmetic_gen.go create mode 100644 prover/maths/common/smartvectorsext/arithmetic_op.go create mode 100644 prover/maths/common/smartvectorsext/constant_ext.go create mode 100644 prover/maths/common/smartvectorsext/fuzzing_ext.go create mode 100644 prover/maths/common/smartvectorsext/fuzzing_heavy.go create mode 100644 prover/maths/common/smartvectorsext/regular_ext.go create mode 100644 prover/maths/common/smartvectorsext/rotated_ext.go create mode 100644 prover/maths/common/smartvectorsext/rotated_ext_test.go create mode 100644 prover/maths/common/smartvectorsext/smartvectors_op.go create mode 100644 prover/maths/common/smartvectorsext/temp_parameters.go create mode 100644 prover/maths/common/smartvectorsext/vecutil.go create mode 100644 prover/maths/common/smartvectorsext/windowed_ext.go create mode 100644 prover/maths/common/smartvectorsext/windowed_ext_test.go create mode 100644 prover/maths/field/fext/additional.go create mode 100644 prover/maths/field/fext/e12.go create mode 100644 prover/maths/field/fext/e12_pairing.go create mode 100644 prover/maths/field/fext/e12_test.go create mode 100644 prover/maths/field/fext/e2.go create mode 100644 prover/maths/field/fext/e2_bls377.go create mode 100644 prover/maths/field/fext/e2_fallback.go create mode 100644 prover/maths/field/fext/e2_fallback_new.go create mode 100644 prover/maths/field/fext/e2_test.go create mode 100644 prover/maths/field/fext/e2new.go create mode 100644 prover/maths/field/fext/e2new_bls377.go create mode 100644 prover/maths/field/fext/e2new_test.go create mode 100644 prover/maths/field/fext/e6.go create mode 100644 prover/maths/field/fext/e6_test.go create mode 100644 prover/maths/field/fext/element.go create mode 100644 prover/maths/field/fext/frobenius.go create mode 100644 prover/maths/field/fext/generators_test.go create mode 100644 prover/maths/field/fext/parameters.go create mode 100644 prover/maths/field/fext/temp_functionality.go create mode 100644 prover/maths/field/fext/unexportedFr.go diff --git a/prover/maths/common/mempoolext/debug_pool.go b/prover/maths/common/mempoolext/debug_pool.go new file mode 100644 index 000000000..2e38f6fe1 --- /dev/null +++ b/prover/maths/common/mempoolext/debug_pool.go @@ -0,0 +1,146 @@ +package mempoolext + +import ( + "errors" + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "runtime" + "strconv" + "unsafe" + + "github.com/consensys/linea-monorepo/prover/utils" +) + +type DebuggeableCall struct { + Parent MemPool + Logs map[uintptr]*[]Record +} + +func NewDebugPool(p MemPool) *DebuggeableCall { + return &DebuggeableCall{ + Parent: p, + Logs: make(map[uintptr]*[]Record), + } +} + +type Record struct { + Where string + What recordType +} + +func (m *DebuggeableCall) Prewarm(nbPrewarm int) MemPool { + m.Parent.Prewarm(nbPrewarm) + return m +} + +type recordType string + +const ( + AllocRecord recordType = "alloc" + FreeRecord recordType = "free" +) + +func (m *DebuggeableCall) Alloc() *[]fext.Element { + + var ( + v = m.Parent.Alloc() + uptr = uintptr(unsafe.Pointer(v)) + logs *[]Record + _, file, line, _ = runtime.Caller(2) + ) + + logs, found := m.Logs[uptr] + + if !found { + logs = &[]Record{} + m.Logs[uptr] = logs + } + + *logs = append(*logs, Record{ + Where: file + ":" + strconv.Itoa(line), + What: AllocRecord, + }) + + return v +} + +func (m *DebuggeableCall) Free(v *[]fext.Element) error { + + var ( + uptr = uintptr(unsafe.Pointer(v)) + logs *[]Record + _, file, line, _ = runtime.Caller(2) + ) + + logs, found := m.Logs[uptr] + + if !found { + logs = &[]Record{} + m.Logs[uptr] = logs + } + + *logs = append(*logs, Record{ + Where: file + ":" + strconv.Itoa(line), + What: FreeRecord, + }) + + return m.Parent.Free(v) +} + +func (m *DebuggeableCall) Size() int { + return m.Parent.Size() +} + +func (m *DebuggeableCall) TearDown() { + if p, ok := m.Parent.(*SliceArena); ok { + p.TearDown() + } +} + +func (m *DebuggeableCall) Errors() error { + + var err error + + for _, logs_ := range m.Logs { + + if logs_ == nil || len(*logs_) == 0 { + utils.Panic("got a nil entry") + } + + logs := *logs_ + + for i := range logs { + if i == 0 && logs[i].What == FreeRecord { + err = errors.Join(err, fmt.Errorf("freed a vector that was not from the pool: where=%v", logs[i].Where)) + } + + if i == len(logs)-1 && logs[i].What == AllocRecord { + err = errors.Join(err, fmt.Errorf("leaked a vector out of the pool: where=%v", logs[i].Where)) + } + + if i == 0 { + continue + } + + if logs[i-1].What == AllocRecord && logs[i].What == AllocRecord { + wheres := []string{logs[i-1].Where, logs[i].Where} + for k := i + 1; k < len(logs) && logs[k].What == AllocRecord; k++ { + wheres = append(wheres, logs[k].Where) + } + + err = errors.Join(err, fmt.Errorf("vector was allocated multiple times concurrently where=%v", wheres)) + } + + if logs[i-1].What == FreeRecord && logs[i].What == FreeRecord { + wheres := []string{logs[i-1].Where, logs[i].Where} + for k := i + 1; k < len(logs) && logs[k].What == FreeRecord; k++ { + wheres = append(wheres, logs[k].Where) + } + + err = errors.Join(err, fmt.Errorf("vector was freed multiple times concurrently where=%v", wheres)) + } + } + } + + return err +} diff --git a/prover/maths/common/mempoolext/debug_pool_test.go b/prover/maths/common/mempoolext/debug_pool_test.go new file mode 100644 index 000000000..7767396f6 --- /dev/null +++ b/prover/maths/common/mempoolext/debug_pool_test.go @@ -0,0 +1,51 @@ +package mempoolext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/stretchr/testify/assert" + "strings" + "testing" +) + +func TestDebugPool(t *testing.T) { + + t.Run("leak-detection", func(t *testing.T) { + + pool := NewDebugPool(CreateFromSyncPool(32)) + + for i := 0; i < 16; i++ { + func() { + _ = pool.Alloc() + }() + } + + err := pool.Errors().Error() + assert.True(t, strings.HasPrefix(err, "leaked a vector out of the pool")) + }) + + t.Run("double-free", func(t *testing.T) { + + pool := NewDebugPool(CreateFromSyncPool(32)) + + v := pool.Alloc() + + for i := 0; i < 16; i++ { + pool.Free(v) + } + + err := pool.Errors().Error() + assert.Truef(t, strings.HasPrefix(err, "vector was freed multiple times concurrently"), err) + }) + + t.Run("foreign-free", func(t *testing.T) { + + pool := NewDebugPool(CreateFromSyncPool(32)) + + v := make([]fext.Element, 32) + pool.Free(&v) + + err := pool.Errors().Error() + assert.Truef(t, strings.HasPrefix(err, "freed a vector that was not from the pool"), err) + }) + +} diff --git a/prover/maths/common/mempoolext/from_sync_pool.go b/prover/maths/common/mempoolext/from_sync_pool.go new file mode 100644 index 000000000..f6463eda3 --- /dev/null +++ b/prover/maths/common/mempoolext/from_sync_pool.go @@ -0,0 +1,72 @@ +package mempoolext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/consensys/linea-monorepo/prover/utils/parallel" + "sync" +) + +// FromSyncPool pools the allocation for slices of [fext.Element] of size `Size`. +// It should be used with great caution and every slice allocated via this pool +// must be manually freed and only once. +// +// FromSyncPool is used to reduce the number of allocation which can be significant +// when doing operations over field elements. +type FromSyncPool struct { + size int + P sync.Pool +} + +// CreateFromSyncPool initializes the Pool with the given number of elements in it. +func CreateFromSyncPool(size int) *FromSyncPool { + // Initializes the pool + return &FromSyncPool{ + size: size, + P: sync.Pool{ + New: func() any { + res := make([]fext.Element, size) + return &res + }, + }, + } +} + +// Prewarm the Pool by preallocating `nbPrewarm` in it. +func (p *FromSyncPool) Prewarm(nbPrewarm int) MemPool { + prewarmed := make([]fext.Element, p.size*nbPrewarm) + parallel.Execute(nbPrewarm, func(start, stop int) { + for i := start; i < stop; i++ { + vec := prewarmed[i*p.size : (i+1)*p.size] + p.P.Put(&vec) + } + }) + return p +} + +// Alloc returns a vector allocated from the pool. Vector allocated via the +// pool should ideally be returned to the pool. If not, they are still going to +// be picked up by the GC. +func (p *FromSyncPool) Alloc() *[]fext.Element { + res := p.P.Get().(*[]fext.Element) + return res +} + +// Free returns an object to the pool. It must never be called twice over +// the same object or undefined behaviours are going to arise. It is fine to +// pass objects allocated to outside of the pool as long as they have the right +// dimension. +func (p *FromSyncPool) Free(vec *[]fext.Element) error { + // Check the vector has the right size + if len(*vec) != p.size { + utils.Panic("expected size %v, expected %v", len(*vec), p.Size()) + } + + p.P.Put(vec) + + return nil +} + +func (p *FromSyncPool) Size() int { + return p.size +} diff --git a/prover/maths/common/mempoolext/mempool.go b/prover/maths/common/mempoolext/mempool.go new file mode 100644 index 000000000..74c1af2dd --- /dev/null +++ b/prover/maths/common/mempoolext/mempool.go @@ -0,0 +1,60 @@ +package mempoolext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/consensys/linea-monorepo/prover/utils" +) + +type MemPool interface { + Prewarm(nbPrewarm int) MemPool + Alloc() *[]fext.Element + Free(vec *[]fext.Element) error + Size() int +} + +// ExtractCheckOptionalStrict returns +// - p[0], true if the expectedSize matches the one of the provided pool +// - nil, false if no `p` is provided +// - panic if the assigned size of the pool does not match +// - panic if the caller provides `nil` as argument for `p` +// +// This is used to unwrap a [FromSyncPool] that is commonly passed to functions as an +// optional variadic parameter and at the same time validating that the pool +// object has the right size. +func ExtractCheckOptionalStrict(expectedSize int, p ...MemPool) (pool MemPool, ok bool) { + // Checks if there is a pool + hasPool := len(p) > 0 && p[0] != nil + if hasPool { + pool = p[0] + } + + // Sanity-check that the size of the pool is actually what we expected + if hasPool && pool.Size() != expectedSize { + utils.Panic("pooled vector size are %v, but required %v", pool.Size(), expectedSize) + } + + return pool, hasPool +} + +// ExtractCheckOptionalSoft returns +// - p[0], true if the expectedSize matches the one of the provided pool +// - nil, false if no `p` is provided +// - nil, false if the length of the vector does not match the one of the pool +// - panic if the caller provides `nil` as argument for `p` +// +// This is used to unwrap a [FromSyncPool] that is commonly passed to functions as an +// optional variadic parameter. +func ExtractCheckOptionalSoft(expectedSize int, p ...MemPool) (pool MemPool, ok bool) { + // Checks if there is a pool + hasPool := len(p) > 0 + if hasPool { + pool = p[0] + } + + // Sanity-check that the size of the pool is actually what we expected + if hasPool && pool.Size() != expectedSize { + return nil, false + } + + return pool, hasPool +} diff --git a/prover/maths/common/mempoolext/slice_arena.go b/prover/maths/common/mempoolext/slice_arena.go new file mode 100644 index 000000000..1741f99ec --- /dev/null +++ b/prover/maths/common/mempoolext/slice_arena.go @@ -0,0 +1,51 @@ +package mempoolext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/field/fext" +) + +// SliceArena is a simple not-threadsafe arena implementation that uses a +// mempool to carry its allocation. It will only put back free memory in the +// the parent pool when TearDown is called. +type SliceArena struct { + frees []*[]fext.Element + parent MemPool +} + +func WrapsWithMemCache(pool MemPool) *SliceArena { + return &SliceArena{ + frees: make([]*[]fext.Element, 0, 1<<7), + parent: pool, + } +} + +func (m *SliceArena) Prewarm(nbPrewarm int) MemPool { + m.parent.Prewarm(nbPrewarm) + return m +} + +func (m *SliceArena) Alloc() *[]fext.Element { + + if len(m.frees) == 0 { + return m.parent.Alloc() + } + + last := m.frees[len(m.frees)-1] + m.frees = m.frees[:len(m.frees)-1] + return last +} + +func (m *SliceArena) Free(v *[]fext.Element) error { + m.frees = append(m.frees, v) + return nil +} + +func (m *SliceArena) Size() int { + return m.parent.Size() +} + +func (m *SliceArena) TearDown() { + for i := range m.frees { + m.parent.Free(m.frees[i]) + } +} diff --git a/prover/maths/common/polyext/poly.go b/prover/maths/common/polyext/poly.go new file mode 100644 index 000000000..dd2e5b297 --- /dev/null +++ b/prover/maths/common/polyext/poly.go @@ -0,0 +1,124 @@ +package polyext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// EvalUnivariate evaluates a univariate polynomial `pol` given as a vector of +// coefficients. Coefficients are for increasing degree monomials: meaning that +// pol[0] is the constant term and pol[len(pol) - 1] is the highest degree term. +// The evaluation is done using the Horner method. +// +// If the empty slice is provided, it is understood as the zero polynomial and +// the function returns zero. +func EvalUnivariate(pol []fext.Element, x fext.Element) fext.Element { + var res fext.Element + for i := len(pol) - 1; i >= 0; i-- { + res.Mul(&res, &x) + res.Add(&res, &pol[i]) + } + return res +} + +// Mul multiplies two polynomials expressed by their coefficients using the +// naive method and writes the result in res. `a` and `b` may have distinct +// degrees and the result is returned in a slice of size len(a) + len(b) - 1. +// +// The algorithm is the schoolbook algorithm and runs in O(n^2). The usage of +// this usage should be reserved to reasonnably small polynomial. Otherwise, +// FFT methods should be preferred to this end. +// +// The empty slice is understood as the zero polynomial. If provided on either +// side the function returns []fext.Element{} +func Mul(a, b []fext.Element) (res []fext.Element) { + + if len(a) == 0 || len(b) == 0 { + return []fext.Element{} + } + + res = make([]fext.Element, len(a)+len(b)-1) + + for i := 0; i < len(a); i++ { + for j := 0; j < len(b); j++ { + var tmp fext.Element + tmp.Mul(&a[i], &b[j]) + res[i+j].Add(&res[i+j], &tmp) + } + } + + return res +} + +// Add adds two polynomials in coefficient form of possibly distinct degree. +// The returned slice has length = max(len(a), len(b)). +// The empty slice is understood as the zero polynomial and if both a and b are +// empty, the function returns the empty slice. +func Add(a, b []fext.Element) (res []fext.Element) { + + res = make([]fext.Element, utils.Max(len(a), len(b))) + copy(res, a) + for i := range b { + res[i].Add(&res[i], &b[i]) + } + + return res +} + +// ScalarMul multiplies a polynomials in coefficient form by a scalar. +func ScalarMul(p []fext.Element, x fext.Element) (res []fext.Element) { + res = make([]fext.Element, len(p)) + vectorext.ScalarMul(res, p, x) + return res +} + +// EvaluateLagrangesAnyDomain evaluates all the Lagrange polynomials for a +// custom domain defined as the point point x. The function implements the naive +// schoolbook algorithm and is only relevant for small domains. +// +// The function panics if provided an empty domain. +func EvaluateLagrangesAnyDomain(domain []fext.Element, x fext.Element) []fext.Element { + + if len(domain) == 0 { + utils.Panic("got provided an empty domain") + } + + lagrange := make([]fext.Element, len(domain)) + + for i := range domain { + // allocate outside of the loop to avoid memory aliasing in for loop + // (gosec G601) + hi := domain[i] + + lhix := fext.One() + for j := range domain { + // allocate outside of the loop to avoid memory aliasing in for loop + // (gosec G601) + hj := domain[j] + + if i == j { + // Skip it + continue + } + + // Otherwise, it would divide by zeri + if hi == hj { + utils.Panic("the domain contained a duplicate %v (at %v and %v)", hi.String(), i, j) + } + + // more convenient to store -h instead of h + hj.Neg(&hj) + factor := x + factor.Add(&factor, &hj) + hj.Add(&hi, &hj) // so x - h + hj.Inverse(&hj) + factor.Mul(&factor, &hj) + + lhix.Mul(&lhix, &factor) + } + lagrange[i] = lhix + } + + return lagrange +} diff --git a/prover/maths/common/smartvectors/arithmetic_gen.go b/prover/maths/common/smartvectors/arithmetic_gen.go index ef1b34cab..f5dccbc3c 100644 --- a/prover/maths/common/smartvectors/arithmetic_gen.go +++ b/prover/maths/common/smartvectors/arithmetic_gen.go @@ -134,17 +134,17 @@ func processOperator(op operator, coeffs []int, svecs []SmartVector, p ...mempoo length := len(regvec) // The windows rolls over - if interval.doesWrapAround() { - op.vecTermIntoVec(regvec[:interval.stop()], windowRes.window[length-interval.start():]) - op.vecTermIntoVec(regvec[interval.start():], windowRes.window[:length-interval.start()]) - op.constTermIntoVec(regvec[interval.stop():interval.start()], &windowRes.paddingVal) + if interval.DoesWrapAround() { + op.vecTermIntoVec(regvec[:interval.Stop()], windowRes.window[length-interval.Start():]) + op.vecTermIntoVec(regvec[interval.Start():], windowRes.window[:length-interval.Start()]) + op.constTermIntoVec(regvec[interval.Stop():interval.Start()], &windowRes.paddingVal) return regularRes } // Else, no roll-over - op.vecTermIntoVec(regvec[interval.start():interval.stop()], windowRes.window) - op.constTermIntoVec(regvec[:interval.start()], &windowRes.paddingVal) - op.constTermIntoVec(regvec[interval.stop():], &windowRes.paddingVal) + op.vecTermIntoVec(regvec[interval.Start():interval.Stop()], windowRes.window) + op.constTermIntoVec(regvec[:interval.Start()], &windowRes.paddingVal) + op.constTermIntoVec(regvec[interval.Stop():], &windowRes.paddingVal) return regularRes } } diff --git a/prover/maths/common/smartvectors/circular_interval.go b/prover/maths/common/smartvectors/circular_interval.go index e34a20e85..2ce6224c1 100644 --- a/prover/maths/common/smartvectors/circular_interval.go +++ b/prover/maths/common/smartvectors/circular_interval.go @@ -6,14 +6,14 @@ import ( "github.com/consensys/linea-monorepo/prover/utils" ) -// circularInterval represents an interval over a discretized circle. The +// CircularInterval represents an interval over a discretized circle. The // discretized circle is assumed to be equipped with an origin point; thus // allowing to set a unique coordinate for each point. // // - The intervals are "cardinal": meaning that the largest possible interval // is the full-circuit // - The empty interval is considered as invalid and should never be constructed -type circularInterval struct { +type CircularInterval struct { // circleSize is the size of the circle circleSize int // istart is the starting point of the interval (included in the interval). @@ -21,89 +21,89 @@ type circularInterval struct { // istart must always be within the bound of the circle (can't be negative // or be larger or equal to `circleSize`. istart int - // intervalLen is length of the interval. Meaning the number of points in + // IntervalLen is length of the interval. Meaning the number of points in // the interval - intervalLen int + IntervalLen int } -// ivalWithFullLen returns an interval representing the full-circle. -func ivalWithFullLen(n int) circularInterval { +// IvalWithFullLen returns an interval representing the full-circle. +func IvalWithFullLen(n int) CircularInterval { if n <= 0 { panic("zero or negative length interval is not allowed") } - return circularInterval{ + return CircularInterval{ istart: 0, - intervalLen: n, + IntervalLen: n, circleSize: n, } } -// ivalWithStartLen constructs an interval by passing the start, the len and n +// IvalWithStartLen constructs an interval by passing the Start, the len and n // being the size of the circle. -func ivalWithStartLen(start, len, n int) circularInterval { +func IvalWithStartLen(start, len, n int) CircularInterval { // empty length is forbidden if len == 0 { panic("empty interval") } - // ensures that start is within bounds + // ensures that Start is within bounds if 0 > start || start >= n { - panic("start out of bounds") + panic("Start out of bounds") } // full length is forbidden if len >= n { panic("full length is forbidden") } - return circularInterval{ + return CircularInterval{ circleSize: n, istart: start, - intervalLen: len, + IntervalLen: len, } } -// ivalWithStartStop constructs a [circularInterval] by using its starting and +// IvalWithStartStop constructs a [CircularInterval] by using its starting and // stopping points. -func ivalWithStartStop(start, stop, n int) circularInterval { +func IvalWithStartStop(start, stop, n int) CircularInterval { // empty interval is forbidden if start == stop { panic("empty interval") } - // ensures that start is within bounds + // ensures that Start is within bounds if 0 > start || start >= n { - panic("start out of bounds") + panic("Start out of bounds") } // full length is forbidden if 0 > stop || stop >= n { - panic("stop out of bound") + panic("Stop out of bound") } - return circularInterval{ + return CircularInterval{ circleSize: n, istart: start, - intervalLen: utils.PositiveMod(stop-start, n), + IntervalLen: utils.PositiveMod(stop-start, n), } } // Start returns the starting point (included) of the interval -func (c circularInterval) start() int { +func (c CircularInterval) Start() int { return c.istart } // Stop returns the stopping point (excluded) of the interval of the interval -func (c circularInterval) stop() int { - return utils.PositiveMod(c.istart+c.intervalLen, c.circleSize) +func (c CircularInterval) Stop() int { + return utils.PositiveMod(c.istart+c.IntervalLen, c.circleSize) } -// doesWrapAround returns true iff the interval rolls over -func (c circularInterval) doesWrapAround() bool { - return c.stop() < c.start() +// DoesWrapAround returns true iff the interval rolls over +func (c CircularInterval) DoesWrapAround() bool { + return c.Stop() < c.Start() } -// isFullCircle returns true of the interval is the full circle -func (c circularInterval) isFullCircle() bool { - return c.intervalLen == c.circleSize +// IsFullCircle returns true of the interval is the full circle +func (c CircularInterval) IsFullCircle() bool { + return c.IntervalLen == c.circleSize } // Returns true iff `p` is included in the receiver interval -func (c circularInterval) doesInclude(p int) bool { +func (c CircularInterval) DoesInclude(p int) bool { // forbidden : the point does not belong on the circle if p < 0 || p > c.circleSize { @@ -111,52 +111,52 @@ func (c circularInterval) doesInclude(p int) bool { } // edge-case - if c.isFullCircle() { + if c.IsFullCircle() { return true } // if the interval wraps around the origin point - if c.doesWrapAround() { - return p < c.stop() || p >= c.start() + if c.DoesWrapAround() { + return p < c.Stop() || p >= c.Start() } // "normal" case - return p >= c.start() && p < c.stop() + return p >= c.Start() && p < c.Stop() } -// doesFullyContain returns true if `c` fully contains `other` -func (c circularInterval) doesFullyContain(other circularInterval) bool { +// DoesFullyContain returns true if `c` fully contains `other` +func (c CircularInterval) DoesFullyContain(other CircularInterval) bool { // edge case : c is the complete circle - if c.isFullCircle() { + if c.IsFullCircle() { return true } // edge case : c is not the complete circle but other is - if !c.isFullCircle() && other.isFullCircle() { + if !c.IsFullCircle() && other.IsFullCircle() { return false } - if !c.doesWrapAround() { - return c.doesInclude(other.start()) && - c.doesInclude(other.stop()-1) && - !other.doesWrapAround() + if !c.DoesWrapAround() { + return c.DoesInclude(other.Start()) && + c.DoesInclude(other.Stop()-1) && + !other.DoesWrapAround() } // Here, we can assume that c wraps around // Case : 1, other is on the left arm - if !other.doesWrapAround() && other.stop() <= c.stop() { + if !other.DoesWrapAround() && other.Stop() <= c.Stop() { return true } // Case : 2, other is on the right arm - if !other.doesWrapAround() && other.start() >= c.start() { + if !other.DoesWrapAround() && other.Start() >= c.Start() { return true } // Case 3 : other also wraps around - if other.doesWrapAround() && other.start() >= c.start() && other.stop() <= c.stop() { + if other.DoesWrapAround() && other.Start() >= c.Start() && other.Stop() <= c.Stop() { return true } @@ -164,22 +164,22 @@ func (c circularInterval) doesFullyContain(other circularInterval) bool { } /* -tryOverlapWith returns true if the left of `c` touches the right of `other` +TryOverlapWith returns true if the left of `c` touches the right of `other` - c.start-------------c.stop - other.start---------other.stop + c.Start-------------c.Stop + other.Start---------other.Stop OR - |c.start|-------------|c.stop| + |c.Start|-------------|c.Stop| - |other.start|---------|other.stop| + |other.Start|---------|other.Stop| -This also include the edge cases where `other.stop`. Also +This also include the edge cases where `other.Stop`. Also returns the resulting circular interval obtained by connecting the two. */ -func (c circularInterval) tryOverlapWith(other circularInterval) (ok bool, connected circularInterval) { +func (c CircularInterval) TryOverlapWith(other CircularInterval) (ok bool, connected CircularInterval) { // Sanity-check, both sides should have the same circle size if c.circleSize != other.circleSize { @@ -191,8 +191,8 @@ func (c circularInterval) tryOverlapWith(other circularInterval) (ok bool, conne // There are still edge-cases for when either c or other are the full circle. // Once these cases are eliminated, we process by case enumeration. - if c.isFullCircle() || other.isFullCircle() { - return true, ivalWithFullLen(n) + if c.IsFullCircle() || other.IsFullCircle() { + return true, IvalWithFullLen(n) } /* @@ -204,9 +204,9 @@ func (c circularInterval) tryOverlapWith(other circularInterval) (ok bool, conne [o0, o1) represents the interval of 'other' */ - c1 := utils.PositiveMod(c.stop()-c.start(), n) - o0 := utils.PositiveMod(other.start()-c.start(), n) - o1 := utils.PositiveMod(other.stop()-c.start(), n) + c1 := utils.PositiveMod(c.Stop()-c.Start(), n) + o0 := utils.PositiveMod(other.Start()-c.Start(), n) + o1 := utils.PositiveMod(other.Stop()-c.Start(), n) /* |-----------------c1 @@ -221,7 +221,7 @@ func (c circularInterval) tryOverlapWith(other circularInterval) (ok bool, conne --------o1 o0--------------- */ if 0 <= o1 && o1 < o0 && o0 <= c1 { - return true, ivalWithFullLen(n) + return true, IvalWithFullLen(n) } /* @@ -229,7 +229,7 @@ func (c circularInterval) tryOverlapWith(other circularInterval) (ok bool, conne o0----------------o1 */ if 0 <= o0 && o0 <= c1 && c1 <= o1 { - return true, ivalWithStartStop(c.start(), other.stop(), n) + return true, IvalWithStartStop(c.Start(), other.Stop(), n) } /* @@ -237,7 +237,7 @@ func (c circularInterval) tryOverlapWith(other circularInterval) (ok bool, conne --------o1 o0-------- */ if 0 <= o1 && o1 <= c1 && c1 < o0 { - return true, ivalWithStartStop(other.start(), c.stop(), n) + return true, IvalWithStartStop(other.Start(), c.Stop(), n) } /* @@ -248,13 +248,13 @@ func (c circularInterval) tryOverlapWith(other circularInterval) (ok bool, conne return true, other } - return false, circularInterval{} + return false, CircularInterval{} } // Returns the smallest windows that covers the entire set -func smallestCoverInterval(intervals []circularInterval) circularInterval { +func SmallestCoverInterval(intervals []CircularInterval) CircularInterval { // Deep-copy the inputs to prevent any side effect - intervals = append([]circularInterval{}, intervals...) + intervals = append([]CircularInterval{}, intervals...) if len(intervals) == 0 { panic("no windows passed") @@ -267,11 +267,11 @@ func smallestCoverInterval(intervals []circularInterval) circularInterval { // First step, we aggregate the windows whose union is a circle arc // into disjoints buckets. Thereafter, we take the complements of the // largest gap between buckets as our result. - sort.Slice(intervals, func(i, j int) bool { return intervals[i].start() <= intervals[j].start() }) - overlaps := []circularInterval{} + sort.Slice(intervals, func(i, j int) bool { return intervals[i].Start() <= intervals[j].Start() }) + overlaps := []CircularInterval{} // Then we group the intervals whose union is still an interval. Since - // the intervals are now sorted by their "start" argument, it suffices + // the intervals are now sorted by their "Start" argument, it suffices // to try and merge each with the next one. It they are not connected on // the right, then the following ones won't either. for i, interval := range intervals { @@ -285,11 +285,11 @@ func smallestCoverInterval(intervals []circularInterval) circularInterval { panic("inconsistent sizes") } - // Since the input intervals are sorted by their start at the beginning, + // Since the input intervals are sorted by their Start at the beginning, // it suffices to try to merge with the last one. last := overlaps[len(overlaps)-1] - if ok, newW := last.tryOverlapWith(interval); ok { + if ok, newW := last.TryOverlapWith(interval); ok { overlaps[len(overlaps)-1] = newW } else { // Else create a new bucket @@ -310,7 +310,7 @@ func smallestCoverInterval(intervals []circularInterval) circularInterval { last := overlaps[len(overlaps)-1] - if ok, newW := last.tryOverlapWith(overlaps[0]); ok { + if ok, newW := last.TryOverlapWith(overlaps[0]); ok { overlaps[len(overlaps)-1] = newW overlaps = overlaps[1:] } else { @@ -328,7 +328,7 @@ func smallestCoverInterval(intervals []circularInterval) circularInterval { for i, w := range overlaps { nextW := overlaps[(i+1)%len(overlaps)] - gap := utils.PositiveMod(nextW.start()-w.stop(), circleSize) + gap := utils.PositiveMod(nextW.Start()-w.Stop(), circleSize) if gap > maxGap { maxGap = gap @@ -341,8 +341,8 @@ func smallestCoverInterval(intervals []circularInterval) circularInterval { utils.Panic("Max gap is %v", maxGap) } - start := overlaps[(posMaxGap+1)%len(overlaps)].start() - stop := overlaps[posMaxGap].stop() - return ivalWithStartStop(start, stop, circleSize) + start := overlaps[(posMaxGap+1)%len(overlaps)].Start() + stop := overlaps[posMaxGap].Stop() + return IvalWithStartStop(start, stop, circleSize) } diff --git a/prover/maths/common/smartvectors/circular_interval_test.go b/prover/maths/common/smartvectors/circular_interval_test.go index 156b62a5f..a66834eb0 100644 --- a/prover/maths/common/smartvectors/circular_interval_test.go +++ b/prover/maths/common/smartvectors/circular_interval_test.go @@ -9,172 +9,172 @@ import ( func TestCircularIntervalConstructors(t *testing.T) { t.Run("for a normal interval", func(t *testing.T) { - i := ivalWithStartLen(2, 5, 10) - assert.Equal(t, 2, i.start(), "start") - assert.Equal(t, 7, i.stop(), "stop") - assert.Equal(t, 5, i.intervalLen, "interval length") - assert.False(t, i.doesWrapAround(), "wrap around") - assert.False(t, i.isFullCircle(), "full circle") + i := IvalWithStartLen(2, 5, 10) + assert.Equal(t, 2, i.Start(), "Start") + assert.Equal(t, 7, i.Stop(), "Stop") + assert.Equal(t, 5, i.IntervalLen, "interval length") + assert.False(t, i.DoesWrapAround(), "wrap around") + assert.False(t, i.IsFullCircle(), "full circle") - assert.True(t, i.doesInclude(5), "in the middle of the interval") - assert.True(t, i.doesInclude(2), "it should be closed on the left") - assert.False(t, i.doesInclude(7), "it should be open on the right") + assert.True(t, i.DoesInclude(5), "in the middle of the interval") + assert.True(t, i.DoesInclude(2), "it should be closed on the left") + assert.False(t, i.DoesInclude(7), "it should be open on the right") - assert.False(t, i.doesInclude(0), "point on the left") - assert.False(t, i.doesInclude(8), "point on the right") + assert.False(t, i.DoesInclude(0), "point on the left") + assert.False(t, i.DoesInclude(8), "point on the right") - assert.Equal(t, ivalWithStartStop(2, 7, 10), i) + assert.Equal(t, IvalWithStartStop(2, 7, 10), i) }) t.Run("for a wrapped around vector", func(t *testing.T) { - i := ivalWithStartLen(7, 5, 10) + i := IvalWithStartLen(7, 5, 10) - assert.Equal(t, 7, i.start(), "start") - assert.Equal(t, 2, i.stop(), "stop") - assert.Equal(t, 5, i.intervalLen, "interval length") - assert.True(t, i.doesWrapAround(), "wrap around") - assert.False(t, i.isFullCircle(), "full circle") + assert.Equal(t, 7, i.Start(), "Start") + assert.Equal(t, 2, i.Stop(), "Stop") + assert.Equal(t, 5, i.IntervalLen, "interval length") + assert.True(t, i.DoesWrapAround(), "wrap around") + assert.False(t, i.IsFullCircle(), "full circle") - assert.False(t, i.doesInclude(5), "in the middle of the interval") - assert.True(t, i.doesInclude(7), "it should be closed on the left") - assert.False(t, i.doesInclude(2), "it should be open on the right") + assert.False(t, i.DoesInclude(5), "in the middle of the interval") + assert.True(t, i.DoesInclude(7), "it should be closed on the left") + assert.False(t, i.DoesInclude(2), "it should be open on the right") - assert.True(t, i.doesInclude(0), "point on the left") - assert.True(t, i.doesInclude(8), "point on the right") + assert.True(t, i.DoesInclude(0), "point on the left") + assert.True(t, i.DoesInclude(8), "point on the right") - assert.Equal(t, ivalWithStartStop(7, 2, 10), i) + assert.Equal(t, IvalWithStartStop(7, 2, 10), i) }) t.Run("for a full vector", func(t *testing.T) { - i := ivalWithFullLen(10) - assert.Equal(t, 0, i.start(), "start") - assert.Equal(t, 0, i.stop(), "stop") - assert.Equal(t, 10, i.intervalLen, "interval length") - assert.False(t, i.doesWrapAround(), "wrap around") - assert.True(t, i.isFullCircle(), "full circle") - - assert.True(t, i.doesInclude(5), "in the middle of the interval") - assert.True(t, i.doesInclude(7), "it should be closed on the left") - assert.True(t, i.doesInclude(2), "it should be open on the right") - - assert.True(t, i.doesInclude(0), "point on the left") - assert.True(t, i.doesInclude(8), "point on the right") + i := IvalWithFullLen(10) + assert.Equal(t, 0, i.Start(), "Start") + assert.Equal(t, 0, i.Stop(), "Stop") + assert.Equal(t, 10, i.IntervalLen, "interval length") + assert.False(t, i.DoesWrapAround(), "wrap around") + assert.True(t, i.IsFullCircle(), "full circle") + + assert.True(t, i.DoesInclude(5), "in the middle of the interval") + assert.True(t, i.DoesInclude(7), "it should be closed on the left") + assert.True(t, i.DoesInclude(2), "it should be open on the right") + + assert.True(t, i.DoesInclude(0), "point on the left") + assert.True(t, i.DoesInclude(8), "point on the right") }) } func TestDoesFullyContain(t *testing.T) { t.Run("for a normal vector", func(t *testing.T) { - i := ivalWithStartStop(5, 10, 15) - - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 3, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 5, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 8, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 10, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 13, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 1, 15))) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(5, 8, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(5, 10, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(5, 13, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(5, 3, 15))) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(7, 8, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(7, 10, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 13, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 3, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 5, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 6, 15))) - - assert.False(t, i.doesFullyContain(ivalWithStartStop(10, 13, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(10, 3, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(10, 5, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(10, 8, 15))) - - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 13, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 3, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 5, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 8, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 10, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 11, 15))) - - assert.False(t, i.doesFullyContain(ivalWithFullLen(15))) + i := IvalWithStartStop(5, 10, 15) + + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 3, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 5, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 8, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 10, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 13, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 1, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(5, 8, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(5, 10, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(5, 13, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(5, 3, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(7, 8, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(7, 10, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 13, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 3, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 5, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 6, 15))) + + assert.False(t, i.DoesFullyContain(IvalWithStartStop(10, 13, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(10, 3, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(10, 5, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(10, 8, 15))) + + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 13, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 3, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 5, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 8, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 10, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 11, 15))) + + assert.False(t, i.DoesFullyContain(IvalWithFullLen(15))) }) t.Run("for a wrap around", func(t *testing.T) { - i := ivalWithStartStop(10, 5, 15) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(2, 3, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(2, 5, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 8, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 10, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 13, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(2, 1, 15))) - - assert.False(t, i.doesFullyContain(ivalWithStartStop(5, 8, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(5, 10, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(5, 13, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(5, 3, 15))) - - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 8, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 10, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 13, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 3, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 5, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(7, 6, 15))) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(10, 13, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(10, 3, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(10, 5, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(10, 8, 15))) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 13, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 3, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 5, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 8, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 10, 15))) - assert.False(t, i.doesFullyContain(ivalWithStartStop(12, 11, 15))) - - assert.False(t, i.doesFullyContain(ivalWithFullLen(15))) + i := IvalWithStartStop(10, 5, 15) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(2, 3, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(2, 5, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 8, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 10, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 13, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(2, 1, 15))) + + assert.False(t, i.DoesFullyContain(IvalWithStartStop(5, 8, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(5, 10, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(5, 13, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(5, 3, 15))) + + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 8, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 10, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 13, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 3, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 5, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(7, 6, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(10, 13, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(10, 3, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(10, 5, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(10, 8, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 13, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 3, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 5, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 8, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 10, 15))) + assert.False(t, i.DoesFullyContain(IvalWithStartStop(12, 11, 15))) + + assert.False(t, i.DoesFullyContain(IvalWithFullLen(15))) }) t.Run("for a wrap around", func(t *testing.T) { - i := ivalWithFullLen(15) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(2, 3, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(2, 5, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(2, 8, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(2, 10, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(2, 13, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(2, 1, 15))) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(5, 8, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(5, 10, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(5, 13, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(5, 3, 15))) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(7, 8, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(7, 10, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(7, 13, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(7, 3, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(7, 5, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(7, 6, 15))) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(10, 13, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(10, 3, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(10, 5, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(10, 8, 15))) - - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 13, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 3, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 5, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 8, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 10, 15))) - assert.True(t, i.doesFullyContain(ivalWithStartStop(12, 11, 15))) - - assert.True(t, i.doesFullyContain(ivalWithFullLen(15))) + i := IvalWithFullLen(15) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(2, 3, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(2, 5, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(2, 8, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(2, 10, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(2, 13, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(2, 1, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(5, 8, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(5, 10, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(5, 13, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(5, 3, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(7, 8, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(7, 10, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(7, 13, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(7, 3, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(7, 5, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(7, 6, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(10, 13, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(10, 3, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(10, 5, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(10, 8, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 13, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 3, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 5, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 8, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 10, 15))) + assert.True(t, i.DoesFullyContain(IvalWithStartStop(12, 11, 15))) + + assert.True(t, i.DoesFullyContain(IvalWithFullLen(15))) }) @@ -183,343 +183,343 @@ func TestDoesFullyContain(t *testing.T) { func TestTryOverlap(t *testing.T) { var ok bool - var res circularInterval + var res CircularInterval t.Run("for a normal vector", func(t *testing.T) { - i := ivalWithStartStop(5, 10, 15) + i := IvalWithStartStop(5, 10, 15) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 3, 15)) assert.False(t, ok) - assert.Equal(t, circularInterval{}, res) + assert.Equal(t, CircularInterval{}, res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(2, 10, 15), res) + assert.Equal(t, IvalWithStartStop(2, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(2, 10, 15), res) + assert.Equal(t, IvalWithStartStop(2, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(2, 10, 15), res) + assert.Equal(t, IvalWithStartStop(2, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(2, 13, 15), res) + assert.Equal(t, IvalWithStartStop(2, 13, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 1, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 1, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(2, 1, 15), res) + assert.Equal(t, IvalWithStartStop(2, 1, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 10, 15), res) + assert.Equal(t, IvalWithStartStop(5, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 10, 15), res) + assert.Equal(t, IvalWithStartStop(5, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 13, 15), res) + assert.Equal(t, IvalWithStartStop(5, 13, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 3, 15), res) + assert.Equal(t, IvalWithStartStop(5, 3, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 10, 15), res) + assert.Equal(t, IvalWithStartStop(5, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 10, 15), res) + assert.Equal(t, IvalWithStartStop(5, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 13, 15), res) + assert.Equal(t, IvalWithStartStop(5, 13, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 3, 15), res) + assert.Equal(t, IvalWithStartStop(5, 3, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 6, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 6, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 13, 15), res) + assert.Equal(t, IvalWithStartStop(5, 13, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(5, 3, 15), res) + assert.Equal(t, IvalWithStartStop(5, 3, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 13, 15)) assert.False(t, ok) - assert.Equal(t, circularInterval{}, res) + assert.Equal(t, CircularInterval{}, res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 3, 15)) assert.False(t, ok) - assert.Equal(t, circularInterval{}, res) + assert.Equal(t, CircularInterval{}, res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(12, 10, 15), res) + assert.Equal(t, IvalWithStartStop(12, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(12, 10, 15), res) + assert.Equal(t, IvalWithStartStop(12, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(12, 10, 15), res) + assert.Equal(t, IvalWithStartStop(12, 10, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 11, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 11, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(12, 11, 15), res) + assert.Equal(t, IvalWithStartStop(12, 11, 15), res) - ok, res = i.tryOverlapWith(ivalWithFullLen(15)) + ok, res = i.TryOverlapWith(IvalWithFullLen(15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) }) t.Run("for a wrap around", func(t *testing.T) { - i := ivalWithStartStop(10, 5, 15) + i := IvalWithStartStop(10, 5, 15) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 5, 15), res) + assert.Equal(t, IvalWithStartStop(10, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 5, 15), res) + assert.Equal(t, IvalWithStartStop(10, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 8, 15), res) + assert.Equal(t, IvalWithStartStop(10, 8, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 1, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 1, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 8, 15), res) + assert.Equal(t, IvalWithStartStop(10, 8, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 8, 15)) assert.False(t, ok) - assert.Equal(t, circularInterval{}, res) + assert.Equal(t, CircularInterval{}, res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(7, 5, 15), res) + assert.Equal(t, IvalWithStartStop(7, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(7, 5, 15), res) + assert.Equal(t, IvalWithStartStop(7, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(7, 5, 15), res) + assert.Equal(t, IvalWithStartStop(7, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(7, 5, 15), res) + assert.Equal(t, IvalWithStartStop(7, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 6, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 6, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(7, 6, 15), res) + assert.Equal(t, IvalWithStartStop(7, 6, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 5, 15), res) + assert.Equal(t, IvalWithStartStop(10, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 5, 15), res) + assert.Equal(t, IvalWithStartStop(10, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 5, 15), res) + assert.Equal(t, IvalWithStartStop(10, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 8, 15), res) + assert.Equal(t, IvalWithStartStop(10, 8, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 5, 15), res) + assert.Equal(t, IvalWithStartStop(10, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 5, 15), res) + assert.Equal(t, IvalWithStartStop(10, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 5, 15), res) + assert.Equal(t, IvalWithStartStop(10, 5, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithStartStop(10, 8, 15), res) + assert.Equal(t, IvalWithStartStop(10, 8, 15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 11, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 11, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithFullLen(15)) + ok, res = i.TryOverlapWith(IvalWithFullLen(15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) }) t.Run("for a wrap around", func(t *testing.T) { - i := ivalWithFullLen(15) + i := IvalWithFullLen(15) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(2, 1, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(2, 1, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(5, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(5, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(7, 6, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(7, 6, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(10, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(10, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 13, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 13, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 3, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 3, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 5, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 5, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 8, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 8, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 10, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 10, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithStartStop(12, 11, 15)) + ok, res = i.TryOverlapWith(IvalWithStartStop(12, 11, 15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) - ok, res = i.tryOverlapWith(ivalWithFullLen(15)) + ok, res = i.TryOverlapWith(IvalWithFullLen(15)) assert.True(t, ok) - assert.Equal(t, ivalWithFullLen(15), res) + assert.Equal(t, IvalWithFullLen(15), res) }) diff --git a/prover/maths/common/smartvectors/constant.go b/prover/maths/common/smartvectors/constant.go index fbb491218..8779f8eb8 100644 --- a/prover/maths/common/smartvectors/constant.go +++ b/prover/maths/common/smartvectors/constant.go @@ -2,6 +2,7 @@ package smartvectors import ( "fmt" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" "github.com/consensys/linea-monorepo/prover/maths/field" "github.com/consensys/linea-monorepo/prover/utils" @@ -25,7 +26,17 @@ func NewConstant(val field.Element, length int) *Constant { func (c *Constant) Len() int { return c.length } // Returns an entry of the constant -func (c *Constant) Get(int) field.Element { return c.val } +func (c *Constant) GetBase(int) (field.Element, error) { return c.val, nil } + +func (c *Constant) GetExt(int) fext.Element { return *new(fext.Element).SetFromBase(&c.val) } + +func (r *Constant) Get(n int) field.Element { + res, err := r.GetBase(n) + if err != nil { + panic(err) + } + return res +} // Returns a subvector func (c *Constant) SubVector(start, stop int) SmartVector { @@ -36,7 +47,7 @@ func (c *Constant) SubVector(start, stop int) SmartVector { utils.Panic("zero length are not allowed") } assertCorrectBound(start, c.length) - // The +1 is because we accept if "stop = length" + // The +1 is because we accept if "Stop = length" assertCorrectBound(stop, c.length+1) return NewConstant(c.val, stop-start) } @@ -54,6 +65,12 @@ func (c *Constant) WriteInSlice(s []field.Element) { } } +func (c *Constant) WriteInSliceExt(s []fext.Element) { + for i := 0; i < len(s); i++ { + s[i].SetFromBase(&c.val) + } +} + func (c *Constant) Val() field.Element { return c.val } @@ -67,5 +84,24 @@ func (c *Constant) DeepCopy() SmartVector { } func (c *Constant) IntoRegVecSaveAlloc() []field.Element { - return IntoRegVec(c) + res, err := c.IntoRegVecSaveAllocBase() + if err != nil { + panic(conversionError) + } + return res +} + +// Temporary function for code transition +func (c *Constant) IntoRegVecSaveAllocBase() ([]field.Element, error) { + return IntoRegVec(c), nil +} + +func (c *Constant) IntoRegVecSaveAllocExt() []fext.Element { + temp := IntoRegVec(c) + res := make([]fext.Element, len(temp)) + for i := 0; i < len(temp); i++ { + elem := temp[i] + res[i].SetFromBase(&elem) + } + return res } diff --git a/prover/maths/common/smartvectors/fft.go b/prover/maths/common/smartvectors/fft.go index ef21f5a11..2e9e3222e 100644 --- a/prover/maths/common/smartvectors/fft.go +++ b/prover/maths/common/smartvectors/fft.go @@ -49,7 +49,7 @@ func FFT(v SmartVector, decimation fft.Decimation, bitReverse bool, cosetRatio i // The polynomial is the constant polynomial, response does not depends on the decimation // or bitReverse interval := x.interval() - if interval.intervalLen == 1 && interval.start() == 0 && x.paddingVal.IsZero() { + if interval.IntervalLen == 1 && interval.Start() == 0 && x.paddingVal.IsZero() { // In this case, the response is a constant vector return NewConstant(x.window[0], x.Len()) } @@ -130,7 +130,7 @@ func FFTInverse(v SmartVector, decimation fft.Decimation, bitReverse bool, coset // It's a multiple of the first Lagrange polynomial c * (1 + x + x^2 + x^3 + ...) // The response is (c) = (c/N, c/N, c/N, ...) interval := x.interval() - if interval.intervalLen == 1 && interval.start() == 0 && x.paddingVal.IsZero() { + if interval.IntervalLen == 1 && interval.Start() == 0 && x.paddingVal.IsZero() { constTerm := field.NewElement(uint64(x.Len())) constTerm.Inverse(&constTerm) constTerm.Mul(&constTerm, &x.window[0]) diff --git a/prover/maths/common/smartvectors/fuzzing.go b/prover/maths/common/smartvectors/fuzzing.go index 1fd07178a..efe60e3d4 100644 --- a/prover/maths/common/smartvectors/fuzzing.go +++ b/prover/maths/common/smartvectors/fuzzing.go @@ -124,9 +124,9 @@ func (gen *testCaseGen) NewTestCaseForProd() (tcase testCase) { case windowT: v := gen.genWindow(val, val) tcase.svecs[i] = v - start := normalize(v.interval().start(), gen.windowMustStartAfter, gen.fullLen) + start := normalize(v.interval().Start(), gen.windowMustStartAfter, gen.fullLen) winMinStart = utils.Min(winMinStart, start) - stop := normalize(v.interval().stop(), gen.windowMustStartAfter, gen.fullLen) + stop := normalize(v.interval().Stop(), gen.windowMustStartAfter, gen.fullLen) if stop < start { stop += gen.fullLen } @@ -200,10 +200,10 @@ func (gen *testCaseGen) NewTestCaseForLinComb() (tcase testCase) { case windowT: v := gen.genWindow(val, val) tcase.svecs[i] = v - start := normalize(v.interval().start(), gen.windowMustStartAfter, gen.fullLen) + start := normalize(v.interval().Start(), gen.windowMustStartAfter, gen.fullLen) winMinStart = utils.Min(winMinStart, start) - stop := normalize(v.interval().stop(), gen.windowMustStartAfter, gen.fullLen) + stop := normalize(v.interval().Stop(), gen.windowMustStartAfter, gen.fullLen) if stop < start { stop += gen.fullLen } @@ -268,10 +268,10 @@ func (gen *testCaseGen) NewTestCaseForPolyEval() (tcase testCase) { case windowT: v := gen.genWindow(val, val) tcase.svecs[i] = v - start := normalize(v.interval().start(), gen.windowMustStartAfter, gen.fullLen) + start := normalize(v.interval().Start(), gen.windowMustStartAfter, gen.fullLen) winMinStart = utils.Min(winMinStart, start) - stop := normalize(v.interval().stop(), gen.windowMustStartAfter, gen.fullLen) + stop := normalize(v.interval().Stop(), gen.windowMustStartAfter, gen.fullLen) if stop < start { stop += gen.fullLen } diff --git a/prover/maths/common/smartvectors/polynomial.go b/prover/maths/common/smartvectors/polynomial.go index 5f7ede085..68ad841ff 100644 --- a/prover/maths/common/smartvectors/polynomial.go +++ b/prover/maths/common/smartvectors/polynomial.go @@ -127,7 +127,7 @@ func BatchInterpolate(vs []SmartVector, x field.Element, oncoset ...bool) []fiel } // non-constant vectors - polys[i] = vs[i].IntoRegVecSaveAlloc() + polys[i], _ = vs[i].IntoRegVecSaveAllocBase() } }) diff --git a/prover/maths/common/smartvectors/regular.go b/prover/maths/common/smartvectors/regular.go index 79d3bb773..05e895551 100644 --- a/prover/maths/common/smartvectors/regular.go +++ b/prover/maths/common/smartvectors/regular.go @@ -2,6 +2,7 @@ package smartvectors import ( "fmt" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" "github.com/consensys/linea-monorepo/prover/maths/common/mempool" "github.com/consensys/linea-monorepo/prover/maths/common/vector" @@ -24,7 +25,19 @@ func NewRegular(v []field.Element) *Regular { func (r *Regular) Len() int { return len(*r) } // Returns a particular element of the vector -func (r *Regular) Get(n int) field.Element { return (*r)[n] } +func (r *Regular) GetBase(n int) (field.Element, error) { return (*r)[n], nil } + +func (r *Regular) GetExt(n int) fext.Element { + return *new(fext.Element).SetFromBase(&(*r)[n]) +} + +func (r *Regular) Get(n int) field.Element { + res, err := r.GetBase(n) + if err != nil { + panic(err) + } + return res +} // Returns a subvector of the regular func (r *Regular) SubVector(start, stop int) SmartVector { @@ -73,6 +86,14 @@ func (r *Regular) WriteInSlice(s []field.Element) { copy(s, *r) } +func (r *Regular) WriteInSliceExt(s []fext.Element) { + assertHasLength(len(s), len(*r)) + for i := 0; i < len(s); i++ { + elem, _ := r.GetBase(i) + s[i].SetFromBase(&elem) + } +} + func (r *Regular) Pretty() string { return fmt.Sprintf("Regular[%v]", vector.Prettify(*r)) } @@ -136,7 +157,24 @@ func (r *Regular) DeepCopy() SmartVector { // Converts a smart-vector into a normal vec. The implementation minimizes // then number of copies. func (r *Regular) IntoRegVecSaveAlloc() []field.Element { - return (*r)[:] + res, err := r.IntoRegVecSaveAllocBase() + if err != nil { + panic(conversionError) + } + return res +} + +func (r *Regular) IntoRegVecSaveAllocBase() ([]field.Element, error) { + return (*r)[:], nil +} + +func (r *Regular) IntoRegVecSaveAllocExt() []fext.Element { + temp := make([]fext.Element, r.Len()) + for i := 0; i < r.Len(); i++ { + elem, _ := r.GetBase(i) + temp[i].SetFromBase(&elem) + } + return temp } type Pooled struct { diff --git a/prover/maths/common/smartvectors/rotated.go b/prover/maths/common/smartvectors/rotated.go index ffe344477..b251fbcd5 100644 --- a/prover/maths/common/smartvectors/rotated.go +++ b/prover/maths/common/smartvectors/rotated.go @@ -2,6 +2,7 @@ package smartvectors import ( "fmt" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/maths/field" @@ -54,12 +55,26 @@ func (r *Rotated) Len() int { } // Returns a particular element of the vector +func (r *Rotated) GetBase(n int) (field.Element, error) { + return r.v.GetBase(utils.PositiveMod(n+r.offset, r.Len())) +} + +// Returns a particular element of the vector +func (r *Rotated) GetExt(n int) fext.Element { + temp, _ := r.v.GetBase(utils.PositiveMod(n+r.offset, r.Len())) + return *new(fext.Element).SetFromBase(&temp) +} + func (r *Rotated) Get(n int) field.Element { - return r.v.Get(utils.PositiveMod(n+r.offset, r.Len())) + res, err := r.GetBase(n) + if err != nil { + panic(err) + } + return res } // Returns a particular element. The subvector is taken at indices -// [start, stop). (stop being excluded from the span) +// [Start, Stop). (Stop being excluded from the span) func (r *Rotated) SubVector(start, stop int) SmartVector { if stop+r.offset < len(r.v.Regular) && start+r.offset > 0 { @@ -73,16 +88,16 @@ func (r *Rotated) SubVector(start, stop int) SmartVector { // checking if stop <= start { - utils.Panic("the start %v >= stop %v", start, stop) + utils.Panic("the Start %v >= Stop %v", start, stop) } // boundary checks if start < 0 { - utils.Panic("the start value was negative %v", start) + utils.Panic("the Start value was negative %v", start) } if stop > size { - utils.Panic("the stop is OOO : %v (the length is %v)", stop, size) + utils.Panic("the Stop is OOO : %v (the length is %v)", stop, size) } // normalize the offset to something positive [0: size) @@ -130,6 +145,14 @@ func (r *Rotated) WriteInSlice(s []field.Element) { res.WriteInSlice(s) } +func (r *Rotated) WriteInSliceExt(s []fext.Element) { + temp := rotatedAsRegular(r) + for i := 0; i < temp.Len(); i++ { + elem, _ := temp.GetBase(i) + s[i].SetFromBase(&elem) + } +} + func (r *Rotated) Pretty() string { return fmt.Sprintf("Rotated[%v, %v]", r.v.Pretty(), r.offset) } @@ -141,7 +164,24 @@ func rotatedAsRegular(r *Rotated) *Regular { } func (r *Rotated) IntoRegVecSaveAlloc() []field.Element { - return *rotatedAsRegular(r) + res, err := r.IntoRegVecSaveAllocBase() + if err != nil { + panic(conversionError) + } + return res +} + +func (r *Rotated) IntoRegVecSaveAllocBase() ([]field.Element, error) { + return *rotatedAsRegular(r), nil +} + +func (r *Rotated) IntoRegVecSaveAllocExt() []fext.Element { + temp := *rotatedAsRegular(r) + res := make([]fext.Element, temp.Len()) + for i := 0; i < temp.Len(); i++ { + res[i].SetFromBase(&temp[i]) + } + return res } // SoftRotate converts v into a [SmartVector] representing the same diff --git a/prover/maths/common/smartvectors/smartvector_test.go b/prover/maths/common/smartvectors/smartvector_test.go index 854b0a8c4..6fb5014f1 100644 --- a/prover/maths/common/smartvectors/smartvector_test.go +++ b/prover/maths/common/smartvectors/smartvector_test.go @@ -101,7 +101,7 @@ func TestSubvectorFuzzy(t *testing.T) { for i := 0; i < stop-start; i++ { expected := v.Get(start + i) actual := sub.Get(i) - require.Equal(t, expected.String(), actual.String(), "start %v, stop %v, i %v", start, stop, i) + require.Equal(t, expected.String(), actual.String(), "Start %v, Stop %v, i %v", start, stop, i) } }, diff --git a/prover/maths/common/smartvectors/smartvectors.go b/prover/maths/common/smartvectors/smartvectors.go index 5ce9b2478..4a735c775 100644 --- a/prover/maths/common/smartvectors/smartvectors.go +++ b/prover/maths/common/smartvectors/smartvectors.go @@ -2,6 +2,7 @@ package smartvectors import ( "fmt" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" "math/rand" "github.com/consensys/gnark/frontend" @@ -10,6 +11,8 @@ import ( "github.com/consensys/linea-monorepo/prover/utils" ) +const conversionError = "smartvector holds field extensions, but a base element was requested" + // SmartVector is an abstraction over vectors of field elements that can be // optimized for structured vectors. For instance, if we have a vector of // repeated elements we can use smartvectors.NewConstant(x, n) to represent it. @@ -27,14 +30,17 @@ type SmartVector interface { // Len returns the length of the SmartVector Len() int // Get returns an entry of the SmartVector at particular position + GetBase(int) (field.Element, error) Get(int) field.Element - // SubVector returns a subvector of the [SmartVector]. It mirrors slice[start:stop] + GetExt(int) fext.Element + // SubVector returns a subvector of the [SmartVector]. It mirrors slice[Start:Stop] SubVector(int, int) SmartVector // RotateRight cyclically rotates the SmartVector RotateRight(int) SmartVector // WriteInSlice writes the SmartVector into a slice. The slice must be just // as large as [Len] otherwise the function will panic WriteInSlice([]field.Element) + WriteInSliceExt([]fext.Element) // Pretty returns a prettified version of the vector, useful for debugging. Pretty() string // DeepCopy returns a deep-copy of the SmartVector which can be freely @@ -43,6 +49,8 @@ type SmartVector interface { // IntoRegVecSaveAlloc converts a smart-vector into a normal vec. The // implementation minimizes then number of copies IntoRegVecSaveAlloc() []field.Element + IntoRegVecSaveAllocBase() ([]field.Element, error) + IntoRegVecSaveAllocExt() []fext.Element } // AllocateRegular returns a newly allocated smart-vector @@ -84,11 +92,26 @@ func IntoRegVec(s SmartVector) []field.Element { return res } +func IntoRegVecExt(s SmartVector) []fext.Element { + res := make([]fext.Element, s.Len()) + s.WriteInSliceExt(res) + return res +} + // IntoGnarkAssignment converts a smart-vector into a gnark assignment func IntoGnarkAssignment(sv SmartVector) []frontend.Variable { res := make([]frontend.Variable, sv.Len()) - for i := range res { - res[i] = sv.Get(i) + _, err := sv.GetBase(0) + if err == nil { + for i := range res { + elem, _ := sv.GetBase(i) + res[i] = elem + } + } else { + for i := range res { + elem := sv.GetExt(i) + res[i] = elem + } } return res } @@ -163,15 +186,48 @@ func Density(v SmartVector) int { // if the vector is Padded with zeroes it return the window. // Namely, the part without zero pads. func Window(v SmartVector) []field.Element { + res, err := WindowBase(v) + if err != nil { + panic(conversionError) + } + return res +} + +func WindowBase(v SmartVector) ([]field.Element, error) { + switch w := v.(type) { + case *Constant: + return w.IntoRegVecSaveAllocBase() + case *PaddedCircularWindow: + return w.window, nil + case *Regular: + return *w, nil + case *Rotated: + return w.IntoRegVecSaveAllocBase() + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } +} + +func WindowExt(v SmartVector) []fext.Element { switch w := v.(type) { case *Constant: - return w.IntoRegVecSaveAlloc() + return w.IntoRegVecSaveAllocExt() case *PaddedCircularWindow: - return w.window + temp := make([]fext.Element, len(w.window)) + for i := 0; i < len(w.window); i++ { + elem := w.window[i] + temp[i].SetFromBase(&elem) + } + return temp case *Regular: - return *w + temp := make([]fext.Element, len(*w)) + for i := 0; i < len(*w); i++ { + elem, _ := w.GetBase(i) + temp[i].SetFromBase(&elem) + } + return temp case *Rotated: - return w.IntoRegVecSaveAlloc() + return w.IntoRegVecSaveAllocExt() default: panic(fmt.Sprintf("unexpected type %T", v)) } diff --git a/prover/maths/common/smartvectors/vectorext/vector.go b/prover/maths/common/smartvectors/vectorext/vector.go new file mode 100644 index 000000000..56993d597 --- /dev/null +++ b/prover/maths/common/smartvectors/vectorext/vector.go @@ -0,0 +1,334 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package vectorext + +import ( + "bytes" + "encoding/binary" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "io" + "runtime" + "strings" + "sync" + "sync/atomic" + "unsafe" +) + +const ( + frBytes = 32 + Bytes = 64 // number of bytes needed to represent a Element +) + +// Vector represents a slice of Element. +// +// It implements the following interfaces: +// - Stringer +// - io.WriterTo +// - io.ReaderFrom +// - encoding.BinaryMarshaler +// - encoding.BinaryUnmarshaler +// - sort.Interface +type Vector []fext.Element + +// MarshalBinary implements encoding.BinaryMarshaler +func (vector *Vector) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + if _, err = vector.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (vector *Vector) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := vector.ReadFrom(r) + return err +} + +// WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. +// Length of the vector is encoded as a uint32 on the first 4 bytes. +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { + // encode slice length + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { + return 0, err + } + + n := int64(4) + + var buf [frBytes]byte + for i := 0; i < len(*vector); i++ { + subElems := [2]fr.Element{(*vector)[i].A0, (*vector)[i].A1} + for j := 0; j < 2; j++ { + fr.BigEndian.PutElement(&buf, subElems[j]) + m, err := w.Write(buf[:]) + n += int64(m) + if err != nil { + return n, err + } + } + } + return n, nil +} + +// AsyncReadFrom reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. +// It also returns a channel that will be closed when the validation is done. +// The validation consist of checking that the elements are smaller than the modulus, and +// converting them to montgomery form. +func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { + chErr := make(chan error, 1) + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + close(chErr) + return int64(read), err, chErr + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + if sliceLen == 0 { + close(chErr) + return n, nil, chErr + } + + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) + read, err := io.ReadFull(r, bSlice) + n += int64(read) + if err != nil { + close(chErr) + return n, err, chErr + } + + go func() { + var cptErrors uint64 + // process the elements in parallel + execute(int(sliceLen), func(start, end int) { + + subElems := make([]fr.Element, 2) + for i := start; i < end; i++ { + // we have to set vector[i] + for j := 0; j < 2; j++ { + var z fr.Element + bstart := (i + j) * frBytes + bend := bstart + frBytes + b := bSlice[bstart:bend] + z[0] = binary.BigEndian.Uint64(b[24:32]) + z[1] = binary.BigEndian.Uint64(b[16:24]) + z[2] = binary.BigEndian.Uint64(b[8:16]) + z[3] = binary.BigEndian.Uint64(b[0:8]) + + if !fext.SmallerThanModulus(&z) { + atomic.AddUint64(&cptErrors, 1) + return + } + fext.ToMont(&z) + subElems[j] = z + } + zExt := fext.Element{subElems[0], subElems[1]} + (*vector)[i] = zExt + } + }) + + if cptErrors > 0 { + chErr <- fmt.Errorf("async read: %d elements failed validation", cptErrors) + } + close(chErr) + }() + return n, nil, chErr +} + +// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { + + var buf [frBytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + return int64(read), err + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + + for i := 0; i < int(sliceLen); i++ { + var A0, A1 fr.Element + read, err := io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + A0, err = fr.BigEndian.Element(&buf) + if err != nil { + return n, err + } + + // read the second element + read, err = io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + A1, err = fr.BigEndian.Element(&buf) + if err != nil { + return n, err + } + (*vector)[i] = fext.Element{A0, A1} + } + + return n, nil +} + +// String implements fmt.Stringer interface +func (vector Vector) String() string { + var sbb strings.Builder + sbb.WriteByte('[') + for i := 0; i < len(vector); i++ { + sbb.WriteString(vector[i].String()) + if i != len(vector)-1 { + sbb.WriteByte(',') + } + } + sbb.WriteByte(']') + return sbb.String() +} + +// Len is the number of elements in the collection. +func (vector Vector) Len() int { + return len(vector) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (vector Vector) Less(i, j int) bool { + return vector[i].Cmp(&vector[j]) == -1 +} + +// Swap swaps the elements with indexes i and j. +func (vector Vector) Swap(i, j int) { + vector[i], vector[j] = vector[j], vector[i] +} + +func addVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Add: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Add(&a[i], &b[i]) + } +} + +func subVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Sub: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Sub(&a[i], &b[i]) + } +} + +func scalarMulVecGeneric(res, a Vector, b *fext.Element) { + if len(a) != len(res) { + panic("vector.ScalarMul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], b) + } +} + +func sumVecGeneric(res *fext.Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + +func innerProductVecGeneric(res *fext.Element, a, b Vector) { + if len(a) != len(b) { + panic("vector.InnerProduct: vectors don't have the same length") + } + var tmp fext.Element + for i := 0; i < len(a); i++ { + tmp.Mul(&a[i], &b[i]) + res.Add(res, &tmp) + } +} + +func mulVecGeneric(res, a, b Vector) { + if len(a) != len(b) || len(a) != len(res) { + panic("vector.Mul: vectors don't have the same length") + } + for i := 0; i < len(a); i++ { + res[i].Mul(&a[i], &b[i]) + } +} + +// TODO @gbotrel make a public package out of that. +// execute executes the work function in parallel. +// this is copy paste from internal/parallel/parallel.go +// as we don't want to generate code importing internal/ +func execute(nbIterations int, work func(int, int), maxCpus ...int) { + + nbTasks := runtime.NumCPU() + if len(maxCpus) == 1 { + nbTasks = maxCpus[0] + if nbTasks < 1 { + nbTasks = 1 + } else if nbTasks > 512 { + nbTasks = 512 + } + } + + if nbTasks == 1 { + // no go routines + work(0, nbIterations) + return + } + + nbIterationsPerCpus := nbIterations / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = nbIterations + } + + var wg sync.WaitGroup + + extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + go func() { + work(_start, _end) + wg.Done() + }() + } + + wg.Wait() +} diff --git a/prover/maths/common/smartvectors/vectorext/vector_ops.go b/prover/maths/common/smartvectors/vectorext/vector_ops.go new file mode 100644 index 000000000..2869465ac --- /dev/null +++ b/prover/maths/common/smartvectors/vectorext/vector_ops.go @@ -0,0 +1,58 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package vectorext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/field/fext" +) + +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *fext.Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res fext.Element) { + sumVecGeneric(&res, *vector) + return +} + +// InnerProduct computes the inner product of two vectors. +// It panics if the vectors don't have the same length. +func (vector *Vector) InnerProduct(other Vector) (res fext.Element) { + innerProductVecGeneric(&res, *vector, other) + return +} + +// Mul multiplies two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Mul(a, b Vector) { + mulVecGeneric(*vector, a, b) +} diff --git a/prover/maths/common/smartvectors/vectorext/vectorext.go b/prover/maths/common/smartvectors/vectorext/vectorext.go new file mode 100644 index 000000000..73c34f196 --- /dev/null +++ b/prover/maths/common/smartvectors/vectorext/vectorext.go @@ -0,0 +1,258 @@ +// vector offers a set of utility function relating to slices of field element +// and that are commonly used as part of the repo. +package vectorext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "math/rand" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// DeepCopy deep-copies the input vector +func DeepCopy(pol []fext.Element) []fext.Element { + return append([]fext.Element{}, pol...) +} + +// ScalarMul multiplies a vector by a scalar - in place. +// The result should be preallocated or it is going to panic. +// res = vec is a valid parameter assignment. +func ScalarMul(res, vec []fext.Element, scalar fext.Element) { + + if len(res)+len(vec) == 0 { + return + } + + r := Vector(res) + r.ScalarMul(Vector(vec), &scalar) +} + +// ScalarProd returns the scalar (inner) product of a and b. The function panics +// if a and b do not have the same size. If they have both empty vectors, the +// function returns 0. +func ScalarProd(a, b []fext.Element) fext.Element { + // The length checks is done by gnark-crypto already + a_ := Vector(a) + res := a_.InnerProduct(Vector(b)) + return res +} + +// Rand creates a random vector of size n +func Rand(n int) []fext.Element { + vec := make([]fext.Element, n) + for i := range vec { + _, err := vec[i].SetRandom() + // Just to enfore never having to deal with zeroes + if err != nil { + panic(err) + } + } + return vec +} + +// MulElementWise multiplies two vectors element wise and write the result in +// res. res = a is a valid assignment. +func MulElementWise(res, a, b []fext.Element) { + // The length checks is done by gnark-crypto already + res_ := Vector(res) + res_.Mul(Vector(a), Vector(b)) +} + +// Prettify returns a string representing `a` in a human-readable fashion +func Prettify(a []fext.Element) string { + res := "[" + + for i := range a { + // Discards the case first element when adding a comma + if i > 0 { + res += ", " + } + + res += fmt.Sprintf("%v", a[i].String()) + } + res += "]" + + return res +} + +// Reverse the elements of a vector inplace +func Reverse(v []fext.Element) { + n := len(v) - 1 + for i := 0; i < len(v)/2; i++ { + v[i], v[n-i] = v[n-i], v[i] + } +} + +// Repeat returns a vector of size n whose values are all equal to x. +func Repeat(x fext.Element, n int) []fext.Element { + res := make([]fext.Element, n) + for i := range res { + res[i].Set(&x) + } + return res +} + +// ForTest returns a vector instantiated from a list of integers. +func ForTest(xs ...int) []fext.Element { + res := make([]fext.Element, len(xs)) + for i, x := range xs { + res[i].SetInt64(int64(x)) + } + return res +} + +// Add adds two vectors `a` and `b` and put the result in `res` +// `res` must be pre-allocated by the caller and res, a and b must all have +// the same size. +// res == a or res == b or both is valid assignment. +func Add(res, a, b []fext.Element, extras ...[]fext.Element) { + + if len(res)+len(a)+len(b) == 0 { + return + } + + r := Vector(res) + r.Add(a, b) + + for _, x := range extras { + r.Add(r, Vector(x)) + } +} + +func AddExt(res, a, b []fext.Element, extras ...[]fext.Element) { + + if len(res)+len(a)+len(b) == 0 { + return + } + + r := Vector(res) + r.Add(a, b) + + for _, x := range extras { + r.Add(r, Vector(x)) + } +} + +// Sub substracts two vectors `a` and `b` and put the result in `res` +// `res` must be pre-allocated by the caller and res, a and b must all have +// the same size. +// res == a or res == b or both is valid assignment. +func Sub(res, a, b []fext.Element) { + + if len(res)+len(a)+len(b) == 0 { + return + } + + r := Vector(res) + r.Sub(Vector(a), Vector(b)) +} + +// ZeroPad pads a vector to a given length. +// If the newLen is smaller than len(v), the function panic. It pads to the +// right (appending, not prepending) +// The resulting slice is allocated by the function, so it can be safely +// modified by the caller after the function returns. +func ZeroPad(v []fext.Element, newLen int) []fext.Element { + if newLen < len(v) { + utils.Panic("newLen (%v) < len(v) (%v)", newLen, len(v)) + } + res := make([]fext.Element, newLen) + copy(res, v) + return res +} + +// Interleave interleave two vectors: +// +// (a, a, a, a), (b, b, b, b) -> (a, b, a, b, a, b, a, b) +// +// The vecs[i] vectors must all have the same length +func Interleave(vecs ...[]fext.Element) []fext.Element { + numVecs := len(vecs) + vecSize := len(vecs[0]) + + // all vectors must have the same length + for i := range vecs { + if len(vecs[i]) != vecSize { + utils.Panic("length mismatch, %v != %v", len(vecs[i]), vecSize) + } + } + + res := make([]fext.Element, numVecs*vecSize) + for i := 0; i < vecSize; i++ { + for j := 0; j < numVecs; j++ { + res[i*numVecs+j] = vecs[j][i] + } + } + + return res +} + +// Fill a vector `vec` in place with the given value `val`. +func Fill(v []fext.Element, val fext.Element) { + for i := range v { + v[i] = val + } +} + +// PowerVec allocates and returns a vector of size n consisting of consecutive +// powers of x, starting from x^0 = 1 and ending on x^{n-1}. The function panics +// if given x=0 and returns an empty vector if n=0. +func PowerVec(x fext.Element, n int) []fext.Element { + if x == fext.Zero() { + utils.Panic("cannot build a power vec for x=0") + } + + if n == 0 { + return []fext.Element{} + } + + res := make([]fext.Element, n) + res[0].SetOne() + + for i := 1; i < n; i++ { + res[i].Mul(&res[i-1], &x) + } + + return res +} + +// IntoGnarkAssignment converts an array of field.Element into an array of +// frontend.Variable that can be used to assign a vector of frontend.Variable +// in a circuit or to generate a vector of constant in the circuit definition. +func IntoGnarkAssignment(msgData []fext.Element) []frontend.Variable { + assignedMsg := []frontend.Variable{} + for _, x := range msgData { + assignedMsg = append(assignedMsg, frontend.Variable(x)) + } + return assignedMsg +} + +// Equal compares a and b and returns a boolean indicating whether they contain +// the same value. The function assumes that a and b have the same length. It +// panics otherwise. +func Equal(a, b []fext.Element) bool { + + if len(a) != len(b) { + utils.Panic("a and b don't have the same length: %v %v", len(a), len(b)) + } + + for i := range a { + if a[i] != b[i] { + return false + } + } + + return true +} + +// PseudoRand generates a vector of field element with a given size using the +// provided random number generator +func PseudoRand(rng *rand.Rand, size int) []fext.Element { + slice := make([]fext.Element, size) + for i := range slice { + slice[i] = fext.PseudoRand(rng) + } + return slice +} diff --git a/prover/maths/common/smartvectors/windowed.go b/prover/maths/common/smartvectors/windowed.go index a1981263a..44fb103e3 100644 --- a/prover/maths/common/smartvectors/windowed.go +++ b/prover/maths/common/smartvectors/windowed.go @@ -2,6 +2,7 @@ package smartvectors import ( "fmt" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/maths/field" @@ -47,26 +48,39 @@ func (p *PaddedCircularWindow) Len() int { } // Returns a queries position -func (p *PaddedCircularWindow) Get(n int) field.Element { +func (p *PaddedCircularWindow) GetBase(n int) (field.Element, error) { // Check if the queried index is in the window posFromWindowsPoV := utils.PositiveMod(n-p.offset, p.totLen) if posFromWindowsPoV < len(p.window) { - return p.window[posFromWindowsPoV] + return p.window[posFromWindowsPoV], nil } // Else, return the padding value - return p.paddingVal + return p.paddingVal, nil } -// Extract a subvector from p[start:stop), the subvector cannot "roll-over". -// i.e, we enforce that start < stop +func (p *PaddedCircularWindow) GetExt(n int) fext.Element { + elem, _ := p.GetBase(n) + return *new(fext.Element).SetFromBase(&elem) +} + +func (r *PaddedCircularWindow) Get(n int) field.Element { + res, err := r.GetBase(n) + if err != nil { + panic(err) + } + return res +} + +// Extract a subvector from p[Start:Stop), the subvector cannot "roll-over". +// i.e, we enforce that Start < Stop func (p *PaddedCircularWindow) SubVector(start, stop int) SmartVector { - // negative start value is not allowed + // negative Start value is not allowed if start < 0 { - panic("negative start value is not allowed") + panic("negative Start value is not allowed") } // Sanity checks for all subvectors assertCorrectBound(start, p.totLen) - // The +1 is because we accept if "stop = length" + // The +1 is because we accept if "Stop = length" assertCorrectBound(stop, p.totLen+1) if start > stop { @@ -104,8 +118,8 @@ func (p *PaddedCircularWindow) SubVector(start, stop int) SmartVector { n := p.Len() b := stop - start - c := normalize(p.interval().start(), start, n) - d := normalize(p.interval().stop(), start, n) + c := normalize(p.interval().Start(), start, n) + d := normalize(p.interval().Stop(), start, n) // Case 1 : return a constant vector if b <= c && c < d { @@ -178,12 +192,22 @@ func (p *PaddedCircularWindow) WriteInSlice(buff []field.Element) { } } +func (p *PaddedCircularWindow) WriteInSliceExt(buff []fext.Element) { + temp := make([]field.Element, len(buff)) + p.WriteInSlice(temp) + for i := 0; i < len(buff); i++ { + elem := temp[i] + buff[i].SetFromBase(&elem) + } + +} + func (p *PaddedCircularWindow) Pretty() string { return fmt.Sprintf("Windowed[totlen=%v offset=%v, paddingVal=%v, window=%v]", p.totLen, p.offset, p.paddingVal.String(), vector.Prettify(p.window)) } -func (p *PaddedCircularWindow) interval() circularInterval { - return ivalWithStartLen(p.offset, len(p.window), p.totLen) +func (p *PaddedCircularWindow) interval() CircularInterval { + return IvalWithStartLen(p.offset, len(p.window), p.totLen) } // normalize converts the (circle) coordinator x to another coordinate by changing @@ -209,7 +233,7 @@ func processWindowedOnly(op operator, svecs []SmartVector, coeffs_ []int) (res S // First we compute the union windows. length := svecs[0].Len() windows := []PaddedCircularWindow{} - intervals := []circularInterval{} + intervals := []CircularInterval{} coeffs := []int{} // Gather all the windows into a slice @@ -229,32 +253,33 @@ func processWindowedOnly(op operator, svecs []SmartVector, coeffs_ []int) (res S } // has the dimension of the cover with garbage values in it - smallestCover := smallestCoverInterval(intervals) + smallestCover := SmallestCoverInterval(intervals) // Edge-case: in case the smallest-cover of the pcw found in svecs is the // full-circle the code below will not work as it assumes that is possible - if smallestCover.isFullCircle() { + if smallestCover.IsFullCircle() { for i, svec := range svecs { if _, ok := svec.(*PaddedCircularWindow); ok { - svecs[i] = NewRegular(svec.IntoRegVecSaveAlloc()) + temp, _ := svec.IntoRegVecSaveAllocBase() + svecs[i] = NewRegular(temp) } } return nil, 0 } - // Sanity-check : normally all offset are normalized, this should ensure that start + // Sanity-check : normally all offset are normalized, this should ensure that Start // is positive. This is critical here because if some of the offset are not normalized // then we may end up with a union windows that does not make sense. - if smallestCover.start() < 0 { - utils.Panic("All offset should be normalized, but start is %v", smallestCover.start()) + if smallestCover.Start() < 0 { + utils.Panic("All offset should be normalized, but Start is %v", smallestCover.Start()) } // Ensures we do not reuse an input vector here to limit the risk of overwriting one // of the input. This can happen if there is only a single window or if one windows // covers all the other. - unionWindow := make([]field.Element, smallestCover.intervalLen) + unionWindow := make([]field.Element, smallestCover.IntervalLen) var paddedTerm field.Element - offset := smallestCover.start() + offset := smallestCover.Start() /* Now we actually compute the linear combinations for all offsets @@ -265,8 +290,8 @@ func processWindowedOnly(op operator, svecs []SmartVector, coeffs_ []int) (res S interval := intervals[i] // Find the intersection with the larger window - start_ := normalize(interval.start(), offset, length) - stop_ := normalize(interval.stop(), offset, length) + start_ := normalize(interval.Start(), offset, length) + stop_ := normalize(interval.Stop(), offset, length) if stop_ == 0 { stop_ = length } @@ -283,10 +308,10 @@ func processWindowedOnly(op operator, svecs []SmartVector, coeffs_ []int) (res S continue } - // sanity-check : start and stop are consistent with the size of pcw + // sanity-check : Start and Stop are consistent with the size of pcw if stop_-start_ != len(pcw.window) { utils.Panic( - "sanity-check failed. The renormalized coordinates (start=%v, stop=%v) are inconsistent with pcw : (len=%v)", + "sanity-check failed. The renormalized coordinates (Start=%v, Stop=%v) are inconsistent with pcw : (len=%v)", start_, stop_, len(pcw.window), ) } @@ -305,7 +330,7 @@ func processWindowedOnly(op operator, svecs []SmartVector, coeffs_ []int) (res S op.constIntoVec(unionWindow[stop_:], &pcw.paddingVal, coeffs[i]) } - if smallestCover.isFullCircle() { + if smallestCover.IsFullCircle() { return NewRegular(unionWindow), numMatches } @@ -320,5 +345,24 @@ func (w *PaddedCircularWindow) DeepCopy() SmartVector { // Converts a smart-vector into a normal vec. The implementation minimizes // then number of copies. func (w *PaddedCircularWindow) IntoRegVecSaveAlloc() []field.Element { - return IntoRegVec(w) + res, err := w.IntoRegVecSaveAllocBase() + if err != nil { + panic(conversionError) + } + return res + +} + +func (w *PaddedCircularWindow) IntoRegVecSaveAllocBase() ([]field.Element, error) { + return IntoRegVec(w), nil +} + +func (w *PaddedCircularWindow) IntoRegVecSaveAllocExt() []fext.Element { + temp, _ := w.IntoRegVecSaveAllocBase() + res := make([]fext.Element, len(temp)) + for i := 0; i < len(temp); i++ { + elem := temp[i] + res[i].SetFromBase(&elem) + } + return res } diff --git a/prover/maths/common/smartvectorsext/arithmetic_basic_ext.go b/prover/maths/common/smartvectorsext/arithmetic_basic_ext.go new file mode 100644 index 000000000..1dc550f02 --- /dev/null +++ b/prover/maths/common/smartvectorsext/arithmetic_basic_ext.go @@ -0,0 +1,296 @@ +package smartvectorsext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/common/mempoolext" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// Add returns a smart-vector obtained by position-wise adding [SmartVector]. +// - all inputs `vecs` must have the same size, or the function panics +// - the output smart-vector has the same size as the input vectors +// - if no input vectors are provided, the function panics +func Add(vecs ...smartvectors.SmartVector) smartvectors.SmartVector { + + coeffs := make([]int, len(vecs)) + for i := range coeffs { + coeffs[i] = 1 + } + + return LinComb(coeffs, vecs) +} + +// Mul returns a smart-vector obtained by position-wise multiplying [SmartVector]. +// - all inputs `vecs` must have the same size, or the function panics +// - the output smart-vector has the same size as the input vectors +// - if no input vectors are provided, the function panics +func Mul(vecs ...smartvectors.SmartVector) smartvectors.SmartVector { + coeffs := make([]int, len(vecs)) + for i := range coeffs { + coeffs[i] = 1 + } + + return Product(coeffs, vecs) +} + +// ScalarMul returns a smart-vector obtained by multiplying a scalar with a [SmartVector]. +// - the output smart-vector has the same size as the input vector +func ScalarMul(vec smartvectors.SmartVector, x fext.Element) smartvectors.SmartVector { + xVec := NewConstantExt(x, vec.Len()) + return Mul(vec, xVec) +} + +// InnerProduct returns a scalar obtained as the inner-product of `a` and `b`. +// - a and b must have the same length, otherwise the function panics +func InnerProduct(a, b smartvectors.SmartVector) fext.Element { + if a.Len() != b.Len() { + panic("length mismatch") + } + + var res fext.Element + + for i := 0; i < a.Len(); i++ { + var tmp fext.Element + a_, b_ := a.GetExt(i), b.GetExt(i) + tmp.Mul(&a_, &b_) + res.Add(&res, &tmp) + } + + return res +} + +// PolyEval returns a [SmartVector] computed as: +// +// result = vecs[0] + vecs[1] * x + vecs[2] * x^2 + vecs[3] * x^3 + ... +// +// where `x` is a scalar and `vecs[i]` are [SmartVector] +func PolyEval(vecs []smartvectors.SmartVector, x fext.Element, p ...mempoolext.MemPool) (result smartvectors.SmartVector) { + + if len(vecs) == 0 { + panic("no input vectors") + } + + length := vecs[0].Len() + pool, hasPool := mempoolext.ExtractCheckOptionalStrict(length, p...) + + // Preallocate the intermediate values + var resReg, tmpVec []fext.Element + if !hasPool { + resReg = make([]fext.Element, length) + tmpVec = make([]fext.Element, length) + } else { + a := AllocFromPoolExt(pool) + b := AllocFromPoolExt(pool) + resReg, tmpVec = a.RegularExt, b.RegularExt + vectorext.Fill(resReg, fext.Zero()) + defer b.Free(pool) + } + + var tmpF, resCon fext.Element + var anyReg, anyCon bool + xPow := fext.One() + + accumulateReg := func(acc, v []fext.Element, x fext.Element) { + for i := 0; i < length; i++ { + tmpF.Mul(&v[i], &x) + acc[i].Add(&acc[i], &tmpF) + } + } + + // Computes the polynomial operation separately on the const, + // windows and regular and the aggregate the results at the end. + // The computation is done following horner's method. + for i := range vecs { + + v := vecs[i] + if asRotated, ok := v.(*RotatedExt); ok { + v = rotatedAsRegular(asRotated) + } + + switch casted := v.(type) { + case *ConstantExt: + anyCon = true + tmpF.Mul(&casted.val, &xPow) + resCon.Add(&resCon, &tmpF) + case *RegularExt: + anyReg = true + v := *casted + accumulateReg(resReg, v, xPow) + case *PooledExt: // e.g. from product + anyReg = true + v := casted.RegularExt + accumulateReg(resReg, v, xPow) + case *PaddedCircularWindowExt: + // treat it as a regular, reusing the buffer + anyReg = true + casted.WriteInSliceExt(tmpVec) + accumulateReg(resReg, tmpVec, xPow) + } + + xPow.Mul(&x, &xPow) + } + + switch { + case anyCon && anyReg: + for i := range resReg { + resReg[i].Add(&resReg[i], &resCon) + } + return NewRegularExt(resReg) + case anyCon && !anyReg: + // and we can directly unpool resreg because it was not used + if hasPool { + pool.Free(&resReg) + } + return NewConstantExt(resCon, length) + case !anyCon && anyReg: + return NewRegularExt(resReg) + } + + // can only happen if no vectors are found or if an unknow type is found + panic("unreachable") +} + +// BatchInvert performs the batch inverse operation over a [SmartVector] and +// returns a SmartVector of the same type. When an input element is zero, the +// function returns 0 at the corresponding position. +func BatchInvert(x smartvectors.SmartVector) smartvectors.SmartVector { + + switch v := x.(type) { + case *ConstantExt: + res := &ConstantExt{length: v.length} + res.val.Inverse(&v.val) + return res + case *PaddedCircularWindowExt: + res := &PaddedCircularWindowExt{ + totLen: v.totLen, + offset: v.offset, + window: fext.BatchInvert(v.window), + } + res.paddingVal.Inverse(&v.paddingVal) + return res + case *RotatedExt: + return NewRotatedExt( + fext.BatchInvert(v.v.RegularExt), + v.offset, + ) + case *PooledExt: + return NewRegularExt(fext.BatchInvert(v.RegularExt)) + case *RegularExt: + return NewRegularExt(fext.BatchInvert(*v)) + } + + panic("unsupported type") +} + +// IsZero returns a [SmartVector] z with the same type of structure than x such +// that x[i] = 0 => z[i] = 1 AND x[i] != 0 => z[i] = 0. +func IsZero(x smartvectors.SmartVector) smartvectors.SmartVector { + switch v := x.(type) { + + case *ConstantExt: + res := &ConstantExt{length: v.length} + if v.val == fext.Zero() { + res.val = fext.One() + } + return res + + case *PaddedCircularWindowExt: + res := &PaddedCircularWindowExt{ + totLen: v.totLen, + offset: v.offset, + window: make([]fext.Element, len(v.window)), + } + + if v.paddingVal == fext.Zero() { + res.paddingVal = fext.One() + } + + for i := range res.window { + if v.window[i] == fext.Zero() { + res.window[i] = fext.One() + } + } + return res + + case *RotatedExt: + res := make([]fext.Element, len(v.v.RegularExt)) + for i := range res { + if v.v.RegularExt[i] == fext.Zero() { + res[i] = fext.One() + } + } + return NewRotatedExt( + res, + v.offset, + ) + + case *RegularExt: + res := make([]fext.Element, len(*v)) + for i := range res { + if (*v)[i] == fext.Zero() { + res[i] = fext.One() + } + } + return NewRegularExt(res) + + case *PooledExt: + res := make([]fext.Element, len(v.RegularExt)) + for i := range res { + if v.RegularExt[i] == fext.Zero() { + res[i] = fext.One() + } + } + return NewRegularExt(res) + } + + panic("unsupported type") +} + +// Sum returns the field summation of all the elements contained in the vector +func Sum(a smartvectors.SmartVector) (res fext.Element) { + + switch v := a.(type) { + case *RegularExt: + res := fext.Zero() + for i := range *v { + res.Add(&res, &(*v)[i]) + } + return res + + case *PaddedCircularWindowExt: + res := fext.Zero() + for i := range v.window { + res.Add(&res, &v.window[i]) + } + constTerm := fext.NewElement(uint64(v.totLen-len(v.window)), 0) + constTerm.Mul(&constTerm, &v.paddingVal) + res.Add(&res, &constTerm) + return res + + case *ConstantExt: + res := fext.NewElement(uint64(v.length), 0) + res.Mul(&res, &v.val) + return res + + case *RotatedExt: + res := fext.Zero() + for i := range v.v.RegularExt { + res.Add(&res, &v.v.RegularExt[i]) + } + return res + + case *PooledExt: + res := fext.Zero() + for i := range v.RegularExt { + res.Add(&res, &v.RegularExt[i]) + } + return res + + default: + utils.Panic("unsupported type: %T", v) + } + + return res +} diff --git a/prover/maths/common/smartvectorsext/arithmetic_basic_test.go b/prover/maths/common/smartvectorsext/arithmetic_basic_test.go new file mode 100644 index 000000000..2eeb74f2f --- /dev/null +++ b/prover/maths/common/smartvectorsext/arithmetic_basic_test.go @@ -0,0 +1,125 @@ +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/stretchr/testify/assert" +) + +func TestBatchInvert(t *testing.T) { + + testCases := []smartvectors.SmartVector{ + NewConstantExt(fext.Zero(), 4), + NewConstantExt(fext.One(), 4), + ForTestExt(0, 1, 2, 3, 0, 0, 4, 4), + ForTestExt(0, 0, 0, 0), + ForTestExt(12, 13, 14, 15), + NewRotatedExt(RegularExt(vectorext.ForTest(0, 0, 2, 2)), 0), + NewRotatedExt(RegularExt(vectorext.ForTest(0, 0, 2, 2)), 1), + NewRotatedExt(RegularExt(vectorext.ForTest(1, 1, 2, 2)), 0), + NewRotatedExt(RegularExt(vectorext.ForTest(3, 3, 2, 2)), 1), + NewRotatedExt(RegularExt(vectorext.ForTest(0, 0, 0, 0)), 0), + NewRotatedExt(RegularExt(vectorext.ForTest(0, 0, 0, 0)), 1), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 0, 0), fext.Zero(), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 1, 1), fext.Zero(), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(1, 1, 2, 2), fext.Zero(), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 0, 0), fext.NewElement(42, 43), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 1, 1), fext.NewElement(42, 43), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(1, 1, 2, 2), fext.NewElement(42, 43), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 0, 0), fext.Zero(), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 1, 1), fext.Zero(), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(1, 1, 2, 2), fext.Zero(), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 0, 0), fext.NewElement(42, 43), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 1, 1), fext.NewElement(42, 43), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(1, 1, 2, 2), fext.NewElement(42, 43), 2, 8), + } + + for i := range testCases { + t.Run(fmt.Sprintf("testcase-%v", i), func(t *testing.T) { + + bi := BatchInvert(testCases[i]) + + assert.Equal(t, bi.Len(), testCases[i].Len()) + + for k := 0; k < bi.Len(); k++ { + var ( + x = bi.GetExt(k) + y = testCases[i].GetExt(k) + ) + + if y == fext.Zero() { + assert.Equal(t, fext.Zero(), x) + continue + } + + y.Inverse(&y) + assert.Equal(t, x, y) + } + }) + } +} + +func TestIsZero(t *testing.T) { + + testCases := []smartvectors.SmartVector{ + NewConstantExt(fext.Zero(), 4), + NewConstantExt(fext.One(), 4), + ForTestExt(0, 1, 2, 3, 0, 0, 4, 4), + ForTestExt(0, 0, 0, 0), + ForTestExt(12, 13, 14, 15), + NewRotatedExt(RegularExt(vectorext.ForTest(0, 0, 2, 2)), 0), + NewRotatedExt(RegularExt(vectorext.ForTest(0, 0, 2, 2)), 1), + NewRotatedExt(RegularExt(vectorext.ForTest(1, 1, 2, 2)), 0), + NewRotatedExt(RegularExt(vectorext.ForTest(3, 3, 2, 2)), 1), + NewRotatedExt(RegularExt(vectorext.ForTest(0, 0, 0, 0)), 0), + NewRotatedExt(RegularExt(vectorext.ForTest(0, 0, 0, 0)), 1), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 0, 0), fext.Zero(), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 1, 1), fext.Zero(), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(1, 1, 2, 2), fext.Zero(), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 0, 0), fext.NewElement(42, 43), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 1, 1), fext.NewElement(42, 43), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(1, 1, 2, 2), fext.NewElement(42, 43), 0, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 0, 0), fext.Zero(), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 1, 1), fext.Zero(), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(1, 1, 2, 2), fext.Zero(), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 0, 0), fext.NewElement(42, 43), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(0, 0, 1, 1), fext.NewElement(42, 43), 2, 8), + NewPaddedCircularWindowExt(vectorext.ForTest(1, 1, 2, 2), fext.NewElement(42, 43), 2, 8), + } + + for i := range testCases { + t.Run(fmt.Sprintf("testcase-%v", i), func(t *testing.T) { + + iz := IsZero(testCases[i]) + + assert.Equal(t, iz.Len(), testCases[i].Len()) + + for k := 0; k < iz.Len(); k++ { + var ( + x = iz.GetExt(k) + y = testCases[i].GetExt(k) + ) + + if y == fext.Zero() { + a, b := x.Uint64() + assert.Equal(t, uint64(1), a) + assert.Equal(t, uint64(0), b) + } + + if y != fext.Zero() { + a, b := x.Uint64() + assert.Equal(t, uint64(0), a) + assert.Equal(t, uint64(0), b) + } + + if t.Failed() { + t.Fatalf("failed at position %v for testcase %v", k, i) + } + } + }) + } +} diff --git a/prover/maths/common/smartvectorsext/arithmetic_ext_test.go b/prover/maths/common/smartvectorsext/arithmetic_ext_test.go new file mode 100644 index 000000000..79e8983b2 --- /dev/null +++ b/prover/maths/common/smartvectorsext/arithmetic_ext_test.go @@ -0,0 +1,337 @@ +//go:build !fuzzlight + +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/common/mempoolext" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestFuzzProduct(t *testing.T) { + + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForProd() + + success := t.Run(tcase.name, func(t *testing.T) { + actualProd := Product(tcase.coeffs, tcase.svecs) + require.Equal(t, tcase.expectedValue.Pretty(), actualProd.Pretty(), "product failed") + + // And let us do it a second time for idempotency + actualProd = Product(tcase.coeffs, tcase.svecs) + require.Equal(t, tcase.expectedValue.Pretty(), actualProd.Pretty(), "product failed") + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } + +} + +func TestFuzzLinComb(t *testing.T) { + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForLinComb() + + success := t.Run(tcase.name, func(t *testing.T) { + + actualLinComb := LinComb(tcase.coeffs, tcase.svecs) + require.Equal(t, tcase.expectedValue.Pretty(), actualLinComb.Pretty(), "linear combination failed") + + // And a second time for idempotency + actualLinComb = LinComb(tcase.coeffs, tcase.svecs) + require.Equal(t, tcase.expectedValue.Pretty(), actualLinComb.Pretty(), "linear combination failed") + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } +} + +func TestFuzzPolyEval(t *testing.T) { + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForPolyEval() + + success := t.Run(tcase.name, func(t *testing.T) { + + actualRes := PolyEval(tcase.svecs, tcase.evaluationPoint) + require.Equal(t, tcase.expectedValue.Pretty(), actualRes.Pretty(), "linear combination failed") + + // and a second time to ensure idempotency + actualRes = PolyEval(tcase.svecs, tcase.evaluationPoint) + require.Equal(t, tcase.expectedValue.Pretty(), actualRes.Pretty(), "linear combination failed") + + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } +} + +func TestFuzzProductWithPool(t *testing.T) { + + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForProd() + + success := t.Run(tcase.name, func(t *testing.T) { + + pool := mempoolext.CreateFromSyncPool(tcase.svecs[0].Len()) + + t.Logf("TEST CASE %v\n", tcase.String()) + + prodWithPool := Product(tcase.coeffs, tcase.svecs, pool) + require.Equal(t, tcase.expectedValue.Pretty(), prodWithPool.Pretty(), "product with pool failed") + + // And let us do it a second time for idempotency + prodWithPool = Product(tcase.coeffs, tcase.svecs, pool) + require.Equal(t, tcase.expectedValue.Pretty(), prodWithPool.Pretty(), "product with pool failed") + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } + +} + +func TestFuzzProductWithPoolCompare(t *testing.T) { + + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForProd() + + success := t.Run(tcase.name, func(t *testing.T) { + + pool := mempoolext.CreateFromSyncPool(tcase.svecs[0].Len()) + + t.Logf("TEST CASE %v\n", tcase.String()) + + // Product() with pool + prodWithPool := Product(tcase.coeffs, tcase.svecs, pool) + require.Equal(t, tcase.expectedValue.Pretty(), prodWithPool.Pretty(), "Product() with pool failed") + + // Product() without pool + prod := Product(tcase.coeffs, tcase.svecs) + + // check if Product() with pool = Product() without pool + require.Equal(t, prodWithPool.Pretty(), prod.Pretty(), "Product() w/ and w/o pool are different") + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } + +} + +func TestFuzzLinCombWithPool(t *testing.T) { + + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForLinComb() + + success := t.Run(tcase.name, func(t *testing.T) { + + pool := mempoolext.CreateFromSyncPool(tcase.svecs[0].Len()) + + t.Logf("TEST CASE %v\n", tcase.String()) + + linCombWithPool := LinComb(tcase.coeffs, tcase.svecs, pool) + require.Equal(t, tcase.expectedValue.Pretty(), linCombWithPool.Pretty(), "LinComb() with pool failed") + + // And let us do it a second time for idempotency + linCombWithPool = LinComb(tcase.coeffs, tcase.svecs, pool) + require.Equal(t, tcase.expectedValue.Pretty(), linCombWithPool.Pretty(), "LinComb() with pool failed") + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } +} + +func TestFuzzLinCombWithPoolCompare(t *testing.T) { + + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForLinComb() + + success := t.Run(tcase.name, func(t *testing.T) { + + pool := mempoolext.CreateFromSyncPool(tcase.svecs[0].Len()) + + t.Logf("TEST CASE %v\n", tcase.String()) + + // LinComb() with pool + linCombWithPool := LinComb(tcase.coeffs, tcase.svecs, pool) + require.Equal(t, tcase.expectedValue.Pretty(), linCombWithPool.Pretty(), "LinComb() with pool failed") + + // LinComb() without pool + linComb := LinComb(tcase.coeffs, tcase.svecs) + + // check if LinComb() with pool = LinComb() without pool + require.Equal(t, linCombWithPool.Pretty(), linComb.Pretty(), "LinComb() w/ and w/o pool are different") + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } +} + +func TestOpBasicEdgeCases(t *testing.T) { + + two := fext.NewElement(2, fieldPaddingInt()) + + testCases := []struct { + explainer string + inputs []smartvectors.SmartVector + expectedRes smartvectors.SmartVector + fn func(...smartvectors.SmartVector) smartvectors.SmartVector + }{ + { + explainer: "full-covering windows and a constant", + inputs: []smartvectors.SmartVector{ + NewConstantExt(two, 16), + LeftPadded(vectorext.Repeat(two, 12), two, 16), + RightPadded(vectorext.Repeat(two, 12), two, 16), + }, + expectedRes: NewRegularExt(vectorext.Repeat(fext.NewElement(6, fieldPaddingInt()), 16)), + fn: Add, + }, + { + explainer: "full-covering windows and a constant (mul)", + inputs: []smartvectors.SmartVector{ + NewConstantExt(two, 16), + LeftPadded(vectorext.Repeat(two, 12), two, 16), + RightPadded(vectorext.Repeat(two, 12), two, 16), + }, + expectedRes: NewRegularExt(vectorext.Repeat(fext.NewElement(8, fieldPaddingInt()), 16)), + fn: Mul, + }, + { + explainer: "full-covering windows, a regular and a constant", + inputs: []smartvectors.SmartVector{ + NewConstantExt(two, 16), + LeftPadded(vectorext.Repeat(two, 12), two, 16), + RightPadded(vectorext.Repeat(two, 12), two, 16), + NewRegularExt(vectorext.Repeat(two, 16)), + }, + expectedRes: NewRegularExt(vectorext.Repeat(fext.NewElement(8, fieldPaddingInt()), 16)), + fn: Add, + }, + } + + for i, testCase := range testCases { + t.Run(fmt.Sprintf("case-%v", i), func(t *testing.T) { + t.Logf("test-case details: %v", testCase.explainer) + res := testCase.fn(testCase.inputs...).(*PooledExt) + actual := NewRegularExt(res.RegularExt) + require.Equal(t, testCase.expectedRes, actual, "expectedRes=%v\nres=%v", testCase.expectedRes.Pretty(), res.Pretty()) + }) + } +} + +func TestInnerProduct(t *testing.T) { + testCases := []struct { + a, b smartvectors.SmartVector + y fext.Element + }{ + { + a: ForTestExt(1, 2, 1, 2, 1), + b: ForTestExt(1, -1, 2, -1, 2), + y: fext.NewElement(1, fieldPaddingInt()), + }, + } + + for i, testCase := range testCases { + t.Run(fmt.Sprintf("case-%v", i), func(t *testing.T) { + y := InnerProduct(testCase.a, testCase.b) + assert.Equal(t, testCase.y.String(), y.String()) + }) + } +} + +func TestScalarMul(t *testing.T) { + testCases := []struct { + a smartvectors.SmartVector + b fext.Element + y smartvectors.SmartVector + }{ + { + a: ForTestExt(1, 2, 1, 2, 1), + b: fext.NewElement(3, fieldPaddingInt()), + y: ForTestExt(3, 6, 3, 6, 3), + }, + } + + for i, testCase := range testCases { + t.Run(fmt.Sprintf("case-%v", i), func(t *testing.T) { + y := ScalarMul(testCase.a, testCase.b) + assert.Equal(t, testCase.y.Pretty(), y.Pretty()) + }) + } +} + +func TestFuzzPolyEvalWithPool(t *testing.T) { + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForPolyEval() + + success := t.Run(tcase.name, func(t *testing.T) { + + pool := mempoolext.CreateFromSyncPool(tcase.svecs[0].Len()) + + // PolyEval() with pool + polyEvalWithPool := PolyEval(tcase.svecs, tcase.evaluationPoint, pool) + require.Equal(t, tcase.expectedValue.Pretty(), polyEvalWithPool.Pretty(), "linear combination with pool failed") + + // and a second time to ensure idempotency + polyEvalWithPool = PolyEval(tcase.svecs, tcase.evaluationPoint, pool) + require.Equal(t, tcase.expectedValue.Pretty(), polyEvalWithPool.Pretty(), "linear combination with pool failed") + + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } +} + +func TestFuzzPolyEvalWithPoolCompare(t *testing.T) { + for i := 0; i < fuzzIteration; i++ { + tcase := newTestBuilder(i).NewTestCaseForPolyEval() + + success := t.Run(tcase.name, func(t *testing.T) { + + pool := mempoolext.CreateFromSyncPool(tcase.svecs[0].Len()) + + // PolyEval() with pool + polyEvalWithPool := PolyEval(tcase.svecs, tcase.evaluationPoint, pool) + require.Equal(t, tcase.expectedValue.Pretty(), polyEvalWithPool.Pretty(), "PolyEval() with pool failed") + + // PolyEval() without pool + polyEval := PolyEval(tcase.svecs, tcase.evaluationPoint) + + // check if PolyEval() with pool = PolyEval() without pool + require.Equal(t, polyEvalWithPool.Pretty(), polyEval.Pretty(), "PolyEval() w/ and w/o pool are different") + + }) + + if !success { + t.Logf("TEST CASE %v\n", tcase.String()) + t.FailNow() + } + } +} diff --git a/prover/maths/common/smartvectorsext/arithmetic_gen.go b/prover/maths/common/smartvectorsext/arithmetic_gen.go new file mode 100644 index 000000000..86bd10cee --- /dev/null +++ b/prover/maths/common/smartvectorsext/arithmetic_gen.go @@ -0,0 +1,175 @@ +package smartvectorsext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/common/mempoolext" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// LinComb computes a linear combination of the given vectors with integer coefficients. +// - The function panics if provided SmartVector of different lengths +// - The function panics if svecs is empty +// - The function panics if the length of coeffs does not match the length of +// svecs +func LinComb(coeffs []int, svecs []smartvectors.SmartVector, p ...mempoolext.MemPool) smartvectors.SmartVector { + // Sanity check : all svec should have the same length + length := svecs[0].Len() + for i := 0; i < len(svecs); i++ { + if svecs[i].Len() != length { + utils.Panic("bad size %v, expected %v", svecs[i].Len(), length) + } + } + return processOperator(linCombOp{}, coeffs, svecs, p...) +} + +// Product computes a product of smart-vectors with integer exponents +// - The function panics if provided SmartVector of different lengths +// - The function panics if svecs is empty +// - The function panics if the length of exponents does not match the length of +// svecs +func Product(exponents []int, svecs []smartvectors.SmartVector, p ...mempoolext.MemPool) smartvectors.SmartVector { + return processOperator(productOp{}, exponents, svecs, p...) +} + +// processOperator computes the result of an [operator] and put the result into res +// - The function panics if provided SmartVector of different lengths +// - The function panics if svecs is empty +// - The function panics if the length of coeffs does not match the length of +// svecs +func processOperator(op operator, coeffs []int, svecs []smartvectors.SmartVector, p ...mempoolext.MemPool) smartvectors.SmartVector { + + // There should be as many coeffs than there are vectors + if len(coeffs) != len(svecs) { + utils.Panic("there are %v coeffs and %v vectors", len(coeffs), len(svecs)) + } + + // Sanity-check to ensure there is at least one vector to lincombine + if len(svecs) == 0 { + utils.Panic("no vector to process") + } + + // Total number of vector passed as operands* + totalToMatch := len(svecs) + + // Sanity-check, they should all have the same length + for i := range svecs { + assertHasLength(svecs[0].Len(), svecs[i].Len()) + } + + // Sanity-check, length zero or negative should be forbidden + assertStrictPositiveLen(svecs[0].Len()) + + // Accumulate the constant + constRes, matchedConst := processConstOnly(op, svecs, coeffs) + + // Full-constant operation, return the constant vec + if matchedConst == totalToMatch { + return constRes + } + + // Special-case : if the operation is a product and the constRes is + // zero, we can early return zero ignoring the rest. + if _, ok := op.(productOp); ok && constRes != nil && constRes.val.IsZero() { + return constRes + } + + // Accumulate the windowed smart-vectors + windowRes, matchedWindow := processWindowedOnly(op, svecs, coeffs) + + // Edge-case : the list of smart-vectors to combine is windowed-only. In + // this case we can return directly. + if matchedWindow == totalToMatch { + return windowRes + } + + // If we had matches for both constants vectors and the windows, we merge + // the constant into the window. + if matchedWindow > 0 && matchedConst > 0 { + switch w := windowRes.(type) { + case *PaddedCircularWindowExt: + op.constTermIntoVec(w.window, &constRes.val) + op.constTermIntoConst(&w.paddingVal, &constRes.val) + case *RegularExt: + op.constTermIntoVec(*w, &constRes.val) + } + } + + // Edge-case : all vectors in the list are either window or constants + if matchedWindow+matchedConst == totalToMatch { + return windowRes + } + + // Accumulate the regular part of the vector + regularRes, matchedRegular := processRegularOnlyExt(op, svecs, coeffs, p...) + + // Sanity-check : all of the vector should fall into only one of the two + // category. + if matchedConst+matchedWindow+matchedRegular != totalToMatch { + utils.Panic("Mismatch between the number of matched vector and the total number of vectors (%v + %v + %v = %v)", matchedConst, matchedWindow, matchedRegular, totalToMatch) + } + + switch { + case matchedRegular == totalToMatch: + return regularRes + case matchedRegular+matchedConst == totalToMatch: + // In this case, there are no windowed in the list. This means we only + // need to merge the const one into the regular one before returning + op.constTermIntoVec(regularRes.RegularExt, &constRes.val) + return regularRes + default: + + // If windowRes is a regular (can happen if all windows arguments cover the full circle) + if w, ok := windowRes.(*RegularExt); ok { + op.vecTermIntoVec(regularRes.RegularExt, *w) + return regularRes + } + + // Overwrite window with its casting into an actual circular windows + windowRes := windowRes.(*PaddedCircularWindowExt) + + // In this case, the constant is already accumulated into the windowed. + // Thus, we just have to merge the windowed one into the regular one. + interval := windowRes.interval() + regvec := regularRes.RegularExt + length := len(regvec) + + // The windows rolls over + if interval.DoesWrapAround() { + op.vecTermIntoVec(regvec[:interval.Stop()], windowRes.window[length-interval.Start():]) + op.vecTermIntoVec(regvec[interval.Start():], windowRes.window[:length-interval.Start()]) + op.constTermIntoVec(regvec[interval.Stop():interval.Start()], &windowRes.paddingVal) + return regularRes + } + + // Else, no roll-over + op.vecTermIntoVec(regvec[interval.Start():interval.Stop()], windowRes.window) + op.constTermIntoVec(regvec[:interval.Start()], &windowRes.paddingVal) + op.constTermIntoVec(regvec[interval.Stop():], &windowRes.paddingVal) + return regularRes + } +} + +// Returns the result of the linear combination including only the constant. numMatches denotes +// the number of Constant smart-vectors found in the list of arguments. +func processConstOnly(op operator, svecs []smartvectors.SmartVector, coeffs []int) (constRes *ConstantExt, numMatches int) { + var constVal fext.Element + for i, svec := range svecs { + if cnst, ok := svec.(*ConstantExt); ok { + if numMatches < 1 { + // First one, no need to add it into constVal since constVal is zero + op.constIntoTerm(&constVal, &cnst.val, coeffs[i]) + numMatches++ + continue + } + op.constIntoConst(&constVal, &cnst.val, coeffs[i]) + numMatches++ + } + } + + if numMatches == 0 { + return nil, 0 + } + + return &ConstantExt{val: constVal, length: svecs[0].Len()}, numMatches +} diff --git a/prover/maths/common/smartvectorsext/arithmetic_op.go b/prover/maths/common/smartvectorsext/arithmetic_op.go new file mode 100644 index 000000000..964787fa3 --- /dev/null +++ b/prover/maths/common/smartvectorsext/arithmetic_op.go @@ -0,0 +1,304 @@ +package smartvectorsext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "math/big" +) + +// operator represents a mathematical operation that can be performed between +// scalars and integers. It is implemented by [linCombOp] and [productOp]. The +// operator interface allows applying the operator in all the combination of +// scalars or vectors operands, in immutable version or assigning version. +// +// In the terminology of this interface: +// - "const" means a scalar. Or equivalently, abstractly, a vector whose all +// coordinates have the same value. +// - "vec" means a slice of field element +// - "term" is a couple (const|vec, coeff) +// - "coeff" means either a linear combination coefficient or an exponent and +// is always assumed to be reasonnably small +// +// The reason to resort to this interface is because applying n-ary mathematical +// operator to smart-vector comes with a lot of inherent complexity. This is +// mitigated that we have a single function [processOperator] owning all the +// "smartvector" logic and all the logic pertaining to doing additions, +// multiplication etc.. is implemented by the [operator] interface. +type operator interface { + // constIntoConst applies the operator over `res` and `(c, coeff)` and sets + // the result into res. This is specialized for the case where both res and + // x are scalars. + // + // res += x * coeff or res *= x^coeff + constIntoConst(res, x *fext.Element, coeff int) + // vecIntoVec applies the operator over `res` and `(c, coeff)` and sets + // the result into res. This is specialized for the case where both res and + // x are vectors. + // + // res += x * coeff or res *= x^coeff + vecIntoVec(res, x []fext.Element, coeff int) + // VecIntoVec applies the operator over `res` and `(c, coeff)` and sets + // the result into res. This is specialized for the case where res is a + // vector and c is a constant. + // + // res += x * coeff or res *= x^coeff + constIntoVec(res []fext.Element, x *fext.Element, coeff int) + // constIntoTerm evaluates the operator over (x, coeff) and sets the result + // into `res`, overwriting it. + // It is specialized for the case where x and res are both scalars. + // + // res = x * coeff or res = x^coeff + constIntoTerm(res, x *fext.Element, coeff int) + // vecIntoTerm evaluates the operator over (x, coeff) and sets the result + // into `res`, overwriting it. + // It is specialized for the case where x and res are both vectors. + // + // res = x * coeff or res = x^coeff where x is a vector + vecIntoTerm(res, x []fext.Element, coeff int) + // constTermIntoConst updates applies the operator over res and term and + // sets the result into res. + // This function is specialized for the case where the term and res are + // scalar. + // + // res += term or res *= term for constants + constTermIntoConst(res, term *fext.Element) + // vecTermIntoVec updates applies the operator over res and term and + // sets the result into res. + // This function is specialized for the case where the term and res are + // vector. + // + // res += term or res *= term + vecTermIntoVec(res, term []fext.Element) + // constTermIntoVec updates a vector `res` by applying the operator over + // it + // + // res += term or res *= term + constTermIntoVec(res []fext.Element, term *fext.Element) +} + +// linCompOp is an implementation of the [operator] interface. It represents a +// linear combination with coefficients. +type linCombOp struct{} + +func (linCombOp) constIntoConst(res, x *fext.Element, coeff int) { + switch coeff { + case 1: + res.Add(res, x) + case -1: + res.Sub(res, x) + case 2: + res.Add(res, x).Add(res, x) + case -2: + res.Sub(res, x).Sub(res, x) + default: + var c fext.Element + c.SetInt64(int64(coeff)) + c.Mul(&c, x) + res.Add(res, &c) + } +} + +func (linCombOp) vecIntoVec(res, x []fext.Element, coeff int) { + // Sanity-check + assertHasLength(len(res), len(x)) + switch coeff { + case 1: + vectorext.Add(res, res, x) + case -1: + vectorext.Sub(res, res, x) + case 2: + for i := range res { + res[i].Add(&res[i], &x[i]).Add(&res[i], &x[i]) + } + case -2: + for i := range res { + res[i].Sub(&res[i], &x[i]).Sub(&res[i], &x[i]) + } + default: + var c, tmp fext.Element + c.SetInt64(int64(coeff)) + for i := range res { + tmp.Mul(&c, &x[i]) + res[i].Add(&res[i], &tmp) + } + } +} + +func (linCombOp) constIntoVec(res []fext.Element, val *fext.Element, coeff int) { + var term fext.Element + linCombOp.constIntoTerm(linCombOp{}, &term, val, coeff) + linCombOp.constTermIntoVec(linCombOp{}, res, &term) +} + +func (linCombOp) vecIntoTerm(term, x []fext.Element, coeff int) { + switch coeff { + case 1: + copy(term, x) + case -1: + for i := range term { + term[i].Neg(&x[i]) + } + case 2: + vectorext.Add(term, x, x) + case -2: + for i := range term { + term[i].Add(&x[i], &x[i]).Neg(&term[i]) + } + default: + var c fext.Element + c.SetInt64(int64(coeff)) + for i := range term { + term[i].Mul(&c, &x[i]) + } + } +} + +func (linCombOp) constIntoTerm(term, x *fext.Element, coeff int) { + switch coeff { + case 1: + term.Set(x) + case -1: + term.Neg(x) + case 2: + term.Add(x, x) + case -2: + term.Add(x, x).Neg(term) + default: + var c fext.Element + c.SetInt64(int64(coeff)) + term.Mul(&c, x) + } +} + +func (linCombOp) constTermIntoConst(res, term *fext.Element) { + res.Add(res, term) +} + +func (linCombOp) vecTermIntoVec(res, term []fext.Element) { + vectorext.Add(res, res, term) +} + +func (linCombOp) constTermIntoVec(res []fext.Element, term *fext.Element) { + for i := range res { + res[i].Add(&res[i], term) + } +} + +type productOp struct{} + +// res *= x ^coeff where both res and x are constants +func (productOp) constIntoConst(res, x *fext.Element, coeff int) { + switch coeff { + case 0: + // Nothing to do + case 1: + res.Mul(res, x) + case 2: + res.Mul(res, x).Mul(res, x) + case 3: + var tmp fext.Element + tmp.Square(x) + tmp.Mul(&tmp, x) + res.Mul(res, &tmp) + default: + var tmp fext.Element + tmp.Exp(*x, big.NewInt(int64(coeff))) + res.Mul(res, &tmp) + } +} + +// res *= x ^coeff where both res and x are vectors +func (productOp) vecIntoVec(res, x []fext.Element, coeff int) { + + // Sanity-check + assertHasLength(len(res), len(x)) + + switch coeff { + case 0: + // Nothing to do + case 1: + vectorext.MulElementWise(res, res, x) + case 2: + for i := range res { + res[i].Mul(&res[i], &x[i]).Mul(&res[i], &x[i]) + } + case 3: + for i := range res { + var tmp fext.Element + tmp.Square(&x[i]) + tmp.Mul(&tmp, &x[i]) + res[i].Mul(&res[i], &tmp) + } + default: + var tmp fext.Element + for i := range res { + fext.ExpToInt(&tmp, x[i], coeff) + res[i].Mul(&res[i], &tmp) + } + } +} + +// res *= x ^coeff where res is a vector and x is a constant +func (productOp) constIntoVec(res []fext.Element, x *fext.Element, coeff int) { + var term fext.Element + productOp.constIntoTerm(productOp{}, &term, x, coeff) + productOp.constTermIntoVec(productOp{}, res, &term) +} + +// res = x ^ coeff where x is a constant +func (productOp) constIntoTerm(res, x *fext.Element, coeff int) { + switch coeff { + case 0: + res.SetOne() + case 1: + res.Set(x) + case 2: + res.Square(x) + case 3: + var tmp fext.Element + tmp.Square(x) + res.Mul(&tmp, x) + default: + res.Exp(*x, big.NewInt(int64(coeff))) + } +} + +// res = x * coeff or res = x ^ coeff where x is a vector +func (productOp) vecIntoTerm(res, x []fext.Element, coeff int) { + switch coeff { + case 0: + vectorext.Fill(res, fext.One()) + case 1: + copy(res, x) + case 2: + vectorext.MulElementWise(res, x, x) + case 3: + for i := range res { + // Creating a new variable for the case where res and x are the same variable + var tmp fext.Element + tmp.Square(&x[i]) + res[i].Mul(&tmp, &x[i]) + } + default: + c := big.NewInt(int64(coeff)) + for i := range res { + res[i].Exp(x[i], c) + } + } +} + +// res += term or res *= coeff for constants +func (productOp) constTermIntoConst(res, term *fext.Element) { + res.Mul(res, term) +} + +// res += term for vectors +func (productOp) vecTermIntoVec(res, term []fext.Element) { + vectorext.MulElementWise(res, res, term) + +} + +// res += term where res is a vector and term is a constant +func (productOp) constTermIntoVec(res []fext.Element, term *fext.Element) { + vectorext.ScalarMul(res, res, *term) +} diff --git a/prover/maths/common/smartvectorsext/constant_ext.go b/prover/maths/common/smartvectorsext/constant_ext.go new file mode 100644 index 000000000..502cfa4b1 --- /dev/null +++ b/prover/maths/common/smartvectorsext/constant_ext.go @@ -0,0 +1,97 @@ +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// A constant vector is a vector obtained by repeated "length" time the same value +type ConstantExt struct { + val fext.Element + length int +} + +// Construct a new "Constant" smart-vector +func NewConstantExt(val fext.Element, length int) *ConstantExt { + if length <= 0 { + utils.Panic("zero or negative length are not allowed") + } + return &ConstantExt{val: val, length: length} +} + +// Return the length of the smart-vector +func (c *ConstantExt) Len() int { return c.length } + +// Returns an entry of the constant +func (c *ConstantExt) GetBase(int) (field.Element, error) { + return field.Zero(), fmt.Errorf(conversionError) +} + +func (c *ConstantExt) GetExt(int) fext.Element { return c.val } + +func (r *ConstantExt) Get(n int) field.Element { + res, err := r.GetBase(n) + if err != nil { + panic(err) + } + return res +} + +// Returns a subvector +func (c *ConstantExt) SubVector(start, stop int) smartvectors.SmartVector { + if start > stop { + utils.Panic("negative length are not allowed") + } + if start == stop { + utils.Panic("zero length are not allowed") + } + assertCorrectBound(start, c.length) + // The +1 is because we accept if "stop = length" + assertCorrectBound(stop, c.length+1) + return NewConstantExt(c.val, stop-start) +} + +// Returns a rotated version of the slice +func (c *ConstantExt) RotateRight(int) smartvectors.SmartVector { + return NewConstantExt(c.val, c.length) +} + +// Write the constant vector in a slice +func (c *ConstantExt) WriteInSlice(s []field.Element) { + panic(conversionError) +} + +func (c *ConstantExt) WriteInSliceExt(s []fext.Element) { + for i := 0; i < len(s); i++ { + s[i].Set(&c.val) + } +} + +func (c *ConstantExt) Val() fext.Element { + return c.val +} + +func (c *ConstantExt) Pretty() string { + return fmt.Sprintf("Constant[%v;%v]", c.val.String(), c.length) +} + +func (c *ConstantExt) DeepCopy() smartvectors.SmartVector { + return NewConstantExt(c.val, c.length) +} + +func (c *ConstantExt) IntoRegVecSaveAlloc() []field.Element { + panic(conversionError) +} + +func (c *ConstantExt) IntoRegVecSaveAllocBase() ([]field.Element, error) { + return nil, fmt.Errorf(conversionError) +} + +func (c *ConstantExt) IntoRegVecSaveAllocExt() []fext.Element { + res := smartvectors.IntoRegVecExt(c) + return res +} diff --git a/prover/maths/common/smartvectorsext/fuzzing_ext.go b/prover/maths/common/smartvectorsext/fuzzing_ext.go new file mode 100644 index 000000000..4a5cd0bfb --- /dev/null +++ b/prover/maths/common/smartvectorsext/fuzzing_ext.go @@ -0,0 +1,335 @@ +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/common/polyext" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "math/big" + "math/rand" + + "github.com/consensys/linea-monorepo/prover/utils" +) + +type smartVecType int + +// The order matters : combining type x with type y implies that the result +// will be of type max(x, y) +const ( + constantT smartVecType = iota + windowT + RegularExtT + RotatedExtT +) + +var smartVecTypeList = []smartVecType{constantT, windowT, RegularExtT, RotatedExtT} + +type testCase struct { + name string + svecs []smartvectors.SmartVector + coeffs []int + expectedValue smartvectors.SmartVector + evaluationPoint fext.Element // Only used for polynomial evaluation +} + +func (tc testCase) String() string { + res := "Testcase:\n" + res += "\tSVECS:\n" + for i := range tc.svecs { + res += fmt.Sprintf("\t\t %v : %v\n", i, tc.svecs[i].Pretty()) + } + res += fmt.Sprintf("\tCOEFFs: %v\n", tc.coeffs) + res += fmt.Sprintf("\tEXPECTED_VALUE: %v\n", tc.expectedValue.Pretty()) + return res +} + +type testCaseGen struct { + // Randomness parameters + seed int + gen *rand.Rand + // Length and number of target vectors + fullLen, numVec int + // Parameters relevant for creating windows. This enforces the windows + // to be included in a certain (which can possible roll over fullLen) + windowWithLen int + windowMustStartAfter int + // Allowed smart-vector types for this testcase + allowedTypes []smartVecType +} + +func newTestBuilder(seed int) *testCaseGen { + // Use a deterministic randomness source + res := &testCaseGen{seed: seed} + // #nosec G404 --we don't need a cryptographic RNG for fuzzing purpose + res.gen = rand.New(rand.NewSource(int64(seed))) + + // We should have some quarantee that the length is not too small + // for the test generation + res.fullLen = 1 << (res.gen.Intn(5) + 3) + res.numVec = res.gen.Intn(8) + 1 + + // In the test, we may restrict the inputs vectors to have a certain type + allowedTypes := append([]smartVecType{}, smartVecTypeList...) + res.gen.Shuffle(len(allowedTypes), func(i, j int) { + allowedTypes[i], allowedTypes[j] = allowedTypes[j], allowedTypes[i] + }) + res.allowedTypes = allowedTypes[:res.gen.Intn(len(allowedTypes)-1)+1] + + // Generating the window : it should be roughly half of the total length + // this aims at maximizing the coverage. + res.windowWithLen = res.gen.Intn(res.fullLen-4)/2 + 2 + res.windowMustStartAfter = res.gen.Intn(res.fullLen) + return res +} + +func (gen *testCaseGen) NewTestCaseForProd() (tcase testCase) { + + tcase.name = fmt.Sprintf("fuzzy-with-seed-%v-prod", gen.seed) + tcase.svecs = make([]smartvectors.SmartVector, gen.numVec) + tcase.coeffs = make([]int, gen.numVec) + + // resVal will contain the value of the repeated in the expected result + // we will compute its value as we instantiate test vectors. + resVal := fext.One() + maxType := constantT + + // For the windows, we need to track the dimension of the windows + winMinStart := gen.fullLen + winMaxStop := 0 + + // Has constant vec keeps track of the case where we incluse a constant + // vector equal to zero in the testcases + hasConstZero := false + + for i := 0; i < gen.numVec; i++ { + // Generate one by one the different vectors + val := gen.genValue() + tcase.coeffs[i] = gen.gen.Intn(5) + chosenType := gen.allowedTypes[gen.gen.Intn(len(gen.allowedTypes))] + maxType = utils.Max(maxType, chosenType) + + // Update the expected res value + var tmp fext.Element + tmp.Exp(val, big.NewInt(int64(tcase.coeffs[i]))) + resVal.Mul(&resVal, &tmp) + + switch chosenType { + case constantT: + // Our implementation uses the convention that 0^0 == 0 + // Even though, this case is avoided by the calling code. + if val.IsZero() && tcase.coeffs[i] != 0 { + hasConstZero = true + } + tcase.svecs[i] = NewConstantExt(val, gen.fullLen) + case windowT: + v := gen.genWindow(val, val) + tcase.svecs[i] = v + start := normalize(v.interval().Start(), gen.windowMustStartAfter, gen.fullLen) + winMinStart = utils.Min(winMinStart, start) + stop := normalize(v.interval().Stop(), gen.windowMustStartAfter, gen.fullLen) + if stop < start { + stop += gen.fullLen + } + winMaxStop = utils.Max(winMaxStop, stop) + case RegularExtT: + tcase.svecs[i] = gen.genRegularExt(val) + case RotatedExtT: + tcase.svecs[i] = gen.genRotatedExt(val) + } + } + + // If there are no windows, then the initial condition that we use + // do pass this sanity-check + if winMaxStop-winMinStart > gen.windowWithLen { + utils.Panic("inconsistent window dimension %v %v with gen %++v", winMinStart, winMaxStop, gen) + } + + // This switch statement resolves the type of smart-vector that we are + // expected for the result. It crucially relies on the number associated + // to the variants of the smartVecTypes enum. + switch { + case hasConstZero: + tcase.expectedValue = NewConstantExt(fext.Zero(), gen.fullLen) + case maxType == constantT: + tcase.expectedValue = NewConstantExt(resVal, gen.fullLen) + case maxType == windowT: + tcase.expectedValue = NewPaddedCircularWindowExt( + vectorext.Repeat(resVal, winMaxStop-winMinStart), + resVal, + normalize(winMinStart, -gen.windowMustStartAfter, gen.fullLen), + gen.fullLen, + ) + case maxType == RegularExtT || maxType == RotatedExtT: + tcase.expectedValue = NewRegularExt(vectorext.Repeat(resVal, gen.fullLen)) + } + + return tcase +} + +func (gen *testCaseGen) NewTestCaseForLinComb() (tcase testCase) { + + tcase.name = fmt.Sprintf("fuzzy-with-seed-%v-lincomb", gen.seed) + tcase.svecs = make([]smartvectors.SmartVector, gen.numVec) + tcase.coeffs = make([]int, gen.numVec) + + // resVal will contain the value of the repeated in the expected result + // we will compute its value as we instantiate test vectors. + resVal := fext.Zero() + maxType := constantT + + // For the windows, we need to track the dimension of the windows + winMinStart := gen.fullLen + winMaxStop := 0 + + for i := 0; i < gen.numVec; i++ { + // Generate one by one the different vectors + val := gen.genValue() + tcase.coeffs[i] = gen.gen.Intn(10) - 5 + chosenType := gen.allowedTypes[gen.gen.Intn(len(gen.allowedTypes))] + maxType = utils.Max(maxType, chosenType) + + // Update the expected res value + var tmp, coeffField fext.Element + coeffField.SetInt64(int64(tcase.coeffs[i])) + tmp.Mul(&val, &coeffField) + resVal.Add(&resVal, &tmp) + + switch chosenType { + case constantT: + tcase.svecs[i] = NewConstantExt(val, gen.fullLen) + case windowT: + v := gen.genWindow(val, val) + tcase.svecs[i] = v + start := normalize(v.interval().Start(), gen.windowMustStartAfter, gen.fullLen) + winMinStart = utils.Min(winMinStart, start) + + stop := normalize(v.interval().Stop(), gen.windowMustStartAfter, gen.fullLen) + if stop < start { + stop += gen.fullLen + } + winMaxStop = utils.Max(winMaxStop, stop) + case RegularExtT: + tcase.svecs[i] = gen.genRegularExt(val) + case RotatedExtT: + tcase.svecs[i] = gen.genRotatedExt(val) + } + } + + // If there are no windows, then the initial condition that we use + // do pass this sanity-check + if winMaxStop-winMinStart > gen.windowWithLen { + utils.Panic("inconsistent window dimension %v %v with gen %++v", winMinStart, winMaxStop, gen) + } + + switch { + case maxType == constantT: + tcase.expectedValue = NewConstantExt(resVal, gen.fullLen) + case maxType == windowT: + tcase.expectedValue = NewPaddedCircularWindowExt( + vectorext.Repeat(resVal, winMaxStop-winMinStart), + resVal, + normalize(winMinStart, -gen.windowMustStartAfter, gen.fullLen), + gen.fullLen, + ) + case maxType == RegularExtT || maxType == RotatedExtT: + tcase.expectedValue = NewRegularExt(vectorext.Repeat(resVal, gen.fullLen)) + } + + return tcase +} + +func (gen *testCaseGen) NewTestCaseForPolyEval() (tcase testCase) { + + tcase.name = fmt.Sprintf("fuzzy-with-seed-%v-poly-eval", gen.seed) + tcase.svecs = make([]smartvectors.SmartVector, gen.numVec) + tcase.coeffs = make([]int, gen.numVec) + tcase.evaluationPoint.SetRandom() + x := tcase.evaluationPoint + vals := []fext.Element{} + + // MaxType is used to determine what type should the result be + maxType := constantT + + // For the windows, we need to track the dimension of the windows + winMinStart := gen.fullLen + winMaxStop := 0 + + for i := 0; i < gen.numVec; i++ { + // Generate one by one the different vectors + val := gen.genValue() + vals = append(vals, val) + tcase.coeffs[i] = gen.gen.Intn(10) - 5 + chosenType := gen.allowedTypes[gen.gen.Intn(len(gen.allowedTypes))] + maxType = utils.Max(maxType, chosenType) + + switch chosenType { + case constantT: + tcase.svecs[i] = NewConstantExt(val, gen.fullLen) + case windowT: + v := gen.genWindow(val, val) + tcase.svecs[i] = v + start := normalize(v.interval().Start(), gen.windowMustStartAfter, gen.fullLen) + winMinStart = utils.Min(winMinStart, start) + + stop := normalize(v.interval().Stop(), gen.windowMustStartAfter, gen.fullLen) + if stop < start { + stop += gen.fullLen + } + winMaxStop = utils.Max(winMaxStop, stop) + case RegularExtT: + tcase.svecs[i] = gen.genRegularExt(val) + case RotatedExtT: + tcase.svecs[i] = gen.genRotatedExt(val) + } + } + + // If there are no windows, then the initial condition that we use + // do pass this sanity-check + if winMaxStop-winMinStart > gen.windowWithLen { + utils.Panic("inconsistent window dimension %v %v with gen %++v", winMinStart, winMaxStop, gen) + } + resVal := polyext.EvalUnivariate(vals, x) + + switch { + case maxType == constantT: + tcase.expectedValue = NewConstantExt(resVal, gen.fullLen) + case maxType == RegularExtT || maxType == windowT || maxType == RotatedExtT: + tcase.expectedValue = NewRegularExt(vectorext.Repeat(resVal, gen.fullLen)) + } + + return tcase +} + +func (gen *testCaseGen) genValue() fext.Element { + // May increase the ceil of the generator to increase the probability to pick + // an actually random value. + switch gen.gen.Intn(4) { + case 0: + return fext.Zero() + case 1: + return fext.One() + default: + return fext.NewElement(uint64(gen.gen.Uint64()), fieldPaddingInt()) + } + +} + +func (gen *testCaseGen) genWindow(val, paddingVal fext.Element) *PaddedCircularWindowExt { + start := gen.windowMustStartAfter + gen.gen.Intn(gen.windowWithLen)/2 + maxStop := gen.windowWithLen + gen.windowMustStartAfter + winLen := gen.gen.Intn(maxStop - start) + if winLen == 0 { + winLen = 1 + } + return NewPaddedCircularWindowExt(vectorext.Repeat(val, winLen), paddingVal, start, gen.fullLen) +} + +func (gen *testCaseGen) genRegularExt(val fext.Element) *RegularExt { + return NewRegularExt(vectorext.Repeat(val, gen.fullLen)) +} + +func (gen *testCaseGen) genRotatedExt(val fext.Element) *RotatedExt { + offset := gen.gen.Intn(gen.fullLen) + return NewRotatedExt(*gen.genRegularExt(val), offset) +} diff --git a/prover/maths/common/smartvectorsext/fuzzing_heavy.go b/prover/maths/common/smartvectorsext/fuzzing_heavy.go new file mode 100644 index 000000000..d20307a2f --- /dev/null +++ b/prover/maths/common/smartvectorsext/fuzzing_heavy.go @@ -0,0 +1,5 @@ +//go:build !fuzzlight + +package smartvectorsext + +const fuzzIteration int = 20000 diff --git a/prover/maths/common/smartvectorsext/regular_ext.go b/prover/maths/common/smartvectorsext/regular_ext.go new file mode 100644 index 000000000..d11fc98a9 --- /dev/null +++ b/prover/maths/common/smartvectorsext/regular_ext.go @@ -0,0 +1,199 @@ +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + + "github.com/consensys/linea-monorepo/prover/maths/common/mempoolext" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/utils" +) + +const conversionError = "smartvector holds field extensions, but a base element was requested" + +// It's normal vector in a nutshell +type RegularExt []fext.Element + +// Instanstiate a new regular from a slice. Returns a pointer so that the result +// can be reused without referencing as a SmartVector. +func NewRegularExt(v []fext.Element) *RegularExt { + assertStrictPositiveLen(len(v)) + res := RegularExt(v) + return &res +} + +// Returns the length of the regular vector +func (r *RegularExt) Len() int { return len(*r) } + +// Returns a particular element of the vector +func (r *RegularExt) GetBase(n int) (field.Element, error) { + return field.Zero(), fmt.Errorf(conversionError) +} + +func (r *RegularExt) GetExt(n int) fext.Element { + return (*r)[n] +} + +func (r *RegularExt) Get(n int) field.Element { + res, err := r.GetBase(n) + if err != nil { + panic(err) + } + return res +} + +// Returns a subvector of the regular +func (r *RegularExt) SubVector(start, stop int) smartvectors.SmartVector { + if start > stop { + utils.Panic("Negative length are not allowed") + } + if start == stop { + utils.Panic("Subvector of zero lengths are not allowed") + } + res := RegularExt((*r)[start:stop]) + return &res +} + +// Rotates the vector into a new one +func (r *RegularExt) RotateRight(offset int) smartvectors.SmartVector { + resSlice := make(RegularExt, r.Len()) + + if offset == 0 { + copy(resSlice, *r) + return &resSlice + } + + if offset > 0 { + // v and w may be the same vector thus we should use a + // separate leftover buffer for temporary memory buffers. + cutAt := len(*r) - offset + leftovers := vectorext.DeepCopy((*r)[cutAt:]) + copy(resSlice[offset:], (*r)[:cutAt]) + copy(resSlice[:offset], leftovers) + return &resSlice + } + + if offset < 0 { + glueAt := len(*r) + offset + leftovers := vectorext.DeepCopy((*r)[:-offset]) + copy(resSlice[:glueAt], (*r)[-offset:]) + copy(resSlice[glueAt:], leftovers) + return &resSlice + } + + panic("unreachable") +} + +func (r *RegularExt) WriteInSlice(s []field.Element) { + panic("conversionError") +} + +func (r *RegularExt) WriteInSliceExt(s []fext.Element) { + assertHasLength(len(s), len(*r)) + for i := 0; i < len(s); i++ { + elem, _ := r.GetBase(i) + s[i].SetFromBase(&elem) + } +} + +func (r *RegularExt) Pretty() string { + return fmt.Sprintf("Regular[%v]", vectorext.Prettify(*r)) +} + +func processRegularOnlyExt(op operator, svecs []smartvectors.SmartVector, coeffs []int, p ...mempoolext.MemPool) (result *PooledExt, numMatches int) { + + length := svecs[0].Len() + + pool, hasPool := mempoolext.ExtractCheckOptionalStrict(length, p...) + + var resvec *PooledExt + + isFirst := true + numMatches = 0 + + for i := range svecs { + + svec := svecs[i] + // In case the current vec is Rotated, we reduce it to a regular form + // NB : this could use the pool. + if rot, ok := svec.(*RotatedExt); ok { + svec = rotatedAsRegular(rot) + } + + if pooled, ok := svec.(*smartvectors.Pooled); ok { + svec = &pooled.Regular + } + + if reg, ok := svec.(*RegularExt); ok { + numMatches++ + // For the first one, we can save by just copying the result + // Importantly, we do not need to assume that regRes is originally + // zero. + if isFirst { + if hasPool { + resvec = AllocFromPoolExt(pool) + } else { + resvec = &PooledExt{RegularExt: make([]fext.Element, length)} + } + + isFirst = false + op.vecIntoTerm(resvec.RegularExt, *reg, coeffs[i]) + continue + } + + op.vecIntoVec(resvec.RegularExt, *reg, coeffs[i]) + } + } + + if numMatches == 0 { + return nil, 0 + } + + return resvec, numMatches +} + +func (r *RegularExt) DeepCopy() smartvectors.SmartVector { + return NewRegularExt(vectorext.DeepCopy(*r)) +} + +// Converts a smart-vector into a normal vec. The implementation minimizes +// then number of copies. +func (r *RegularExt) IntoRegVecSaveAlloc() []field.Element { + panic(conversionError) +} + +func (r *RegularExt) IntoRegVecSaveAllocBase() ([]field.Element, error) { + return nil, fmt.Errorf(conversionError) +} + +func (r *RegularExt) IntoRegVecSaveAllocExt() []fext.Element { + temp := make([]fext.Element, r.Len()) + for i := 0; i < r.Len(); i++ { + elem, _ := r.GetBase(i) + temp[i].SetFromBase(&elem) + } + return temp +} + +type PooledExt struct { + RegularExt + poolPtr *[]fext.Element +} + +func AllocFromPoolExt(pool mempoolext.MemPool) *PooledExt { + poolPtr := pool.Alloc() + return &PooledExt{ + RegularExt: *poolPtr, + poolPtr: poolPtr, + } +} + +func (p *PooledExt) Free(pool mempoolext.MemPool) { + if p.poolPtr != nil { + pool.Free(p.poolPtr) + } + p.poolPtr = nil + p.RegularExt = nil +} diff --git a/prover/maths/common/smartvectorsext/rotated_ext.go b/prover/maths/common/smartvectorsext/rotated_ext.go new file mode 100644 index 000000000..22c8e3e41 --- /dev/null +++ b/prover/maths/common/smartvectorsext/rotated_ext.go @@ -0,0 +1,210 @@ +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// Rotated represents a rotated version of a regular smartvector and also +// implements the [SmartVector] interface. Rotated have a very niche use-case +// in the repository as they are used to help saving FFT operations in the +// [github.com/consensys/linea-monorepo/prover/protocol/compiler/arithmetic.CompileGlobal] +// compiler when the coset evaluation is done over a cyclic rotation of a +// smart-vector. +// +// Rotated works by abstractly storing the offset and only applying the rotation +// when the vector is written or sub-vectored. This makes rotations essentially +// free. +type RotatedExt struct { + v *PooledExt + offset int +} + +// NewRotated constructs a new Rotated, positive offset means a cyclic left shift. +func NewRotatedExt(reg RegularExt, offset int) *RotatedExt { + + // empty vector + if len(reg) == 0 { + utils.Panic("got an empty vector") + } + + // negative offset is not allowed + if offset < 0 { + if -offset > len(reg) { + utils.Panic("len %v is less than, offset %v", len(reg), offset) + } + } + + // offset larger than the vector itself + if offset > len(reg) { + utils.Panic("len %v, is less than, offset %v", len(reg), offset) + } + + return &RotatedExt{ + v: &PooledExt{RegularExt: reg}, offset: offset, + } +} + +// Returns the lenght of the vector +func (r *RotatedExt) Len() int { + return r.v.Len() +} + +// Returns a particular element of the vector +func (r *RotatedExt) GetBase(n int) (field.Element, error) { + return field.Zero(), fmt.Errorf(conversionError) +} + +// Returns a particular element of the vector +func (r *RotatedExt) GetExt(n int) fext.Element { + return r.v.GetExt(utils.PositiveMod(n+r.offset, r.Len())) +} + +func (r *RotatedExt) Get(n int) field.Element { + res, err := r.GetBase(n) + if err != nil { + panic(err) + } + return res +} + +// Returns a particular element. The subvector is taken at indices +// [start, stop). (stop being excluded from the span) +func (r *RotatedExt) SubVector(start, stop int) smartvectors.SmartVector { + + if stop+r.offset < len(r.v.RegularExt) && start+r.offset > 0 { + res := RegularExt(r.v.RegularExt[start+r.offset : stop+r.offset]) + return &res + } + + res := make([]fext.Element, stop-start) + size := r.Len() + spanSize := stop - start + + // checking + if stop <= start { + utils.Panic("the start %v >= stop %v", start, stop) + } + + // boundary checks + if start < 0 { + utils.Panic("the start value was negative %v", start) + } + + if stop > size { + utils.Panic("the stop is OOO : %v (the length is %v)", stop, size) + } + + // normalize the offset to something positive [0: size) + startWithOffsetClean := utils.PositiveMod(start+r.offset, size) + + // NB: we may need to construct the res in several steps + // in case + copy(res, r.v.RegularExt[startWithOffsetClean:utils.Min(size, startWithOffsetClean+spanSize)]) + + // If this is negative of zero, it means the first copy already copied + // everything we needed to copy + howManyElementLeftToCopy := startWithOffsetClean + spanSize - size + howManyAlreadyCopied := spanSize - howManyElementLeftToCopy + if howManyElementLeftToCopy <= 0 { + ret := RegularExt(res) + return &ret + } + + // if necessary perform a second + copy(res[howManyAlreadyCopied:], r.v.RegularExt[:howManyElementLeftToCopy]) + ret := RegularExt(res) + return &ret +} + +// Rotates the vector into a new one, a positive offset means a left cyclic shift +func (r *RotatedExt) RotateRight(offset int) smartvectors.SmartVector { + // We limit the offset value to prevent integer overflow + if offset > 1<<40 { + utils.Panic("offset is too large") + } + return &RotatedExt{ + v: &PooledExt{ + RegularExt: vectorext.DeepCopy(r.v.RegularExt), + }, + offset: r.offset + offset, + } +} + +func (r *RotatedExt) DeepCopy() smartvectors.SmartVector { + return NewRotatedExt(vectorext.DeepCopy(r.v.RegularExt), r.offset) +} + +func (r *RotatedExt) WriteInSlice(s []field.Element) { + panic(conversionError) +} + +func (r *RotatedExt) WriteInSliceExt(s []fext.Element) { + temp := rotatedAsRegular(r) + assertHasLength(len(s), len(*temp)) + copy(s, *temp) +} + +func (r *RotatedExt) Pretty() string { + return fmt.Sprintf("Rotated[%v, %v]", r.v.Pretty(), r.offset) +} + +// rotatedAsRegular converts a [Rotated] into a [Regular] by effecting the +// symbolic shifting operation. The function allocates the result. +func rotatedAsRegular(r *RotatedExt) *RegularExt { + return r.SubVector(0, r.Len()).(*RegularExt) +} + +func (r *RotatedExt) IntoRegVecSaveAlloc() []field.Element { + panic(conversionError) +} + +func (r *RotatedExt) IntoRegVecSaveAllocBase() ([]field.Element, error) { + return nil, fmt.Errorf(conversionError) +} + +func (r *RotatedExt) IntoRegVecSaveAllocExt() []fext.Element { + temp := *rotatedAsRegular(r) + res := make([]fext.Element, temp.Len()) + for i := 0; i < temp.Len(); i++ { + res[i].Set(&temp[i]) + } + return res +} + +// SoftRotate converts v into a [SmartVector] representing the same +// [SmartVector]. The function tries to not reallocate the result. This means +// that changing the v can subsequently affects the result of this function. +func SoftRotate(v smartvectors.SmartVector, offset int) smartvectors.SmartVector { + + switch casted := v.(type) { + case *RegularExt: + return NewRotatedExt(*casted, offset) + case *RotatedExt: + return NewRotatedExt(casted.v.RegularExt, utils.PositiveMod(offset+casted.offset, v.Len())) + case *PaddedCircularWindowExt: + return NewPaddedCircularWindowExt( + casted.window, + casted.paddingVal, + utils.PositiveMod(casted.offset+offset, casted.Len()), + casted.Len(), + ) + case *ConstantExt: + // It's a constant so it does not need to be rotated + return v + case *PooledExt: + return &RotatedExt{ + v: casted, + offset: offset, + } + default: + utils.Panic("unknown type %T", v) + } + + panic("unreachable") + +} diff --git a/prover/maths/common/smartvectorsext/rotated_ext_test.go b/prover/maths/common/smartvectorsext/rotated_ext_test.go new file mode 100644 index 000000000..839ecaede --- /dev/null +++ b/prover/maths/common/smartvectorsext/rotated_ext_test.go @@ -0,0 +1,96 @@ +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/vector" + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/stretchr/testify/require" +) + +func TestRotatedSubVector(t *testing.T) { + + size := 16 + original := vector.Rand(size) + myVec := smartvectors.NewRegular(original) + + for offset := 0; offset < size; offset++ { + rotated := smartvectors.NewRotated(*myVec, offset) + for start := 0; start < size; start++ { + for stop := start + 1; stop <= size; stop++ { + + sub := rotated.SubVector(start, stop) + recovered := sub.Get(0) + require.Equal(t, recovered.String(), original[utils.PositiveMod(start+offset, size)].String()) + + } + } + } + +} + +func TestRotatedWriteInSlice(t *testing.T) { + + size := 16 + original := vectorext.Rand(size) + myVec := NewRegularExt(original) + + for offset := 0; offset < size; offset++ { + rotated := NewRotatedExt(*myVec, offset) + written := make([]fext.Element, size) + rotated.WriteInSliceExt(written) + + for i := range written { + require.Equal(t, written[i].String(), original[utils.PositiveMod(i+offset, size)].String()) + } + } +} + +func TestRotatedOffsetOverflow(t *testing.T) { + v := []fext.Element{fext.NewFromString("1"), + fext.NewFromString("2"), + fext.NewFromString("3"), + fext.NewFromString("4"), + fext.NewFromString("5")} + myVec := NewRegularExt(v) + // First we check that the absolute value for negative offsets + // are not allowed to be larger than the length + negOffset := -7 + expectedPanicMessage := fmt.Sprintf("len %v is less than, offset %v", 5, negOffset) + require.PanicsWithValue(t, expectedPanicMessage, func() { NewRotatedExt(*myVec, negOffset) }, + "NewRotated should panic with 'got negative offset' message") + // Next we check offset overflow + offset := 1 + r := NewRotatedExt(*myVec, offset) + rotateOffset := 1 << 41 + // The function should panic as rotateOffset is too large + require.PanicsWithValue(t, "offset is too large", func() { r.RotateRight(rotateOffset) }, + "RotateRight should panic with 'offset is too large' message") +} + +func TestRotateRightSimple(t *testing.T) { + v := []fext.Element{fext.NewFromString("0"), + fext.NewFromString("1"), + fext.NewFromString("2"), + fext.NewFromString("3"), + fext.NewFromString("4")} + myVec := NewRegularExt(v) + offset := 1 + rotated := NewRotatedExt(*myVec, offset) + // Next we rotate the vector + rotated_ := rotated.RotateRight(2) + m := make([]fext.Element, 0, len(v)) + for i := range v { + m = append(m, rotated_.GetExt(i)) + } + v_shifted := []fext.Element{fext.NewFromString("3"), + fext.NewFromString("4"), + fext.NewFromString("0"), + fext.NewFromString("1"), + fext.NewFromString("2")} + require.Equal(t, m, v_shifted) +} diff --git a/prover/maths/common/smartvectorsext/smartvectors_op.go b/prover/maths/common/smartvectorsext/smartvectors_op.go new file mode 100644 index 000000000..0cfceb359 --- /dev/null +++ b/prover/maths/common/smartvectorsext/smartvectors_op.go @@ -0,0 +1,158 @@ +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/gnark/frontend" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// ForTest returns a witness from a explicit litteral assignement +func ForTestExt(xs ...int) smartvectors.SmartVector { + return NewRegularExt(vectorext.ForTest(xs...)) +} + +// IntoRegVec converts a smart-vector into a normal vec. The resulting vector +// is always reallocated and can be safely mutated without side-effects +// on s. +func IntoRegVec(s smartvectors.SmartVector) []field.Element { + res := make([]field.Element, s.Len()) + s.WriteInSlice(res) + return res +} + +func IntoRegVecExt(s smartvectors.SmartVector) []fext.Element { + res := make([]fext.Element, s.Len()) + s.WriteInSliceExt(res) + return res +} + +// IntoGnarkAssignment converts a smart-vector into a gnark assignment +func IntoGnarkAssignment(sv smartvectors.SmartVector) []frontend.Variable { + res := make([]frontend.Variable, sv.Len()) + _, err := sv.GetBase(0) + if err == nil { + for i := range res { + elem, _ := sv.GetBase(i) + res[i] = elem + } + } else { + for i := range res { + elem := sv.GetExt(i) + res[i] = elem + } + } + return res +} + +// LeftPadded creates a new padded vector (padded on the left) +func LeftPadded(v []fext.Element, padding fext.Element, targetLen int) smartvectors.SmartVector { + + if len(v) > targetLen { + utils.Panic("target length %v must be less than %v", len(v), targetLen) + } + + if len(v) == targetLen { + return NewRegularExt(v) + } + + if len(v) == 0 { + return NewConstantExt(padding, targetLen) + } + + return NewPaddedCircularWindowExt(v, padding, targetLen-len(v), targetLen) +} + +// RightPadded creates a new vector (padded on the right) +func RightPadded(v []fext.Element, padding fext.Element, targetLen int) smartvectors.SmartVector { + + if len(v) > targetLen { + utils.Panic("target length %v must be less than %v", len(v), targetLen) + } + + if len(v) == targetLen { + return NewRegularExt(v) + } + + if len(v) == 0 { + return NewConstantExt(padding, targetLen) + } + + return NewPaddedCircularWindowExt(v, padding, 0, targetLen) +} + +// RightZeroPadded creates a new vector (padded on the right) +func RightZeroPadded(v []fext.Element, targetLen int) smartvectors.SmartVector { + return RightPadded(v, fext.Zero(), targetLen) +} + +// LeftZeroPadded creates a new vector (padded on the left) +func LeftZeroPadded(v []fext.Element, targetLen int) smartvectors.SmartVector { + return LeftPadded(v, fext.Zero(), targetLen) +} + +// Density returns the density of a smart-vector. By density we mean the size +// of the concrete underlying vectors. This can be used as a proxi for the +// memory required to store the smart-vector. +func Density(v smartvectors.SmartVector) int { + switch w := v.(type) { + case *ConstantExt: + return 0 + case *PaddedCircularWindowExt: + return len(w.window) + case *RegularExt: + return len(*w) + case *RotatedExt: + return len(w.v.RegularExt) + case *PooledExt: + return len(w.RegularExt) + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } +} + +// Window returns the effective window of the vector, +// if the vector is Padded with zeroes it return the window. +// Namely, the part without zero pads. +func Window(v smartvectors.SmartVector) []fext.Element { + switch w := v.(type) { + case *ConstantExt: + return w.IntoRegVecSaveAllocExt() + case *PaddedCircularWindowExt: + return w.window + case *RegularExt: + return *w + case *RotatedExt: + return w.IntoRegVecSaveAllocExt() + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } +} + +func WindowExt(v smartvectors.SmartVector) []fext.Element { + switch w := v.(type) { + case *ConstantExt: + return w.IntoRegVecSaveAllocExt() + case *PaddedCircularWindowExt: + temp := make([]fext.Element, len(w.window)) + for i := 0; i < len(w.window); i++ { + elem := w.window[i] + temp[i].Set(&elem) + } + return temp + case *RegularExt: + temp := make([]fext.Element, len(*w)) + for i := 0; i < len(*w); i++ { + elem := w.GetExt(i) + temp[i].Set(&elem) + } + return temp + case *RotatedExt: + return w.IntoRegVecSaveAllocExt() + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } +} diff --git a/prover/maths/common/smartvectorsext/temp_parameters.go b/prover/maths/common/smartvectorsext/temp_parameters.go new file mode 100644 index 000000000..327702589 --- /dev/null +++ b/prover/maths/common/smartvectorsext/temp_parameters.go @@ -0,0 +1,13 @@ +package smartvectorsext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/field" +) + +func fieldPadding() field.Element { + return field.Zero() +} + +func fieldPaddingInt() uint64 { + return 0 +} diff --git a/prover/maths/common/smartvectorsext/vecutil.go b/prover/maths/common/smartvectorsext/vecutil.go new file mode 100644 index 000000000..c52aff4b4 --- /dev/null +++ b/prover/maths/common/smartvectorsext/vecutil.go @@ -0,0 +1,39 @@ +package smartvectorsext + +import ( + "github.com/consensys/linea-monorepo/prover/utils" + "github.com/sirupsen/logrus" +) + +// assertCorrectBound panics if pos >= length +func assertCorrectBound(pos, length int) { + if pos >= length { + logrus.Panicf("Bound assertion failed, cannot access pos %v for vector of length %v", pos, length) + } +} + +// assertHasLength panics if a and b are not equal +func assertHasLength(a, b int) { + if a != b { + utils.Panic("the two slices should have the same length (found %v and %v)", a, b) + } +} + +// assertPowerOfTwoLen panics if l is not a power of two +func assertPowerOfTwoLen(l int) { + if !utils.IsPowerOfTwo(l) { + logrus.Panicf("Slice should have a power of two length but has %v", l) + } +} + +// assertStrictPositiveLen panics if l is 0 or negative +func assertStrictPositiveLen(l int) { + + if l == 0 { + logrus.Panicf("FORBIDDEN : Got a null length vector") + } + + if l <= 0 { + logrus.Panicf("FORBIDDEN : Got a negative length %v", l) + } +} diff --git a/prover/maths/common/smartvectorsext/windowed_ext.go b/prover/maths/common/smartvectorsext/windowed_ext.go new file mode 100644 index 000000000..6e2ef8013 --- /dev/null +++ b/prover/maths/common/smartvectorsext/windowed_ext.go @@ -0,0 +1,351 @@ +package smartvectorsext + +import ( + "fmt" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// It's a slice - zero padded up to a certain length - and RotatedExt +type PaddedCircularWindowExt struct { + window []fext.Element + paddingVal fext.Element + // Totlen is the length of the represented vector + totLen, offset int +} + +// Create a new padded circular window vector +func NewPaddedCircularWindowExt(window []fext.Element, paddingVal fext.Element, offset, totLen int) *PaddedCircularWindowExt { + // The window should not be larger than the total length + if len(window) > totLen { + utils.Panic("The window size is too large %v because totlen is %v", len(window), totLen) + } + + if len(window) == totLen { + utils.Panic("Forbidden : the window should not take the full length") + } + + if len(window) == 0 { + utils.Panic("Forbidden : empty window") + } + + // Normalize the offset to be in range [0:totlen) + offset = utils.PositiveMod(offset, totLen) + return &PaddedCircularWindowExt{ + window: window, + paddingVal: paddingVal, + offset: offset, + totLen: totLen, + } +} + +// Returns the length of the vector +func (p *PaddedCircularWindowExt) Len() int { + return p.totLen +} + +// Returns a queries position +func (p *PaddedCircularWindowExt) GetBase(n int) (field.Element, error) { + return field.Zero(), fmt.Errorf(conversionError) +} + +func (p *PaddedCircularWindowExt) GetExt(n int) fext.Element { + // Check if the queried index is in the window + posFromWindowsPoV := utils.PositiveMod(n-p.offset, p.totLen) + if posFromWindowsPoV < len(p.window) { + return p.window[posFromWindowsPoV] + } + // Else, return the padding value + return p.paddingVal +} + +func (r *PaddedCircularWindowExt) Get(n int) field.Element { + res, err := r.GetBase(n) + if err != nil { + panic(err) + } + return res +} + +// Extract a subvector from p[start:stop), the subvector cannot "roll-over". +// i.e, we enforce that start < stop +func (p *PaddedCircularWindowExt) SubVector(start, stop int) smartvectors.SmartVector { + // negative start value is not allowed + if start < 0 { + panic("negative start value is not allowed") + } + // Sanity checks for all subvectors + assertCorrectBound(start, p.totLen) + // The +1 is because we accept if "stop = length" + assertCorrectBound(stop, p.totLen+1) + + if start > stop { + panic("rollover are forbidden") + } + + if start == stop { + panic("zero length subvector is forbidden") + } + + /* + This function has a high-combinatoric complexity and in order to reason about + each case, we represent them as follows: + [a, b) is the interval of the subvector. We can assume that [a, b) does not + roll-over the vector. + [c,d) is the interval of the window of `p`. It can roll-over. + + We use 'a' as the origin for other coordinates and reduce + the ongoing combinatoric when listing all the cases. Let b_ = b-a + + 0xxxxxxxxxxxxxxxxxxxxxb N + | | | + 1) | | c------------d | + 2) |-------------------------d c------------| (including b == d) + 2*) c-------------------------d | (including b == d) + 3) | c-------d | | + 4) | c------------------------d | (including b == d) + 5) |-----d c----------------------------| + 6) |-------d | c----------| (including c == b) + 7) d | c----------| (including c == b) + + + For consistency, with the above picture, we rename as the offset coordinates + */ + + n := p.Len() + b := stop - start + c := normalize(p.interval().Start(), start, n) + d := normalize(p.interval().Stop(), start, n) + + // Case 1 : return a ConstantExt vector + if b <= c && c < d { + return NewConstantExt(p.paddingVal, b) + } + + // Case 2 : return a RegularExt vector + if b <= d && d < c { + reg := RegularExt(p.window[n-c : n-c+b]) + return ® + } + + // Case 2* : same as 2 but c == 0 + if b <= d && c == 0 { + reg := RegularExt(p.window[:b]) + return ® + } + + // Case 3 : the window is fully contained in the subvector + if c < d && d <= b { + return NewPaddedCircularWindowExt(p.window, p.paddingVal, c, b) + } + + // Case 4 : left-ended + if c < b && c <= d { + return NewPaddedCircularWindowExt(p.window[:b-c], p.paddingVal, c, b) + } + + // Case 5 : the window is double ended (we skip some element in the center of the window) + if d < c && c < b { + left := p.window[:b-c] + right := p.window[n-c:] + + // The deep-copy of left ensures that we do not append + // on the same concrete slice. + w := append(vectorext.DeepCopy(left), right...) + return NewPaddedCircularWindowExt(w, p.paddingVal, c, b) + } + + // Case 6 : right-ended + if 0 < d && d < b && b <= c { + return NewPaddedCircularWindowExt(p.window[n-c:], p.paddingVal, 0, b) + } + + // Case 7 : d == 0 and c is out + if d == 0 && b <= c { + return NewConstantExt(p.paddingVal, b) + } + + panic(fmt.Sprintf("unsupported case : b %v, c %v, d %v", b, c, d)) + +} + +// Rotate the vector +func (p *PaddedCircularWindowExt) RotateRight(offset int) smartvectors.SmartVector { + return NewPaddedCircularWindowExt(vectorext.DeepCopy(p.window), p.paddingVal, p.offset+offset, p.totLen) +} + +func (p *PaddedCircularWindowExt) WriteInSlice(buff []field.Element) { + panic(conversionError) +} + +func (p *PaddedCircularWindowExt) WriteInSliceExt(buff []fext.Element) { + assertHasLength(len(buff), p.totLen) + + for i := range p.window { + pos := utils.PositiveMod(i+p.offset, p.totLen) + buff[pos] = p.window[i] + } + + for i := len(p.window); i < p.totLen; i++ { + pos := utils.PositiveMod(i+p.offset, p.totLen) + buff[pos] = p.paddingVal + } +} + +func (p *PaddedCircularWindowExt) Pretty() string { + return fmt.Sprintf("Windowed[totlen=%v offset=%v, paddingVal=%v, window=%v]", p.totLen, p.offset, p.paddingVal.String(), vectorext.Prettify(p.window)) +} + +func (p *PaddedCircularWindowExt) interval() smartvectors.CircularInterval { + return smartvectors.IvalWithStartLen(p.offset, len(p.window), p.totLen) +} + +// normalize converts the (circle) coordinator x to another coordinate by changing +// the origin point on the discret circle. mod denotes the number of points in +// the circle. +func normalize(x, newRef, mod int) int { + return utils.PositiveMod(x-newRef, mod) +} + +// processWindowedOnly applies the operator `op` to all the smartvectors +// contained in `svecs` with `coeffs` that have the type [PaddedCircularWindowExt] +// +// The function does so by attempting to fit result on the smallest possible +// window. +// +// In case, this is not possible. The function will "give up" and convert all +// the PaddedCircularWindowExt into RegularExts and pretend it did not find any. +// +// The function returns the partial result of the operation and the number of +// padded circular windows SmartVector that it found. +func processWindowedOnly(op operator, svecs []smartvectors.SmartVector, coeffs_ []int) (res smartvectors.SmartVector, numMatches int) { + + // First we compute the union windows. + length := svecs[0].Len() + windows := []PaddedCircularWindowExt{} + intervals := []smartvectors.CircularInterval{} + coeffs := []int{} + + // Gather all the windows into a slice + for i, svec := range svecs { + if pcw, ok := svec.(*PaddedCircularWindowExt); ok { + windows = append(windows, *pcw) + intervals = append(intervals, pcw.interval()) + coeffs = append(coeffs, coeffs_[i]) // collect the coeffs related to each window + // Sanity-check : all vectors must have the same length + assertHasLength(svec.Len(), length) + numMatches++ + } + } + + if numMatches == 0 { + return nil, numMatches + } + + // has the dimension of the cover with garbage values in it + smallestCover := smartvectors.SmallestCoverInterval(intervals) + + // Edge-case: in case the smallest-cover of the pcw found in svecs is the + // full-circle the code below will not work as it assumes that is possible + if smallestCover.IsFullCircle() { + for i, svec := range svecs { + if _, ok := svec.(*PaddedCircularWindowExt); ok { + temp := svec.IntoRegVecSaveAllocExt() + svecs[i] = NewRegularExt(temp) + } + } + return nil, 0 + } + + // Sanity-check : normally all offset are normalized, this should ensure that start + // is positive. This is critical here because if some of the offset are not normalized + // then we may end up with a union windows that does not make sense. + if smallestCover.Start() < 0 { + utils.Panic("All offset should be normalized, but start is %v", smallestCover.Start()) + } + + // Ensures we do not reuse an input vector here to limit the risk of overwriting one + // of the input. This can happen if there is only a single window or if one windows + // covers all the other. + unionWindow := make([]fext.Element, smallestCover.IntervalLen) + var paddedTerm fext.Element + offset := smallestCover.Start() + + /* + Now we actually compute the linear combinations for all offsets + */ + + isFirst := true + for i, pcw := range windows { + interval := intervals[i] + + // Find the intersection with the larger window + start_ := normalize(interval.Start(), offset, length) + stop_ := normalize(interval.Stop(), offset, length) + if stop_ == 0 { + stop_ = length + } + + // For the first match, we can save the operations by copying instead of + // multiplying / adding + if isFirst { + isFirst = false + op.vecIntoTerm(unionWindow[start_:stop_], pcw.window, coeffs[i]) + // #nosec G601 -- Deliberate pass by reference. (We trust the pointed object is not mutated) + op.constIntoTerm(&paddedTerm, &pcw.paddingVal, coeffs[i]) + vectorext.Fill(unionWindow[:start_], paddedTerm) + vectorext.Fill(unionWindow[stop_:], paddedTerm) + continue + } + + // sanity-check : start and stop are consistent with the size of pcw + if stop_-start_ != len(pcw.window) { + utils.Panic( + "sanity-check failed. The renormalized coordinates (start=%v, stop=%v) are inconsistent with pcw : (len=%v)", + start_, stop_, len(pcw.window), + ) + } + + op.vecIntoVec(unionWindow[start_:stop_], pcw.window, coeffs[i]) + + // Update the padded term + // #nosec G601 -- Deliberate pass by reference. (We trust the pointed object is not mutated) + op.constIntoConst(&paddedTerm, &pcw.paddingVal, coeffs[i]) + + // Complete the left and the right-side of the window (i.e) the part + // of unionWindow that does not overlap with pcw.window. + // #nosec G601 -- Deliberate pass by reference. (We trust the pointed object is not mutated) + op.constIntoVec(unionWindow[:start_], &pcw.paddingVal, coeffs[i]) + // #nosec G601 -- Deliberate pass by reference. (We trust the pointed object is not mutated) + op.constIntoVec(unionWindow[stop_:], &pcw.paddingVal, coeffs[i]) + } + + if smallestCover.IsFullCircle() { + return NewRegularExt(unionWindow), numMatches + } + + return NewPaddedCircularWindowExt(unionWindow, paddedTerm, offset, length), numMatches +} + +func (w *PaddedCircularWindowExt) DeepCopy() smartvectors.SmartVector { + window := vectorext.DeepCopy(w.window) + return NewPaddedCircularWindowExt(window, w.paddingVal, w.offset, w.totLen) +} + +// Converts a smart-vector into a normal vec. The implementation minimizes +// then number of copies. +func (w *PaddedCircularWindowExt) IntoRegVecSaveAlloc() []field.Element { + panic(conversionError) +} + +func (w *PaddedCircularWindowExt) IntoRegVecSaveAllocBase() ([]field.Element, error) { + return nil, fmt.Errorf(conversionError) +} + +func (w *PaddedCircularWindowExt) IntoRegVecSaveAllocExt() []fext.Element { + res := IntoRegVecExt(w) + return res +} diff --git a/prover/maths/common/smartvectorsext/windowed_ext_test.go b/prover/maths/common/smartvectorsext/windowed_ext_test.go new file mode 100644 index 000000000..64e7207ac --- /dev/null +++ b/prover/maths/common/smartvectorsext/windowed_ext_test.go @@ -0,0 +1,67 @@ +package smartvectorsext + +import ( + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors/vectorext" + "github.com/consensys/linea-monorepo/prover/maths/field/fext" + "testing" + + "github.com/stretchr/testify/require" +) + +// This is a simple error case we have faces in the past, the test ensures that +// it does go through. +func TestProcessWindowed(_ *testing.T) { + + a := NewPaddedCircularWindowExt( + vectorext.Rand(5), + fext.Zero(), + 0, + 16, + ) + + b := NewPaddedCircularWindowExt( + vectorext.Rand(12), + fext.Zero(), + 4, + 16, + ) + + _, _ = processWindowedOnly( + linCombOp{}, + []smartvectors.SmartVector{b, a}, + []int{1, 1}, + ) +} + +func TestEdgeCases(t *testing.T) { + require.PanicsWithValue(t, "zero length subvector is forbidden", func() { + NewPaddedCircularWindowExt( + vectorext.Rand(5), + fext.Zero(), + 0, + 16, + ).SubVector(0, 0) + }, + "SubVector should panic with 'zero length subvector is forbidden' message") + require.PanicsWithValue(t, "Subvector of zero lengths are not allowed", func() { + NewRegularExt([]fext.Element{fext.Zero()}).SubVector(0, 0) + }, + "SubVector should panic with 'Subvector of zero lengths are not allowed' message") + require.PanicsWithValue(t, "zero or negative length are not allowed", func() { + NewConstantExt(fext.Zero(), 0) + }, + "NewConstant should panic with 'zero or negative length are not allowed' message") + require.PanicsWithValue(t, "zero or negative length are not allowed", func() { + NewConstantExt(fext.Zero(), -1) + }, + "NewConstant should panic with 'zero or negative length are not allowed' message") + require.PanicsWithValue(t, "negative length are not allowed", func() { + NewConstantExt(fext.Zero(), 10).SubVector(3, 1) + }, + "NewConstant.Subvector should panic with 'negative length are not allowed' message") + require.PanicsWithValue(t, "zero length are not allowed", func() { + NewConstantExt(fext.Zero(), 10).SubVector(3, 3) + }, + "NewConstant.Subvector should panic with 'zero length are not allowed' message") +} diff --git a/prover/maths/field/fext/additional.go b/prover/maths/field/fext/additional.go new file mode 100644 index 000000000..4a8aff633 --- /dev/null +++ b/prover/maths/field/fext/additional.go @@ -0,0 +1,30 @@ +package fext + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/linea-monorepo/prover/maths/field" +) + +// SetUint64 sets z to v and returns z +func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + var x fr.Element + x.SetUint64(v) + y := field.Zero() + z = &Element{x, y} + return z // z.toMont() +} + +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + var x fr.Element + x.SetInt64(v) + y := field.Zero() + z = &Element{x, y} + return z // z.toMont() +} + +func (z *Element) Uint64() (uint64, uint64) { + return z.A0.Bits()[0], z.A1.Bits()[0] +} diff --git a/prover/maths/field/fext/e12.go b/prover/maths/field/fext/e12.go new file mode 100644 index 000000000..2ea78fc96 --- /dev/null +++ b/prover/maths/field/fext/e12.go @@ -0,0 +1,866 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math/big" + "sync" +) + +var bigIntPool = sync.Pool{ + New: func() interface{} { + return new(big.Int) + }, +} + +// E12 is a degree two finite field extension of fp6 +type E12 struct { + C0, C1 E6 +} + +// Equal returns true if z equals x, false otherwise +func (z *E12) Equal(x *E12) bool { + return z.C0.Equal(&x.C0) && z.C1.Equal(&x.C1) +} + +// String puts E12 in string form +func (z *E12) String() string { + return z.C0.String() + "+(" + z.C1.String() + ")*w" +} + +// SetString sets a E12 from string +func (z *E12) SetString(s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11 string) *E12 { + z.C0.SetString(s0, s1, s2, s3, s4, s5) + z.C1.SetString(s6, s7, s8, s9, s10, s11) + return z +} + +// Set copies x into z and returns z +func (z *E12) Set(x *E12) *E12 { + z.C0 = x.C0 + z.C1 = x.C1 + return z +} + +// SetOne sets z to 1 in Montgomery form and returns z +func (z *E12) SetOne() *E12 { + *z = E12{} + z.C0.B0.A0.SetOne() + return z +} + +// Add sets z=x+y in E12 and returns z +func (z *E12) Add(x, y *E12) *E12 { + z.C0.Add(&x.C0, &y.C0) + z.C1.Add(&x.C1, &y.C1) + return z +} + +// Sub sets z to x-y and returns z +func (z *E12) Sub(x, y *E12) *E12 { + z.C0.Sub(&x.C0, &y.C0) + z.C1.Sub(&x.C1, &y.C1) + return z +} + +// Double sets z=2*x and returns z +func (z *E12) Double(x *E12) *E12 { + z.C0.Double(&x.C0) + z.C1.Double(&x.C1) + return z +} + +// SetRandom used only in tests +func (z *E12) SetRandom() (*E12, error) { + if _, err := z.C0.SetRandom(); err != nil { + return nil, err + } + if _, err := z.C1.SetRandom(); err != nil { + return nil, err + } + return z, nil +} + +// IsZero returns true if z is zero, false otherwise +func (z *E12) IsZero() bool { + return z.C0.IsZero() && z.C1.IsZero() +} + +// IsOne returns true if z is one, false otherwise +func (z *E12) IsOne() bool { + return z.C0.IsOne() && z.C1.IsZero() +} + +// Mul sets z=x*y in E12 and returns z +func (z *E12) Mul(x, y *E12) *E12 { + var a, b, c E6 + a.Add(&x.C0, &x.C1) + b.Add(&y.C0, &y.C1) + a.Mul(&a, &b) + b.Mul(&x.C0, &y.C0) + c.Mul(&x.C1, &y.C1) + z.C1.Sub(&a, &b).Sub(&z.C1, &c) + z.C0.MulByNonResidue(&c).Add(&z.C0, &b) + return z +} + +// Square sets z=x*x in E12 and returns z +func (z *E12) Square(x *E12) *E12 { + + //Algorithm 22 from https://eprint.iacr.org/2010/354.pdf + var c0, c2, c3 E6 + c0.Sub(&x.C0, &x.C1) + c3.MulByNonResidue(&x.C1).Neg(&c3).Add(&x.C0, &c3) + c2.Mul(&x.C0, &x.C1) + c0.Mul(&c0, &c3).Add(&c0, &c2) + z.C1.Double(&c2) + c2.MulByNonResidue(&c2) + z.C0.Add(&c0, &c2) + + return z +} + +// Karabina's compressed cyclotomic square +// https://eprint.iacr.org/2010/542.pdf +// Th. 3.2 with minor modifications to fit our tower +func (z *E12) CyclotomicSquareCompressed(x *E12) *E12 { + + var t [7]E2 + + // t0 = g1^2 + t[0].Square(&x.C0.B1) + // t1 = g5^2 + t[1].Square(&x.C1.B2) + // t5 = g1 + g5 + t[5].Add(&x.C0.B1, &x.C1.B2) + // t2 = (g1 + g5)^2 + t[2].Square(&t[5]) + + // t3 = g1^2 + g5^2 + t[3].Add(&t[0], &t[1]) + // t5 = 2 * g1 * g5 + t[5].Sub(&t[2], &t[3]) + + // t6 = g3 + g2 + t[6].Add(&x.C1.B0, &x.C0.B2) + // t3 = (g3 + g2)^2 + t[3].Square(&t[6]) + // t2 = g3^2 + t[2].Square(&x.C1.B0) + + // t6 = 2 * nr * g1 * g5 + t[6].MulByNonResidue(&t[5]) + // t5 = 4 * nr * g1 * g5 + 2 * g3 + t[5].Add(&t[6], &x.C1.B0). + Double(&t[5]) + // z3 = 6 * nr * g1 * g5 + 2 * g3 + z.C1.B0.Add(&t[5], &t[6]) + + // t4 = nr * g5^2 + t[4].MulByNonResidue(&t[1]) + // t5 = nr * g5^2 + g1^2 + t[5].Add(&t[0], &t[4]) + // t6 = nr * g5^2 + g1^2 - g2 + t[6].Sub(&t[5], &x.C0.B2) + + // t1 = g2^2 + t[1].Square(&x.C0.B2) + + // t6 = 2 * nr * g5^2 + 2 * g1^2 - 2*g2 + t[6].Double(&t[6]) + // z2 = 3 * nr * g5^2 + 3 * g1^2 - 2*g2 + z.C0.B2.Add(&t[6], &t[5]) + + // t4 = nr * g2^2 + t[4].MulByNonResidue(&t[1]) + // t5 = g3^2 + nr * g2^2 + t[5].Add(&t[2], &t[4]) + // t6 = g3^2 + nr * g2^2 - g1 + t[6].Sub(&t[5], &x.C0.B1) + // t6 = 2 * g3^2 + 2 * nr * g2^2 - 2 * g1 + t[6].Double(&t[6]) + // z1 = 3 * g3^2 + 3 * nr * g2^2 - 2 * g1 + z.C0.B1.Add(&t[6], &t[5]) + + // t0 = g2^2 + g3^2 + t[0].Add(&t[2], &t[1]) + // t5 = 2 * g3 * g2 + t[5].Sub(&t[3], &t[0]) + // t6 = 2 * g3 * g2 + g5 + t[6].Add(&t[5], &x.C1.B2) + // t6 = 4 * g3 * g2 + 2 * g5 + t[6].Double(&t[6]) + // z5 = 6 * g3 * g2 + 2 * g5 + z.C1.B2.Add(&t[5], &t[6]) + + return z +} + +// DecompressKarabina Karabina's cyclotomic square result +// if g3 != 0 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// if g3 == 0 +// +// g4 = 2g1g5/g2 +// +// if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) +// Theorem 3.1 is well-defined for all x in Gϕₙ\{1} +func (z *E12) DecompressKarabina(x *E12) *E12 { + + var t [3]E2 + var one E2 + one.SetOne() + + if x.C1.B2.IsZero() /* g3 == 0 */ { + t[0].Mul(&x.C0.B1, &x.C1.B2). + Double(&t[0]) + // t1 = g2 + t[1].Set(&x.C0.B2) + + if t[1].IsZero() /* g2 == g3 == 0 */ { + return z.SetOne() + } + } else /* g3 != 0 */ { + + // t0 = g1^2 + t[0].Square(&x.C0.B1) + // t1 = 3 * g1^2 - 2 * g2 + t[1].Sub(&t[0], &x.C0.B2). + Double(&t[1]). + Add(&t[1], &t[0]) + // t0 = E * g5^2 + t1 + t[2].Square(&x.C1.B2) + t[0].MulByNonResidue(&t[2]). + Add(&t[0], &t[1]) + // t1 = 4 * g3 + t[1].Double(&x.C1.B0). + Double(&t[1]) + } + + // z4 = g4 + z.C1.B1.Div(&t[0], &t[1]) // costly + + // t1 = g2 * g1 + t[1].Mul(&x.C0.B2, &x.C0.B1) + // t2 = 2 * g4^2 - 3 * g2 * g1 + t[2].Square(&z.C1.B1). + Sub(&t[2], &t[1]). + Double(&t[2]). + Sub(&t[2], &t[1]) + // t1 = g3 * g5 (g3 can be 0) + t[1].Mul(&x.C1.B0, &x.C1.B2) + // c_0 = E * (2 * g4^2 + g3 * g5 - 3 * g2 * g1) + 1 + t[2].Add(&t[2], &t[1]) + z.C0.B0.MulByNonResidue(&t[2]). + Add(&z.C0.B0, &one) + + z.C0.B1.Set(&x.C0.B1) + z.C0.B2.Set(&x.C0.B2) + z.C1.B0.Set(&x.C1.B0) + z.C1.B2.Set(&x.C1.B2) + + return z +} + +// BatchDecompressKarabina multiple Karabina's cyclotomic square results +// if g3 != 0 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// if g3 == 0 +// +// g4 = 2g1g5/g2 +// +// if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) +// Theorem 3.1 is well-defined for all x in Gϕₙ\{1} +// +// Divisions by 4g3 or g2 is batched using Montgomery batch inverse +func BatchDecompressKarabina(x []E12) []E12 { + + n := len(x) + if n == 0 { + return x + } + + t0 := make([]E2, n) + t1 := make([]E2, n) + t2 := make([]E2, n) + zeroes := make([]bool, n) + + var one E2 + one.SetOne() + + for i := 0; i < n; i++ { + if x[i].C1.B2.IsZero() /* g3 == 0 */ { + t0[i].Mul(&x[i].C0.B1, &x[i].C1.B2). + Double(&t0[i]) + // t1 = g2 + t1[i].Set(&x[i].C0.B2) + + if t1[i].IsZero() /* g3 == g2 == 0 */ { + x[i].SetOne() + zeroes[i] = true + continue + } + } else /* g3 != 0 */ { + // t0 = g1^2 + t0[i].Square(&x[i].C0.B1) + // t1 = 3 * g1^2 - 2 * g2 + t1[i].Sub(&t0[i], &x[i].C0.B2). + Double(&t1[i]). + Add(&t1[i], &t0[i]) + // t0 = E * g5^2 + t1 + t2[i].Square(&x[i].C1.B2) + t0[i].MulByNonResidue(&t2[i]). + Add(&t0[i], &t1[i]) + // t1 = 4 * g3 + t1[i].Double(&x[i].C1.B0). + Double(&t1[i]) + } + } + + t1 = BatchInvertE2(t1) // costs 1 inverse + + for i := 0; i < n; i++ { + if zeroes[i] { + continue + } + + // z4 = g4 + x[i].C1.B1.Mul(&t0[i], &t1[i]) + + // t1 = g2 * g1 + t1[i].Mul(&x[i].C0.B2, &x[i].C0.B1) + // t2 = 2 * g4^2 - 3 * g2 * g1 + t2[i].Square(&x[i].C1.B1) + t2[i].Sub(&t2[i], &t1[i]) + t2[i].Double(&t2[i]) + t2[i].Sub(&t2[i], &t1[i]) + + // t1 = g3 * g5 (g3s can be 0s) + t1[i].Mul(&x[i].C1.B0, &x[i].C1.B2) + // z0 = E * (2 * g4^2 + g3 * g5 - 3 * g2 * g1) + 1 + t2[i].Add(&t2[i], &t1[i]) + x[i].C0.B0.MulByNonResidue(&t2[i]). + Add(&x[i].C0.B0, &one) + } + + return x +} + +// Granger-Scott's cyclotomic square +// https://eprint.iacr.org/2009/565.pdf, 3.2 +func (z *E12) CyclotomicSquare(x *E12) *E12 { + + // x=(x0,x1,x2,x3,x4,x5,x6,x7) in E2^6 + // cyclosquare(x)=(3*x4^2*u + 3*x0^2 - 2*x0, + // 3*x2^2*u + 3*x3^2 - 2*x1, + // 3*x5^2*u + 3*x1^2 - 2*x2, + // 6*x1*x5*u + 2*x3, + // 6*x0*x4 + 2*x4, + // 6*x2*x3 + 2*x5) + + var t [9]E2 + + t[0].Square(&x.C1.B1) + t[1].Square(&x.C0.B0) + t[6].Add(&x.C1.B1, &x.C0.B0).Square(&t[6]).Sub(&t[6], &t[0]).Sub(&t[6], &t[1]) // 2*x4*x0 + t[2].Square(&x.C0.B2) + t[3].Square(&x.C1.B0) + t[7].Add(&x.C0.B2, &x.C1.B0).Square(&t[7]).Sub(&t[7], &t[2]).Sub(&t[7], &t[3]) // 2*x2*x3 + t[4].Square(&x.C1.B2) + t[5].Square(&x.C0.B1) + t[8].Add(&x.C1.B2, &x.C0.B1).Square(&t[8]).Sub(&t[8], &t[4]).Sub(&t[8], &t[5]).MulByNonResidue(&t[8]) // 2*x5*x1*u + + t[0].MulByNonResidue(&t[0]).Add(&t[0], &t[1]) // x4^2*u + x0^2 + t[2].MulByNonResidue(&t[2]).Add(&t[2], &t[3]) // x2^2*u + x3^2 + t[4].MulByNonResidue(&t[4]).Add(&t[4], &t[5]) // x5^2*u + x1^2 + + z.C0.B0.Sub(&t[0], &x.C0.B0).Double(&z.C0.B0).Add(&z.C0.B0, &t[0]) + z.C0.B1.Sub(&t[2], &x.C0.B1).Double(&z.C0.B1).Add(&z.C0.B1, &t[2]) + z.C0.B2.Sub(&t[4], &x.C0.B2).Double(&z.C0.B2).Add(&z.C0.B2, &t[4]) + + z.C1.B0.Add(&t[8], &x.C1.B0).Double(&z.C1.B0).Add(&z.C1.B0, &t[8]) + z.C1.B1.Add(&t[6], &x.C1.B1).Double(&z.C1.B1).Add(&z.C1.B1, &t[6]) + z.C1.B2.Add(&t[7], &x.C1.B2).Double(&z.C1.B2).Add(&z.C1.B2, &t[7]) + + return z +} + +// Inverse sets z to the inverse of x in E12 and returns z +// +// if x == 0, sets and returns z = x +func (z *E12) Inverse(x *E12) *E12 { + // Algorithm 23 from https://eprint.iacr.org/2010/354.pdf + + var t0, t1, tmp E6 + t0.Square(&x.C0) + t1.Square(&x.C1) + tmp.MulByNonResidue(&t1) + t0.Sub(&t0, &tmp) + t1.Inverse(&t0) + z.C0.Mul(&x.C0, &t1) + z.C1.Mul(&x.C1, &t1).Neg(&z.C1) + + return z +} + +// BatchInvertE12 returns a new slice with every element in a inverted. +// It uses Montgomery batch inversion trick. +// +// if a[i] == 0, returns result[i] = a[i] +func BatchInvertE12(a []E12) []E12 { + res := make([]E12, len(a)) + if len(a) == 0 { + return res + } + + zeroes := make([]bool, len(a)) + var accumulator E12 + accumulator.SetOne() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes[i] = true + continue + } + res[i].Set(&accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes[i] { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +// Exp sets z=xᵏ (mod q¹²) and returns it +// uses 2-bits windowed method +func (z *E12) Exp(x E12, k *big.Int) *E12 { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q¹²) == (x⁻¹)ᵏ (mod q¹²) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = bigIntPool.Get().(*big.Int) + defer bigIntPool.Put(e) + e.Neg(k) + } + + var res E12 + var ops [3]E12 + + res.SetOne() + ops[0].Set(&x) + ops[1].Square(&ops[0]) + ops[2].Set(&ops[0]).Mul(&ops[2], &ops[1]) + + b := e.Bytes() + for i := range b { + w := b[i] + mask := byte(0xc0) + for j := 0; j < 4; j++ { + res.Square(&res).Square(&res) + c := (w & mask) >> (6 - 2*j) + if c != 0 { + res.Mul(&res, &ops[c-1]) + } + mask = mask >> 2 + } + } + z.Set(&res) + + return z +} + +// CyclotomicExp sets z=xᵏ (mod q¹²) and returns it +// uses 2-NAF decomposition +// x must be in the cyclotomic subgroup +// TODO: use a windowed method +func (z *E12) CyclotomicExp(x E12, k *big.Int) *E12 { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert (=conjugate) + // if k < 0: xᵏ (mod q¹²) == (x⁻¹)ᵏ (mod q¹²) + x.Conjugate(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = bigIntPool.Get().(*big.Int) + defer bigIntPool.Put(e) + e.Neg(k) + } + + var res, xInv E12 + xInv.InverseUnitary(&x) + res.SetOne() + eNAF := make([]int8, e.BitLen()+3) + n := ecc.NafDecomposition(e, eNAF[:]) + for i := n - 1; i >= 0; i-- { + res.CyclotomicSquare(&res) + if eNAF[i] == 1 { + res.Mul(&res, &x) + } else if eNAF[i] == -1 { + res.Mul(&res, &xInv) + } + } + z.Set(&res) + return z +} + +// ExpGLV sets z=xᵏ (q¹²) and returns it +// uses 2-dimensional GLV with 2-bits windowed method +// x must be in GT +// TODO: use 2-NAF +// TODO: use higher dimensional decomposition +func (z *E12) ExpGLV(x E12, k *big.Int) *E12 { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert (=conjugate) + // if k < 0: xᵏ (mod q¹²) == (x⁻¹)ᵏ (mod q¹²) + x.Conjugate(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = bigIntPool.Get().(*big.Int) + defer bigIntPool.Put(e) + e.Neg(k) + } + + var table [15]E12 + var res E12 + var s1, s2 fr.Element + + res.SetOne() + + // table[b3b2b1b0-1] = b3b2*Frobinius(x) + b1b0*x + table[0].Set(&x) + table[3].Frobenius(&x) + + // split the scalar, modifies ±x, Frob(x) accordingly + s := ecc.SplitScalar(e, &glvBasis) + + if s[0].Sign() == -1 { + s[0].Neg(&s[0]) + table[0].InverseUnitary(&table[0]) + } + if s[1].Sign() == -1 { + s[1].Neg(&s[1]) + table[3].InverseUnitary(&table[3]) + } + + // precompute table (2 bits sliding window) + // table[b3b2b1b0-1] = b3b2*Frobenius(x) + b1b0*x if b3b2b1b0 != 0 + table[1].CyclotomicSquare(&table[0]) + table[2].Mul(&table[1], &table[0]) + table[4].Mul(&table[3], &table[0]) + table[5].Mul(&table[3], &table[1]) + table[6].Mul(&table[3], &table[2]) + table[7].CyclotomicSquare(&table[3]) + table[8].Mul(&table[7], &table[0]) + table[9].Mul(&table[7], &table[1]) + table[10].Mul(&table[7], &table[2]) + table[11].Mul(&table[7], &table[3]) + table[12].Mul(&table[11], &table[0]) + table[13].Mul(&table[11], &table[1]) + table[14].Mul(&table[11], &table[2]) + + // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() + + maxBit := s1.BitLen() + if s2.BitLen() > maxBit { + maxBit = s2.BitLen() + } + hiWordIndex := (maxBit - 1) / 64 + + // loop starts from len(s1)/2 due to the bounds + for i := hiWordIndex; i >= 0; i-- { + mask := uint64(3) << 62 + for j := 0; j < 32; j++ { + res.CyclotomicSquare(&res).CyclotomicSquare(&res) + b1 := (s1[i] & mask) >> (62 - 2*j) + b2 := (s2[i] & mask) >> (62 - 2*j) + if b1|b2 != 0 { + s := (b2<<2 | b1) + res.Mul(&res, &table[s-1]) + } + mask = mask >> 2 + } + } + + z.Set(&res) + return z +} + +// InverseUnitary inverses a unitary element +func (z *E12) InverseUnitary(x *E12) *E12 { + return z.Conjugate(x) +} + +// Conjugate sets z to x conjugated and returns z +func (z *E12) Conjugate(x *E12) *E12 { + *z = *x + z.C1.Neg(&z.C1) + return z +} + +// SizeOfGT represents the size in bytes that a GT element need in binary form +const SizeOfGT = 48 * 12 + +// Marshal converts z to a byte slice +func (z *E12) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// Unmarshal is an alias to SetBytes() +func (z *E12) Unmarshal(buf []byte) error { + return z.SetBytes(buf) +} + +// Bytes returns the regular (non montgomery) value +// of z as a big-endian byte array. +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +func (z *E12) Bytes() (r [SizeOfGT]byte) { + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[528:528+fp.Bytes]), z.C0.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[480:480+fp.Bytes]), z.C0.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[432:432+fp.Bytes]), z.C0.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[384:384+fp.Bytes]), z.C0.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[336:336+fp.Bytes]), z.C0.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[288:288+fp.Bytes]), z.C0.B2.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[240:240+fp.Bytes]), z.C1.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[192:192+fp.Bytes]), z.C1.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[144:144+fp.Bytes]), z.C1.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[96:96+fp.Bytes]), z.C1.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[48:48+fp.Bytes]), z.C1.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[0:0+fp.Bytes]), z.C1.B2.A1) + + return +} + +// SetBytes interprets e as the bytes of a big-endian GT +// sets z to that value (in Montgomery form), and returns z. +// size(e) == 48 * 12 +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +func (z *E12) SetBytes(e []byte) error { + if len(e) != SizeOfGT { + return errors.New("invalid buffer size") + } + if err := z.C0.B0.A0.SetBytesCanonical(e[528 : 528+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B0.A1.SetBytesCanonical(e[480 : 480+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A0.SetBytesCanonical(e[432 : 432+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A1.SetBytesCanonical(e[384 : 384+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A0.SetBytesCanonical(e[336 : 336+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A1.SetBytesCanonical(e[288 : 288+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A0.SetBytesCanonical(e[240 : 240+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A1.SetBytesCanonical(e[192 : 192+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A0.SetBytesCanonical(e[144 : 144+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A1.SetBytesCanonical(e[96 : 96+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A0.SetBytesCanonical(e[48 : 48+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A1.SetBytesCanonical(e[0 : 0+fp.Bytes]); err != nil { + return err + } + + return nil +} + +// IsInSubGroup ensures GT/E12 is in correct subgroup +func (z *E12) IsInSubGroup() bool { + var a, b E12 + + // check z^(phi_k(p)) == 1 + a.FrobeniusSquare(z) + b.FrobeniusSquare(&a).Mul(&b, z) + + if !a.Equal(&b) { + return false + } + + // check z^(p+1-t) == 1 + a.Frobenius(z) + b.Expt(z) + + return a.Equal(&b) +} + +// CompressTorus GT/E12 element to half its size +// z must be in the cyclotomic subgroup +// i.e. z^(p^4-p^2+1)=1 +// e.g. GT +// "COMPRESSION IN FINITE FIELDS AND TORUS-BASED CRYPTOGRAPHY", K. RUBIN AND A. SILVERBERG +// z.C1 == 0 only when z \in {-1,1} +func (z *E12) CompressTorus() (E6, error) { + + if z.C1.IsZero() { + return E6{}, errors.New("invalid input") + } + + var res, tmp, one E6 + one.SetOne() + tmp.Inverse(&z.C1) + res.Add(&z.C0, &one). + Mul(&res, &tmp) + + return res, nil +} + +// BatchCompressTorus GT/E12 elements to half their size using a batch inversion. +// +// if len(x) == 0 or if any of the x[i].C1 coordinate is 0, this function returns an error. +func BatchCompressTorus(x []E12) ([]E6, error) { + + n := len(x) + if n == 0 { + return nil, errors.New("invalid input size") + } + + var one E6 + one.SetOne() + res := make([]E6, n) + + for i := 0; i < n; i++ { + res[i].Set(&x[i].C1) + // throw an error if any of the x[i].C1 is 0 + if res[i].IsZero() { + return nil, errors.New("invalid input; C1 is 0") + } + } + + t := BatchInvertE6(res) // costs 1 inverse + + for i := 0; i < n; i++ { + res[i].Add(&x[i].C0, &one). + Mul(&res[i], &t[i]) + } + + return res, nil +} + +// DecompressTorus GT/E12 a compressed element +// element must be in the cyclotomic subgroup +// "COMPRESSION IN FINITE FIELDS AND TORUS-BASED CRYPTOGRAPHY", K. RUBIN AND A. SILVERBERG +func (z *E6) DecompressTorus() E12 { + + var res, num, denum E12 + num.C0.Set(z) + num.C1.SetOne() + denum.C0.Set(z) + denum.C1.SetOne().Neg(&denum.C1) + res.Inverse(&denum). + Mul(&res, &num) + + return res +} + +// BatchDecompressTorus GT/E12 compressed elements +// using a batch inversion +func BatchDecompressTorus(x []E6) ([]E12, error) { + + n := len(x) + if n == 0 { + return []E12{}, errors.New("invalid input size") + } + + res := make([]E12, n) + num := make([]E12, n) + denum := make([]E12, n) + + for i := 0; i < n; i++ { + num[i].C0.Set(&x[i]) + num[i].C1.SetOne() + denum[i].C0.Set(&x[i]) + denum[i].C1.SetOne().Neg(&denum[i].C1) + } + + denum = BatchInvertE12(denum) // costs 1 inverse + + for i := 0; i < n; i++ { + res[i].Mul(&num[i], &denum[i]) + } + + return res, nil +} + +// Select is conditional move. +// If cond = 0, it sets z to caseZ and returns it. otherwise caseNz. +func (z *E12) Select(cond int, caseZ *E12, caseNz *E12) *E12 { + //Might be able to save a nanosecond or two by an aggregate implementation + + z.C0.Select(cond, &caseZ.C0, &caseNz.C0) + z.C1.Select(cond, &caseZ.C1, &caseNz.C1) + + return z +} + +// Div divides an element in E12 by an element in E12 +func (z *E12) Div(x *E12, y *E12) *E12 { + var r E12 + r.Inverse(y).Mul(x, &r) + return z.Set(&r) +} diff --git a/prover/maths/field/fext/e12_pairing.go b/prover/maths/field/fext/e12_pairing.go new file mode 100644 index 000000000..209791448 --- /dev/null +++ b/prover/maths/field/fext/e12_pairing.go @@ -0,0 +1,107 @@ +package fext + +func (z *E12) nSquare(n int) { + for i := 0; i < n; i++ { + z.CyclotomicSquare(z) + } +} + +func (z *E12) nSquareCompressed(n int) { + for i := 0; i < n; i++ { + z.CyclotomicSquareCompressed(z) + } +} + +// Expt set z to x^t in E12 and return z +func (z *E12) Expt(x *E12) *E12 { + // const tAbsVal uint64 = 9586122913090633729 + // tAbsVal in binary: 1000010100001000110000000000000000000000000000000000000000000001 + // drop the low 46 bits (all 0 except the least significant bit): 100001010000100011 = 136227 + // Shortest addition chains can be found at https://wwwhomes.uni-bielefeld.de/achim/addition_chain.html + + var result, x33 E12 + + // a shortest addition chain for 136227 + result.Set(x) + result.nSquare(5) + result.Mul(&result, x) + x33.Set(&result) + result.nSquare(7) + result.Mul(&result, &x33) + result.nSquare(4) + result.Mul(&result, x) + result.CyclotomicSquare(&result) + result.Mul(&result, x) + + // the remaining 46 bits + result.nSquareCompressed(46) + result.DecompressKarabina(&result) + result.Mul(&result, x) + + z.Set(&result) + return z +} + +// MulBy034 multiplication by sparse element (c0,0,0,c3,c4,0) +func (z *E12) MulBy034(c0, c3, c4 *E2) *E12 { + + var a, b, d E6 + + a.MulByE2(&z.C0, c0) + + b.Set(&z.C1) + b.MulBy01(c3, c4) + + var d0 E2 + d0.Add(c0, c3) + d.Add(&z.C0, &z.C1) + d.MulBy01(&d0, c4) + + z.C1.Add(&a, &b).Neg(&z.C1).Add(&z.C1, &d) + z.C0.MulByNonResidue(&b).Add(&z.C0, &a) + + return z +} + +// MulBy34 multiplication by sparse element (1,0,0,c3,c4,0) +func (z *E12) MulBy34(c3, c4 *E2) *E12 { + + var a, b, d E6 + + a.Set(&z.C0) + + b.Set(&z.C1) + b.MulBy01(c3, c4) + + var d0 E2 + d0.SetOne().Add(&d0, c3) + d.Add(&z.C0, &z.C1) + d.MulBy01(&d0, c4) + + z.C1.Add(&a, &b).Neg(&z.C1).Add(&z.C1, &d) + z.C0.MulByNonResidue(&b).Add(&z.C0, &a) + + return z +} + +// MulBy01234 multiplies z by an E12 sparse element of the form (x0, x1, x2, x3, x4, 0) +func (z *E12) MulBy01234(x *[5]E2) *E12 { + var c1, a, b, c, z0, z1 E6 + c0 := &E6{B0: x[0], B1: x[1], B2: x[2]} + c1.B0 = x[3] + c1.B1 = x[4] + a.Add(&z.C0, &z.C1) + b.Add(c0, &c1) + a.Mul(&a, &b) + b.Mul(&z.C0, c0) + c.Set(&z.C1).MulBy01(&x[3], &x[4]) + z1.Sub(&a, &b) + z1.Sub(&z1, &c) + z0.MulByNonResidue(&c) + z0.Add(&z0, &b) + + z.C0 = z0 + z.C1 = z1 + + return z +} diff --git a/prover/maths/field/fext/e12_test.go b/prover/maths/field/fext/e12_test.go new file mode 100644 index 000000000..3406e7b30 --- /dev/null +++ b/prover/maths/field/fext/e12_test.go @@ -0,0 +1,569 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +// ------------------------------------------------------------ +// tests + +func TestE12Serialization(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE12() + + properties.Property("[BLS12-377] SetBytes(Bytes()) should stay constant", prop.ForAll( + func(a *E12) bool { + var b E12 + buf := a.Bytes() + if err := b.SetBytes(buf[:]); err != nil { + return false + } + return a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestE12ReceiverIsOperand(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE12() + genB := GenE12() + + properties.Property("[BLS12-377] Having the receiver as operand (addition) should output the same result", prop.ForAll( + func(a, b *E12) bool { + var c, d E12 + d.Set(a) + c.Add(a, b) + a.Add(a, b) + b.Add(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (sub) should output the same result", prop.ForAll( + func(a, b *E12) bool { + var c, d E12 + d.Set(a) + c.Sub(a, b) + a.Sub(a, b) + b.Sub(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul) should output the same result", prop.ForAll( + func(a, b *E12) bool { + var c, d E12 + d.Set(a) + c.Mul(a, b) + a.Mul(a, b) + b.Mul(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (square) should output the same result", prop.ForAll( + func(a *E12) bool { + var b E12 + b.Square(a) + a.Square(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (double) should output the same result", prop.ForAll( + func(a *E12) bool { + var b E12 + b.Double(a) + a.Double(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Inverse) should output the same result", prop.ForAll( + func(a *E12) bool { + var b E12 + b.Inverse(a) + a.Inverse(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Cyclotomic square) should output the same result", prop.ForAll( + func(a *E12) bool { + var b E12 + b.CyclotomicSquare(a) + a.CyclotomicSquare(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Conjugate) should output the same result", prop.ForAll( + func(a *E12) bool { + var b E12 + b.Conjugate(a) + a.Conjugate(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Frobenius) should output the same result", prop.ForAll( + func(a *E12) bool { + var b E12 + b.Frobenius(a) + a.Frobenius(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (FrobeniusSquare) should output the same result", prop.ForAll( + func(a *E12) bool { + var b E12 + b.FrobeniusSquare(a) + a.FrobeniusSquare(a) + return a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestE12Ops(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE12() + genB := GenE12() + genExp := GenFp() + + properties.Property("[BLS12-377] sub & add should leave an element invariant", prop.ForAll( + func(a, b *E12) bool { + var c E12 + c.Set(a) + c.Add(&c, b).Sub(&c, b) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] mul & inverse should leave an element invariant", prop.ForAll( + func(a, b *E12) bool { + var c, d E12 + d.Inverse(b) + c.Set(a) + c.Mul(&c, b).Mul(&c, &d) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] inverse twice should leave an element invariant", prop.ForAll( + func(a *E12) bool { + var b E12 + b.Inverse(a).Inverse(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] square and mul should output the same result", prop.ForAll( + func(a *E12) bool { + var b, c E12 + b.Mul(a, a) + c.Square(a) + return b.Equal(&c) + }, + genA, + )) + + properties.Property("[BLS12-377] a + pi(a), a-pi(a) should be real", prop.ForAll( + func(a *E12) bool { + var b, c, d E12 + var e, f, g E6 + b.Conjugate(a) + c.Add(a, &b) + d.Sub(a, &b) + e.Double(&a.C0) + f.Double(&a.C1) + return c.C1.Equal(&g) && d.C0.Equal(&g) && e.Equal(&c.C0) && f.Equal(&d.C1) + }, + genA, + )) + + properties.Property("[BLS12-377] Torus-based Compress/decompress E12 elements in the cyclotomic subgroup", prop.ForAll( + func(a *E12) bool { + var b E12 + b.Conjugate(a) + a.Inverse(a) + b.Mul(&b, a) + a.FrobeniusSquare(&b).Mul(a, &b) + + c, _ := a.CompressTorus() + d := c.DecompressTorus() + return a.Equal(&d) + }, + genA, + )) + + properties.Property("[BLS12-377] Torus-based batch Compress/decompress E12 elements in the cyclotomic subgroup", prop.ForAll( + func(a, e, f *E12) bool { + var b E12 + b.Conjugate(a) + a.Inverse(a) + b.Mul(&b, a) + a.FrobeniusSquare(&b).Mul(a, &b) + + e.CyclotomicSquare(a) + f.CyclotomicSquare(e) + + c, _ := BatchCompressTorus([]E12{*a, *e, *f}) + d, _ := BatchDecompressTorus(c) + return a.Equal(&d[0]) && e.Equal(&d[1]) && f.Equal(&d[2]) + }, + genA, + genA, + genA, + )) + + properties.Property("[BLS12-377] pi**12=id", prop.ForAll( + func(a *E12) bool { + var b E12 + b.Frobenius(a). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b). + Frobenius(&b) + return b.Equal(a) + }, + genA, + )) + + properties.Property("[BLS12-377] (pi**2)**6=id", prop.ForAll( + func(a *E12) bool { + var b E12 + b.FrobeniusSquare(a). + FrobeniusSquare(&b). + FrobeniusSquare(&b). + FrobeniusSquare(&b). + FrobeniusSquare(&b). + FrobeniusSquare(&b) + return b.Equal(a) + }, + genA, + )) + + properties.Property("[BLS12-377] cyclotomic square (Granger-Scott) and square should be the same in the cyclotomic subgroup", prop.ForAll( + func(a *E12) bool { + var b, c, d E12 + b.Conjugate(a) + a.Inverse(a) + b.Mul(&b, a) + a.FrobeniusSquare(&b).Mul(a, &b) + c.Square(a) + d.CyclotomicSquare(a) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("[BLS12-377] compressed cyclotomic square (Karabina) and square should be the same in the cyclotomic subgroup", prop.ForAll( + func(a *E12) bool { + var _a, b, c, d, _c, _d E12 + _a.SetOne().Double(&_a) + + // put a and _a in the cyclotomic subgroup + // a (g3 != 0 probably) + b.Conjugate(a) + a.Inverse(a) + b.Mul(&b, a) + a.FrobeniusSquare(&b).Mul(a, &b) + // _a (g3 == 0) + b.Conjugate(&_a) + _a.Inverse(&_a) + b.Mul(&b, &_a) + _a.FrobeniusSquare(&b).Mul(&_a, &b) + + // case g3 != 0 + c.Square(a) + d.CyclotomicSquareCompressed(a).DecompressKarabina(&d) + + // case g3 == 0 + _c.Square(&_a) + _d.CyclotomicSquareCompressed(&_a).DecompressKarabina(&_d) + + return c.Equal(&d) + }, + genA, + )) + + properties.Property("[BLS12-377] batch decompress and individual decompress (Karabina) should be the same", prop.ForAll( + func(a *E12) bool { + var _a, b E12 + _a.SetOne().Double(&_a) + + // put a and _a in the cyclotomic subgroup + // a (g3 !=0 probably) + b.Conjugate(a) + a.Inverse(a) + b.Mul(&b, a) + a.FrobeniusSquare(&b).Mul(a, &b) + // _a (g3 == 0) + b.Conjugate(&_a) + _a.Inverse(&_a) + b.Mul(&b, &_a) + _a.FrobeniusSquare(&b).Mul(&_a, &b) + + var a2, a4, a17 E12 + a2.Set(&_a) + a4.Set(a) + a17.Set(a) + a2.nSquareCompressed(2) // case g3 == 0 + a4.nSquareCompressed(4) + a17.nSquareCompressed(17) + batch := BatchDecompressKarabina([]E12{a2, a4, a17}) + a2.DecompressKarabina(&a2) + a4.DecompressKarabina(&a4) + a17.DecompressKarabina(&a17) + + return a2.Equal(&batch[0]) && a4.Equal(&batch[1]) && a17.Equal(&batch[2]) + }, + genA, + )) + + properties.Property("[BLS12-377] Exp and CyclotomicExp results must be the same in the cyclotomic subgroup", prop.ForAll( + func(a *E12, e fp.Element) bool { + var b, c, d E12 + // put in the cyclo subgroup + b.Conjugate(a) + a.Inverse(a) + b.Mul(&b, a) + a.FrobeniusSquare(&b).Mul(a, &b) + + var _e big.Int + k := new(big.Int).SetUint64(12) + e.Exp(e, k) + e.BigInt(&_e) + + c.Exp(*a, &_e) + d.CyclotomicExp(*a, &_e) + + return c.Equal(&d) + }, + genA, + genExp, + )) + + properties.Property("[BLS12-377] Frobenius of x in E12 should be equal to x^q", prop.ForAll( + func(a *E12) bool { + var b, c E12 + q := fp.Modulus() + b.Frobenius(a) + c.Exp(*a, q) + return c.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] FrobeniusSquare of x in E12 should be equal to x^(q^2)", prop.ForAll( + func(a *E12) bool { + var b, c E12 + q := fp.Modulus() + b.FrobeniusSquare(a) + c.Exp(*a, q).Exp(c, q) + return c.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +// ------------------------------------------------------------ +// benches + +func BenchmarkE12Add(b *testing.B) { + var a, c E12 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Add(&a, &c) + } +} + +func BenchmarkE12Sub(b *testing.B) { + var a, c E12 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Sub(&a, &c) + } +} + +func BenchmarkE12Mul(b *testing.B) { + var a, c E12 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Mul(&a, &c) + } +} + +func BenchmarkE12Cyclosquare(b *testing.B) { + var a E12 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.CyclotomicSquare(&a) + } +} + +func BenchmarkE12Square(b *testing.B) { + var a E12 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Square(&a) + } +} + +func BenchmarkE12Inverse(b *testing.B) { + var a E12 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Inverse(&a) + } +} + +func BenchmarkE12Conjugate(b *testing.B) { + var a E12 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Conjugate(&a) + } +} + +func BenchmarkE12Frobenius(b *testing.B) { + var a E12 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Frobenius(&a) + } +} + +func BenchmarkE12FrobeniusSquare(b *testing.B) { + var a E12 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.FrobeniusSquare(&a) + } +} + +func BenchmarkE12Expt(b *testing.B) { + var a E12 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Expt(&a) + } +} + +func TestE12Div(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + properties := gopter.NewProperties(parameters) + + genA := GenE12() + genB := GenE12() + + properties.Property("[BLS12-377] dividing then multiplying by the same element does nothing", prop.ForAll( + func(a, b *E12) bool { + var c E12 + c.Div(a, b) + c.Mul(&c, b) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} diff --git a/prover/maths/field/fext/e2.go b/prover/maths/field/fext/e2.go new file mode 100644 index 000000000..3863d5025 --- /dev/null +++ b/prover/maths/field/fext/e2.go @@ -0,0 +1,305 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "math/big" +) + +// E2 is a degree two finite field extension of fp.Element +type E2 struct { + A0, A1 fp.Element +} + +// Equal returns true if z equals x, false otherwise +func (z *E2) Equal(x *E2) bool { + return z.A0.Equal(&x.A0) && z.A1.Equal(&x.A1) +} + +// Bits +// TODO @gbotrel fixme this shouldn't return a E2 +func (z *E2) Bits() E2 { + r := E2{} + r.A0 = z.A0.Bits() + r.A1 = z.A1.Bits() + return r +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *E2) Cmp(x *E2) int { + if a1 := z.A1.Cmp(&x.A1); a1 != 0 { + return a1 + } + return z.A0.Cmp(&x.A0) +} + +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *E2) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + if z.A1.IsZero() { + return z.A0.LexicographicallyLargest() + } + return z.A1.LexicographicallyLargest() +} + +// SetString sets a E2 element from strings +func (z *E2) SetString(s1, s2 string) (*E2, error) { + _, err := z.A0.SetString(s1) + if err != nil { + return nil, err + } + + _, err = z.A1.SetString(s2) + if err != nil { + return nil, err + } + return z, nil +} + +// SetZero sets an E2 elmt to zero +func (z *E2) SetZero() *E2 { + z.A0.SetZero() + z.A1.SetZero() + return z +} + +// Set sets an E2 from x +func (z *E2) Set(x *E2) *E2 { + z.A0 = x.A0 + z.A1 = x.A1 + return z +} + +// SetOne sets z to 1 in Montgomery form and returns z +func (z *E2) SetOne() *E2 { + z.A0.SetOne() + z.A1.SetZero() + return z +} + +// SetRandom sets a0 and a1 to random values +func (z *E2) SetRandom() (*E2, error) { + if _, err := z.A0.SetRandom(); err != nil { + return nil, err + } + if _, err := z.A1.SetRandom(); err != nil { + return nil, err + } + return z, nil +} + +// IsZero returns true if z is zero, false otherwise +func (z *E2) IsZero() bool { + return z.A0.IsZero() && z.A1.IsZero() +} + +// IsOne returns true if z is one, false otherwise +func (z *E2) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + +// Add adds two elements of E2 +func (z *E2) Add(x, y *E2) *E2 { + addE2(z, x, y) + return z +} + +// Sub subtracts two elements of E2 +func (z *E2) Sub(x, y *E2) *E2 { + subE2(z, x, y) + return z +} + +// Double doubles an E2 element +func (z *E2) Double(x *E2) *E2 { + doubleE2(z, x) + return z +} + +// Neg negates an E2 element +func (z *E2) Neg(x *E2) *E2 { + negE2(z, x) + return z +} + +// String implements Stringer interface for fancy printing +func (z *E2) String() string { + return z.A0.String() + "+" + z.A1.String() + "*u" +} + +// MulByElement multiplies an element in E2 by an element in fp +func (z *E2) MulByElement(x *E2, y *fp.Element) *E2 { + var yCopy fp.Element + yCopy.Set(y) + z.A0.Mul(&x.A0, &yCopy) + z.A1.Mul(&x.A1, &yCopy) + return z +} + +// Conjugate conjugates an element in E2 +func (z *E2) Conjugate(x *E2) *E2 { + z.A0 = x.A0 + z.A1.Neg(&x.A1) + return z +} + +// Halve sets z to z / 2 +func (z *E2) Halve() { + z.A0.Halve() + z.A1.Halve() +} + +// Legendre returns the Legendre symbol of z +func (z *E2) Legendre() int { + var n fp.Element + z.norm(&n) + return n.Legendre() +} + +// Exp sets z=xᵏ (mod q²) and returns it +func (z *E2) Exp(x E2, k *big.Int) *E2 { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q²) == (x⁻¹)ᵏ (mod q²) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = bigIntPool.Get().(*big.Int) + defer bigIntPool.Put(e) + e.Neg(k) + } + + z.SetOne() + b := e.Bytes() + for i := 0; i < len(b); i++ { + w := b[i] + for j := 0; j < 8; j++ { + z.Square(z) + if (w & (0b10000000 >> j)) != 0 { + z.Mul(z, &x) + } + } + } + + return z +} + +// Sqrt sets z to the square root of and returns z +// The function does not test whether the square root +// exists or not, it's up to the caller to call +// Legendre beforehand. +// cf https://eprint.iacr.org/2012/685.pdf (algo 10) +func (z *E2) Sqrt(x *E2) *E2 { + + // precomputation + var b, c, d, e, f, x0 E2 + var _b, o fp.Element + + // c must be a non square (works for p=1 mod 12 hence 1 mod 4, only bls377 has such a p currently) + c.A1.SetOne() + + q := fp.Modulus() + var exp, one big.Int + one.SetUint64(1) + exp.Set(q).Sub(&exp, &one).Rsh(&exp, 1) + d.Exp(c, &exp) + e.Mul(&d, &c).Inverse(&e) + f.Mul(&d, &c).Square(&f) + + // computation + exp.Rsh(&exp, 1) + b.Exp(*x, &exp) + b.norm(&_b) + o.SetOne() + if _b.Equal(&o) { + x0.Square(&b).Mul(&x0, x) + _b.Set(&x0.A0).Sqrt(&_b) + z.Conjugate(&b).MulByElement(z, &_b) + return z + } + x0.Square(&b).Mul(&x0, x).Mul(&x0, &f) + _b.Set(&x0.A0).Sqrt(&_b) + z.Conjugate(&b).MulByElement(z, &_b).Mul(z, &e) + + return z +} + +// BatchInvertE2 returns a new slice with every element in a inverted. +// It uses Montgomery batch inversion trick. +// +// if a[i] == 0, returns result[i] = a[i] +func BatchInvertE2(a []E2) []E2 { + res := make([]E2, len(a)) + if len(a) == 0 { + return res + } + + zeroes := make([]bool, len(a)) + var accumulator E2 + accumulator.SetOne() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes[i] = true + continue + } + res[i].Set(&accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes[i] { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +// Select is conditional move. +// If cond = 0, it sets z to caseZ and returns it. otherwise caseNz. +func (z *E2) Select(cond int, caseZ *E2, caseNz *E2) *E2 { + //Might be able to save a nanosecond or two by an aggregate implementation + + z.A0.Select(cond, &caseZ.A0, &caseNz.A0) + z.A1.Select(cond, &caseZ.A1, &caseNz.A1) + + return z +} + +// Div divides an element in E2 by an element in E2 +func (z *E2) Div(x *E2, y *E2) *E2 { + var r E2 + r.Inverse(y).Mul(x, &r) + return z.Set(&r) +} diff --git a/prover/maths/field/fext/e2_bls377.go b/prover/maths/field/fext/e2_bls377.go new file mode 100644 index 000000000..694b73034 --- /dev/null +++ b/prover/maths/field/fext/e2_bls377.go @@ -0,0 +1,117 @@ +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fext + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" +) + +// Mul sets z to the E2-product of x,y, returns z +func (z *E2) Mul(x, y *E2) *E2 { + var a, b, c fp.Element + a.Add(&x.A0, &x.A1) + b.Add(&y.A0, &y.A1) + a.Mul(&a, &b) + b.Mul(&x.A0, &y.A0) + c.Mul(&x.A1, &y.A1) + z.A1.Sub(&a, &b).Sub(&z.A1, &c) + fp.MulBy5(&c) + z.A0.Sub(&b, &c) + return z +} + +// Square sets z to the E2-product of x,x returns z +func (z *E2) Square(x *E2) *E2 { + //algo 22 https://eprint.iacr.org/2010/354.pdf + var c0, c2 fp.Element + c0.Add(&x.A0, &x.A1) + c2.Neg(&x.A1) + fp.MulBy5(&c2) + c2.Add(&c2, &x.A0) + + c0.Mul(&c0, &c2) // (x1+x2)*(x1+(u**2)x2) + c2.Mul(&x.A0, &x.A1).Double(&c2) + z.A1 = c2 + c2.Double(&c2) + z.A0.Add(&c0, &c2) + + return z +} + +// MulByNonResidue multiplies a E2 by (0,1) +func (z *E2) MulByNonResidue(x *E2) *E2 { + a := x.A0 + b := x.A1 // fetching x.A1 in the function below is slower + fp.MulBy5(&b) + z.A0.Neg(&b) + z.A1 = a + return z +} + +// MulByNonResidueInv multiplies a E2 by (0,1)^{-1} +func (z *E2) MulByNonResidueInv(x *E2) *E2 { + //z.A1.MulByNonResidueInv(&x.A0) + a := x.A1 + fiveinv := fp.Element{ + 330620507644336508, + 9878087358076053079, + 11461392860540703536, + 6973035786057818995, + 8846909097162646007, + 104838758629667239, + } + z.A1.Mul(&x.A0, &fiveinv).Neg(&z.A1) + z.A0 = a + return z +} + +// Inverse sets z to the E2-inverse of x, returns z +func (z *E2) Inverse(x *E2) *E2 { + // Algorithm 8 from https://eprint.iacr.org/2010/354.pdf + //var a, b, t0, t1, tmp fp.Element + var t0, t1, tmp fp.Element + a := &x.A0 // creating the buffers a, b is faster than querying &x.A0, &x.A1 in the functions call below + b := &x.A1 + t0.Square(a) + t1.Square(b) + tmp.Set(&t1) + fp.MulBy5(&tmp) + t0.Add(&t0, &tmp) + t1.Inverse(&t0) + z.A0.Mul(a, &t1) + z.A1.Mul(b, &t1).Neg(&z.A1) + + return z +} + +// norm sets x to the norm of z +func (z *E2) norm(x *fp.Element) { + var tmp fp.Element + x.Square(&z.A1) + tmp.Set(x) + fp.MulBy5(&tmp) + x.Square(&z.A0).Add(x, &tmp) +} + +// MulBybTwistCurveCoeff multiplies by 1/(0,1) +func (z *E2) MulBybTwistCurveCoeff(x *E2) *E2 { + + var res E2 + res.A0.Set(&x.A1) + res.A1.MulByNonResidueInv(&x.A0) + z.Set(&res) + + return z +} diff --git a/prover/maths/field/fext/e2_fallback.go b/prover/maths/field/fext/e2_fallback.go new file mode 100644 index 000000000..fd8d703f4 --- /dev/null +++ b/prover/maths/field/fext/e2_fallback.go @@ -0,0 +1,37 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +func addE2(z, x, y *E2) { + z.A0.Add(&x.A0, &y.A0) + z.A1.Add(&x.A1, &y.A1) +} + +func subE2(z, x, y *E2) { + z.A0.Sub(&x.A0, &y.A0) + z.A1.Sub(&x.A1, &y.A1) +} + +func doubleE2(z, x *E2) { + z.A0.Double(&x.A0) + z.A1.Double(&x.A1) +} + +func negE2(z, x *E2) { + z.A0.Neg(&x.A0) + z.A1.Neg(&x.A1) +} diff --git a/prover/maths/field/fext/e2_fallback_new.go b/prover/maths/field/fext/e2_fallback_new.go new file mode 100644 index 000000000..7567bc99d --- /dev/null +++ b/prover/maths/field/fext/e2_fallback_new.go @@ -0,0 +1,37 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +func addE2New(z, x, y *Element) { + z.A0.Add(&x.A0, &y.A0) + z.A1.Add(&x.A1, &y.A1) +} + +func subE2New(z, x, y *Element) { + z.A0.Sub(&x.A0, &y.A0) + z.A1.Sub(&x.A1, &y.A1) +} + +func doubleE2New(z, x *Element) { + z.A0.Double(&x.A0) + z.A1.Double(&x.A1) +} + +func negE2New(z, x *Element) { + z.A0.Neg(&x.A0) + z.A1.Neg(&x.A1) +} diff --git a/prover/maths/field/fext/e2_test.go b/prover/maths/field/fext/e2_test.go new file mode 100644 index 000000000..6f36b3c7a --- /dev/null +++ b/prover/maths/field/fext/e2_test.go @@ -0,0 +1,531 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +import ( + "crypto/rand" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +// ------------------------------------------------------------ +// tests + +const ( + nbFuzzShort = 10 + nbFuzz = 50 +) + +func TestE2ReceiverIsOperand(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE2() + genB := GenE2() + genfp := GenFp() + + properties.Property("[BLS12-377] Having the receiver as operand (addition) should output the same result", prop.ForAll( + func(a, b *E2) bool { + var c, d E2 + d.Set(a) + c.Add(a, b) + a.Add(a, b) + b.Add(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (sub) should output the same result", prop.ForAll( + func(a, b *E2) bool { + var c, d E2 + d.Set(a) + c.Sub(a, b) + a.Sub(a, b) + b.Sub(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul) should output the same result", prop.ForAll( + func(a, b *E2) bool { + var c, d E2 + d.Set(a) + c.Mul(a, b) + a.Mul(a, b) + b.Mul(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (square) should output the same result", prop.ForAll( + func(a *E2) bool { + var b E2 + b.Square(a) + a.Square(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (neg) should output the same result", prop.ForAll( + func(a *E2) bool { + var b E2 + b.Neg(a) + a.Neg(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (double) should output the same result", prop.ForAll( + func(a *E2) bool { + var b E2 + b.Double(a) + a.Double(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul by non residue) should output the same result", prop.ForAll( + func(a *E2) bool { + var b E2 + b.MulByNonResidue(a) + a.MulByNonResidue(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul by non residue inverse) should output the same result", prop.ForAll( + func(a *E2) bool { + var b E2 + b.MulByNonResidueInv(a) + a.MulByNonResidueInv(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Inverse) should output the same result", prop.ForAll( + func(a *E2) bool { + var b E2 + b.Inverse(a) + a.Inverse(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Conjugate) should output the same result", prop.ForAll( + func(a *E2) bool { + var b E2 + b.Conjugate(a) + a.Conjugate(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul by element) should output the same result", prop.ForAll( + func(a *E2, b fp.Element) bool { + var c E2 + c.MulByElement(a, &b) + a.MulByElement(a, &b) + return a.Equal(&c) + }, + genA, + genfp, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Sqrt) should output the same result", prop.ForAll( + func(a *E2) bool { + var b, c, d, s E2 + + s.Square(a) + a.Set(&s) + b.Set(&s) + + a.Sqrt(a) + b.Sqrt(&b) + + c.Square(a) + d.Square(&b) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestE2MulMaxed(t *testing.T) { + // let's pick a and b, with maxed A0 and A1 + var a, b E2 + fpMaxValue := fp.Element{ + 9586122913090633729, + 1660523435060625408, + 2230234197602682880, + 1883307231910630287, + 14284016967150029115, + 121098312706494698, + } + fpMaxValue[0]-- + + a.A0 = fpMaxValue + a.A1 = fpMaxValue + b.A0 = fpMaxValue + b.A1 = fpMaxValue + + var c, d E2 + d.Inverse(&b) + c.Set(&a) + c.Mul(&c, &b).Mul(&c, &d) + if !c.Equal(&a) { + t.Fatal("mul with max fp failed") + } +} + +func TestE2Ops(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE2() + genB := GenE2() + genfp := GenFp() + + properties.Property("[BLS12-377] sub & add should leave an element invariant", prop.ForAll( + func(a, b *E2) bool { + var c E2 + c.Set(a) + c.Add(&c, b).Sub(&c, b) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] mul & inverse should leave an element invariant", prop.ForAll( + func(a, b *E2) bool { + var c, d E2 + d.Inverse(b) + c.Set(a) + c.Mul(&c, b).Mul(&c, &d) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] BatchInvertE2 should output the same result as Inverse", prop.ForAll( + func(a, b, c *E2) bool { + + batch := BatchInvertE2([]E2{*a, *b, *c}) + a.Inverse(a) + b.Inverse(b) + c.Inverse(c) + return a.Equal(&batch[0]) && b.Equal(&batch[1]) && c.Equal(&batch[2]) + }, + genA, + genA, + genA, + )) + + properties.Property("[BLS12-377] inverse twice should leave an element invariant", prop.ForAll( + func(a *E2) bool { + var b E2 + b.Inverse(a).Inverse(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] neg twice should leave an element invariant", prop.ForAll( + func(a *E2) bool { + var b E2 + b.Neg(a).Neg(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] square and mul should output the same result", prop.ForAll( + func(a *E2) bool { + var b, c E2 + b.Mul(a, a) + c.Square(a) + return b.Equal(&c) + }, + genA, + )) + + properties.Property("[BLS12-377] MulByElement MulByElement inverse should leave an element invariant", prop.ForAll( + func(a *E2, b fp.Element) bool { + var c E2 + var d fp.Element + d.Inverse(&b) + c.MulByElement(a, &b).MulByElement(&c, &d) + return c.Equal(a) + }, + genA, + genfp, + )) + + properties.Property("[BLS12-377] Double and mul by 2 should output the same result", prop.ForAll( + func(a *E2) bool { + var b E2 + var c fp.Element + c.SetUint64(2) + b.Double(a) + a.MulByElement(a, &c) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Mulbynonres mulbynonresinv should leave the element invariant", prop.ForAll( + func(a *E2) bool { + var b E2 + b.MulByNonResidue(a).MulByNonResidueInv(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] a + pi(a), a-pi(a) should be real", prop.ForAll( + func(a *E2) bool { + var b, c, d E2 + var e, f fp.Element + b.Conjugate(a) + c.Add(a, &b) + d.Sub(a, &b) + e.Double(&a.A0) + f.Double(&a.A1) + return c.A1.IsZero() && d.A0.IsZero() && e.Equal(&c.A0) && f.Equal(&d.A1) + }, + genA, + )) + + properties.Property("[BLS12-377] Legendre on square should output 1", prop.ForAll( + func(a *E2) bool { + var b E2 + b.Square(a) + c := b.Legendre() + return c == 1 + }, + genA, + )) + + properties.Property("[BLS12-377] square(sqrt) should leave an element invariant", prop.ForAll( + func(a *E2) bool { + var b, c, d, e E2 + b.Square(a) + c.Sqrt(&b) + d.Square(&c) + e.Neg(a) + return (c.Equal(a) || c.Equal(&e)) && d.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] neg(E2) == neg(E2.A0, E2.A1)", prop.ForAll( + func(a *E2) bool { + var b, c E2 + b.Neg(a) + c.A0.Neg(&a.A0) + c.A1.Neg(&a.A1) + return c.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Cmp and LexicographicallyLargest should be consistent", prop.ForAll( + func(a *E2) bool { + var negA E2 + negA.Neg(a) + cmpResult := a.Cmp(&negA) + lResult := a.LexicographicallyLargest() + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +// ------------------------------------------------------------ +// benches + +func BenchmarkE2Add(b *testing.B) { + var a, c E2 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Add(&a, &c) + } +} + +func BenchmarkE2Sub(b *testing.B) { + var a, c E2 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Sub(&a, &c) + } +} + +func BenchmarkE2Mul(b *testing.B) { + var a, c E2 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Mul(&a, &c) + } +} + +func BenchmarkE2MulByElement(b *testing.B) { + var a E2 + var c fp.Element + _, _ = c.SetRandom() + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.MulByElement(&a, &c) + } +} + +func BenchmarkE2Square(b *testing.B) { + var a E2 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Square(&a) + } +} + +func BenchmarkE2Sqrt(b *testing.B) { + var a E2 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Sqrt(&a) + } +} + +func BenchmarkE2Exp(b *testing.B) { + var x E2 + _, _ = x.SetRandom() + b1, _ := rand.Int(rand.Reader, fp.Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Exp(x, b1) + } +} + +func BenchmarkE2Inverse(b *testing.B) { + var a E2 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Inverse(&a) + } +} + +func BenchmarkE2MulNonRes(b *testing.B) { + var a E2 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.MulByNonResidue(&a) + } +} + +func BenchmarkE2MulNonResInv(b *testing.B) { + var a E2 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.MulByNonResidueInv(&a) + } +} + +func BenchmarkE2Conjugate(b *testing.B) { + var a E2 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Conjugate(&a) + } +} + +func TestE2Div(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + properties := gopter.NewProperties(parameters) + + genA := GenE2() + genB := GenE2() + + properties.Property("[BLS12-377] dividing then multiplying by the same element does nothing", prop.ForAll( + func(a, b *E2) bool { + var c E2 + c.Div(a, b) + c.Mul(&c, b) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} diff --git a/prover/maths/field/fext/e2new.go b/prover/maths/field/fext/e2new.go new file mode 100644 index 000000000..e85f0766f --- /dev/null +++ b/prover/maths/field/fext/e2new.go @@ -0,0 +1,317 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/linea-monorepo/prover/maths/field" + "math/big" + "math/rand" +) + +const noQNR = 11 + +// Element is a degree two finite field extension of fr.Element +type Element struct { + A0, A1 fr.Element +} + +// Equal returns true if z equals x, false otherwise +func (z *Element) Equal(x *Element) bool { + return z.A0.Equal(&x.A0) && z.A1.Equal(&x.A1) +} + +// Bits +// TODO @gbotrel fixme this shouldn't return a Element +func (z *Element) Bits() Element { + r := Element{} + r.A0 = z.A0.Bits() + r.A1 = z.A1.Bits() + return r +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Element) Cmp(x *Element) int { + if a1 := z.A1.Cmp(&x.A1); a1 != 0 { + return a1 + } + return z.A0.Cmp(&x.A0) +} + +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + if z.A1.IsZero() { + return z.A0.LexicographicallyLargest() + } + return z.A1.LexicographicallyLargest() +} + +// SetString sets a Element element from strings +func (z *Element) SetString(s1, s2 string) (*Element, error) { + _, err := z.A0.SetString(s1) + if err != nil { + return z, err + } + + _, err = z.A1.SetString(s2) + if err != nil { + return z, err + } + return z, nil +} + +// SetZero sets an Element elmt to zero +func (z *Element) SetZero() *Element { + z.A0.SetZero() + z.A1.SetZero() + return z +} + +// Set sets an Element from x +func (z *Element) Set(x *Element) *Element { + z.A0 = x.A0 + z.A1 = x.A1 + return z +} + +// SetOne sets z to 1 in Montgomery form and returns z +func (z *Element) SetOne() *Element { + z.A0.SetOne() + z.A1.SetZero() + return z +} + +// SetRandom sets a0 and a1 to random values +func (z *Element) SetRandom() (*Element, error) { + if _, err := z.A0.SetRandom(); err != nil { + return nil, err + } + if _, err := z.A1.SetRandom(); err != nil { + return nil, err + } + return z, nil +} + +// IsZero returns true if z is zero, false otherwise +func (z *Element) IsZero() bool { + return z.A0.IsZero() && z.A1.IsZero() +} + +// IsOne returns true if z is one, false otherwise +func (z *Element) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + +// Add adds two elements of Element +func (z *Element) Add(x, y *Element) *Element { + addE2New(z, x, y) + return z +} + +// Sub subtracts two elements of Element +func (z *Element) Sub(x, y *Element) *Element { + subE2New(z, x, y) + return z +} + +// Double doubles an Element element +func (z *Element) Double(x *Element) *Element { + doubleE2New(z, x) + return z +} + +// Neg negates an Element element +func (z *Element) Neg(x *Element) *Element { + negE2New(z, x) + return z +} + +// String implements Stringer interface for fancy printing +func (z *Element) String() string { + return z.A0.String() + "+" + z.A1.String() + "*u" +} + +// MulByElement multiplies an element in Element by an element in fp +func (z *Element) MulByElement(x *Element, y *fr.Element) *Element { + var yCopy fr.Element + yCopy.Set(y) + z.A0.Mul(&x.A0, &yCopy) + z.A1.Mul(&x.A1, &yCopy) + return z +} + +// Conjugate conjugates an element in Element +func (z *Element) Conjugate(x *Element) *Element { + z.A0 = x.A0 + z.A1.Neg(&x.A1) + return z +} + +// Halve sets z to z / 2 +func (z *Element) Halve() { + z.A0.Halve() + z.A1.Halve() +} + +// Legendre returns the Legendre symbol of z +func (z *Element) Legendre() int { + var n fr.Element + z.norm(&n) + return n.Legendre() +} + +// Exp sets z=xᵏ (mod q²) and returns it +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q²) == (x⁻¹)ᵏ (mod q²) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = bigIntPool.Get().(*big.Int) + defer bigIntPool.Put(e) + e.Neg(k) + } + + z.SetOne() + b := e.Bytes() + for i := 0; i < len(b); i++ { + w := b[i] + for j := 0; j < 8; j++ { + z.Square(z) + if (w & (0b10000000 >> j)) != 0 { + z.Mul(z, &x) + } + } + } + + return z +} + +// Sqrt sets z to the square root of and returns z +// The function does not test whether the square root +// exists or not, it's up to the caller to call +// Legendre beforehand. +// cf https://eprint.iacr.org/2012/685.pdf (algo 10) +func (z *Element) Sqrt(x *Element) *Element { + + // precomputation + var b, c, d, e, f, x0 Element + var _b, o fr.Element + + // c must be a non square (works for p=1 mod 12 hence 1 mod 4, only bls377 has such a p currently) + c.A1.SetOne() + + q := fp.Modulus() + var exp, one big.Int + one.SetUint64(1) + exp.Set(q).Sub(&exp, &one).Rsh(&exp, 1) + d.Exp(c, &exp) + e.Mul(&d, &c).Inverse(&e) + f.Mul(&d, &c).Square(&f) + + // computation + exp.Rsh(&exp, 1) + b.Exp(*x, &exp) + b.norm(&_b) + o.SetOne() + if _b.Equal(&o) { + x0.Square(&b).Mul(&x0, x) + _b.Set(&x0.A0).Sqrt(&_b) + z.Conjugate(&b).MulByElement(z, &_b) + return z + } + x0.Square(&b).Mul(&x0, x).Mul(&x0, &f) + _b.Set(&x0.A0).Sqrt(&_b) + z.Conjugate(&b).MulByElement(z, &_b).Mul(z, &e) + + return z +} + +// BatchInvertE2New returns a new slice with every element in a inverted. +// It uses Montgomery batch inversion trick. +// +// if a[i] == 0, returns result[i] = a[i] +func BatchInvertE2New(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := make([]bool, len(a)) + var accumulator Element + accumulator.SetOne() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes[i] = true + continue + } + res[i].Set(&accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes[i] { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +// Select is conditional move. +// If cond = 0, it sets z to caseZ and returns it. otherwise caseNz. +func (z *Element) Select(cond int, caseZ *Element, caseNz *Element) *Element { + //Might be able to save a nanosecond or two by an aggregate implementation + + z.A0.Select(cond, &caseZ.A0, &caseNz.A0) + z.A1.Select(cond, &caseZ.A1, &caseNz.A1) + + return z +} + +// Div divides an element in Element by an element in Element +func (z *Element) Div(x *Element, y *Element) *Element { + var r Element + r.Inverse(y).Mul(x, &r) + return z.Set(&r) +} + +func PseudoRand(rng *rand.Rand) Element { + x := field.PseudoRand(rng) + y := field.PseudoRand(rng) + result := new(Element).SetZero() + return *result.Add(result, &Element{x, y}) +} diff --git a/prover/maths/field/fext/e2new_bls377.go b/prover/maths/field/fext/e2new_bls377.go new file mode 100644 index 000000000..028adcb86 --- /dev/null +++ b/prover/maths/field/fext/e2new_bls377.go @@ -0,0 +1,110 @@ +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fext + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +// Mul sets z to the Element-product of x,y, returns z +func (z *Element) Mul(x, y *Element) *Element { + var a, b, c fr.Element + a.Add(&x.A0, &x.A1) + b.Add(&y.A0, &y.A1) + a.Mul(&a, &b) + b.Mul(&x.A0, &y.A0) + c.Mul(&x.A1, &y.A1) + z.A1.Sub(&a, &b).Sub(&z.A1, &c) + MulByQnr(&c) + z.A0.Sub(&b, &c) + return z +} + +// Square sets z to the Element-product of x,x returns z +func (z *Element) Square(x *Element) *Element { + //algo 22 https://eprint.iacr.org/2010/354.pdf + z.Mul(x, x) + return z +} + +// MulByNonResidue multiplies a Element by (0,1) +func (z *Element) MulByNonResidue(x *Element) *Element { + a := x.A0 + b := x.A1 // fetching x.A1 in the function below is slower + MulByQnr(&b) + z.A0.Neg(&b) + z.A1 = a + return z +} + +// MulByNonResidueInv multiplies a Element by (0,1)^{-1} +func (z *Element) MulByNonResidueInv(x *Element) *Element { + //z.A1.MulByNonResidueInv(&x.A0) + a := x.A1 + qnr := new(fr.Element).SetInt64(noQNR) + var qnrInv fr.Element + qnrInv.Inverse(qnr) + z.A1.Mul(&x.A0, &qnrInv).Neg(&z.A1) + z.A0 = a + return z +} + +// Inverse sets z to the Element-inverse of x, returns z +func (z *Element) Inverse(x *Element) *Element { + // Algorithm 8 from https://eprint.iacr.org/2010/354.pdf + //var a, b, t0, t1, tmp fr.Element + var t0, t1, tmp fr.Element + a := &x.A0 // creating the buffers a, b is faster than querying &x.A0, &x.A1 in the functions call below + b := &x.A1 + t0.Square(a) + t1.Square(b) + tmp.Set(&t1) + MulByQnr(&tmp) + t0.Add(&t0, &tmp) + t1.Inverse(&t0) + z.A0.Mul(a, &t1) + z.A1.Mul(b, &t1).Neg(&z.A1) + + return z +} + +// norm sets x to the norm of z +func (z *Element) norm(x *fr.Element) { + var tmp fr.Element + x.Square(&z.A1) + tmp.Set(x) + MulByQnr(&tmp) + x.Square(&z.A0).Add(x, &tmp) + // A0^2+A1^2*QNR +} +func MulByQnr(x *fr.Element) { + old := new(fr.Element).Set(x) + for i := 0; i < noQNR-1; i++ { + x.Add(x, old) + } +} + +/* +// MulBybTwistCurveCoeff multiplies by 1/(0,1) +func (z *Element) MulBybTwistCurveCoeff(x *Element) *Element { + + var res Element + res.A0.Set(&x.A1) + res.A1.MulByNonResidueInv(&x.A0) + z.Set(&res) + + return z +} +*/ diff --git a/prover/maths/field/fext/e2new_test.go b/prover/maths/field/fext/e2new_test.go new file mode 100644 index 000000000..f58c8880e --- /dev/null +++ b/prover/maths/field/fext/e2new_test.go @@ -0,0 +1,600 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +// ------------------------------------------------------------ +// tests + +func GenFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + + if _, err := elmt.SetRandom(); err != nil { + panic(err) + } + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +func GenE2New() gopter.Gen { + return gopter.CombineGens( + GenFr(), + GenFr(), + ).Map(func(values []interface{}) *Element { + return &Element{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} + +func TestE2NewNewReceiverIsOperand(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE2New() + genB := GenE2New() + GenFr := GenFr() + + properties.Property("[BLS12-377] Having the receiver as operand (addition) should output the same result", prop.ForAll( + func(a, b *Element) bool { + var c, d Element + d.Set(a) + c.Add(a, b) + a.Add(a, b) + b.Add(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (sub) should output the same result", prop.ForAll( + func(a, b *Element) bool { + var c, d Element + d.Set(a) + c.Sub(a, b) + a.Sub(a, b) + b.Sub(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul) should output the same result", prop.ForAll( + func(a, b *Element) bool { + var c, d Element + d.Set(a) + c.Mul(a, b) + a.Mul(a, b) + b.Mul(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (square) should output the same result", prop.ForAll( + func(a *Element) bool { + var b Element + b.Square(a) + a.Square(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (neg) should output the same result", prop.ForAll( + func(a *Element) bool { + var b Element + b.Neg(a) + a.Neg(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (double) should output the same result", prop.ForAll( + func(a *Element) bool { + var b Element + b.Double(a) + a.Double(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul by non residue) should output the same result", prop.ForAll( + func(a *Element) bool { + var b Element + b.MulByNonResidue(a) + a.MulByNonResidue(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul by non residue inverse) should output the same result", prop.ForAll( + func(a *Element) bool { + var b Element + b.MulByNonResidueInv(a) + a.MulByNonResidueInv(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Inverse) should output the same result", prop.ForAll( + func(a *Element) bool { + var b Element + b.Inverse(a) + a.Inverse(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Conjugate) should output the same result", prop.ForAll( + func(a *Element) bool { + var b Element + b.Conjugate(a) + a.Conjugate(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul by element) should output the same result", prop.ForAll( + func(a *Element, b fr.Element) bool { + var c Element + c.MulByElement(a, &b) + a.MulByElement(a, &b) + return a.Equal(&c) + }, + genA, + GenFr, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Sqrt) should output the same result", prop.ForAll( + func(a *Element) bool { + var b, c, d, s Element + + s.Square(a) + a.Set(&s) + b.Set(&s) + + a.Sqrt(a) + b.Sqrt(&b) + + c.Square(a) + d.Square(&b) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +/* +func TestE2NewMulMaxed(t *testing.T) { + // let's pick a and b, with maxed A0 and A1 + var a, b Element + fpMaxValue := fp.Element{ + 9586122913090633729, + 1660523435060625408, + 2230234197602682880, + 1883307231910630287, + 14284016967150029115, + 121098312706494698, + } + fpMaxValue[0]-- + + a.A0 = fpMaxValue + a.A1 = fpMaxValue + b.A0 = fpMaxValue + b.A1 = fpMaxValue + + var c, d Element + d.Inverse(&b) + c.Set(&a) + c.Mul(&c, &b).Mul(&c, &d) + if !c.Equal(&a) { + t.Fatal("mul with max fp failed") + } +}*/ + +func TestE2NewOps(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE2New() + genB := GenE2New() + GenFr := GenFr() + + properties.Property("[BLS12-377] sub & add should leave an element invariant", prop.ForAll( + func(a, b *Element) bool { + var c Element + c.Set(a) + c.Add(&c, b).Sub(&c, b) // a+b-b + return c.Equal(a) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] mul & inverse should leave an element invariant", prop.ForAll( + func(a, b *Element) bool { + var c, d Element + d.Inverse(b) + c.Set(a) + c.Mul(&c, b).Mul(&c, &d) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] BatchInvertE2New should output the same result as Inverse", prop.ForAll( + func(a, b, c *Element) bool { + + batch := BatchInvertE2New([]Element{*a, *b, *c}) + a.Inverse(a) + b.Inverse(b) + c.Inverse(c) + return a.Equal(&batch[0]) && b.Equal(&batch[1]) && c.Equal(&batch[2]) + }, + genA, + genA, + genA, + )) + + properties.Property("[BLS12-377] inverse twice should leave an element invariant", prop.ForAll( + func(a *Element) bool { + var b Element + b.Inverse(a).Inverse(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] neg twice should leave an element invariant", prop.ForAll( + func(a *Element) bool { + var b Element + b.Neg(a).Neg(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] square and mul should output the same result", prop.ForAll( + func(a *Element) bool { + var b, c Element + b.Mul(a, a) + c.Square(a) + res := b.Equal(&c) + fmt.Println("square and mul should output the same result———", res) + return b.Equal(&c) + }, + genA, + )) + + properties.Property("[BLS12-377] MulByElement MulByElement inverse should leave an element invariant", prop.ForAll( + func(a *Element, b fr.Element) bool { + var c Element + var d fr.Element + d.Inverse(&b) + c.MulByElement(a, &b).MulByElement(&c, &d) + return c.Equal(a) + }, + genA, + GenFr, + )) + + properties.Property("[BLS12-377] Double and mul by 2 should output the same result", prop.ForAll( + func(a *Element) bool { + var b Element + var c fr.Element + c.SetUint64(2) + b.Double(a) + a.MulByElement(a, &c) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Mulbynonres mulbynonresinv should leave the element invariant", prop.ForAll( + func(a *Element) bool { + var b Element + b.MulByNonResidue(a).MulByNonResidueInv(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] a + pi(a), a-pi(a) should be real", prop.ForAll( + func(a *Element) bool { + var b, c, d Element + var e, f fr.Element + b.Conjugate(a) + c.Add(a, &b) + d.Sub(a, &b) + e.Double(&a.A0) + f.Double(&a.A1) + return c.A1.IsZero() && d.A0.IsZero() && e.Equal(&c.A0) && f.Equal(&d.A1) + }, + genA, + )) + + properties.Property("[BLS12-377] Legendre on square should output 1", prop.ForAll( + func(a *Element) bool { + var b Element + b.Mul(a, a) + c := b.Legendre() + fmt.Println("DEBUG MESSAGE ", c) + return c == 1 + }, + genA, + )) + + /* + properties.Property("[BLS12-377] square(sqrt) should leave an element invariant", prop.ForAll( + func(a *Element) bool { + var b, c, d, e Element + b.Square(a) + c.Sqrt(&b) + d.Square(&c) + e.Neg(a) + return (c.Equal(a) || c.Equal(&e)) && d.Equal(&b) + }, + genA, + )) + + */ + + properties.Property("[BLS12-377] neg(Element) == neg(Element.A0, Element.A1)", prop.ForAll( + func(a *Element) bool { + var b, c Element + b.Neg(a) + c.A0.Neg(&a.A0) + c.A1.Neg(&a.A1) + return c.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Cmp and LexicographicallyLargest should be consistent", prop.ForAll( + func(a *Element) bool { + var negA Element + negA.Neg(a) + cmpResult := a.Cmp(&negA) + lResult := a.LexicographicallyLargest() + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +// ------------------------------------------------------------ +// benches + +func BenchmarkE2NewAdd(b *testing.B) { + var a, c Element + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Add(&a, &c) + } +} + +func BenchmarkE2NewSub(b *testing.B) { + var a, c Element + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Sub(&a, &c) + } +} + +func BenchmarkE2NewMul(b *testing.B) { + var a, c Element + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Mul(&a, &c) + } +} + +func BenchmarkE2NewMulByElement(b *testing.B) { + var a Element + var c fr.Element + _, _ = c.SetRandom() + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.MulByElement(&a, &c) + } +} + +func BenchmarkE2NewSquare(b *testing.B) { + var a Element + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Square(&a) + } +} + +func BenchmarkE2NewSqrt(b *testing.B) { + var a Element + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Sqrt(&a) + } +} + +func BenchmarkE2NewExp(b *testing.B) { + var x Element + _, _ = x.SetRandom() + b1, _ := rand.Int(rand.Reader, fp.Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Exp(x, b1) + } +} + +func BenchmarkE2NewInverse(b *testing.B) { + var a Element + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Inverse(&a) + } +} + +func BenchmarkE2NewMulNonRes(b *testing.B) { + var a Element + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.MulByNonResidue(&a) + } +} + +func BenchmarkE2NewMulNonResInv(b *testing.B) { + var a Element + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.MulByNonResidueInv(&a) + } +} + +func BenchmarkE2NewConjugate(b *testing.B) { + var a Element + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Conjugate(&a) + } +} + +func TestE2NewDiv(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + properties := gopter.NewProperties(parameters) + + genA := GenE2New() + genB := GenE2New() + + properties.Property("[BLS12-377] dividing then multiplying by the same element does nothing", prop.ForAll( + func(a, b *Element) bool { + var c Element + c.Div(a, b) + c.Mul(&c, b) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestBasicQnrEquality(t *testing.T) { + var xCoeff = [...]int64{1, 3, 3} + var yCoeff = [...]int64{2, 1, 2} + var zCoeff = [...]int64{-43, -2, -35} + var wCoeff = [...]int64{4, 6, 12} + + for i := 0; i < len(xCoeff); i++ { + x := new(fr.Element).SetInt64(xCoeff[i]) + y := new(fr.Element).SetInt64(yCoeff[i]) + z := new(fr.Element).SetInt64(zCoeff[i]) + w := new(fr.Element).SetInt64(wCoeff[i]) + a := Element{*x, *y} + c := new(Element) + c.Mul(&a, &a) + d := Element{*z, *w} + res := c.Equal(&d) + fmt.Println(res) + } + +} + +func TestBasicQnrEqualitySquare(t *testing.T) { + var xCoeff = [...]int64{1, 3, 3} + var yCoeff = [...]int64{2, 1, 2} + var zCoeff = [...]int64{-43, -2, -35} + var wCoeff = [...]int64{4, 6, 12} + for i := 0; i < len(xCoeff); i++ { + x := new(fr.Element).SetInt64(xCoeff[i]) + y := new(fr.Element).SetInt64(yCoeff[i]) + z := new(fr.Element).SetInt64(zCoeff[i]) + w := new(fr.Element).SetInt64(wCoeff[i]) + a := Element{*x, *y} + c := new(Element) + c.Square(&a) + d := Element{*z, *w} + res := c.Equal(&d) + //fmt.Println(c.A0.String(), c.A1.String()) + fmt.Println(c.String()) + fmt.Println(d.String()) + fmt.Println(res) + } + +} diff --git a/prover/maths/field/fext/e6.go b/prover/maths/field/fext/e6.go new file mode 100644 index 000000000..06fa573bc --- /dev/null +++ b/prover/maths/field/fext/e6.go @@ -0,0 +1,343 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +// E6 is a degree three finite field extension of fp2 +type E6 struct { + B0, B1, B2 E2 +} + +// Equal returns true if z equals x, false otherwise +func (z *E6) Equal(x *E6) bool { + return z.B0.Equal(&x.B0) && z.B1.Equal(&x.B1) && z.B2.Equal(&x.B2) +} + +// SetString sets a E6 elmt from stringf +func (z *E6) SetString(s1, s2, s3, s4, s5, s6 string) *E6 { + z.B0.SetString(s1, s2) + z.B1.SetString(s3, s4) + z.B2.SetString(s5, s6) + return z +} + +// Set Sets a E6 elmt form another E6 elmt +func (z *E6) Set(x *E6) *E6 { + z.B0 = x.B0 + z.B1 = x.B1 + z.B2 = x.B2 + return z +} + +// SetOne sets z to 1 in Montgomery form and returns z +func (z *E6) SetOne() *E6 { + *z = E6{} + z.B0.A0.SetOne() + return z +} + +// SetRandom set z to a random elmt +func (z *E6) SetRandom() (*E6, error) { + if _, err := z.B0.SetRandom(); err != nil { + return nil, err + } + if _, err := z.B1.SetRandom(); err != nil { + return nil, err + } + if _, err := z.B2.SetRandom(); err != nil { + return nil, err + } + return z, nil +} + +// IsZero returns true if z is zero, false otherwise +func (z *E6) IsZero() bool { + return z.B0.IsZero() && z.B1.IsZero() && z.B2.IsZero() +} + +// IsOne returns true if z is one, false otherwise +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() && z.B2.IsZero() +} + +// Add adds two elements of E6 +func (z *E6) Add(x, y *E6) *E6 { + z.B0.Add(&x.B0, &y.B0) + z.B1.Add(&x.B1, &y.B1) + z.B2.Add(&x.B2, &y.B2) + return z +} + +// Neg negates the E6 number +func (z *E6) Neg(x *E6) *E6 { + z.B0.Neg(&x.B0) + z.B1.Neg(&x.B1) + z.B2.Neg(&x.B2) + return z +} + +// Sub subtracts two elements of E6 +func (z *E6) Sub(x, y *E6) *E6 { + z.B0.Sub(&x.B0, &y.B0) + z.B1.Sub(&x.B1, &y.B1) + z.B2.Sub(&x.B2, &y.B2) + return z +} + +// Double doubles an element in E6 +func (z *E6) Double(x *E6) *E6 { + z.B0.Double(&x.B0) + z.B1.Double(&x.B1) + z.B2.Double(&x.B2) + return z +} + +// String puts E6 elmt in string form +func (z *E6) String() string { + return z.B0.String() + "+(" + z.B1.String() + ")*v+(" + z.B2.String() + ")*v**2" +} + +// MulByNonResidue mul x by (0,1,0) +func (z *E6) MulByNonResidue(x *E6) *E6 { + z.B2, z.B1, z.B0 = x.B1, x.B0, x.B2 + z.B0.MulByNonResidue(&z.B0) + return z +} + +// MulByE2 multiplies an element in E6 by an element in E2 +func (z *E6) MulByE2(x *E6, y *E2) *E6 { + var yCopy E2 + yCopy.Set(y) + z.B0.Mul(&x.B0, &yCopy) + z.B1.Mul(&x.B1, &yCopy) + z.B2.Mul(&x.B2, &yCopy) + return z +} + +// MulBy12 multiplication by sparse element (0,b1,b2) +func (x *E6) MulBy12(b1, b2 *E2) *E6 { + var t1, t2, c0, tmp, c1, c2 E2 + t1.Mul(&x.B1, b1) + t2.Mul(&x.B2, b2) + c0.Add(&x.B1, &x.B2) + tmp.Add(b1, b2) + c0.Mul(&c0, &tmp) + c0.Sub(&c0, &t1) + c0.Sub(&c0, &t2) + c0.MulByNonResidue(&c0) + c1.Add(&x.B0, &x.B1) + c1.Mul(&c1, b1) + c1.Sub(&c1, &t1) + tmp.MulByNonResidue(&t2) + c1.Add(&c1, &tmp) + tmp.Add(&x.B0, &x.B2) + c2.Mul(b2, &tmp) + c2.Sub(&c2, &t2) + c2.Add(&c2, &t1) + + x.B0 = c0 + x.B1 = c1 + x.B2 = c2 + + return x +} + +// MulBy01 multiplication by sparse element (c0,c1,0) +func (z *E6) MulBy01(c0, c1 *E2) *E6 { + + var a, b, tmp, t0, t1, t2 E2 + + a.Mul(&z.B0, c0) + b.Mul(&z.B1, c1) + + tmp.Add(&z.B1, &z.B2) + t0.Mul(c1, &tmp) + t0.Sub(&t0, &b) + t0.MulByNonResidue(&t0) + t0.Add(&t0, &a) + + tmp.Add(&z.B0, &z.B2) + t2.Mul(c0, &tmp) + t2.Sub(&t2, &a) + t2.Add(&t2, &b) + + t1.Add(c0, c1) + tmp.Add(&z.B0, &z.B1) + t1.Mul(&t1, &tmp) + t1.Sub(&t1, &a) + t1.Sub(&t1, &b) + + z.B0.Set(&t0) + z.B1.Set(&t1) + z.B2.Set(&t2) + + return z +} + +// MulBy1 multiplication of E6 by sparse element (0, c1, 0) +func (z *E6) MulBy1(c1 *E2) *E6 { + + var b, tmp, t0, t1 E2 + b.Mul(&z.B1, c1) + + tmp.Add(&z.B1, &z.B2) + t0.Mul(c1, &tmp) + t0.Sub(&t0, &b) + t0.MulByNonResidue(&t0) + + tmp.Add(&z.B0, &z.B1) + t1.Mul(c1, &tmp) + t1.Sub(&t1, &b) + + z.B0.Set(&t0) + z.B1.Set(&t1) + z.B2.Set(&b) + + return z +} + +// Mul sets z to the E6 product of x,y, returns z +func (z *E6) Mul(x, y *E6) *E6 { + // Algorithm 13 from https://eprint.iacr.org/2010/354.pdf + var t0, t1, t2, c0, c1, c2, tmp E2 + t0.Mul(&x.B0, &y.B0) + t1.Mul(&x.B1, &y.B1) + t2.Mul(&x.B2, &y.B2) + + c0.Add(&x.B1, &x.B2) + tmp.Add(&y.B1, &y.B2) + c0.Mul(&c0, &tmp).Sub(&c0, &t1).Sub(&c0, &t2).MulByNonResidue(&c0).Add(&c0, &t0) + + c1.Add(&x.B0, &x.B1) + tmp.Add(&y.B0, &y.B1) + c1.Mul(&c1, &tmp).Sub(&c1, &t0).Sub(&c1, &t1) + tmp.MulByNonResidue(&t2) + c1.Add(&c1, &tmp) + + tmp.Add(&x.B0, &x.B2) + c2.Add(&y.B0, &y.B2).Mul(&c2, &tmp).Sub(&c2, &t0).Sub(&c2, &t2).Add(&c2, &t1) + + z.B0.Set(&c0) + z.B1.Set(&c1) + z.B2.Set(&c2) + + return z +} + +// Square sets z to the E6 product of x,x, returns z +func (z *E6) Square(x *E6) *E6 { + + // Algorithm 16 from https://eprint.iacr.org/2010/354.pdf + var c4, c5, c1, c2, c3, c0 E2 + c4.Mul(&x.B0, &x.B1).Double(&c4) + c5.Square(&x.B2) + c1.MulByNonResidue(&c5).Add(&c1, &c4) + c2.Sub(&c4, &c5) + c3.Square(&x.B0) + c4.Sub(&x.B0, &x.B1).Add(&c4, &x.B2) + c5.Mul(&x.B1, &x.B2).Double(&c5) + c4.Square(&c4) + c0.MulByNonResidue(&c5).Add(&c0, &c3) + z.B2.Add(&c2, &c4).Add(&z.B2, &c5).Sub(&z.B2, &c3) + z.B0.Set(&c0) + z.B1.Set(&c1) + + return z +} + +// Inverse an element in E6 +// +// if x == 0, sets and returns z = x +func (z *E6) Inverse(x *E6) *E6 { + // Algorithm 17 from https://eprint.iacr.org/2010/354.pdf + // step 9 is wrong in the paper it's t1-t4 + var t0, t1, t2, t3, t4, t5, t6, c0, c1, c2, d1, d2 E2 + t0.Square(&x.B0) + t1.Square(&x.B1) + t2.Square(&x.B2) + t3.Mul(&x.B0, &x.B1) + t4.Mul(&x.B0, &x.B2) + t5.Mul(&x.B1, &x.B2) + c0.MulByNonResidue(&t5).Neg(&c0).Add(&c0, &t0) + c1.MulByNonResidue(&t2).Sub(&c1, &t3) + c2.Sub(&t1, &t4) + t6.Mul(&x.B0, &c0) + d1.Mul(&x.B2, &c1) + d2.Mul(&x.B1, &c2) + d1.Add(&d1, &d2).MulByNonResidue(&d1) + t6.Add(&t6, &d1) + t6.Inverse(&t6) + z.B0.Mul(&c0, &t6) + z.B1.Mul(&c1, &t6) + z.B2.Mul(&c2, &t6) + + return z +} + +// BatchInvertE6 returns a new slice with every element in a inverted. +// It uses Montgomery batch inversion trick. +// +// if a[i] == 0, returns result[i] = a[i] +func BatchInvertE6(a []E6) []E6 { + res := make([]E6, len(a)) + if len(a) == 0 { + return res + } + + zeroes := make([]bool, len(a)) + var accumulator E6 + accumulator.SetOne() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes[i] = true + continue + } + res[i].Set(&accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes[i] { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +// Select is conditional move. +// If cond = 0, it sets z to caseZ and returns it. otherwise caseNz. +func (z *E6) Select(cond int, caseZ *E6, caseNz *E6) *E6 { + //Might be able to save a nanosecond or two by an aggregate implementation + + z.B0.Select(cond, &caseZ.B0, &caseNz.B0) + z.B1.Select(cond, &caseZ.B1, &caseNz.B1) + z.B2.Select(cond, &caseZ.B2, &caseNz.B2) + + return z +} + +// Div divides an element in E6 by an element in E6 +func (z *E6) Div(x *E6, y *E6) *E6 { + var r E6 + r.Inverse(y).Mul(x, &r) + return z.Set(&r) +} diff --git a/prover/maths/field/fext/e6_test.go b/prover/maths/field/fext/e6_test.go new file mode 100644 index 000000000..afe22556b --- /dev/null +++ b/prover/maths/field/fext/e6_test.go @@ -0,0 +1,363 @@ +// Copyright 2020 Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +import ( + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +// ------------------------------------------------------------ +// tests + +func TestE6ReceiverIsOperand(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE6() + genB := GenE6() + genE2 := GenE2() + + properties.Property("[BLS12-377] Having the receiver as operand (addition) should output the same result", prop.ForAll( + func(a, b *E6) bool { + var c, d E6 + d.Set(a) + c.Add(a, b) + a.Add(a, b) + b.Add(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (sub) should output the same result", prop.ForAll( + func(a, b *E6) bool { + var c, d E6 + d.Set(a) + c.Sub(a, b) + a.Sub(a, b) + b.Sub(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul) should output the same result", prop.ForAll( + func(a, b *E6) bool { + var c, d E6 + d.Set(a) + c.Mul(a, b) + a.Mul(a, b) + b.Mul(&d, b) + return a.Equal(b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (square) should output the same result", prop.ForAll( + func(a *E6) bool { + var b E6 + b.Square(a) + a.Square(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (neg) should output the same result", prop.ForAll( + func(a *E6) bool { + var b E6 + b.Neg(a) + a.Neg(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (double) should output the same result", prop.ForAll( + func(a *E6) bool { + var b E6 + b.Double(a) + a.Double(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul by non residue) should output the same result", prop.ForAll( + func(a *E6) bool { + var b E6 + b.MulByNonResidue(a) + a.MulByNonResidue(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (Inverse) should output the same result", prop.ForAll( + func(a *E6) bool { + var b E6 + b.Inverse(a) + a.Inverse(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Having the receiver as operand (mul by E2) should output the same result", prop.ForAll( + func(a *E6, b *E2) bool { + var c E6 + c.MulByE2(a, b) + a.MulByE2(a, b) + return a.Equal(&c) + }, + genA, + genE2, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestE6Ops(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := GenE6() + genB := GenE6() + genE2 := GenE2() + + properties.Property("[BLS12-377] sub & add should leave an element invariant", prop.ForAll( + func(a, b *E6) bool { + var c E6 + c.Set(a) + c.Add(&c, b).Sub(&c, b) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] mul & inverse should leave an element invariant", prop.ForAll( + func(a, b *E6) bool { + var c, d E6 + d.Inverse(b) + c.Set(a) + c.Mul(&c, b).Mul(&c, &d) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.Property("[BLS12-377] inverse twice should leave an element invariant", prop.ForAll( + func(a *E6) bool { + var b E6 + b.Inverse(a).Inverse(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] BatchInvertE6 should output the same result as Inverse", prop.ForAll( + func(a, b, c *E6) bool { + + batch := BatchInvertE6([]E6{*a, *b, *c}) + a.Inverse(a) + b.Inverse(b) + c.Inverse(c) + return a.Equal(&batch[0]) && b.Equal(&batch[1]) && c.Equal(&batch[2]) + }, + genA, + genA, + genA, + )) + + properties.Property("[BLS12-377] neg twice should leave an element invariant", prop.ForAll( + func(a *E6) bool { + var b E6 + b.Neg(a).Neg(&b) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] square and mul should output the same result", prop.ForAll( + func(a *E6) bool { + var b, c E6 + b.Mul(a, a) + c.Square(a) + return b.Equal(&c) + }, + genA, + )) + + properties.Property("[BLS12-377] Double and add twice should output the same result", prop.ForAll( + func(a *E6) bool { + var b E6 + b.Add(a, a) + a.Double(a) + return a.Equal(&b) + }, + genA, + )) + + properties.Property("[BLS12-377] Mul by non residue should be the same as multiplying by (0,1,0)", prop.ForAll( + func(a *E6) bool { + var b, c E6 + b.B1.A0.SetOne() + c.Mul(a, &b) + a.MulByNonResidue(a) + return a.Equal(&c) + }, + genA, + )) + + properties.Property("[BLS12-377] MulByE2 MulByE2 inverse should leave an element invariant", prop.ForAll( + func(a *E6, b *E2) bool { + var c E6 + var d E2 + d.Inverse(b) + c.MulByE2(a, b).MulByE2(&c, &d) + return c.Equal(a) + }, + genA, + genE2, + )) + + properties.Property("[BLS12-377] Mul and MulBy01 should output the same result", prop.ForAll( + func(a *E6, c0, c1 *E2) bool { + var b E6 + b.B0.Set(c0) + b.B1.Set(c1) + b.Mul(&b, a) + a.MulBy01(c0, c1) + return b.Equal(a) + }, + genA, + genE2, + genE2, + )) + + properties.Property("[BLS12-377] Mul and MulBy1 should output the same result", prop.ForAll( + func(a *E6, c1 *E2) bool { + var b E6 + b.B1.Set(c1) + b.Mul(&b, a) + a.MulBy1(c1) + return b.Equal(a) + }, + genA, + genE2, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +// ------------------------------------------------------------ +// benches + +func BenchmarkE6Add(b *testing.B) { + var a, c E6 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Add(&a, &c) + } +} + +func BenchmarkE6Sub(b *testing.B) { + var a, c E6 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Sub(&a, &c) + } +} + +func BenchmarkE6Mul(b *testing.B) { + var a, c E6 + _, _ = a.SetRandom() + _, _ = c.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Mul(&a, &c) + } +} + +func BenchmarkE6Square(b *testing.B) { + var a E6 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Square(&a) + } +} + +func BenchmarkE6Inverse(b *testing.B) { + var a E6 + _, _ = a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.Inverse(&a) + } +} + +func TestE6Div(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + properties := gopter.NewProperties(parameters) + + genA := GenE6() + genB := GenE6() + + properties.Property("[BLS12-377] dividing then multiplying by the same element does nothing", prop.ForAll( + func(a, b *E6) bool { + var c E6 + c.Div(a, b) + c.Mul(&c, b) + return c.Equal(a) + }, + genA, + genB, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} diff --git a/prover/maths/field/fext/element.go b/prover/maths/field/fext/element.go new file mode 100644 index 000000000..5ff220fee --- /dev/null +++ b/prover/maths/field/fext/element.go @@ -0,0 +1,353 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fext + +import ( + "encoding/binary" + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/linea-monorepo/prover/maths/field" + "math/big" + "math/bits" + "strings" + + "github.com/bits-and-blooms/bitset" +) + +const ( + frBytes = 32 // number of bytes needed to represent a Element +) + +func NewElement(v1 uint64, v2 uint64) Element { + z1 := fr.Element{v1} + z1.Mul(&z1, &rSquare) + + z2 := fr.Element{v2} + z2.Mul(&z2, &rSquare) + return Element{z1, z2} +} + +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint64 { + if z.Equal(x) { + return 0 + } else { + return 1 + } +} + +func Zero() Element { + x := field.Zero() + y := field.Zero() + return Element{x, y} +} + +// One returns 1 +func One() Element { + var one Element + one.SetOne() + return one +} + +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := bitset.New(uint(len(a))) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes.Set(uint(i)) + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes.Test(uint(i)) { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [2 * frBytes]byte) { + BigEndian.PutElement(&res, *z) + return +} + +// Marshal returns the value of z as a big-endian byte slice +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// Unmarshal is an alias for SetBytes, it sets z to the value of e. +func (z *Element) Unmarshal(e []byte) { + z.SetBytes(e) +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value, and returns z. +func (z *Element) SetBytes(e []byte) *Element { + if len(e) == 2*frBytes { + // fast path + vect := (*[2 * frBytes]byte)(e) + v1, err1 := fr.BigEndian.Element((*[frBytes]byte)(vect[0:frBytes])) + v2, err2 := fr.BigEndian.Element((*[frBytes]byte)(vect[frBytes : 2*frBytes])) + if err1 == nil && err2 == nil { + *z = Element{v1, v2} + return z + } + } + + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != frBytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[2 * frBytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + +// setBigInt assumes 0 ⩽ v < q +func (z *Element) setBigInt(v1 *big.Int, v2 *big.Int) *Element { + vBits := v1.Bits() + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z.A0[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z.A0[i/2] = uint64(vBits[i]) + } else { + z.A0[i/2] |= uint64(vBits[i]) << 32 + } + } + } + ToMont(&z.A0) + + vBits = v2.Bits() + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z.A1[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z.A1[i/2] = uint64(vBits[i]) + } else { + z.A1[i/2] |= uint64(vBits[i]) << 32 + } + } + } + ToMont(&z.A1) + return z +} + +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + s := z.A0.Text(10) + + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + + s = z.A1.Text(10) + /* + if len(s) <= maxSafeBound { + return []byte(s), nil + } + */ + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[frBytes]byte) (Element, error) + PutElement(*[frBytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[2 * frBytes]byte) (Element, error) { + var z fr.Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + var x fr.Element + x[0] = binary.BigEndian.Uint64((*b)[frBytes+24 : frBytes+32]) + x[1] = binary.BigEndian.Uint64((*b)[frBytes+16 : frBytes+24]) + x[2] = binary.BigEndian.Uint64((*b)[frBytes+8 : frBytes+16]) + x[3] = binary.BigEndian.Uint64((*b)[frBytes+0 : frBytes+8]) + + if !SmallerThanModulus(&z) { + return Element{}, errors.New("invalid fr.Element encoding") + } + + ToMont(&z) + + if !SmallerThanModulus(&x) { + return Element{}, errors.New("invalid fr.Element encoding") + } + ToMont(&x) + + return Element{z, x}, nil +} + +func (bigEndian) PutElement(b *[2 * frBytes]byte, e Element) { + FromMont(&e.A0) + binary.BigEndian.PutUint64((*b)[24:32], e.A0[0]) + binary.BigEndian.PutUint64((*b)[16:24], e.A0[1]) + binary.BigEndian.PutUint64((*b)[8:16], e.A0[2]) + binary.BigEndian.PutUint64((*b)[0:8], e.A0[3]) + + FromMont(&e.A1) + binary.BigEndian.PutUint64((*b)[frBytes+24:frBytes+32], e.A1[0]) + binary.BigEndian.PutUint64((*b)[frBytes+16:frBytes+24], e.A1[1]) + binary.BigEndian.PutUint64((*b)[frBytes+8:frBytes+16], e.A1[2]) + binary.BigEndian.PutUint64((*b)[frBytes+0:frBytes+8], e.A1[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[2 * frBytes]byte) (Element, error) { + var z Element + z.A0[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z.A0[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z.A0[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z.A0[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !SmallerThanModulus(&z.A0) { + return Element{}, errors.New("invalid fr.Element encoding") + } + ToMont(&z.A0) + + z.A1[0] = binary.LittleEndian.Uint64((*b)[frBytes+0 : frBytes+8]) + z.A1[1] = binary.LittleEndian.Uint64((*b)[frBytes+8 : frBytes+16]) + z.A1[2] = binary.LittleEndian.Uint64((*b)[frBytes+16 : frBytes+24]) + z.A1[3] = binary.LittleEndian.Uint64((*b)[frBytes+24 : frBytes+32]) + + ToMont(&z.A1) + return z, nil +} + +func (littleEndian) PutElement(b *[2 * frBytes]byte, e Element) { + FromMont(&e.A0) + binary.LittleEndian.PutUint64((*b)[0:8], e.A0[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e.A0[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e.A0[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e.A0[3]) + + FromMont(&e.A1) + binary.LittleEndian.PutUint64((*b)[frBytes+0:frBytes+8], e.A1[0]) + binary.LittleEndian.PutUint64((*b)[frBytes+8:frBytes+16], e.A1[1]) + binary.LittleEndian.PutUint64((*b)[frBytes+16:frBytes+24], e.A1[2]) + binary.LittleEndian.PutUint64((*b)[frBytes+24:frBytes+32], e.A1[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + +// linearComb z = xC * x + yC * y; +// 0 ≤ x, y < 2²⁵³ +// |xC|, |yC| < 2⁶³ +func (z *Element) linearComb(x *Element, xC int64, y *Element, yC int64) { + var e1, e2 Element + e1.Set(x) + e2.Set(y) + var i int64 + for i = 0; i < xC-1; i++ { + e1.Add(&e1, x) + } + for i = 0; i < yC-1; i++ { + e2.Add(&e1, y) + } + z.Add(&e1, &e2) +} + +func (z *Element) SetFromBase(x *fr.Element) *Element { + z.A0.Set(x) + z.A1.SetZero() + return z +} + +func ExpToInt(z *Element, x Element, k int) *Element { + if k == 0 { + return z.SetOne() + } + + if k < 0 { + x.Inverse(&x) + k = -k + } + + z.Set(&x) + + for i := bits.Len(uint(k)) - 2; i >= 0; i-- { + z.Square(z) + if (k>>i)&1 == 1 { + z.Mul(z, &x) + } + } + + return z +} diff --git a/prover/maths/field/fext/frobenius.go b/prover/maths/field/fext/frobenius.go new file mode 100644 index 000000000..9174b40d3 --- /dev/null +++ b/prover/maths/field/fext/frobenius.go @@ -0,0 +1,227 @@ +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fext + +import "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + +// Frobenius set z to Frobenius(x), return z +func (z *E12) Frobenius(x *E12) *E12 { + // Algorithm 28 from https://eprint.iacr.org/2010/354.pdf (beware typos!) + var t [6]E2 + + // Frobenius acts on fp2 by conjugation + t[0].Conjugate(&x.C0.B0) + t[1].Conjugate(&x.C0.B1) + t[2].Conjugate(&x.C0.B2) + t[3].Conjugate(&x.C1.B0) + t[4].Conjugate(&x.C1.B1) + t[5].Conjugate(&x.C1.B2) + + t[1].MulByNonResidue1Power2(&t[1]) + t[2].MulByNonResidue1Power4(&t[2]) + t[3].MulByNonResidue1Power1(&t[3]) + t[4].MulByNonResidue1Power3(&t[4]) + t[5].MulByNonResidue1Power5(&t[5]) + + z.C0.B0 = t[0] + z.C0.B1 = t[1] + z.C0.B2 = t[2] + z.C1.B0 = t[3] + z.C1.B1 = t[4] + z.C1.B2 = t[5] + + return z +} + +// FrobeniusSquare set z to Frobenius^2(x), and return z +func (z *E12) FrobeniusSquare(x *E12) *E12 { + // Algorithm 29 from https://eprint.iacr.org/2010/354.pdf (beware typos!) + var t [6]E2 + + t[1].MulByNonResidue2Power2(&x.C0.B1) + t[2].MulByNonResidue2Power4(&x.C0.B2) + t[3].MulByNonResidue2Power1(&x.C1.B0) + t[4].MulByNonResidue2Power3(&x.C1.B1) + t[5].MulByNonResidue2Power5(&x.C1.B2) + + z.C0.B0 = x.C0.B0 + z.C0.B1 = t[1] + z.C0.B2 = t[2] + z.C1.B0 = t[3] + z.C1.B1 = t[4] + z.C1.B2 = t[5] + + return z +} + +// MulByNonResidue1Power1 set z=x*(0,1)^(1*(p^1-1)/6) and return z +func (z *E2) MulByNonResidue1Power1(x *E2) *E2 { + // 92949345220277864758624960506473182677953048909283248980960104381795901929519566951595905490535835115111760994353 + b := fp.Element{ + 7981638599956744862, + 11830407261614897732, + 6308788297503259939, + 10596665404780565693, + 11693741422477421038, + 61545186993886319, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue1Power2 set z=x*(0,1)^(2*(p^1-1)/6) and return z +func (z *E2) MulByNonResidue1Power2(x *E2) *E2 { + // 80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410946 + b := fp.Element{ + 6382252053795993818, + 1383562296554596171, + 11197251941974877903, + 6684509567199238270, + 6699184357838251020, + 19987743694136192, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue1Power3 set z=x*(0,1)^(3*(p^1-1)/6) and return z +func (z *E2) MulByNonResidue1Power3(x *E2) *E2 { + // 216465761340224619389371505802605247630151569547285782856803747159100223055385581585702401816380679166954762214499 + b := fp.Element{ + 10965161018967488287, + 18251363109856037426, + 7036083669251591763, + 16109345360066746489, + 4679973768683352764, + 96952949334633821, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue1Power4 set z=x*(0,1)^(4*(p^1-1)/6) and return z +func (z *E2) MulByNonResidue1Power4(x *E2) *E2 { + // 80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410945 + b := fp.Element{ + 15766275933608376691, + 15635974902606112666, + 1934946774703877852, + 18129354943882397960, + 15437979634065614942, + 101285514078273488, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue1Power5 set z=x*(0,1)^(5*(p^1-1)/6) and return z +func (z *E2) MulByNonResidue1Power5(x *E2) *E2 { + // 123516416119946754630746545296132064952198520638002533875843642777304321125866014634106496325844844051843001220146 + b := fp.Element{ + 2983522419010743425, + 6420955848241139694, + 727295371748331824, + 5512679955286180796, + 11432976419915483342, + 35407762340747501, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue2Power1 set z=x*(0,1)^(1*(p^2-1)/6) and return z +func (z *E2) MulByNonResidue2Power1(x *E2) *E2 { + // 80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410946 + b := fp.Element{ + 6382252053795993818, + 1383562296554596171, + 11197251941974877903, + 6684509567199238270, + 6699184357838251020, + 19987743694136192, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue2Power2 set z=x*(0,1)^(2*(p^2-1)/6) and return z +func (z *E2) MulByNonResidue2Power2(x *E2) *E2 { + // 80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410945 + b := fp.Element{ + 15766275933608376691, + 15635974902606112666, + 1934946774703877852, + 18129354943882397960, + 15437979634065614942, + 101285514078273488, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue2Power3 set z=x*(0,1)^(3*(p^2-1)/6) and return z +func (z *E2) MulByNonResidue2Power3(x *E2) *E2 { + // 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458176 + b := fp.Element{ + 9384023879812382873, + 14252412606051516495, + 9184438906438551565, + 11444845376683159689, + 8738795276227363922, + 81297770384137296, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue2Power4 set z=x*(0,1)^(4*(p^2-1)/6) and return z +func (z *E2) MulByNonResidue2Power4(x *E2) *E2 { + // 258664426012969093929703085429980814127835149614277183275038967946009968870203535512256352201271898244626862047231 + b := fp.Element{ + 3203870859294639911, + 276961138506029237, + 9479726329337356593, + 13645541738420943632, + 7584832609311778094, + 101110569012358506, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} + +// MulByNonResidue2Power5 set z=x*(0,1)^(5*(p^2-1)/6) and return z +func (z *E2) MulByNonResidue2Power5(x *E2) *E2 { + // 258664426012969093929703085429980814127835149614277183275038967946009968870203535512256352201271898244626862047232 + b := fp.Element{ + 12266591053191808654, + 4471292606164064357, + 295287422898805027, + 2200696361737783943, + 17292781406793965788, + 19812798628221209, + } + z.A0.Mul(&x.A0, &b) + z.A1.Mul(&x.A1, &b) + return z +} diff --git a/prover/maths/field/fext/generators_test.go b/prover/maths/field/fext/generators_test.go new file mode 100644 index 000000000..547770b34 --- /dev/null +++ b/prover/maths/field/fext/generators_test.go @@ -0,0 +1,50 @@ +package fext + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/leanovate/gopter" +) + +// Fp generates an Fp element +func GenFp() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fp.Element + + if _, err := elmt.SetRandom(); err != nil { + panic(err) + } + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// E2 generates an E2 elmt +func GenE2() gopter.Gen { + return gopter.CombineGens( + GenFp(), + GenFp(), + ).Map(func(values []interface{}) *E2 { + return &E2{A0: values[0].(fp.Element), A1: values[1].(fp.Element)} + }) +} + +// E6 generates an E6 elmt +func GenE6() gopter.Gen { + return gopter.CombineGens( + GenE2(), + GenE2(), + GenE2(), + ).Map(func(values []interface{}) *E6 { + return &E6{B0: *values[0].(*E2), B1: *values[1].(*E2), B2: *values[2].(*E2)} + }) +} + +// E12 generates an E6 elmt +func GenE12() gopter.Gen { + return gopter.CombineGens( + GenE6(), + GenE6(), + ).Map(func(values []interface{}) *E12 { + return &E12{C0: *values[0].(*E6), C1: *values[1].(*E6)} + }) +} diff --git a/prover/maths/field/fext/parameters.go b/prover/maths/field/fext/parameters.go new file mode 100644 index 000000000..67406061d --- /dev/null +++ b/prover/maths/field/fext/parameters.go @@ -0,0 +1,33 @@ +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fext + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +// generator of the curve +var xGen big.Int + +var glvBasis ecc.Lattice + +func init() { + xGen.SetString("9586122913090633729", 10) + _r := fr.Modulus() + ecc.PrecomputeLattice(_r, &xGen, &glvBasis) +} diff --git a/prover/maths/field/fext/temp_functionality.go b/prover/maths/field/fext/temp_functionality.go new file mode 100644 index 000000000..c267fdee1 --- /dev/null +++ b/prover/maths/field/fext/temp_functionality.go @@ -0,0 +1,11 @@ +package fext + +import "github.com/consensys/linea-monorepo/prover/maths/field" + +/* +Currently, this function only sets the first coordinate of the field extension +*/ +func NewFromString(s string) (res Element) { + elem := field.NewFromString(s) + return Element{elem, field.Zero()} +} diff --git a/prover/maths/field/fext/unexportedFr.go b/prover/maths/field/fext/unexportedFr.go new file mode 100644 index 000000000..ff67ca101 --- /dev/null +++ b/prover/maths/field/fext/unexportedFr.go @@ -0,0 +1,107 @@ +package fext + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math/bits" +) + +var rSquare = fr.Element{ + 2726216793283724667, + 14712177743343147295, + 12091039717619697043, + 81024008013859129, +} + +const ( + q0 uint64 = 725501752471715841 + q1 uint64 = 6461107452199829505 + q2 uint64 = 6968279316240510977 + q3 uint64 = 1345280370688173398 +) + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func ToMont(z *fr.Element) *fr.Element { + return z.Mul(z, &rSquare) +} + +func SmallerThanModulus(z *fr.Element) bool { + return (z[3] < q3 || (z[3] == q3 && (z[2] < q2 || (z[2] == q2 && (z[1] < q1 || (z[1] == q1 && (z[0] < q0))))))) +} + +const qInvNeg uint64 = 725501752471715839 + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func FromMontGeneric(z *fr.Element) { + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + + // if z ⩾ q → z -= q + if !SmallerThanModulus(z) { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func FromMont(z *fr.Element) { + FromMontGeneric(z) +}