From 296b1037f13376815d3a3f2266c919e5f18ae333 Mon Sep 17 00:00:00 2001 From: Bas Westerbaan Date: Fri, 16 Aug 2024 15:15:21 +0200 Subject: [PATCH] xwing: align with draft 04 --- kem/xwing/scheme.go | 4 ++- kem/xwing/xwing.go | 71 +++++++++++++++++++++++------------------ kem/xwing/xwing_test.go | 23 +++++++------ 3 files changed, 56 insertions(+), 42 deletions(-) diff --git a/kem/xwing/scheme.go b/kem/xwing/scheme.go index ac1a3f980..a9a574a57 100644 --- a/kem/xwing/scheme.go +++ b/kem/xwing/scheme.go @@ -62,7 +62,9 @@ func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) { return nil, kem.ErrPubKeySize } - pk.Unpack(buf) + if err := pk.Unpack(buf); err != nil { + return nil, err + } return &pk, nil } diff --git a/kem/xwing/xwing.go b/kem/xwing/xwing.go index d3683c938..7c925c67a 100644 --- a/kem/xwing/xwing.go +++ b/kem/xwing/xwing.go @@ -2,7 +2,7 @@ // // https://datatracker.ietf.org/doc/draft-connolly-cfrg-xwing-kem // -// Currently implements what will likely be -01. +// Currently implements -04. package xwing import ( @@ -18,9 +18,10 @@ import ( // An X-Wing private key. type PrivateKey struct { - m mlkem768.PrivateKey - x x25519.Key - xpk x25519.Key + seed [32]byte + m mlkem768.PrivateKey + x x25519.Key + xpk x25519.Key } // An X-Wing public key. @@ -31,13 +32,13 @@ type PublicKey struct { const ( // Size of a seed of a keypair - SeedSize = 96 + SeedSize = 32 // Size of an X-Wing public key PublicKeySize = 1216 // Size of an X-Wing private key - PrivateKeySize = 2464 + PrivateKeySize = 32 // Size of the seed passed to EncapsulateTo EncapsulationSeedSize = 64 @@ -74,9 +75,7 @@ func (sk *PrivateKey) Pack(buf []byte) { if len(buf) != PrivateKeySize { panic(kem.ErrPrivKeySize) } - sk.m.Pack(buf[:mlkem768.PrivateKeySize]) - copy(buf[mlkem768.PrivateKeySize:mlkem768.PrivateKeySize+32], sk.x[:]) - copy(buf[mlkem768.PrivateKeySize+32:], sk.xpk[:]) + copy(buf, sk.seed[:]) } // Packs pk to buf. @@ -95,18 +94,29 @@ func (pk *PublicKey) Pack(buf []byte) { // // Panics if seed is not of length SeedSize. func DeriveKeyPair(seed []byte) (*PrivateKey, *PublicKey) { + var ( + sk PrivateKey + pk PublicKey + ) + + deriveKeyPair(seed, &sk, &pk) + + return &sk, &pk +} + +func deriveKeyPair(seed []byte, sk *PrivateKey, pk *PublicKey) { if len(seed) != SeedSize { panic(kem.ErrSeedSize) } - var ( - pk PublicKey - sk PrivateKey - seedm [mlkem768.KeySeedSize]byte - ) + var seedm [mlkem768.KeySeedSize]byte + + copy(sk.seed[:], seed) - copy(seedm[:], seed[:64]) - copy(sk.x[:], seed[64:]) + h := sha3.NewShake128() + _, _ = h.Write(seed) + _, _ = h.Read(seedm[:]) + _, _ = h.Read(sk.x[:]) pkm, skm := mlkem768.NewKeyFromSeed(seedm[:]) sk.m = *skm @@ -114,8 +124,6 @@ func DeriveKeyPair(seed []byte) (*PrivateKey, *PublicKey) { x25519.KeyGen(&pk.x, &sk.x) sk.xpk = pk.x - - return &sk, &pk } // DeriveKeyPairPacked derives a keypair like DeriveKeyPair, and @@ -170,15 +178,19 @@ func GenerateKeyPairPacked(rand io.Reader) ([]byte, []byte, error) { // Warning: note that the order of the returned ss and ct matches the // X-Wing standard, which is the reverse of the Circl KEM API. // +// Returns ErrPubKey if ML-KEM encapsulation key check fails. +// // Panics if pk is not of size PublicKeySize, or randomness could not -// be read from crypto/rand.Reader -func Encapsulate(pk, seed []byte) (ss, ct []byte) { +// be read from crypto/rand.Reader. +func Encapsulate(pk, seed []byte) (ss, ct []byte, err error) { var pub PublicKey - pub.Unpack(pk) + if err := pub.Unpack(pk); err != nil { + return nil, nil, err + } ct = make([]byte, CiphertextSize) ss = make([]byte, SharedKeySize) pub.EncapsulateTo(ct, ss, seed) - return ss, ct + return ss, ct, nil } // Decapsulate computes the shared key which is encapsulated in ct @@ -276,24 +288,21 @@ func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) { // Unpacks pk from buf. // // Panics if buf is not of size PublicKeySize. -func (pk *PublicKey) Unpack(buf []byte) { +// +// Returns ErrPubKey if pk fails the ML-KEM encapsulation key check. +func (pk *PublicKey) Unpack(buf []byte) error { if len(buf) != PublicKeySize { panic(kem.ErrPubKeySize) } copy(pk.x[:], buf[mlkem768.PublicKeySize:]) - pk.m.Unpack(buf[:mlkem768.PublicKeySize]) + return pk.m.Unpack(buf[:mlkem768.PublicKeySize]) } // Unpacks sk from buf. // // Panics if buf is not of size PrivateKeySize. func (sk *PrivateKey) Unpack(buf []byte) { - if len(buf) != PrivateKeySize { - panic(kem.ErrPrivKeySize) - } - - copy(sk.x[:], buf[mlkem768.PrivateKeySize:mlkem768.PrivateKeySize+32]) - copy(sk.xpk[:], buf[mlkem768.PrivateKeySize+32:]) - sk.m.Unpack(buf[:mlkem768.PrivateKeySize]) + var pk PublicKey + deriveKeyPair(buf, sk, &pk) } diff --git a/kem/xwing/xwing_test.go b/kem/xwing/xwing_test.go index 02508c7fe..7bf564978 100644 --- a/kem/xwing/xwing_test.go +++ b/kem/xwing/xwing_test.go @@ -13,8 +13,8 @@ func writeHex(w io.Writer, prefix string, val interface{}) { indent := " " width := 74 hex := fmt.Sprintf("%x", val) - if len(prefix)+len(hex)+1 < width { - fmt.Fprintf(w, "%s %s\n", prefix, hex) + if len(prefix)+len(hex)+5 < width { + fmt.Fprintf(w, "%s %s\n", prefix, hex) return } fmt.Fprintf(w, "%s\n", prefix) @@ -38,19 +38,22 @@ func TestVectors(t *testing.T) { for i := 0; i < 3; i++ { var seed [SeedSize]byte _, _ = h.Read(seed[:]) - writeHex(w, "seed ", seed) + writeHex(w, "seed", seed) sk, pk := DeriveKeyPairPacked(seed[:]) - writeHex(w, "sk ", sk) - writeHex(w, "pk ", pk) + writeHex(w, "sk", sk) + writeHex(w, "pk", pk) var eseed [EncapsulationSeedSize]byte _, _ = h.Read(eseed[:]) - writeHex(w, "eseed ", eseed) + writeHex(w, "eseed", eseed) - ss, ct := Encapsulate(pk, eseed[:]) - writeHex(w, "ct ", ct) - writeHex(w, "ss ", ss) + ss, ct, err := Encapsulate(pk, eseed[:]) + if err != nil { + t.Fatal(err) + } + writeHex(w, "ct", ct) + writeHex(w, "ss", ss) ss2 := Decapsulate(ct, sk) if !bytes.Equal(ss, ss2) { @@ -66,7 +69,7 @@ func TestVectors(t *testing.T) { var cs [32]byte _, _ = h.Read(cs[:]) got := fmt.Sprintf("%x", cs) - want := "1b2fd3a79ad0a82d814dcdf5da62a3830bc5f48e392dfe01ac1c3f9bb37ff86e" + want := "0e414d1453095f77f7959da8ddba81559e9d62508c2f665a004467420d5d0c51" if got != want { t.Fatalf("%s ≠ %s", got, want) }