Skip to content

Commit

Permalink
Merge pull request #30 from GiacomoPope/restructure
Browse files Browse the repository at this point in the history
Restructure code into modules
  • Loading branch information
GiacomoPope authored Jul 20, 2024
2 parents 9ff7a04 + cdb0f33 commit ad4f12c
Show file tree
Hide file tree
Showing 27 changed files with 129 additions and 83 deletions.
Empty file added __init__.py
Empty file.
File renamed without changes.
73 changes: 73 additions & 0 deletions benchmarks/benchmark_ml_kem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from ml_kem import ML_KEM128, ML_KEM192, ML_KEM256
import cProfile
from time import time


def profile_ml_kem(ML_KEM):
(ek, dk) = ML_KEM.keygen()
(K, c) = ML_KEM.encaps(ek)

gvars = {}
lvars = {"ML_KEM": ML_KEM, "c": c, "ek": ek, "dk": dk}

cProfile.runctx(
"[ML_KEM.keygen() for _ in range(100)]",
globals=gvars,
locals=lvars,
sort=1,
)
cProfile.runctx(
"[ML_KEM.encaps(ek) for _ in range(100)]",
globals=gvars,
locals=lvars,
sort=1,
)
cProfile.runctx(
"[ML_KEM.decaps(c, dk) for _ in range(100)]",
globals=gvars,
locals=lvars,
sort=1,
)


def benchmark_ml_kem(ML_KEM, name, count):
keygen_times = []
enc_times = []
dec_times = []

for _ in range(count):
t0 = time()
ek, dk = ML_KEM.keygen()
keygen_times.append(time() - t0)

t1 = time()
_, c = ML_KEM.encaps(ek)
enc_times.append(time() - t1)

t2 = time()
_ = ML_KEM.decaps(c, dk)
dec_times.append(time() - t2)

avg_keygen = sum(keygen_times) / count
avg_enc = sum(enc_times) / count
avg_dec = sum(dec_times) / count
print(
f" {name:11} |"
f"{avg_keygen*1000:8.2f}ms {1/avg_keygen:11.2f}"
f"{avg_enc*1000:8.2f}ms {1/avg_enc:10.2f}"
f"{avg_dec*1000:8.2f}ms {1/avg_dec:8.2f}"
)


if __name__ == "__main__":
count = 1000
# common banner
print("-" * 80)
print(
" Params | keygen | keygen/s | encap | encap/s "
"| decap | decap/s"
)
print("-" * 80)
benchmark_ml_kem(ML_KEM128, "ML_KEM128", count)
benchmark_ml_kem(ML_KEM192, "ML_KEM192", count)
benchmark_ml_kem(ML_KEM256, "ML_KEM256", count)
8 changes: 2 additions & 6 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@ API
===

.. toctree::
aes256_ctr_drbg
benchmark_kyber
drbg
kyber
ml_kem
modules
modules_generic
polynomials
polynomials_generic
run_kyber
utils
utilities
4 changes: 2 additions & 2 deletions docs/source/aes256_ctr_drbg.rst → docs/source/drbg.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aes256\_ctr\_drbg module
drbg module
========================

.. automodule:: aes256_ctr_drbg
.. automodule:: drbg
:members:
:undoc-members:
:show-inheritance:
7 changes: 0 additions & 7 deletions docs/source/modules_generic.rst

This file was deleted.

7 changes: 0 additions & 7 deletions docs/source/polynomials_generic.rst

This file was deleted.

7 changes: 0 additions & 7 deletions docs/source/run_kyber.rst

This file was deleted.

4 changes: 2 additions & 2 deletions docs/source/utils.rst → docs/source/utilities.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
utils module
utilities module
============

.. automodule:: utils
.. automodule:: utilities
:members:
:undoc-members:
:show-inheritance:
Empty file added drbg/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions aes256_ctr_drbg.py → drbg/aes256_ctr_drbg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from utils import xor_bytes
from utilities.utils import xor_bytes
from Crypto.Cipher import AES


Expand Down Expand Up @@ -88,7 +88,7 @@ def random_bytes(self, num_bytes, additional=None):
if len(additional) > self.seed_length:
raise ValueError(
f"The additional input must be of length at most: "
f"{self.seed_length}. Input has length {len(seed)}"
f"{self.seed_length}. Input has length {len(additional)}"
)
elif len(additional) < self.seed_length:
additional += bytes([0]) * (self.seed_length - len(additional))
Expand Down
7 changes: 3 additions & 4 deletions kyber.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from polynomials import PolynomialRingKyber
from modules import ModuleKyber
from modules.modules import ModuleKyber

try:
from aes256_ctr_drbg import AES256_CTR_DRBG
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG
except ImportError as e:
print(
"Error importing AES CTR DRBG. Have you tried installing requirements?"
Expand Down Expand Up @@ -47,8 +46,8 @@ def __init__(self, parameter_set):
self.du = parameter_set["du"]
self.dv = parameter_set["dv"]

self.R = PolynomialRingKyber()
self.M = ModuleKyber()
self.R = self.M.ring

self.drbg = None
self.random_bytes = os.urandom
Expand Down
7 changes: 3 additions & 4 deletions ml_kem.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from polynomials import PolynomialRingKyber
from modules import ModuleKyber
from modules.modules import ModuleKyber

try:
from aes256_ctr_drbg import AES256_CTR_DRBG
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG
except ImportError as e:
print(
"Error importing AES CTR DRBG. Have you tried installing requirements?"
Expand Down Expand Up @@ -32,8 +31,8 @@ def __init__(self, params, seed=None):
self.du = params["du"]
self.dv = params["dv"]

self.R = PolynomialRingKyber()
self.M = ModuleKyber()
self.R = self.M.ring

# NIST approved randomness
if seed is None:
Expand Down
Empty file added modules/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions modules.py → modules/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from polynomials import PolynomialRingKyber
from modules_generic import Module, Matrix
from polynomials.polynomials import PolynomialRingKyber
from modules.modules_generic import Module, Matrix


class ModuleKyber(Module):
Expand Down
File renamed without changes.
Empty file added polynomials/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions polynomials.py → polynomials/polynomials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from polynomials_generic import PolynomialRing, Polynomial
from utils import bytes_to_bits, bitstring_to_bytes
from polynomials.polynomials_generic import PolynomialRing, Polynomial
from utilities.utils import bytes_to_bits, bitstring_to_bytes


class PolynomialRingKyber(PolynomialRing):
Expand Down
File renamed without changes.
8 changes: 0 additions & 8 deletions run_kyber.py

This file was deleted.

28 changes: 0 additions & 28 deletions test_ml_kem.py

This file was deleted.

Empty file added tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion test_kyber.py → tests/test_kyber.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import os
from kyber import Kyber512, Kyber768, Kyber1024
from aes256_ctr_drbg import AES256_CTR_DRBG
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG


def parse_kat_data(data):
Expand Down
7 changes: 6 additions & 1 deletion test_kyber_kat.py → tests/test_kyber_kat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""
An alternative way of checking the Kyber KAT
Does nothing which isn't already checked in test_kyber.py
"""

from kyber import Kyber512, Kyber768, Kyber1024
from hashlib import sha256
from aes256_ctr_drbg import AES256_CTR_DRBG
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG


def generate_kat_hash(kyber):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_ml_kem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
from ml_kem import ML_KEM128, ML_KEM192, ML_KEM256


class TestML_KEM(unittest.TestCase):
"""
Test ML_KEM levels for internal
consistency by generating key pairs
and shared secrets.
"""

def generic_test_ML_KEM(self, ML_KEM, count):
for _ in range(count):
(ek, dk) = ML_KEM.keygen()
for _ in range(count):
(K, c) = ML_KEM.encaps(ek)
K_prime = ML_KEM.decaps(c, dk)
self.assertEqual(K, K_prime)

def test_ML_KEM128(self):
self.generic_test_ML_KEM(ML_KEM128, 5)

def test_ML_KEM192(self):
self.generic_test_ML_KEM(ML_KEM192, 5)

def test_ML_KEM256(self):
self.generic_test_ML_KEM(ML_KEM256, 5)


if __name__ == "__main__":
unittest.main()
Empty file added utilities/__init__.py
Empty file.
File renamed without changes.

0 comments on commit ad4f12c

Please sign in to comment.