From 9eddd2eae84a8e08053101d0afd1bdc833720b5c Mon Sep 17 00:00:00 2001 From: "Markku-Juhani O. Saarinen" Date: Tue, 20 Aug 2024 19:08:11 +0000 Subject: [PATCH] add FIPS 205 stuff --- README.md | 44 +++- fips204.py | 96 ++++++- fips205.py | 638 ++++++++++++++++++++++++++++++++++++++++++++++ genvals_mldsa.py | 2 +- genvals_slhdsa.py | 104 ++++++++ test_slhdsa.py | 229 +++++++++++++++++ 6 files changed, 1108 insertions(+), 5 deletions(-) create mode 100644 fips205.py create mode 100644 genvals_slhdsa.py create mode 100644 test_slhdsa.py diff --git a/README.md b/README.md index a2968d8..5f67544 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,19 @@ 2024-07-01 Markku-Juhani O. Saarinen mjos@iki.fi -Updated 2024-08-18 for the release FIPS 203, FIPS 204 +Updated 2024-08-20 for the release FIPS 203, FIPS 204, FIPS 205. ``` py-acvp-pqc ├── fips203.py # Python implementation of ML-KEM ("Kyber") ├── fips204.py # Python implementation of ML-DSA ("Dilithium") +├── fips205.py # Python implementation of SLH-DSA ("SPHINCS+") ├── genvals_mlkem.py # Python wrapper for ML-KEM in NIST's C# Gen/Vals ├── genvals_mldsa.py # Python wrapper for ML-DSA in NIST's C# Gen/Vals +├── genvals_slhdsa.py # Python wrapper for SLH-DSA in NIST's C# Gen/Vals ├── test_mlkem.py # Parser/tester for ML-KEM ACVP test vectors ├── test_mldsa.py # Parser/tester for ML-DSA ACVP test vectors +├── test_slhdsa.py # Parser/tester for SLH-DSA ACVP test vectors ├── ACVP-Server # (Symlink to) NIST's ACVP-Server repo for Gen/Vals ├── json-copy # Local copy from ACVP-Server/gen-val/json-files/ ├── Makefile # Makefile for cleanups @@ -26,10 +29,11 @@ You won't need the NIST C# dependencies to run the local Python implementations * ML-KEM: [fips203.py](fips203.py) is a self-contained implementation of [FIPS 203 ML-KEM](https://doi.org/10.6028/NIST.FIPS.203) a.k.a. Kyber. * ML-DSA: [fips204.py](fips204.py) is a self-contained implementation of [FIPS 204 ML-DSA](https://doi.org/10.6028/NIST.FIPS.204) a.k.a. Dilithium. +* SLH-DSA: [fips205.py](fips205.py) is a self-contained implementation of [FIPS 205 SLH-DSA](https://doi.org/10.6028/NIST.FIPS.205) a.k.a. SPHINCS+. * Test vector json parsers: [test_mlkem.py](test_mlkem.py) and [test_mldsa.py](test_mldsa.py). * Test vectors: there's a local copy of relevant json test vectors from NIST in [json-copy](json-copy). These can be synced with [https://github.com/usnistgov/ACVP-Server/tree/master/gen-val/json-files](https://github.com/usnistgov/ACVP-Server/tree/master/gen-val/json-files). -The main functions have unit tests: +The main functions have unit tests. For ML-KEM: ``` $ python3 fips203.py @@ -40,6 +44,7 @@ ML-KEM (fips203.py) -- Total FAIL= 0 ``` _( This indicates success.)_ +Running the test for ML_DSA is similar: ``` $ python3 fips204.py ML-DSA KeyGen (fips204.py): PASS= 75 FAIL= 0 @@ -50,6 +55,24 @@ ML-DSA (fips204.py) -- Total FAIL= 0 _( If you're curious why 30 test vectors are "skipped," The non-deterministic signature code is indeed non-deterministic and makes an internal call to an RBG. Hence, we're not trying to match those answers. )_ +By default the output for SLH-DSA is a bit verbose, as it will take several minutes to run them all: + +``` +$ python3 fips205.py +SLH-DSA-SHA2-128s KeyGen/1 pass +(.. output truncated ..) +SLH-DSA-SHAKE-256f KeyGen/40 pass +SLH-DSA KeyGen (fips205.py): PASS= 40 FAIL= 0 +SLH-DSA-SHA2-192s SigGen/1 pass +(.. output truncated ..) +SLH-DSA-SHAKE-128f SigGen/88 pass +SLH-DSA SigGen (fips205.py): PASS= 88 FAIL= 0 SKIP= 0 +SLH-DSA-SHA2-192s SigVer/1 pass +(.. output truncated ..) +SLH-DSA-SHAKE-128f SigVer/45 pass +SLH-DSA SigVer (fips205.py): PASS= 45 FAIL= 0 +SLH-DSA (fips205.py) -- Total FAIL= 0 +``` # NIST Gen/Vals @@ -126,7 +149,7 @@ $ source .venv/bin/activate Note that you will have to "enter" the enviroment with `source .venv/bin/activate` to use pythonnet installed locally this way. -Anyway, we should now be able to execute our Kyber and Dilithium test programs: +Anyway, assuming that all of the DLLs are in the right places, we should be abole to run our Kyber, Dilithium, and SPHINCS+ tests: ``` (.venv) $ python3 genvals_mlkem.py ML-KEM KeyGen (NIST Gen/Vals): PASS= 75 FAIL= 0 @@ -139,6 +162,21 @@ ML-DSA KeyGen (NIST Gen/Vals): PASS= 75 FAIL= 0 ML-DSA SigGen (NIST Gen/Vals): PASS= 30 FAIL= 0 SKIP= 30 ML-DSA SigVer (NIST Gen/Vals): PASS= 45 FAIL= 0 ML-DSA (NIST Gen/Vals) -- Total FAIL= 0 + +(.venv) $ $ python3 genvals_slhdsa.py +SLH-DSA-SHA2-128s KeyGen/1 pass +(.. output truncated ..) +SLH-DSA-SHAKE-256f KeyGen/40 pass +SLH-DSA KeyGen (NIST Gen/Vals): PASS= 40 FAIL= 0 +SLH-DSA-SHA2-192s SigGen/1 pass +(.. output truncated ..) +SLH-DSA-SHAKE-128f SigGen/88 pass +SLH-DSA SigGen (NIST Gen/Vals): PASS= 88 FAIL= 0 SKIP= 0 +SLH-DSA-SHA2-192s SigVer/1 pass +(.. output truncated ..) +SLH-DSA-SHAKE-128f SigVer/45 pass +SLH-DSA SigVer (NIST Gen/Vals): PASS= 45 FAIL= 0 +SLH-DSA (NIST Gen/Vals) -- Total FAIL= 0 ``` This is a success! diff --git a/fips204.py b/fips204.py index 1040649..7c890ff 100644 --- a/fips204.py +++ b/fips204.py @@ -7,7 +7,7 @@ from test_mldsa import test_mldsa # hash functions -from Crypto.Hash import SHAKE128, SHAKE256, SHA3_256, SHA3_512 +from Crypto.Hash import SHAKE128, SHAKE256, SHA3_256, SHA3_512, SHA256, SHA512 ML_DSA_Q = 8380417 ML_DSA_N = 256 @@ -73,6 +73,100 @@ def __init__(self, param='ML-DSA-65'): def h(self, s, l): return SHAKE256.new(s).read(l) + # Algorithm 2, ML-DSA.Sign(sk, M, ctx) + # XXX: Not covered by test vectors. + + def sign(self, sk, m, ctx, rnd_in=None, param=None): + if param != None: + self.__init__(param) + + if rnd_in == None: + rnd = b'\x00'*32 + else: + rnd = rnd_in + + mp = ( self.integer_to_bytes(0, 1) + + self.integer_to_bytes(len(ctx), 1) + ctx + m ) + sig = self.sign_internal(sk, mp, rnd) + return sig + + # Algorithm 3, ML-DSA.Verify(pk, M, sigma, ctx) + # XXX: Not covered by test vectors. + + def verify(self, pk, m, sig, ctx, param=None): + if param != None: + self.__init__(param) + if len(ctx) > 255: + return False + mp = ( self.integer_to_bytes(0, 1) + + self.integer_to_bytes(len(ctx), 1) + ctx + m) + return self.verify_internal(pk, mp, sig) + + # Algorithm 4, HashML-DSA.Sign(sk, M, ctx, PH) + # XXX: Not covered by test vectors. + + def hash_ml_dsa_sign(self, sk, m, ctx, ph, rnd_in=None, param=None): + if param != None: + self.__init__(param) + if len(ctx) > 255: + return None + + if rnd_in == None: + rnd = b'\x00'*32 + else: + rnd = rnd_in + + if ph == 'SHA-256': + oid = bytes([ 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, + 0x04, 0x02, 0x01]) + phm = SHA256.new(m).digest() + elif ph == 'SHA-512': + oid = bytes([ 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, + 0x04, 0x02, 0x03]) + phm = SHA512.new(m).digest() + elif ph == 'SHAKE128': + oid = bytes([ 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, + 0x04, 0x02, 0x0B]) + phm = SHAKE128.new(m).read(256 // 8) + else: + return None + + mp = ( self.integer_to_bytes(1, 1) + + self.integer_to_bytes(len(ctx), 1) + + oid + phm ) + sig = self.sign_internal(sk, mp, rnd) + return sig + + # Algorithm 5, HashML-DSA.Verify(pk, M, sig, ctx, PH) + # Note 2024-08-20: Not covered by test vectors. + + def hash_ml_dsa_verify(self, pk, m, sig, ctx, ph, param=None): + if param != None: + self.__init__(param) + if len(ctx) > 255: + return None + + if ph == 'SHA-256': + oid = bytes([ 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, + 0x04, 0x02, 0x01]) + phm = SHA256.new(m).digest() + elif ph == 'SHA-512': + oid = bytes([ 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, + 0x04, 0x02, 0x03]) + phm = SHA512.new(m).digest() + elif ph == 'SHAKE128': + oid = bytes([ 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, + 0x04, 0x02, 0x0B]) + phm = SHAKE128.new(m).read(256 // 8) + else: + return False + + mp = ( self.integer_to_bytes(1, 1) + + self.integer_to_bytes(len(ctx), 1) + + oid + phm ) + return self.verify_internal(pk, mp, sig) + + # Algorithm 6, ML-DSA.KeyGen_internal(xi) def keygen_internal(self, xi, param=None): diff --git a/fips205.py b/fips205.py new file mode 100644 index 0000000..de93424 --- /dev/null +++ b/fips205.py @@ -0,0 +1,638 @@ +# fips205.py +# 2023-11-24 Markku-Juhani O. Saarinen < mjos@iki.fi>. See LICENSE +# === FIPS 205 implementation https://doi.org/10.6028/NIST.FIPS.205 +# SLH-DSA / Stateless Hash-Based Digital Signature Standard + +# test_slhdsa is only used by the unit test in the end +from test_slhdsa import test_slhdsa + +# hashes +from Crypto.Hash import SHAKE256, SHA256, SHA512 + +# A class for handling Addresses (Section 4.2.) + +class ADRS: + # type constants + WOTS_HASH = 0 + WOTS_PK = 1 + TREE = 2 + FORS_TREE = 3 + FORS_ROOTS = 4 + WOTS_PRF = 5 + FORS_PRF = 6 + + def __init__(self, a=32): + """Initialize.""" + self.a = bytearray(a) + + def copy(self): + """ Make a copy of self.""" + return ADRS(self.a) + + def set_layer_address(self, x): + """ Set layer address.""" + self.a[ 0: 4] = x.to_bytes(4, byteorder='big') + + def set_tree_address(self, x): + """ Set tree address.""" + self.a[ 4:16] = x.to_bytes(12, byteorder='big') + + def set_key_pair_address(self, x): + """ Set key pair Address.""" + self.a[20:24] = x.to_bytes(4, byteorder='big') + + def get_key_pair_address(self): + """ Get key pair Address.""" + return int.from_bytes(self.a[20:24], byteorder='big') + + def set_tree_height(self, x): + """ Set FORS tree height.""" + self.a[24:28] = x.to_bytes(4, byteorder='big') + + def set_chain_address(self, x): + """ Set WOTS+ chain address.""" + self.a[24:28] = x.to_bytes(4, byteorder='big') + + def set_tree_index(self, x): + """ Set FORS tree index.""" + self.a[28:32] = x.to_bytes(4, byteorder='big') + + def get_tree_index(self): + """ Get FORS tree index.""" + return int.from_bytes(self.a[28:32], byteorder='big') + + def set_hash_address(self, x): + """ Set WOTS+ hash address.""" + self.a[28:32] = x.to_bytes(4, byteorder='big') + + def set_type_and_clear(self, t): + """ The member function ADRS.setTypeAndClear(Y) for addresses sets + the type of the ADRS to Y and sets the fnal 12 bytes of the ADRS + to zero.""" + self.a[16:20] = t.to_bytes(4, byteorder='big') + for i in range(12): + self.a[20 + i] = 0 + + def adrs(self): + """ Return the ADRS as bytes.""" + return self.a + + def adrsc(self): + """ Compressed address ADRDc used with SHA-2.""" + return self.a[3:4] + self.a[8 : 16] + self.a[19:20] + self.a[20:32] + + +# SLH-DSA Implementation + +class SLH_DSA: + + # initialize + def __init__(self, hashname='SHAKE', n=16, h=66, + d=22, hp=3, a=6, k=33, lg_w=4, m=34): + self.hashname = hashname + self.n = n + self.h = h + self.d = d + self.hp = hp + self.a = a + self.k = k + self.lg_w = lg_w + self.m = m + + # instantiate hash functions + if hashname == 'SHAKE': + self.h_msg = self.shake_h_msg + self.prf = self.shake_prf + self.prf_msg = self.shake_prf_msg + self.h_f = self.shake_f + self.h_h = self.shake_f + self.h_t = self.shake_f + elif hashname == 'SHA2' and self.n == 16: + self.h_msg = self.sha256_h_msg + self.prf = self.sha256_prf + self.prf_msg = self.sha256_prf_msg + self.h_f = self.sha256_f + self.h_h = self.sha256_f + self.h_t = self.sha256_f + elif hashname == 'SHA2' and self.n > 16: + self.h_msg = self.sha512_h_msg + self.prf = self.sha256_prf + self.prf_msg = self.sha512_prf_msg + self.h_f = self.sha256_f + self.h_h = self.sha512_h + self.h_t = self.sha512_h + + # equations 5.1 - 5.4 + self.w = 2**self.lg_w + self.len1 = (8 * self.n + (self.lg_w - 1)) // self.lg_w + self.len2 = (self.len1 * + (self.w - 1)).bit_length() // self.lg_w + 1 + self.len = self.len1 + self.len2 + + # external parameter sizes + self.pk_sz = 2 * self.n + self.sk_sz = 4 * self.n + self.sig_sz = (1 + self.k*(1 + self.a) + self.h + + self.d * self.len) * self.n + + # 10.1. SLH-DSA Using SHAKE + def shake256(self, x, l): + """SHAKE256(x, l): Internal hook.""" + return SHAKE256.new(x).read(l) + + def shake_h_msg(self, r, pk_seed, pk_root, m): + return self.shake256(r + pk_seed + pk_root + m, self.m) + + def shake_prf(self, pk_seed, sk_seed, adrs): + return self.shake256(pk_seed + adrs.adrs() + sk_seed, self.n) + + def shake_prf_msg(self, sk_prf, opt_rand, m): + return self.shake256(sk_prf + opt_rand + m, self.n) + + def shake_f(self, pk_seed, adrs, m1): + return self.shake256(pk_seed + adrs.adrs() + m1, self.n) + + # Various constructions required for SHA-2 variants. + + def sha256(self, x, n=32): + """Tranc_n(SHA2-256(x)).""" + return SHA256.new(x).digest()[0:n] + + def sha512(self, x, n=64): + """Tranc_n(SHA2-512(x)).""" + return SHA512.new(x).digest()[0:n] + + def mgf(self, hash_f, hash_l, mgf_seed, mask_len): + """NIST SP 800-56B REV. 2 / The Mask Generation Function (MGF).""" + t = b'' + for c in range((mask_len + hash_l - 1) // hash_l): + t += hash_f(mgf_seed + c.to_bytes(4, byteorder='big')) + return t[0:mask_len] + + def mgf_sha256(self, mgf_seed, mask_len): + """MGF1-SHA1-256(mgfSeed, maskLen).""" + return self.mgf(self.sha256, 32, mgf_seed, mask_len) + + def mgf_sha512(self, mgf_seed, mask_len): + """MGF1-SHA1-512(mgfSeed, maskLen).""" + return self.mgf(self.sha512, 64, mgf_seed, mask_len) + + def hmac(self, hash_f, hash_l, hash_b, k, text): + """FIPS PUB 198-1 HMAC.""" + if len(k) > hash_b: + k = hash_f(k) + ipad = bytearray(hash_b) + ipad[0:len(k)] = k + opad = bytearray(ipad) + for i in range(hash_b): + ipad[i] ^= 0x36 + opad[i] ^= 0x5C + return hash_f(opad + hash_f(ipad + text)) + + def hmac_sha256(self, k, text, n=32): + """Trunc_n(HMAC-SHA-256(k, text)): Internal hook.""" + return self.hmac(self.sha256, 32, 64, k, text)[0:n] + + def hmac_sha512(self, k, text, n=64): + """Trunc_n(HMAC-SHA-256(k, text)): Internal hook.""" + return self.hmac(self.sha512, 64, 128, k, text)[0:n] + + # 10.2 SLH-DSA Using SHA2 for Security Category 1 + + def sha256_h_msg(self, r, pk_seed, pk_root, m): + return self.mgf_sha256( r + pk_seed + + self.sha256(r + pk_seed + pk_root + m), self.m) + + def sha256_prf(self, pk_seed, sk_seed, adrs): + return self.sha256(pk_seed + bytes(64 - self.n) + + adrs.adrsc() + sk_seed, self.n) + + def sha256_prf_msg(self, sk_prf, opt_rand, m): + return self.hmac_sha256(sk_prf, opt_rand + m, self.n) + + def sha256_f(self, pk_seed, adrs, m1): + return self.sha256(pk_seed + bytes(64 - self.n) + + adrs.adrsc() + m1, self.n) + + # 10.3 SLH-DSA Using SHA2 for Security Categories 3 and 5 + + def sha512_h_msg(self, r, pk_seed, pk_root, m): + return self.mgf_sha512( r + pk_seed + + self.sha512(r + pk_seed + pk_root + m), self.m) + + def sha512_prf_msg(self, sk_prf, opt_rand, m): + return self.hmac_sha512(sk_prf, opt_rand + m, self.n) + + def sha512_h(self, pk_seed, adrs, m2): + return self.sha512(pk_seed + bytes(128 - self.n) + + adrs.adrsc() + m2, self.n) + + # --- FIPS 205 Algorithms + + def to_int(self, s, n): + """ Algorithm 2: toInt(X, n). Convert a byte string to an integer.""" + t = 0 + for i in range(n): + t = (t << 8) + int(s[i]) + return t + + def to_byte(self, x, n): + """ Algorithm 3: toByte(x, n). Convert an integer to a byte string.""" + t = x + s = bytearray(n) + for i in range(n): + s[n - 1 - i] = t & 0xFF + t >>= 8 + return s + + def base_2b(self, s, b, out_len): + """ Algorithm 4: base_2b (X, b, out_len). + Compute the base 2**b representation of X.""" + i = 0 # in + c = 0 # bits + t = 0 # total + v = [] # baseb + m = (1 << b) - 1 # mask + for j in range(out_len): + while c < b: + t = (t << 8) + int(s[i]) + i += 1 + c += 8 + c -= b + v += [ (t >> c) & m ] + return v + + def chain(self, x, i, s, pk_seed, adrs): + """ Algorithm 5: chain(X, i, s, PK.seed, ADRS). + Chaining function used in WOTS+.""" + if i + s >= self.w: + return None + t = x + for j in range(i, i + s): + adrs.set_hash_address(j) + t = self.h_f(pk_seed, adrs, t) + return t + + def wots_pkgen(self, sk_seed, pk_seed, adrs): + """ Algorithm 6: wots_PKgen(SK.seed, PK.seed, ADRS). + Generate a WOTS+ public key.""" + sk_adrs = adrs.copy() + sk_adrs.set_type_and_clear(ADRS.WOTS_PRF) + sk_adrs.set_key_pair_address(adrs.get_key_pair_address()) + tmp = b'' + for i in range(self.len): + sk_adrs.set_chain_address(i) + sk = self.prf(pk_seed, sk_seed, sk_adrs) + adrs.set_chain_address(i) + tmp += self.chain(sk, 0, self.w - 1, pk_seed, adrs) + wotspk_adrs = adrs.copy() + wotspk_adrs.set_type_and_clear(ADRS.WOTS_PK) + wotspk_adrs.set_key_pair_address(adrs.get_key_pair_address()) + pk = self.h_t(pk_seed, wotspk_adrs, tmp) + return pk + + def wots_sign(self, m, sk_seed, pk_seed, adrs): + """ Algorithm 7: wots_sign(M, SK.seed, PK.seed, ADRS). + Generate a WOTS+ signature on an n-byte message.""" + csum = 0 + msg = self.base_2b(m, self.lg_w, self.len1) + for i in range(self.len1): + csum += self.w - 1 - msg[i] + csum <<= ((8 - ((self.len2 * self.lg_w) % 8)) % 8) + msg += self.base_2b(self.to_byte(csum, + (self.len2 * self.lg_w + 7) // 8), self.lg_w, self.len2) + sk_adrs = adrs.copy() + sk_adrs.set_type_and_clear(ADRS.WOTS_PRF) + sk_adrs.set_key_pair_address(adrs.get_key_pair_address()) + sig = b'' + for i in range(self.len): + sk_adrs.set_chain_address(i) + sk = self.prf(pk_seed, sk_seed, sk_adrs) + adrs.set_chain_address(i) + sig += self.chain(sk, 0, msg[i], pk_seed, adrs) + + return sig + + def wots_pk_from_sig(self, sig, m, pk_seed, adrs): + """ Algorithm 8: wots_PKFromSig(sig, M, PK.seed, ADRS). + Compute a WOTS+ public key from a message and its signature.""" + csum = 0 + msg = self.base_2b(m, self.lg_w, self.len1) + for i in range(self.len1): + csum += self.w - 1 - msg[i] + csum <<= ((8 - ((self.len2 * self.lg_w) % 8)) % 8) + msg += self.base_2b(self.to_byte(csum, + (self.len2 * self.lg_w + 7) // 8), self.lg_w, self.len2) + tmp = b'' + for i in range(self.len): + adrs.set_chain_address(i) + tmp += self.chain(sig[i*self.n:(i+1)*self.n], + msg[i], self.w - 1 - msg[i], + pk_seed, adrs) + wots_pk_adrs = adrs.copy() + wots_pk_adrs.set_type_and_clear(ADRS.WOTS_PK) + wots_pk_adrs.set_key_pair_address(adrs.get_key_pair_address()) + pk_sig = self.h_t(pk_seed, wots_pk_adrs, tmp) + return pk_sig + + def xmss_node(self, sk_seed, i, z, pk_seed, adrs): + """ Algorithm 9: xmss_node(SK.seed, i, z, PK.seed, ADRS). + Compute the root of a Merkle subtree of WOTS+ public keys.""" + if z > self.hp or i >= 2**(self.hp - z): + return None + if z == 0: + adrs.set_type_and_clear(ADRS.WOTS_HASH) + adrs.set_key_pair_address(i) + node = self.wots_pkgen(sk_seed, pk_seed, adrs) + else: + lnode = self.xmss_node(sk_seed, 2 * i, z - 1, pk_seed, adrs) + rnode = self.xmss_node(sk_seed, 2 * i + 1, z - 1, pk_seed, adrs) + adrs.set_type_and_clear(ADRS.TREE) + adrs.set_tree_height(z) + adrs.set_tree_index(i) + node = self.h_h(pk_seed, adrs, lnode + rnode) + return node + + def xmss_sign(self, m, sk_seed, idx, pk_seed, adrs): + """ Algorithm 10: xmss_sign(M, SK.seed, idx, PK.seed, ADRS). + Generate an XMSS signature.""" + auth = b'' + for j in range(self.hp): + k = (idx >> j) ^ 1 + auth += self.xmss_node(sk_seed, k, j, pk_seed, adrs) + adrs.set_type_and_clear(ADRS.WOTS_HASH) + adrs.set_key_pair_address(idx) + sig = self.wots_sign(m, sk_seed, pk_seed, adrs) + sig_xmss = sig + auth + return sig_xmss + + def xmss_pk_from_sig(self, idx, sig_xmss, m, pk_seed, adrs): + """ Algorithm 11: xmss_PKFromSig(idx, SIG_XMSS, M, PK.seed, ADRS). + Compute an XMSS public key from an XMSS signature.""" + adrs.set_type_and_clear(ADRS.WOTS_HASH) + adrs.set_key_pair_address(idx) + sig = sig_xmss[0:self.len*self.n] + auth = sig_xmss[self.len*self.n:] + node_0 = self.wots_pk_from_sig(sig, m, pk_seed, adrs) + + adrs.set_type_and_clear(ADRS.TREE) + adrs.set_tree_index(idx) + for k in range(self.hp): + adrs.set_tree_height(k + 1) + auth_k = auth[k*self.n:(k+1)*self.n] + if (idx >> k) & 1 == 0: + adrs.set_tree_index(adrs.get_tree_index() // 2) + node_1 = self.h_h(pk_seed, adrs, node_0 + auth_k) + else: + adrs.set_tree_index((adrs.get_tree_index() - 1) // 2) + node_1 = self.h_h(pk_seed, adrs, auth_k + node_0) + node_0 = node_1 + + return node_0 + + def ht_sign(self, m, sk_seed, pk_seed, i_tree, i_leaf): + """ Algorithm 12: ht_sign(M, SK.seed, PK.seed, idx_tree, idx_leaf). + Generate a hypertree signature.""" + adrs = ADRS() + adrs.set_tree_address(i_tree) + sig_tmp = self.xmss_sign(m, sk_seed, i_leaf, pk_seed, adrs) + sig_ht = sig_tmp + root = self.xmss_pk_from_sig(i_leaf, sig_tmp, m, pk_seed, adrs) + hp_m = ((1 << self.hp) - 1) + for j in range(1, self.d): + i_leaf = i_tree & hp_m + i_tree = i_tree >> self.hp + adrs.set_layer_address(j) + adrs.set_tree_address(i_tree) + sig_tmp = self.xmss_sign(root, sk_seed, i_leaf, pk_seed, adrs) + sig_ht += sig_tmp + if j < self.d - 1: + root = self.xmss_pk_from_sig(i_leaf, sig_tmp, root, + pk_seed, adrs) + return sig_ht + + def ht_verify(self, m, sig_ht, pk_seed, i_tree, i_leaf, pk_root): + """ Algorithm 13: ht_verify(M, SIG_HT, PK.seed, idx_tree, idx_leaf, + PK.root). Verify a hypertree signature.""" + adrs = ADRS() + adrs.set_tree_address(i_tree) + sig_tmp = sig_ht[0:(self.hp + self.len)*self.n] + node = self.xmss_pk_from_sig(i_leaf, sig_tmp, m, pk_seed, adrs) + + hp_m = ((1 << self.hp) - 1) + for j in range(1, self.d): + i_leaf = i_tree & hp_m + i_tree = i_tree >> self.hp + adrs.set_layer_address(j) + adrs.set_tree_address(i_tree) + sig_tmp = sig_ht[j*(self.hp + self.len)*self.n: + (j+1)*(self.hp + self.len)*self.n] + node = self.xmss_pk_from_sig(i_leaf, sig_tmp, node, + pk_seed, adrs) + return node == pk_root + + def fors_sk_gen(self, sk_seed, pk_seed, adrs, idx): + """ Algorithm 14: fors_SKgen(SK.seed, PK.seed, ADRS, idx). + Generate a FORS private-key value.""" + sk_adrs = adrs.copy() + sk_adrs.set_type_and_clear(ADRS.FORS_PRF) + sk_adrs.set_key_pair_address(adrs.get_key_pair_address()) + sk_adrs.set_tree_index(idx) + return self.prf(pk_seed, sk_seed, sk_adrs) + + def fors_node(self, sk_seed, i, z, pk_seed, adrs): + """ Algorithm 15: fors_node(SK.seed, i, z, PK.seed, ADRS). + Compute the root of a Merkle subtree of FORS public values.""" + + if z > self.a or i >= (self.k << (self.a - z)): + return None + if z == 0: + sk = self.fors_sk_gen(sk_seed, pk_seed, adrs, i) + adrs.set_tree_height(0) + adrs.set_tree_index(i) + node = self.h_f(pk_seed, adrs, sk) + else: + lnode = self.fors_node(sk_seed, 2 * i, z - 1, pk_seed, adrs) + rnode = self.fors_node(sk_seed, 2 * i + 1, z - 1, pk_seed, adrs) + adrs.set_tree_height(z) + adrs.set_tree_index(i) + node = self.h_h(pk_seed, adrs, lnode + rnode) + return node + + def fors_sign(self, md, sk_seed, pk_seed, adrs): + """ Algorithm 16: fors_sign(md, SK.seed, PK.seed, ADRS). + Generate a FORS signature.""" + sig_fors = b'' + indices = self.base_2b(md, self.a, self.k) + for i in range(self.k): + sig_fors += self.fors_sk_gen(sk_seed, pk_seed, adrs, + (i << self.a) + indices[i]) + for j in range(self.a): + s = (indices[i] >> j) ^ 1 + sig_fors += self.fors_node(sk_seed, + (i << (self.a - j)) + s, j, + pk_seed, adrs) + return sig_fors + + def fors_pk_from_sig(self, sig_fors, md, pk_seed, adrs): + """ Algorithm 17: fors_pkFromSig(SIG_FORS, md, PK.seed, ADRS). + Compute a FORS public key from a FORS signature.""" + def get_sk(sig_fors, i): + return sig_fors[i*(self.a+1)*self.n:(i*(self.a+1)+1)*self.n] + + def get_auth(sig_fors, i): + return sig_fors[(i*(self.a+1)+1)*self.n:(i+1)*(self.a+1)*self.n] + + indices = self.base_2b(md, self.a, self.k) + + root = b'' + for i in range(self.k): + sk = get_sk(sig_fors, i) + adrs.set_tree_height(0) + adrs.set_tree_index((i << self.a) + indices[i]) + node_0 = self.h_f(pk_seed, adrs, sk) + + auth = get_auth(sig_fors, i) + for j in range(self.a): + auth_j = auth[j*self.n:(j+1)*self.n] + adrs.set_tree_height(j + 1) + if (indices[i] >> j) & 1 == 0: + adrs.set_tree_index(adrs.get_tree_index() // 2) + node_1 = self.h_h(pk_seed, adrs, node_0 + auth_j) + else: + adrs.set_tree_index((adrs.get_tree_index() - 1) // 2) + node_1 = self.h_h(pk_seed, adrs, auth_j + node_0) + node_0 = node_1 + root += node_0 + + fors_pk_adrs = adrs.copy() + fors_pk_adrs.set_type_and_clear(ADRS.FORS_ROOTS) + fors_pk_adrs.set_key_pair_address(adrs.get_key_pair_address()) + pk = self.h_t(pk_seed, fors_pk_adrs, root) + return pk + + def slh_keygen_internal(self, sk_seed, sk_prf, pk_seed): + """ Algorithm 18: slh_keygen_internal().""" + + # The behavior is different if one performs three distinct + adrs = ADRS() + adrs.set_layer_address(self.d - 1) + pk_root = self.xmss_node(sk_seed, 0, self.hp, pk_seed, adrs) + sk = sk_seed + sk_prf + pk_seed + pk_root + pk = pk_seed + pk_root + return (pk, sk) # Alg 17 has (sk, pk) + + def split_digest(self, digest): + """ Helper: Lines 11-16 of Alg 18 / Lines 10-15 of Alg 19.""" + ka1 = (self.k * self.a + 7) // 8 + md = digest[0:ka1] + hd = self.h // self.d + hhd = self.h - hd + ka2 = ka1 + ((hhd + 7) // 8) + i_tree = self.to_int( digest[ka1:ka2], (hhd + 7) // 8) % (2 ** hhd) + ka3 = ka2 + ((hd + 7) // 8) + i_leaf = self.to_int( digest[ka2:ka3], (hd + 7) // 8) % (2 ** hd) + return (md, i_tree, i_leaf) + + def slh_sign_internal(self, m, sk, addrnd): + """ Algorithm 19: slh_sign_internal(M, SK). """ + adrs = ADRS() + sk_seed = sk[ 0: self.n] + sk_prf = sk[ self.n:2*self.n] + pk_seed = sk[2*self.n:3*self.n] + pk_root = sk[3*self.n:] + + if addrnd == None: + addrnd = pk_seed + + r = self.prf_msg(sk_prf, addrnd, m) + sig = r + + digest = self.h_msg(r, pk_seed, pk_root, m) + (md, i_tree, i_leaf) = self.split_digest(digest) + + adrs.set_tree_address(i_tree) + adrs.set_type_and_clear(ADRS.FORS_TREE) + adrs.set_key_pair_address(i_leaf) + + sig_fors = self.fors_sign(md, sk_seed, pk_seed, adrs) + sig += sig_fors + + pk_fors = self.fors_pk_from_sig(sig_fors, md, pk_seed, adrs) + sig_ht = self.ht_sign(pk_fors, sk_seed, pk_seed, i_tree, i_leaf) + sig += sig_ht + + return sig + + def slh_verify_internal(self, m, sig, pk): + """ Algorithm 20: slh_verify_internal(M, SIG, PK).""" + if len(sig) != self.sig_sz or len(pk) != self.pk_sz: + return False + + pk_seed = pk[:self.n] + pk_root = pk[self.n:] + + adrs = ADRS() + r = sig[0:self.n] + sig_fors = sig[self.n:(1+self.k*(1+self.a))*self.n] + sig_ht = sig[(1 + self.k*(1 + self.a))*self.n:] + + digest = self.h_msg(r, pk_seed, pk_root, m) + (md, i_tree, i_leaf) = self.split_digest(digest) + + adrs.set_tree_address(i_tree) + adrs.set_type_and_clear(ADRS.FORS_TREE) + adrs.set_key_pair_address(i_leaf) + + pk_fors = self.fors_pk_from_sig(sig_fors, md, pk_seed, adrs) + return self.ht_verify(pk_fors, sig_ht, pk_seed, + i_tree, i_leaf, pk_root) + +# Section 11: Table 2. SLH-DSA parameter sets + +SLH_DSA_PARAMS = { + 'SLH-DSA-SHA2-128s': SLH_DSA(hashname='SHA2', + n=16, h=63, d=7, hp=9, a=12, k=14, lg_w=4, m=30), + 'SLH-DSA-SHAKE-128s': SLH_DSA(hashname='SHAKE', + n=16, h=63, d=7, hp=9, a=12, k=14, lg_w=4, m=30), + 'SLH-DSA-SHA2-128f': SLH_DSA(hashname='SHA2', + n=16, h=66, d=22, hp=3, a=6, k=33, lg_w=4, m=34), + 'SLH-DSA-SHAKE-128f': SLH_DSA(hashname='SHAKE', + n=16, h=66, d=22, hp=3, a=6, k=33, lg_w=4, m=34), + 'SLH-DSA-SHA2-192s': SLH_DSA(hashname='SHA2', + n=24, h=63, d=7, hp=9, a=14, k=17, lg_w=4, m=39), + 'SLH-DSA-SHAKE-192s': SLH_DSA(hashname='SHAKE', + n=24, h=63, d=7, hp=9, a=14, k=17, lg_w=4, m=39), + 'SLH-DSA-SHA2-192f': SLH_DSA(hashname='SHA2', + n=24, h=66, d=22, hp=3, a=8, k=33, lg_w=4, m=42), + 'SLH-DSA-SHAKE-192f': SLH_DSA(hashname='SHAKE', + n=24, h=66, d=22, hp=3, a=8, k=33, lg_w=4, m=42), + 'SLH-DSA-SHA2-256s': SLH_DSA(hashname='SHA2', + n=32, h=64, d=8, hp=8, a=14, k=22, lg_w=4, m=47), + 'SLH-DSA-SHAKE-256s': SLH_DSA(hashname='SHAKE', + n=32, h=64, d=8, hp=8, a=14, k=22, lg_w=4, m=47), + 'SLH-DSA-SHA2-256f': SLH_DSA(hashname='SHA2', + n=32, h=68, d=17, hp=4, a=9, k=35, lg_w=4, m=49), + 'SLH-DSA-SHAKE-256f': SLH_DSA(hashname='SHAKE', + n=32, h=68, d=17, hp=4, a=9, k=35, lg_w=4, m=49) +} + +def param_keygen( sk_seed, sk_prf, pk_seed, param): + slh = SLH_DSA_PARAMS[param] + return slh.slh_keygen_internal(sk_seed, sk_prf, pk_seed) + +def param_sign( msg, sk, addrnd, param): + slh = SLH_DSA_PARAMS[param] + return slh.slh_sign_internal(msg, sk, addrnd) + +def param_verify( msg, sig, pk, param): + slh = SLH_DSA_PARAMS[param] + return slh.slh_verify_internal( msg, sig, pk ) + +# run the test on these functions +if __name__ == '__main__': + test_slhdsa(param_keygen, + param_sign, + param_verify, + '(fips205.py)') + diff --git a/genvals_mldsa.py b/genvals_mldsa.py index 8907250..2157cec 100644 --- a/genvals_mldsa.py +++ b/genvals_mldsa.py @@ -60,7 +60,7 @@ def nist_mldsa_keygen(seed, param='ML-DSA-65'): return (pk, sk) def nist_mldsa_sign(sk, m, rnd, param='ML-DSA-65'): - """ sig = ML-DSA.Sign(sk, M, det, param='ML-DSA-65'). """ + """ sig = ML-DSA.Sign(sk, M, rnd, param='ML-DSA-65'). """ dilithium = Dilithium( ml_dsa_ps[param], NativeFastSha.NativeShaFactory(), EntropyProvider(Random800_90())) diff --git a/genvals_slhdsa.py b/genvals_slhdsa.py new file mode 100644 index 0000000..9d9ac35 --- /dev/null +++ b/genvals_slhdsa.py @@ -0,0 +1,104 @@ +# genvals_slhdsa.py +# 2024-08-18 Markku-Juhani O. Saarinen +# === Python wrapper for SLH-DSA / SPHINCS+ in the NIST ACVTS Libraries + +# test_slhdsa is only used by the unit test in the end +from test_slhdsa import test_slhdsa + +# .NET Core +from pythonnet import load +load("coreclr") +import os,clr + +# you may have to adjust these paths (need to be absolute!) +abs_path = os.getcwd() + '/ACVP-Server/gen-val/src/crypto/' +clr.AddReference(abs_path + 'test/NIST.CVP.ACVTS.Libraries.Crypto.SLHDSA.Tests/bin/Debug/net6.0/NLog.dll') +clr.AddReference(abs_path + 'test/NIST.CVP.ACVTS.Libraries.Crypto.SLHDSA.Tests/bin/Debug/net6.0/NIST.CVP.ACVTS.Libraries.Common.dll') +clr.AddReference(abs_path + 'test/NIST.CVP.ACVTS.Libraries.Crypto.SLHDSA.Tests/bin/Debug/net6.0/NIST.CVP.ACVTS.Libraries.Crypto.dll') + +# imports for slh-dsa +from System.Collections import BitArray +from NIST.CVP.ACVTS.Libraries.Crypto.SHA import NativeFastSha +from NIST.CVP.ACVTS.Libraries.Crypto.SLHDSA import Slhdsa, Wots, Xmss, Hypertree, Fors +from NIST.CVP.ACVTS.Libraries.Crypto.Common.PQC.SLHDSA.Enums import SlhdsaParameterSet +from NIST.CVP.ACVTS.Libraries.Crypto.Common.PQC.SLHDSA.Helpers import AttributesHelper +from NIST.CVP.ACVTS.Libraries.Crypto.Common.PQC.SLHDSA import PublicKey, PrivateKey + +# XXX supress debug output as the SLH-DSA code currently has +# Console.WriteLine() debug. + +import System +#System.Console.SetOut(System.IO.TextWriter.Null); + +# SLH-DSA parameter sets + +slh_dsa_ps = { + 'SLH-DSA-SHA2-128s' : SlhdsaParameterSet.SLH_DSA_SHA2_128s, + 'SLH-DSA-SHA2-128f' : SlhdsaParameterSet.SLH_DSA_SHA2_128f, + 'SLH-DSA-SHA2-192s' : SlhdsaParameterSet.SLH_DSA_SHA2_192s, + 'SLH-DSA-SHA2-192f' : SlhdsaParameterSet.SLH_DSA_SHA2_192f, + 'SLH-DSA-SHA2-256s' : SlhdsaParameterSet.SLH_DSA_SHA2_256s, + 'SLH-DSA-SHA2-256f' : SlhdsaParameterSet.SLH_DSA_SHA2_256f, + 'SLH-DSA-SHAKE-128s' : SlhdsaParameterSet.SLH_DSA_SHAKE_128s, + 'SLH-DSA-SHAKE-128f' : SlhdsaParameterSet.SLH_DSA_SHAKE_128f, + 'SLH-DSA-SHAKE-192s' : SlhdsaParameterSet.SLH_DSA_SHAKE_192s, + 'SLH-DSA-SHAKE-192f' : SlhdsaParameterSet.SLH_DSA_SHAKE_192f, + 'SLH-DSA-SHAKE-256s' : SlhdsaParameterSet.SLH_DSA_SHAKE_256s, + 'SLH-DSA-SHAKE-256f' : SlhdsaParameterSet.SLH_DSA_SHAKE_256f +} + +# helper functions + +def nist_slh_getinstance(): + t_shaf = NativeFastSha.NativeShaFactory() + t_wots = Wots(t_shaf) + t_xmss = Xmss(t_shaf, t_wots) + t_htree = Hypertree(t_xmss) + t_fors = Fors(t_shaf) + return Slhdsa(t_shaf, t_xmss, t_htree, t_fors) + +# test wrappers for NIST functions + +def nist_slh_keygen( sk_seed, sk_prf, pk_seed, param='SLH-DSA-SHA2-128s'): + slhdsa = nist_slh_getinstance() + t_attrb = AttributesHelper.GetParameterSetAttribute(slh_dsa_ps[param]) + t_keys = slhdsa.SlhKeyGen( sk_seed, sk_prf, pk_seed, t_attrb ); + + # it seems that caller has to construct the concatenated byte blobs + pk = ( bytes(t_keys.PublicKey.PkSeed) + + bytes(t_keys.PublicKey.PkRoot) ) + sk = ( bytes(t_keys.PrivateKey.SkSeed) + + bytes(t_keys.PrivateKey.SkPrf) + + bytes(t_keys.PrivateKey.PkSeed) + + bytes(t_keys.PrivateKey.PkRoot) ) + return (pk, sk) + +def nist_slh_sign( msg, sk, addrnd, param='SLH-DSA-SHA2-128s'): + slhdsa = nist_slh_getinstance() + t_attrb = AttributesHelper.GetParameterSetAttribute(slh_dsa_ps[param]) + n = t_attrb.N + t_sk = PrivateKey(sk[0:n], sk[n:2*n], sk[2*n:3*n], sk[3*n:4*n]) + + # "substitute opt_rand <- PK.seed for the deterministic variant" + if addrnd == None: + addrnd = sk[2*n:3*n] + + t_sig = slhdsa.SlhSignNonDeterministic(msg, t_sk, addrnd, t_attrb); + return bytes(t_sig) + + +def nist_slh_verify( msg, sig, pk, param='SLH-DSA-SHA2-128s'): + slhdsa = nist_slh_getinstance() + t_attrb = AttributesHelper.GetParameterSetAttribute(slh_dsa_ps[param]) + n = t_attrb.N + t_pk = PublicKey(pk[0:n], pk[n:2*n]) + t_res = slhdsa.SlhVerify(msg, sig, t_pk, t_attrb) + return t_res.Success + +# run the test on these functions +if __name__ == '__main__': + test_slhdsa(nist_slh_keygen, + nist_slh_sign, + nist_slh_verify, + '(NIST Gen/Vals)') + diff --git a/test_slhdsa.py b/test_slhdsa.py new file mode 100644 index 0000000..1c275ce --- /dev/null +++ b/test_slhdsa.py @@ -0,0 +1,229 @@ +# test_slhdsa.py +# 2024-08-19 Markku-Juhani O. Saarinen +# === SLH-DSA / Dilithium KAT test with json files + +import json + +# === read json prompts and responses === + +# Load key generation KATs + +def slhdsa_load_keygen(req_fn, res_fn): + with open(req_fn) as f: + keygen_req = json.load(f) + with open(res_fn) as f: + keygen_res = json.load(f) + + keygen_kat = [] + for qtg in keygen_req['testGroups']: + alg = qtg['parameterSet'] + tgid = qtg['tgId'] + + rtg = None + for tg in keygen_res['testGroups']: + if tg['tgId'] == tgid: + rtg = tg['tests'] + break + + for qt in qtg['tests']: + tcid = qt['tcId'] + for t in rtg: + if t['tcId'] == tcid: + qt.update(t) + qt['parameterSet'] = alg + keygen_kat += [qt] + return keygen_kat + +# Perform key generation tests on keygen_func + +def slhdsa_test_keygen(keygen_kat, keygen_func, iut=''): + keygen_pass = 0 + keygen_fail = 0 + + for x in keygen_kat: + # run keygen + (pk, sk) = keygen_func( bytes.fromhex(x['skSeed']), + bytes.fromhex(x['skPrf']), + bytes.fromhex(x['pkSeed']), + x['parameterSet']) + # compare + tc = x['parameterSet'] + ' KeyGen/' + str(x['tcId']) + if pk == bytes.fromhex(x['pk']) and sk == bytes.fromhex(x['sk']): + keygen_pass += 1 + print(tc, 'pass') + else: + keygen_fail += 1 + print(tc, 'pk ref=', x['pk']) + print(tc, 'pk got=', pk.hex()) + print(tc, 'sk ref=', x['sk']) + print(tc, 'sk got=', sk.hex()) + + print(f'SLH-DSA KeyGen {iut}: PASS= {keygen_pass} FAIL= {keygen_fail}') + return keygen_fail + +# Load signature Generation KATs + +def slhdsa_load_siggen(req_fn, res_fn): + with open(req_fn) as f: + siggen_req = json.load(f) + with open(res_fn) as f: + siggen_res = json.load(f) + + siggen_kat = [] + for qtg in siggen_req['testGroups']: + alg = qtg['parameterSet'] + det = qtg['deterministic'] + tgid = qtg['tgId'] + + rtg = None + for tg in siggen_res['testGroups']: + if tg['tgId'] == tgid: + rtg = tg['tests'] + break + + for qt in qtg['tests']: + tcid = qt['tcId'] + for t in rtg: + if t['tcId'] == tcid: + qt.update(t) + qt['parameterSet'] = alg + qt['deterministic'] = det + siggen_kat += [qt] + return siggen_kat + +# Perform signature generation tests on siggen_func + +def slhdsa_test_siggen(siggen_kat, siggen_func, iut=''): + + siggen_pass = 0 + siggen_fail = 0 + siggen_skip = 0 + + for x in siggen_kat: + if 'additionalRandomness' in x: + addrnd = bytes.fromhex(x['additionalRandomness']) + else: + addrnd = None + # generate signature + sig = siggen_func( bytes.fromhex(x['message']), + bytes.fromhex(x['sk']), + addrnd, + x['parameterSet']) + + # compare + tc = x['parameterSet'] + ' SigGen/' + str(x['tcId']) + if sig == bytes.fromhex(x['signature']): + siggen_pass += 1 + print(tc, 'pass') + else: + siggen_fail += 1 + print(tc, 'fail') + print(tc, 'sig ref=', x['signature']) + print(tc, 'sig got=', sig.hex()) + + print( f'SLH-DSA SigGen {iut}:', + f'PASS= {siggen_pass} FAIL= {siggen_fail} SKIP= {siggen_skip}') + + return siggen_fail + +# Load signature verification KATs + +def slhdsa_load_sigver(req_fn, res_fn, int_fn): + + with open(req_fn) as f: + sigver_req = json.load(f) + with open(res_fn) as f: + sigver_res = json.load(f) + with open(int_fn) as f: + sigver_int = json.load(f) + + sigver_kat = [] + for qtg in sigver_req['testGroups']: + alg = qtg['parameterSet'] + tgid = qtg['tgId'] + + rtg = None + for tg in sigver_res['testGroups']: + if tg['tgId'] == tgid: + rtg = tg['tests'] + break + + itg = None + for tg in sigver_int['testGroups']: + if tg['tgId'] == tgid: + itg = tg['tests'] + break + + for qt in qtg['tests']: + pk = qt['pk'] + tcid = qt['tcId'] + for t in rtg: + if t['tcId'] == tcid: + qt.update(t) + # message, signature in this file overrides prompts + for t in itg: + if t['tcId'] == tcid: + qt.update(t) + qt['parameterSet'] = alg + qt['pk'] = pk + sigver_kat += [qt] + return sigver_kat + +# Perform signature verification tests on sigver_func + +def slhdsa_test_sigver(sigver_kat, sigver_func, iut=''): + + sigver_pass = 0 + sigver_fail = 0 + + for x in sigver_kat: + # verify signature + res = sigver_func( bytes.fromhex(x['message']), + bytes.fromhex(x['signature']), + bytes.fromhex(x['pk']), + x['parameterSet']) + + # compare result + tc = x['parameterSet'] + ' SigVer/' + str(x['tcId']) + if res == x['testPassed']: + sigver_pass += 1 + print(tc, 'pass') + else: + sigver_fail += 1 + print(tc, 'res ref=', x['testPassed']) + print(tc, 'res got=', res) + print(tc, x['reason']) + + print(f'SLH-DSA SigVer {iut}: PASS= {sigver_pass} FAIL= {sigver_fail}') + return sigver_fail + +# === run the tests === + +# load all KATs +#json_path = 'ACVP-Server/gen-val/json-files/' +json_path = 'json-copy/' + +keygen_kat = slhdsa_load_keygen( + json_path + 'SLH-DSA-keyGen-FIPS205/prompt.json', + json_path + 'SLH-DSA-keyGen-FIPS205/expectedResults.json') + +siggen_kat = slhdsa_load_siggen( + json_path + 'SLH-DSA-sigGen-FIPS205/prompt.json', + json_path + 'SLH-DSA-sigGen-FIPS205/expectedResults.json') + +sigver_kat = slhdsa_load_sigver( + json_path + 'SLH-DSA-sigVer-FIPS205/prompt.json', + json_path + 'SLH-DSA-sigVer-FIPS205/expectedResults.json', + json_path + 'SLH-DSA-sigVer-FIPS205/internalProjection.json') + +def test_slhdsa(keygen_func, siggen_func, sigver_func, iut=''): + fail = 0 + fail += slhdsa_test_keygen(keygen_kat, keygen_func, iut) + fail += slhdsa_test_siggen(siggen_kat, siggen_func, iut) + fail += slhdsa_test_sigver(sigver_kat, sigver_func, iut) + print(f'SLH-DSA {iut} -- Total FAIL= {fail}') + +#print(keygen_kat) + +if __name__ == '__main__': + print('no unit tests here: provide cryptographic functions to test.')