diff --git a/internal/conv/conv.go b/internal/conv/conv.go index 649a8e931..66485a086 100644 --- a/internal/conv/conv.go +++ b/internal/conv/conv.go @@ -5,6 +5,8 @@ import ( "fmt" "math/big" "strings" + + "golang.org/x/crypto/cryptobyte" ) // BytesLe2Hex returns an hexadecimal string of a number stored in a @@ -138,3 +140,36 @@ func BigInt2Uint64Le(z []uint64, x *big.Int) { z[i] = 0 } } + +// MarshalBinary encodes a value into a byte array in a format readable by UnmarshalBinary. +func MarshalBinary(v cryptobyte.MarshalingValue) ([]byte, error) { + var b cryptobyte.Builder + b.AddValue(v) + return b.Bytes() +} + +// A UnmarshalingValue decodes itself from a cryptobyte.String and advances the pointer. +// Returns true indicating the reading was successful. +type UnmarshalingValue interface { + ReadValue(*cryptobyte.String) bool +} + +// UnmarshalBinary recovers a value from a byte array. +// Returns an error if the recovered value is invalid. +// Any panic raised when calling to ReadValue is recovered and an error is returned instead. +func UnmarshalBinary(v UnmarshalingValue, data []byte) (err error) { + defer func() { + r := recover() + if r != nil { + err = fmt.Errorf("%T failed to unmarshal: %v", v, r) + } + }() + + r := cryptobyte.String(data) + ok := v.ReadValue(&r) + if !ok { + return fmt.Errorf("cannot read %T from input string", v) + } + + return nil +} diff --git a/internal/test/test.go b/internal/test/test.go index 576211a9f..72d37fdf8 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,6 +1,8 @@ package test import ( + "bytes" + "encoding" "errors" "fmt" "strings" @@ -58,3 +60,26 @@ func CheckPanic(f func()) error { f() return hasPanicked } + +func CheckMarshal( + t *testing.T, + x, y interface { + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler + }, +) { + t.Helper() + + want, err := x.MarshalBinary() + CheckNoErr(t, err, fmt.Sprintf("cannot marshal %T = %v", x, x)) + + err = y.UnmarshalBinary(want) + CheckNoErr(t, err, fmt.Sprintf("cannot unmarshal %T from %x", y, want)) + + got, err := y.MarshalBinary() + CheckNoErr(t, err, fmt.Sprintf("cannot marshal %T = %v", y, y)) + + if !bytes.Equal(got, want) { + ReportError(t, got, want, x, y) + } +} diff --git a/tss/rsa/keyshare.go b/tss/rsa/keyshare.go index aa19efda8..ec3a6acf8 100644 --- a/tss/rsa/keyshare.go +++ b/tss/rsa/keyshare.go @@ -3,15 +3,15 @@ package rsa import ( "crypto/rand" "crypto/rsa" - "encoding/binary" "errors" "fmt" "io" - "math" "math/big" "sync" + "github.com/cloudflare/circl/internal/conv" "github.com/cloudflare/circl/zk/qndleq" + "golang.org/x/crypto/cryptobyte" ) // VerifyKeys contains keys used to verify whether a signature share @@ -23,15 +23,17 @@ type VerifyKeys struct { VerifyKey *big.Int } +func (v VerifyKeys) String() string { + return fmt.Sprintf("groupKey: 0x%v verifyKey: 0x%v", + v.GroupKey.Text(16), v.VerifyKey.Text(16)) +} + // KeyShare represents a portion of the key. It can only be used to generate SignShare's. During the dealing phase (when Deal is called), one KeyShare is generated per player. type KeyShare struct { - si *big.Int + share + si *big.Int twoDeltaSi *big.Int // this value is used to marginally speed up SignShare generation in Sign. - Index uint // When KeyShare's are generated they are each assigned an index sequentially - - Players uint - Threshold uint // It stores keys to produce verifiable signature shares. // If it's nil, signature shares are still produced but @@ -42,137 +44,89 @@ type KeyShare struct { } func (kshare KeyShare) String() string { - return fmt.Sprintf("(t,n): (%v,%v) index: %v si: 0x%v", - kshare.Threshold, kshare.Players, kshare.Index, kshare.si.Text(16)) + return fmt.Sprintf("%v si: 0x%v twoDeltaSi: 0x%v vk: {%v}", + kshare.share, kshare.si.Text(16), kshare.twoDeltaSi.Text(16), kshare.vk, + ) } -// MarshalBinary encodes a KeyShare into a byte array in a format readable by UnmarshalBinary. -// Note: Only Index's up to math.MaxUint16 are supported -func (kshare *KeyShare) MarshalBinary() ([]byte, error) { - // The encoding format is - // | Players: uint16 | Threshold: uint16 | Index: uint16 | siLen: uint16 | si: []byte | twoDeltaSiNil: bool | twoDeltaSiLen: uint16 | twoDeltaSi: []byte | - // with all values in big-endian. +func (kshare *KeyShare) MarshalBinary() ([]byte, error) { return conv.MarshalBinary(kshare) } +func (kshare *KeyShare) UnmarshalBinary(b []byte) error { return conv.UnmarshalBinary(kshare, b) } - if kshare.Players > math.MaxUint16 { - return nil, fmt.Errorf("rsa_threshold: keyshare marshall: Players is too big to fit in a uint16") - } +func (kshare *KeyShare) Marshal(b *cryptobyte.Builder) error { + buf := make([]byte, (kshare.ModulusLength+7)/8) + b.AddValue(&kshare.share) + b.AddBytes(kshare.si.FillBytes(buf)) + b.AddBytes(kshare.twoDeltaSi.FillBytes(buf)) - if kshare.Threshold > math.MaxUint16 { - return nil, fmt.Errorf("rsa_threshold: keyshare marshall: Threhsold is too big to fit in a uint16") + isVerifiable := kshare.IsVerifiable() + var flag uint8 + if isVerifiable { + flag = 0x01 } - - if kshare.Index > math.MaxUint16 { - return nil, fmt.Errorf("rsa_threshold: keyshare marshall: Index is too big to fit in a uint16") - } - - players := uint16(kshare.Players) - threshold := uint16(kshare.Threshold) - index := uint16(kshare.Index) - - twoDeltaSiBytes := kshare.twoDeltaSi.Bytes() - twoDeltaSiLen := len(twoDeltaSiBytes) - - if twoDeltaSiLen > math.MaxInt16 { - return nil, fmt.Errorf("rsa_threshold: keyshare marshall: twoDeltaSiBytes is too big to fit it's length in a uint16") - } - - siBytes := kshare.si.Bytes() - - siLength := len(siBytes) - - if siLength == 0 { - siLength = 1 - siBytes = []byte{0} + b.AddUint8(flag) + if isVerifiable { + b.AddBytes(kshare.vk.GroupKey.FillBytes(buf)) + b.AddBytes(kshare.vk.VerifyKey.FillBytes(buf)) } - if siLength > math.MaxInt16 { - return nil, fmt.Errorf("rsa_threshold: keyshare marshall: siBytes is too big to fit it's length in a uint16") - } - - blen := 2 + 2 + 2 + 2 + 2 + 1 + siLength + twoDeltaSiLen - out := make([]byte, blen) - - binary.BigEndian.PutUint16(out[0:2], players) - binary.BigEndian.PutUint16(out[2:4], threshold) - binary.BigEndian.PutUint16(out[4:6], index) - - binary.BigEndian.PutUint16(out[6:8], uint16(siLength)) // okay because of conditions checked above - - copy(out[8:8+siLength], siBytes) - - out[8+siLength] = 1 // twoDeltaSiNil - - binary.BigEndian.PutUint16(out[8+siLength+1:8+siLength+3], uint16(twoDeltaSiLen)) - - copy(out[8+siLength+3:8+siLength+3+twoDeltaSiLen], twoDeltaSiBytes) - - return out, nil + return nil } -// UnmarshalBinary recovers a KeyShare from a slice of bytes, or returns an error if the encoding is invalid. -func (kshare *KeyShare) UnmarshalBinary(data []byte) error { - // The encoding format is - // | Players: uint16 | Threshold: uint16 | Index: uint16 | siLen: uint16 | si: []byte | twoDeltaSiNil: bool | twoDeltaSiLen: uint16 | twoDeltaSi: []byte | - // with all values in big-endian. - if len(data) < 6 { - return fmt.Errorf("rsa_threshold: keyshare unmarshalKeyShareTest failed: data length was too short for reading Players, Threashold, Index") - } - - players := binary.BigEndian.Uint16(data[0:2]) - threshold := binary.BigEndian.Uint16(data[2:4]) - index := binary.BigEndian.Uint16(data[4:6]) - - if len(data[6:]) < 2 { - return fmt.Errorf("rsa_threshold: keyshare unmarshalKeyShareTest failed: data length was too short for reading siLen length") - } - - siLen := binary.BigEndian.Uint16(data[6:8]) - - if siLen == 0 { - return fmt.Errorf("rsa_threshold: keyshare unmarshalKeyShareTest failed: si is a required field but siLen was 0") - } - - if uint16(len(data[8:])) < siLen { - return fmt.Errorf("rsa_threshold: keyshare unmarshalKeyShareTest failed: data length was too short for reading si, needed: %d found: %d", siLen, len(data[8:])) - } - - si := new(big.Int).SetBytes(data[8 : 8+siLen]) - - if len(data[8+siLen:]) < 1 { - return fmt.Errorf("rsa_threshold: keyshare unmarshalKeyShareTest failed: data length was too short for reading twoDeltaSiNil") - } - - if len(data[8+siLen+1:]) < 2 { - return fmt.Errorf("rsa_threshold: keyshare unmarshalKeyShareTest failed: data length was too short for reading twoDeltaSiLen length") - } - - twoDeltaSiLen := binary.BigEndian.Uint16(data[8+siLen+1 : 8+siLen+3]) +func (kshare *KeyShare) ReadValue(r *cryptobyte.String) bool { + var sh share + ok := sh.ReadValue(r) + if !ok { + return false + } + + mlen := int((sh.ModulusLength + 7) / 8) + var siBytes, twoDeltaSiBytes []byte + ok = r.ReadBytes(&siBytes, mlen) && + r.ReadBytes(&twoDeltaSiBytes, mlen) + if !ok { + return false + } + + var isVerifiable uint8 + ok = r.ReadUint8(&isVerifiable) + if !ok { + return false + } + + var vk *VerifyKeys + switch isVerifiable { + case 0: + vk = nil + case 1: + var groupKeyBytes, verifyKeyBytes []byte + ok = r.ReadBytes(&groupKeyBytes, mlen) && + r.ReadBytes(&verifyKeyBytes, mlen) + if !ok { + return false + } - if uint16(len(data[8+siLen+3:])) < twoDeltaSiLen { - return fmt.Errorf("rsa_threshold: keyshare unmarshalKeyShareTest failed: data length was too short for reading twoDeltaSi, needed: %d found: %d", twoDeltaSiLen, len(data[8+siLen+2:])) + vk = &VerifyKeys{ + GroupKey: new(big.Int).SetBytes(groupKeyBytes), + VerifyKey: new(big.Int).SetBytes(verifyKeyBytes), + } + default: + return false } - twoDeltaSi := new(big.Int).SetBytes(data[8+siLen+3 : 8+siLen+3+twoDeltaSiLen]) - - kshare.Players = uint(players) - kshare.Threshold = uint(threshold) - kshare.Index = uint(index) - kshare.si = si - kshare.twoDeltaSi = twoDeltaSi + kshare.share = sh + kshare.si = new(big.Int).SetBytes(siBytes) + kshare.twoDeltaSi = new(big.Int).SetBytes(twoDeltaSiBytes) + kshare.vk = vk - return nil + return true } -// Returns the cached value in twoDeltaSi or if nil, generates 2∆s_i, stores it in twoDeltaSi, and returns it -func (kshare *KeyShare) get2DeltaSi(players int64) *big.Int { - // use the cached value if it exists - if kshare.twoDeltaSi != nil { - return kshare.twoDeltaSi - } +// Returns calculates and returns twoDeltaSi = 2∆s_i mod m. +func (kshare *KeyShare) get2DeltaSi(players int64, m *big.Int) *big.Int { delta := calculateDelta(players) // 2∆s_i // delta << 1 == delta * 2 - delta.Lsh(delta, 1).Mul(delta, kshare.si) + delta.Lsh(delta, 1).Mul(delta, kshare.si).Mod(delta, m) kshare.twoDeltaSi = delta return delta } @@ -182,17 +136,18 @@ func (kshare *KeyShare) get2DeltaSi(players int64) *big.Int { func (kshare *KeyShare) IsVerifiable() bool { return kshare.vk != nil } // VerifyKeys returns a copy of the verification keys used to verify -// signature shares. Returns nil if the key share cannot produce -// verifiable signature shares. +// signature shares. Panics if the key share cannot produce +// verifiable signature shares. Use the [IsVerifiable] method to +// determine whether there are associated verification keys. func (kshare *KeyShare) VerifyKeys() (vk *VerifyKeys) { - if kshare.IsVerifiable() { - vk = &VerifyKeys{ - GroupKey: new(big.Int).Set(kshare.vk.GroupKey), - VerifyKey: new(big.Int).Set(kshare.vk.VerifyKey), - } + if !kshare.IsVerifiable() { + panic(ErrKeyShareNonVerifiable) } - return + return &VerifyKeys{ + GroupKey: new(big.Int).Set(kshare.vk.GroupKey), + VerifyKey: new(big.Int).Set(kshare.vk.VerifyKey), + } } // Sign msg using a KeyShare. msg MUST be padded and hashed. Call PadHash before this method. @@ -203,17 +158,14 @@ func (kshare *KeyShare) VerifyKeys() (vk *VerifyKeys) { // parallel indicates whether the blinding operations should use go routines to operate in parallel. // If parallel is false, blinding will take about 2x longer than nonbinding, otherwise it will take about the same time // (see benchmarks). If randSource is nil, parallel has no effect. parallel should almost always be set to true. -func (kshare *KeyShare) Sign(randSource io.Reader, pub *rsa.PublicKey, digest []byte, parallel bool) (SignShare, error) { +func (kshare *KeyShare) Sign(randSource io.Reader, pub *rsa.PublicKey, digest []byte, parallel bool) (*SignShare, error) { x := &big.Int{} x.SetBytes(digest) - exp := kshare.get2DeltaSi(int64(kshare.Players)) - - var signShare SignShare - signShare.Players = kshare.Players - signShare.Threshold = kshare.Threshold - signShare.Index = kshare.Index + exp := kshare.twoDeltaSi + signShare := new(SignShare) + signShare.share = kshare.share signShare.xi = &big.Int{} if randSource != nil { @@ -226,7 +178,7 @@ func (kshare *KeyShare) Sign(randSource io.Reader, pub *rsa.PublicKey, digest [] r, err := rand.Int(randSource, pub.N) if err != nil { - return SignShare{}, errors.New("rsa_threshold: unable to get random value for blinding") + return nil, errors.New("rsa_threshold: unable to get random value for blinding") } expPlusr := big.Int{} // exp + r @@ -254,7 +206,7 @@ func (kshare *KeyShare) Sign(randSource io.Reader, pub *rsa.PublicKey, digest [] if res == nil { // extremely unlikely, somehow x^r is p or q - return SignShare{}, errors.New("rsa_threshold: no mod inverse") + return nil, errors.New("rsa_threshold: no mod inverse") } if wg != nil { @@ -285,7 +237,7 @@ func (kshare *KeyShare) Sign(randSource io.Reader, pub *rsa.PublicKey, digest [] x4Delta, xiSqr, pub.N, SecParam) if err != nil { - return SignShare{}, err + return nil, err } signShare.proof = proof } diff --git a/tss/rsa/keyshare_test.go b/tss/rsa/keyshare_test.go index 2775767f7..2d6f36d59 100644 --- a/tss/rsa/keyshare_test.go +++ b/tss/rsa/keyshare_test.go @@ -1,12 +1,59 @@ package rsa import ( + "crypto" "crypto/rand" "crypto/rsa" + "fmt" "math/big" "testing" + + "github.com/cloudflare/circl/internal/test" ) +func TestProtocol(t *testing.T) { + const ( + bits = 512 + Players = 10 + Threshold = 5 + ) + + priv, err := GenerateKey(rand.Reader, bits) + pub := &priv.PublicKey + test.CheckNoErr(t, err, fmt.Sprintf("cannot generate keys: %v", err)) + + msg := []byte("Cloudflare!") + hash := crypto.SHA256 + padded, err := PadHash(new(PKCS1v15Padder), hash, pub, msg) + test.CheckNoErr(t, err, fmt.Sprintf("cannot pad message: %v", err)) + + keyShares, err := Deal(rand.Reader, Players, Threshold, priv) + test.CheckNoErr(t, err, fmt.Sprintf("cannot deal key shares: %v", err)) + + test.CheckMarshal(t, &keyShares[0], new(KeyShare)) + + signShares := make([]*SignShare, len(keyShares)) + for i := range keyShares { + signShares[i], err = keyShares[i].Sign(rand.Reader, pub, padded, true) + test.CheckNoErr(t, err, fmt.Sprintf("cannot create signature share: %v", err)) + + err = signShares[i].Verify(pub, keyShares[i].VerifyKeys(), padded) + test.CheckNoErr(t, err, fmt.Sprintf("signature share does not verify: %v", err)) + } + + test.CheckMarshal(t, signShares[0], new(SignShare)) + + signature, err := CombineSignShares(pub, signShares, padded) + test.CheckNoErr(t, err, fmt.Sprintf("cannot create RSA signature: %v", err)) + + hasher := hash.New() + hasher.Write(msg) + hashed := hasher.Sum(nil) + + err = rsa.VerifyPKCS1v15(pub, hash, hashed, signature) + test.CheckNoErr(t, err, fmt.Sprintf("RSA signature does not verify: %v", err)) +} + func TestKeyShare_Sign(t *testing.T) { // delta = 3! = 6 // n = 253 @@ -16,9 +63,14 @@ func TestKeyShare_Sign(t *testing.T) { // x_i = x^{2∆kshare.si} = 150^{2 * 6 * 15} = 150^180 = 243 kshare := KeyShare{ - si: big.NewInt(15), - Index: 1, - Players: 3, + share: share{ + ModulusLength: 256, + Threshold: 1, + Players: 3, + Index: 1, + }, + si: big.NewInt(15), + twoDeltaSi: big.NewInt(180), } pub := rsa.PublicKey{N: big.NewInt(253)} share, err := kshare.Sign(nil, &pub, []byte{150}, false) @@ -39,9 +91,14 @@ func testSignBlind(parallel bool, t *testing.T) { // x_i = x^{2∆kshare.si} = 150^{2 * 6 * 15} = 150^180 = 243 kshare := KeyShare{ - si: big.NewInt(15), - Index: 1, - Players: 3, + share: share{ + ModulusLength: 256, + Threshold: 1, + Players: 3, + Index: 1, + }, + si: big.NewInt(15), + twoDeltaSi: big.NewInt(180), } pub := rsa.PublicKey{N: big.NewInt(253)} share, err := kshare.Sign(rand.Reader, &pub, []byte{150}, parallel) @@ -60,102 +117,3 @@ func TestKeyShare_SignBlind(t *testing.T) { func TestKeyShare_SignBlindParallel(t *testing.T) { testSignBlind(true, t) } - -func marshalTestKeyShare(share KeyShare, t *testing.T) { - marshall, err := share.MarshalBinary() - if err != nil { - t.Fatal(err) - } - - share2 := KeyShare{} - err = share2.UnmarshalBinary(marshall) - if err != nil { - t.Fatal(err) - } - - if share.Players != share2.Players { - t.Fatalf("Players did not match, expected %d, found %d", share.Players, share2.Players) - } - - if share.Threshold != share2.Threshold { - t.Fatalf("Threshold did not match, expected %d, found %d", share.Threshold, share2.Threshold) - } - - if share.Index != share2.Index { - t.Fatalf("Index did not match, expected %d, found %d", share.Index, share2.Index) - } - - if (share.twoDeltaSi == nil || share2.twoDeltaSi == nil) && share.twoDeltaSi != share2.twoDeltaSi { - t.Fatalf("twoDeltaSi did not match, expected %v, found %v", share.twoDeltaSi, share2.twoDeltaSi) - } - - if !(share.twoDeltaSi == nil && share2.twoDeltaSi == nil) && share.twoDeltaSi.Cmp(share2.twoDeltaSi) != 0 { - t.Fatalf("twoDeltaSi did not match, expected %v, found %v", share.twoDeltaSi.Bytes(), share2.twoDeltaSi.Bytes()) - } - - if share.si.Cmp(share2.si) != 0 { - t.Fatalf("si did not match, expected %v, found %v", share.si.Bytes(), share2.si.Bytes()) - } -} - -func unmarshalKeyShareTest(t *testing.T, input []byte) { - share := KeyShare{} - err := share.UnmarshalBinary(input) - if err == nil { - t.Fatalf("unmarshall succeeded when it shouldn't have") - } -} - -func TestMarshallKeyShare(t *testing.T) { - marshalTestKeyShare(KeyShare{ - si: big.NewInt(10), - twoDeltaSi: big.NewInt(20), - Index: 30, - Threshold: 10, - Players: 2, - }, t) - - marshalTestKeyShare(KeyShare{ - si: big.NewInt(10), - twoDeltaSi: big.NewInt(20), - Index: 30, - Threshold: 0, - Players: 200, - }, t) - - marshalTestKeyShare(KeyShare{ - si: big.NewInt(0), - twoDeltaSi: big.NewInt(0), - Index: 0, - Threshold: 0, - Players: 0, - }, t) - - unmarshalKeyShareTest(t, []byte{}) - unmarshalKeyShareTest(t, []byte{1, 0, 1}) - unmarshalKeyShareTest(t, []byte{1, 0, 1}) - unmarshalKeyShareTest(t, []byte{0, 1, 0, 1, 0, 1}) - unmarshalKeyShareTest(t, []byte{0, 1, 0, 1, 0, 1, 0, 1}) - unmarshalKeyShareTest(t, []byte{0, 1, 0, 1, 0, 1, 0}) - unmarshalKeyShareTest(t, []byte{0, 1, 0, 1, 0, 1, 0, 2, 1}) - unmarshalKeyShareTest(t, []byte{0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0}) - unmarshalKeyShareTest(t, []byte{0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1}) -} - -func TestMarshallKeyShareFull(t *testing.T) { - const players = 3 - const threshold = 2 - const bits = 4096 - - key, err := rsa.GenerateKey(rand.Reader, bits) - if err != nil { - t.Fatal(err) - } - keys, err := Deal(rand.Reader, players, threshold, key) - if err != nil { - t.Fatal(err) - } - for _, share := range keys { - marshalTestKeyShare(share, t) - } -} diff --git a/tss/rsa/rsa_threshold.go b/tss/rsa/rsa_threshold.go index e18a02375..3a01a591b 100644 --- a/tss/rsa/rsa_threshold.go +++ b/tss/rsa/rsa_threshold.go @@ -80,7 +80,7 @@ func GenerateKey(random io.Reader, bits int) (*rsa.PrivateKey, error) { func validateParams(players, threshold uint) error { if players <= 1 { - return errors.New("rsa_threshold: Players (l) invalid: should be > 1") + return fmt.Errorf("rsa_threshold: Players (%v) invalid: should be > 1", players) } if threshold < 1 || threshold > players { return fmt.Errorf("rsa_threshold: Threshold (k) invalid: %d < 1 || %d > %d", threshold, threshold, players) @@ -157,18 +157,19 @@ func Deal(randSource io.Reader, players, threshold uint, key *rsa.PrivateKey) ([ return nil, err } } - + modulusLengthBits := uint(key.N.BitLen()) shares := make([]KeyShare, players) // 1 <= i <= l for i := uint(1); i <= players; i++ { + shares[i-1].ModulusLength = modulusLengthBits shares[i-1].Players = players shares[i-1].Threshold = threshold // Σ^{k-1}_{i=0} | a_i * X^i (mod m) si := computePolynomial(threshold, a, i, &m) shares[i-1].si = si shares[i-1].Index = i - shares[i-1].get2DeltaSi(int64(players)) + shares[i-1].get2DeltaSi(int64(players), &m) // If the modulus is composed by safe primes, verification keys are included. if hasSafePrimes { @@ -226,7 +227,7 @@ func PadHash(padder Padder, hash crypto.Hash, pub *rsa.PublicKey, msg []byte) ([ type Signature = []byte // CombineSignShares combines t SignShare's to produce a valid signature -func CombineSignShares(pub *rsa.PublicKey, shares []SignShare, msg []byte) (Signature, error) { +func CombineSignShares(pub *rsa.PublicKey, shares []*SignShare, msg []byte) (Signature, error) { players := shares[0].Players threshold := shares[0].Threshold @@ -313,7 +314,7 @@ func CombineSignShares(pub *rsa.PublicKey, shares []SignShare, msg []byte) (Sign // computes lagrange Interpolation for the shares // i must be an id 0..l but not in S // j must be in S -func computeLambda(delta *big.Int, S []SignShare, i, j int64) (*big.Int, error) { +func computeLambda(delta *big.Int, S []*SignShare, i, j int64) (*big.Int, error) { if i == j { return nil, errors.New("rsa_threshold: i and j can't be equal by precondition") } diff --git a/tss/rsa/rsa_threshold_test.go b/tss/rsa/rsa_threshold_test.go index 6d85c9954..d752a0ad5 100644 --- a/tss/rsa/rsa_threshold_test.go +++ b/tss/rsa/rsa_threshold_test.go @@ -81,8 +81,9 @@ func TestComputeLambda(t *testing.T) { // dem = (3 - 1) * (3 - 2) * (3 - 4) * (3 - 5) = 4 // num/dev = 40/4 = 10 // ∆ * 10 = 120 * 10 = 1200 - shares := make([]SignShare, 5) + shares := make([]*SignShare, 5) for i := uint(1); i <= 5; i++ { + shares[i-1] = new(SignShare) shares[i-1].Index = i } i := int64(0) @@ -192,7 +193,7 @@ func testIntegration(t *testing.T, algo crypto.Hash, priv *rsa.PrivateKey, thres t.Fatal(err) } - signshares := make([]SignShare, threshold) + signshares := make([]*SignShare, threshold) for i := uint(0); i < threshold; i++ { signshares[i], err = keys[i].Sign(rand.Reader, pub, msgPH, true) @@ -308,7 +309,7 @@ func benchmarkSignCombineHelper(randSource io.Reader, parallel bool, b *testing. b.Fatal(err) } - signshares := make([]SignShare, threshold) + signshares := make([]*SignShare, threshold) b.ResetTimer() for i := 0; i < b.N; i++ { for i := uint(0); i < threshold; i++ { diff --git a/tss/rsa/signShare.go b/tss/rsa/signShare.go index dfce65f29..b6542d6f7 100644 --- a/tss/rsa/signShare.go +++ b/tss/rsa/signShare.go @@ -2,23 +2,20 @@ package rsa import ( "crypto/rsa" - "encoding/binary" "errors" "fmt" - "math" "math/big" + "github.com/cloudflare/circl/internal/conv" "github.com/cloudflare/circl/zk/qndleq" + "golang.org/x/crypto/cryptobyte" ) // SignShare represents a portion of a signature. It is generated when a message is signed by a KeyShare. t SignShare's are then combined by calling CombineSignShares, where t is the Threshold. type SignShare struct { - xi *big.Int - - Index uint + share - Players uint - Threshold uint + xi *big.Int // It stores a DLEQ proof attesting that the signature // share was computed using the signer's key share. @@ -29,8 +26,7 @@ type SignShare struct { } func (s SignShare) String() string { - return fmt.Sprintf("(t,n): (%v,%v) index: %v xi: 0x%v", - s.Threshold, s.Players, s.Index, s.xi.Text(16)) + return fmt.Sprintf("%v xi: 0x%v proof: {%v}", s.share, s.xi.Text(16), s.proof) } // IsVerifiable returns true if the signature share contains @@ -58,94 +54,80 @@ func (s *SignShare) Verify(pub *rsa.PublicKey, vk *VerifyKeys, digest []byte) er xiSqr := new(big.Int).Mul(s.xi, s.xi) xiSqr.Mod(xiSqr, pub.N) - if !s.proof.Verify(vk.GroupKey, vk.VerifyKey, x4Delta, xiSqr, pub.N) { + const SecParam = 128 + if !s.proof.Verify(vk.GroupKey, vk.VerifyKey, x4Delta, xiSqr, pub.N, SecParam) { return ErrSignShareInvalid } return nil } -// MarshalBinary encodes SignShare into a byte array in a format readable by UnmarshalBinary. -// Note: Only Index's up to math.MaxUint16 are supported -func (s *SignShare) MarshalBinary() ([]byte, error) { - // | Players: uint16 | Threshold: uint16 | Index: uint16 | xiLen: uint16 | xi: []byte | - - if s.Players > math.MaxUint16 { - return nil, fmt.Errorf("rsa_threshold: signshare marshall: Players is too big to fit in a uint16") - } - - if s.Threshold > math.MaxUint16 { - return nil, fmt.Errorf("rsa_threshold: signshare marshall: Threshold is too big to fit in a uint16") - } - - if s.Index > math.MaxUint16 { - return nil, fmt.Errorf("rsa_threshold: signshare marshall: Index is too big to fit in a uint16") - } - - players := uint16(s.Players) - threshold := uint16(s.Threshold) - index := uint16(s.Index) +func (s *SignShare) Marshal(b *cryptobyte.Builder) error { + buf := make([]byte, (s.ModulusLength+7)/8) + b.AddValue(&s.share) + b.AddBytes(s.xi.FillBytes(buf)) - xiBytes := s.xi.Bytes() - xiLen := len(xiBytes) - - if xiLen > math.MaxInt16 { - return nil, fmt.Errorf("rsa_threshold: signshare marshall: xiBytes is too big to fit it's length in a uint16") + isVerifiable := s.IsVerifiable() + var flag uint8 + if isVerifiable { + flag = 0x01 } + b.AddUint8(flag) - if xiLen == 0 { - xiLen = 1 - xiBytes = []byte{0} + if isVerifiable { + b.AddValue(s.proof) } - blen := 2 + 2 + 2 + 2 + xiLen - out := make([]byte, blen) - - binary.BigEndian.PutUint16(out[0:2], players) - binary.BigEndian.PutUint16(out[2:4], threshold) - binary.BigEndian.PutUint16(out[4:6], index) - - binary.BigEndian.PutUint16(out[6:8], uint16(xiLen)) - - copy(out[8:8+xiLen], xiBytes) - - return out, nil + return nil } -// UnmarshalBinary converts a byte array outputted from Marshall into a SignShare or returns an error if the value is invalid -func (s *SignShare) UnmarshalBinary(data []byte) error { - // | Players: uint16 | Threshold: uint16 | Index: uint16 | xiLen: uint16 | xi: []byte | - if len(data) < 8 { - return fmt.Errorf("rsa_threshold: signshare unmarshalKeyShareTest failed: data length was too short for reading Players, Threshold, Index, and xiLen") +func (s *SignShare) ReadValue(r *cryptobyte.String) bool { + var sh share + ok := sh.ReadValue(r) + if !ok { + return false } - players := binary.BigEndian.Uint16(data[0:2]) - threshold := binary.BigEndian.Uint16(data[2:4]) - index := binary.BigEndian.Uint16(data[4:6]) - xiLen := binary.BigEndian.Uint16(data[6:8]) - - if xiLen == 0 { - return fmt.Errorf("rsa_threshold: signshare unmarshalKeyShareTest failed: xi is a required field but xiLen was 0") + mlen := int((sh.ModulusLength + 7) / 8) + var xiBytes []byte + ok = r.ReadBytes(&xiBytes, mlen) + if !ok { + return false } - if uint16(len(data[8:])) < xiLen { - return fmt.Errorf("rsa_threshold: signshare unmarshalKeyShareTest failed: data length was too short for reading xi, needed: %d found: %d", xiLen, len(data[6:])) + var isVerifiable uint8 + ok = r.ReadUint8(&isVerifiable) + if !ok { + return false } - xi := big.Int{} - bytes := make([]byte, xiLen) - copy(bytes, data[8:8+xiLen]) - xi.SetBytes(bytes) + var proof *qndleq.Proof + switch isVerifiable { + case 0: + proof = nil + case 1: + proof = new(qndleq.Proof) + ok = proof.ReadValue(r) + if !ok { + return false + } + + default: + return false + } - s.Players = uint(players) - s.Threshold = uint(threshold) - s.Index = uint(index) - s.xi = &xi + s.share = sh + s.xi = new(big.Int).SetBytes(xiBytes) + s.proof = proof - return nil + return true } +func (s *SignShare) MarshalBinary() ([]byte, error) { return conv.MarshalBinary(s) } +func (s *SignShare) UnmarshalBinary(b []byte) error { return conv.UnmarshalBinary(s, b) } + var ( + ErrKeyShareNonVerifiable = errors.New("key share has no verification keys") ErrSignShareNonVerifiable = errors.New("signature share is not verifiable") ErrSignShareInvalid = errors.New("signature share is invalid") ) diff --git a/tss/rsa/signShare_test.go b/tss/rsa/signShare_test.go index 031e4a135..3cda0eaee 100644 --- a/tss/rsa/signShare_test.go +++ b/tss/rsa/signShare_test.go @@ -46,17 +46,23 @@ func unmarshalSignShareTest(t *testing.T, input []byte) { func TestMarshallSignShare(t *testing.T) { marshalTestSignShare(SignShare{ - xi: big.NewInt(10), - Index: 30, - Players: 16, - Threshold: 18, + xi: big.NewInt(10), + share: share{ + ModulusLength: 256, + Index: 30, + Players: 18, + Threshold: 16, + }, }, t) marshalTestSignShare(SignShare{ - xi: big.NewInt(0), - Index: 0, - Players: 0, - Threshold: 0, + xi: big.NewInt(0), + share: share{ + ModulusLength: 256, + Index: 1, + Players: 2, + Threshold: 1, + }, }, t) unmarshalSignShareTest(t, []byte{}) diff --git a/tss/rsa/util.go b/tss/rsa/util.go index 218032697..8366e8981 100644 --- a/tss/rsa/util.go +++ b/tss/rsa/util.go @@ -1,7 +1,10 @@ package rsa import ( + "fmt" "math/big" + + "golang.org/x/crypto/cryptobyte" ) func calculateDelta(l int64) *big.Int { @@ -10,3 +13,49 @@ func calculateDelta(l int64) *big.Int { delta.MulRange(1, l) return &delta } + +type share struct { + ModulusLength uint // Size of RSA modulus in bits. + Threshold uint // Minimum number of shares to produce a signature. + Players uint // Total number of signers. + Index uint // Non-zero identifier of the signer. +} + +func (s share) String() string { + return fmt.Sprintf("(t=%v,n=%v)-RSA-%v index: %v", s.Threshold, s.Players, s.ModulusLength, s.Index) +} + +func (s *share) Marshal(b *cryptobyte.Builder) error { + b.AddUint16(uint16(s.ModulusLength)) + b.AddUint16(uint16(s.Threshold)) + b.AddUint16(uint16(s.Players)) + b.AddUint16(uint16(s.Index)) + return nil +} + +func (s *share) ReadValue(r *cryptobyte.String) bool { + var ModulusLength, Index, Threshold, Players uint16 + ok := r.ReadUint16(&ModulusLength) && + r.ReadUint16(&Threshold) && + r.ReadUint16(&Players) && + r.ReadUint16(&Index) + if !ok { + return false + } + + err := validateParams(uint(Players), uint(Threshold)) + if err != nil { + panic(err) + } + + if Index == 0 { + panic("index cannot be zero") + } + + s.ModulusLength = uint(ModulusLength) + s.Threshold = uint(Threshold) + s.Players = uint(Players) + s.Index = uint(Index) + + return true +} diff --git a/zk/qndleq/qndleq.go b/zk/qndleq/qndleq.go index d1baf72df..d4dd04fea 100644 --- a/zk/qndleq/qndleq.go +++ b/zk/qndleq/qndleq.go @@ -24,15 +24,21 @@ package qndleq import ( "crypto/rand" + "fmt" "io" "math/big" + "github.com/cloudflare/circl/internal/conv" "github.com/cloudflare/circl/internal/sha3" + "golang.org/x/crypto/cryptobyte" ) type Proof struct { - Z, C *big.Int - SecParam uint + Z, C *big.Int +} + +func (p Proof) String() string { + return fmt.Sprintf("Z: 0x%v C: 0x%v", p.Z.Text(16), p.C.Text(16)) } // SampleQn returns an element of Qn (the subgroup of squares in (Z/nZ)*). @@ -84,11 +90,11 @@ func Prove(random io.Reader, x, g, gx, h, hx, N *big.Int, secParam uint) (*Proof z.Mul(c, x).Add(z, r) r.Xor(r, r) - return &Proof{Z: z, C: c, SecParam: secParam}, nil + return &Proof{Z: z, C: c}, nil } // Verify checks whether x = Log_g(g^x) = Log_h(h^x). -func (p Proof) Verify(g, gx, h, hx, N *big.Int) bool { +func (p Proof) Verify(g, gx, h, hx, N *big.Int, secParam uint) bool { gPNum := new(big.Int).Exp(g, p.Z, N) gPDen := new(big.Int).Exp(gx, p.C, N) ok := gPDen.ModInverse(gPDen, N) @@ -107,7 +113,7 @@ func (p Proof) Verify(g, gx, h, hx, N *big.Int) bool { hP := hPNum.Mul(hPNum, hPDen) hP.Mod(hP, N) - c := doChallenge(g, gx, h, hx, gP, hP, N, p.SecParam) + c := doChallenge(g, gx, h, hx, gP, hP, N, secParam) return p.C.Cmp(c) == 0 } @@ -142,3 +148,26 @@ func doChallenge(g, gx, h, hx, gP, hP, N *big.Int, secParam uint) *big.Int { return new(big.Int).SetBytes(cBytes) } + +func (p *Proof) Marshal(b *cryptobyte.Builder) error { + b.AddUint16LengthPrefixed(func(c *cryptobyte.Builder) { c.AddBytes(p.Z.Bytes()) }) + b.AddUint16LengthPrefixed(func(c *cryptobyte.Builder) { c.AddBytes(p.C.Bytes()) }) + return nil +} + +func (p *Proof) ReadValue(r *cryptobyte.String) bool { + var zStr, cStr cryptobyte.String + ok := r.ReadUint16LengthPrefixed(&zStr) && + r.ReadUint16LengthPrefixed(&cStr) + if !ok { + return false + } + + p.Z = new(big.Int).SetBytes([]byte(zStr)) + p.C = new(big.Int).SetBytes([]byte(cStr)) + + return true +} + +func (p *Proof) MarshalBinary() ([]byte, error) { return conv.MarshalBinary(p) } +func (p *Proof) UnmarshalBinary(b []byte) error { return conv.UnmarshalBinary(p, b) } diff --git a/zk/qndleq/qndleq_test.go b/zk/qndleq/qndleq_test.go index a8af777c2..e1c1bebc2 100644 --- a/zk/qndleq/qndleq_test.go +++ b/zk/qndleq/qndleq_test.go @@ -30,7 +30,8 @@ func TestProve(t *testing.T) { proof, err := qndleq.Prove(rand.Reader, x, g, gx, h, hx, N, SecParam) test.CheckNoErr(t, err, "failed to generate proof") - test.CheckOk(proof.Verify(g, gx, h, hx, N), "failed to verify", t) + test.CheckMarshal(t, proof, new(qndleq.Proof)) + test.CheckOk(proof.Verify(g, gx, h, hx, N, SecParam), "failed to verify", t) } } @@ -78,7 +79,7 @@ func Benchmark_qndleq(b *testing.B) { b.Run("Verify", func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = proof.Verify(g, gx, h, hx, N) + _ = proof.Verify(g, gx, h, hx, N, SecParam) } }) }