Skip to content

Commit

Permalink
Merge pull request #70 from GiacomoPope/tweak_tests
Browse files Browse the repository at this point in the history
Tweak tests and clean up
  • Loading branch information
GiacomoPope authored Jul 24, 2024
2 parents 99efb58 + 6821038 commit 2f7af63
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 60 deletions.
11 changes: 3 additions & 8 deletions src/kyber_py/kyber/kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@


class Kyber:
def __init__(self, parameter_set, seed=None):
def __init__(self, parameter_set):
"""
Initialise Kyber with specified lattice parameters.
:param dict params: the lattice parameters
:param bytes seed: the optional seed for a DRBG, must be unique and
unpredictable
"""
self.k = parameter_set["k"]
self.eta_1 = parameter_set["eta_1"]
Expand All @@ -22,13 +20,10 @@ def __init__(self, parameter_set, seed=None):
self.M = ModuleKyber()
self.R = self.M.ring

# Use system randomness by default
# Use system randomness by default, for deterministic randomness
# use the method `set_drbg_seed()`
self.random_bytes = os.urandom

# If a seed is supplied, use deterministic randomness
if seed is not None:
self.set_drbg_seed(seed)

def set_drbg_seed(self, seed):
"""
Change entropy source to a DRBG and seed it with provided value.
Expand Down
11 changes: 3 additions & 8 deletions src/kyber_py/ml_kem/ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@


class ML_KEM:
def __init__(self, params, seed=None):
def __init__(self, params):
"""
Initialise the ML-KEM with specified lattice parameters.
:param dict params: the lattice parameters
:param bytes seed: the optional seed for a DRBG, must be unique and
unpredictable
"""
# ml-kem params
self.k = params["k"]
Expand All @@ -27,13 +25,10 @@ def __init__(self, params, seed=None):
self.M = ModuleKyber()
self.R = self.M.ring

# Use system randomness by default
# Use system randomness by default, for deterministic randomness
# use the method `set_drbg_seed()`
self.random_bytes = os.urandom

# If a seed is supplied, use deterministic randomness
if seed is not None:
self.set_drbg_seed(seed)

def set_drbg_seed(self, seed):
"""
Change entropy source to a DRBG and seed it with provided value.
Expand Down
6 changes: 1 addition & 5 deletions src/kyber_py/modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ def decode_vector(self, input_bytes, k, d, is_ntt=False):

class MatrixKyber(Matrix):
def __init__(self, parent, matrix_data, transpose=False):
self.parent = parent
self._data = matrix_data
self._transpose = transpose
if not self._check_dimensions():
raise ValueError("Inconsistent row lengths in matrix")
super().__init__(parent, matrix_data, transpose=transpose)

def encode(self, d):
output = b""
Expand Down
37 changes: 18 additions & 19 deletions src/kyber_py/polynomials/polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def __init__(self):

root_of_unity = 17
self.ntt_zetas = [
pow(root_of_unity, self.br(i, 7), 3329) for i in range(128)
pow(root_of_unity, self._br(i, 7), 3329) for i in range(128)
]
self.ntt_f = pow(128, -1, 3329)

@staticmethod
def br(i, k):
def _br(i, k):
"""
bit reversal of an unsigned k-bit integer
"""
Expand Down Expand Up @@ -105,7 +105,7 @@ def __call__(self, coefficients, is_ntt=False):
return element(self, [coefficients])
if not isinstance(coefficients, list):
raise TypeError(
f"Polynomials should be constructed from a list of integers, of length at most d = {256}"
f"Polynomials should be constructed from a list of integers, of length at most n = {256}"
)
return element(self, coefficients)

Expand All @@ -122,15 +122,15 @@ def encode(self, d):
bit_string = "".join(format(c, f"0{d}b")[::-1] for c in self.coeffs)
return bitstring_to_bytes(bit_string)

def compress_ele(self, x, d):
def _compress_ele(self, x, d):
"""
Compute round((2^d / q) * x) % 2^d
"""
t = 1 << d
y = (t * x + 1664) // 3329 # 1664 = 3329 // 2
return y % t

def decompress_ele(self, x, d):
def _decompress_ele(self, x, d):
"""
Compute round((q / 2^d) * x)
"""
Expand All @@ -143,7 +143,7 @@ def compress(self, d):
Compress the polynomial by compressing each coefficient
NOTE: This is lossy compression
"""
self.coeffs = [self.compress_ele(c, d) for c in self.coeffs]
self.coeffs = [self._compress_ele(c, d) for c in self.coeffs]
return self

def decompress(self, d):
Expand All @@ -153,7 +153,7 @@ def decompress(self, d):
x' = decompress(compress(x)), which x' != x, but is
close in magnitude.
"""
self.coeffs = [self.decompress_ele(c, d) for c in self.coeffs]
self.coeffs = [self._decompress_ele(c, d) for c in self.coeffs]
return self

def to_ntt(self):
Expand Down Expand Up @@ -182,7 +182,7 @@ def to_ntt(self):
return self.parent(coeffs, is_ntt=True)

def from_ntt(self):
raise TypeError(f"Polynomial is of type: {type(self)}")
raise TypeError(f"Polynomial not in the NTT domain: {type(self) = }")


class PolynomialKyberNTT(PolynomialKyber):
Expand All @@ -191,7 +191,9 @@ def __init__(self, parent, coefficients):
self.coeffs = self._parse_coefficients(coefficients)

def to_ntt(self):
raise TypeError(f"Polynomial is of type: {type(self)}")
raise TypeError(
f"Polynomial is already in the NTT domain: {type(self) = }"
)

def from_ntt(self):
"""
Expand Down Expand Up @@ -222,26 +224,26 @@ def from_ntt(self):
return self.parent(coeffs, is_ntt=False)

@staticmethod
def ntt_base_multiplication(a0, a1, b0, b1, zeta):
def _ntt_base_multiplication(a0, a1, b0, b1, zeta):
"""
Base case for ntt multiplication
"""
r0 = (a0 * b0 + zeta * a1 * b1) % 3329
r1 = (a1 * b0 + a0 * b1) % 3329
return r0, r1

def ntt_coefficient_multiplication(self, f_coeffs, g_coeffs):
def _ntt_coefficient_multiplication(self, f_coeffs, g_coeffs):
new_coeffs = []
zetas = self.parent.ntt_zetas
for i in range(64):
r0, r1 = self.ntt_base_multiplication(
r0, r1 = self._ntt_base_multiplication(
f_coeffs[4 * i + 0],
f_coeffs[4 * i + 1],
g_coeffs[4 * i + 0],
g_coeffs[4 * i + 1],
zetas[64 + i],
)
r2, r3 = self.ntt_base_multiplication(
r2, r3 = self._ntt_base_multiplication(
f_coeffs[4 * i + 2],
f_coeffs[4 * i + 3],
g_coeffs[4 * i + 2],
Expand All @@ -251,15 +253,12 @@ def ntt_coefficient_multiplication(self, f_coeffs, g_coeffs):
new_coeffs += [r0, r1, r2, r3]
return new_coeffs

def ntt_multiplication(self, other):
def _ntt_multiplication(self, other):
"""
Number Theoretic Transform multiplication.
Only implemented (currently) for n = 256
"""
if not isinstance(other, type(self)):
raise ValueError

new_coeffs = self.ntt_coefficient_multiplication(
new_coeffs = self._ntt_coefficient_multiplication(
self.coeffs, other.coeffs
)
return new_coeffs
Expand All @@ -274,7 +273,7 @@ def __sub__(self, other):

def __mul__(self, other):
if isinstance(other, type(self)):
new_coeffs = self.ntt_multiplication(other)
new_coeffs = self._ntt_multiplication(other)
elif isinstance(other, int):
new_coeffs = [(c * other) % 3329 for c in self.coeffs]
else:
Expand Down
9 changes: 2 additions & 7 deletions tests/test_kyber.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
import os
from itertools import islice
import pytest
from kyber_py.kyber import Kyber512, Kyber768, Kyber1024
from kyber_py.drbg.aes256_ctr_drbg import AES256_CTR_DRBG
Expand Down Expand Up @@ -143,13 +142,9 @@ def test_generic_kyber_known_answer(Kyber, seed, data):

# Assert encapsulation matches
ss, ct = Kyber.encaps(pk)
assert ct == data["ct"]
assert ss == data["ss"]
assert ct == data["ct"]

# Assert decapsulation matches
_ss = Kyber.decaps(ct, sk)
assert ss == data["ss"]


if __name__ == "__main__":
unittest.main()
assert _ss == data["ss"]
4 changes: 0 additions & 4 deletions tests/test_ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,3 @@ def test_mlkem_known_answer(ML_KEM, seed, kat_vals):
# Assert decapsulation with faulty ciphertext
ss_n = ML_KEM.decaps(data["ct_n"], dk)
assert ss_n == data["ss_n"]


if __name__ == "__main__":
unittest.main()
20 changes: 20 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
from random import randint
from kyber_py.modules.modules import ModuleKyber


class TestModuleKyber(unittest.TestCase):
M = ModuleKyber()
R = M.ring

def test_decode_vector(self):
for _ in range(100):
k = randint(1, 5)
v = self.M.random_element(k, 1)
v_bytes = v.encode(12)
self.assertEqual(v, self.M.decode_vector(v_bytes, k, 12))

def test_recode_vector_wrong_length(self):
self.assertRaises(
ValueError, lambda: self.M.decode_vector(b"1", 2, 12)
)
4 changes: 0 additions & 4 deletions tests/test_module_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,3 @@ def test_print(self):
su = "[1 + 2*x, 3 + 4*x + 5*x^2 + 6*x^3]"
self.assertEqual(str(A), sA)
self.assertEqual(str(u), su)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 2f7af63

Please sign in to comment.