From cd46a0613a7849a14ebeb7f1218f219078448dc9 Mon Sep 17 00:00:00 2001 From: Bas Westerbaan Date: Fri, 16 Aug 2024 15:15:21 +0200 Subject: [PATCH] xwing: align with draft 04 --- kem/xwing/scheme.go | 4 ++- kem/xwing/xwing.go | 61 ++++++++++++++++++++++------------------- kem/xwing/xwing_test.go | 18 ++++++------ 3 files changed, 45 insertions(+), 38 deletions(-) diff --git a/kem/xwing/scheme.go b/kem/xwing/scheme.go index ac1a3f980..a9a574a57 100644 --- a/kem/xwing/scheme.go +++ b/kem/xwing/scheme.go @@ -62,7 +62,9 @@ func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) { return nil, kem.ErrPubKeySize } - pk.Unpack(buf) + if err := pk.Unpack(buf); err != nil { + return nil, err + } return &pk, nil } diff --git a/kem/xwing/xwing.go b/kem/xwing/xwing.go index d3683c938..367049fdf 100644 --- a/kem/xwing/xwing.go +++ b/kem/xwing/xwing.go @@ -2,7 +2,7 @@ // // https://datatracker.ietf.org/doc/draft-connolly-cfrg-xwing-kem // -// Currently implements what will likely be -01. +// Currently implements -04. package xwing import ( @@ -18,9 +18,10 @@ import ( // An X-Wing private key. type PrivateKey struct { - m mlkem768.PrivateKey - x x25519.Key - xpk x25519.Key + seed [32]byte + m mlkem768.PrivateKey + x x25519.Key + xpk x25519.Key } // An X-Wing public key. @@ -31,13 +32,13 @@ type PublicKey struct { const ( // Size of a seed of a keypair - SeedSize = 96 + SeedSize = 32 // Size of an X-Wing public key PublicKeySize = 1216 // Size of an X-Wing private key - PrivateKeySize = 2464 + PrivateKeySize = 32 // Size of the seed passed to EncapsulateTo EncapsulationSeedSize = 64 @@ -74,9 +75,7 @@ func (sk *PrivateKey) Pack(buf []byte) { if len(buf) != PrivateKeySize { panic(kem.ErrPrivKeySize) } - sk.m.Pack(buf[:mlkem768.PrivateKeySize]) - copy(buf[mlkem768.PrivateKeySize:mlkem768.PrivateKeySize+32], sk.x[:]) - copy(buf[mlkem768.PrivateKeySize+32:], sk.xpk[:]) + copy(buf, sk.seed[:]) } // Packs pk to buf. @@ -95,18 +94,29 @@ func (pk *PublicKey) Pack(buf []byte) { // // Panics if seed is not of length SeedSize. func DeriveKeyPair(seed []byte) (*PrivateKey, *PublicKey) { + var ( + sk PrivateKey + pk PublicKey + ) + + deriveKeyPair(seed, &sk, &pk) + + return &sk, &pk +} + +func deriveKeyPair(seed []byte, sk *PrivateKey, pk *PublicKey) { if len(seed) != SeedSize { panic(kem.ErrSeedSize) } - var ( - pk PublicKey - sk PrivateKey - seedm [mlkem768.KeySeedSize]byte - ) + var seedm [mlkem768.KeySeedSize]byte - copy(seedm[:], seed[:64]) - copy(sk.x[:], seed[64:]) + copy(sk.seed[:], seed) + + h := sha3.NewShake128() + _, _ = h.Write(seed) + _, _ = h.Read(seedm[:]) + _, _ = h.Read(sk.x[:]) pkm, skm := mlkem768.NewKeyFromSeed(seedm[:]) sk.m = *skm @@ -114,8 +124,6 @@ func DeriveKeyPair(seed []byte) (*PrivateKey, *PublicKey) { x25519.KeyGen(&pk.x, &sk.x) sk.xpk = pk.x - - return &sk, &pk } // DeriveKeyPairPacked derives a keypair like DeriveKeyPair, and @@ -171,7 +179,7 @@ func GenerateKeyPairPacked(rand io.Reader) ([]byte, []byte, error) { // X-Wing standard, which is the reverse of the Circl KEM API. // // Panics if pk is not of size PublicKeySize, or randomness could not -// be read from crypto/rand.Reader +// be read from crypto/rand.Reader. func Encapsulate(pk, seed []byte) (ss, ct []byte) { var pub PublicKey pub.Unpack(pk) @@ -276,24 +284,21 @@ func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { // Unpacks pk from buf. // // Panics if buf is not of size PublicKeySize. -func (pk *PublicKey) Unpack(buf []byte) { +// +// Returns ErrPubKey if pk fails the ML-KEM encapsulation key check. +func (pk *PublicKey) Unpack(buf []byte) error { if len(buf) != PublicKeySize { panic(kem.ErrPubKeySize) } copy(pk.x[:], buf[mlkem768.PublicKeySize:]) - pk.m.Unpack(buf[:mlkem768.PublicKeySize]) + return pk.m.Unpack(buf[:mlkem768.PublicKeySize]) } // Unpacks sk from buf. // // Panics if buf is not of size PrivateKeySize. func (sk *PrivateKey) Unpack(buf []byte) { - if len(buf) != PrivateKeySize { - panic(kem.ErrPrivKeySize) - } - - copy(sk.x[:], buf[mlkem768.PrivateKeySize:mlkem768.PrivateKeySize+32]) - copy(sk.xpk[:], buf[mlkem768.PrivateKeySize+32:]) - sk.m.Unpack(buf[:mlkem768.PrivateKeySize]) + var pk PublicKey + deriveKeyPair(buf, sk, &pk) } diff --git a/kem/xwing/xwing_test.go b/kem/xwing/xwing_test.go index 02508c7fe..c56ea901e 100644 --- a/kem/xwing/xwing_test.go +++ b/kem/xwing/xwing_test.go @@ -13,8 +13,8 @@ func writeHex(w io.Writer, prefix string, val interface{}) { indent := " " width := 74 hex := fmt.Sprintf("%x", val) - if len(prefix)+len(hex)+1 < width { - fmt.Fprintf(w, "%s %s\n", prefix, hex) + if len(prefix)+len(hex)+5 < width { + fmt.Fprintf(w, "%s %s\n", prefix, hex) return } fmt.Fprintf(w, "%s\n", prefix) @@ -38,19 +38,19 @@ func TestVectors(t *testing.T) { for i := 0; i < 3; i++ { var seed [SeedSize]byte _, _ = h.Read(seed[:]) - writeHex(w, "seed ", seed) + writeHex(w, "seed", seed) sk, pk := DeriveKeyPairPacked(seed[:]) - writeHex(w, "sk ", sk) - writeHex(w, "pk ", pk) + writeHex(w, "sk", sk) + writeHex(w, "pk", pk) var eseed [EncapsulationSeedSize]byte _, _ = h.Read(eseed[:]) - writeHex(w, "eseed ", eseed) + writeHex(w, "eseed", eseed) ss, ct := Encapsulate(pk, eseed[:]) - writeHex(w, "ct ", ct) - writeHex(w, "ss ", ss) + writeHex(w, "ct", ct) + writeHex(w, "ss", ss) ss2 := Decapsulate(ct, sk) if !bytes.Equal(ss, ss2) { @@ -66,7 +66,7 @@ func TestVectors(t *testing.T) { var cs [32]byte _, _ = h.Read(cs[:]) got := fmt.Sprintf("%x", cs) - want := "1b2fd3a79ad0a82d814dcdf5da62a3830bc5f48e392dfe01ac1c3f9bb37ff86e" + want := "0e414d1453095f77f7959da8ddba81559e9d62508c2f665a004467420d5d0c51" if got != want { t.Fatalf("%s ≠ %s", got, want) }