Skip to content

Commit

Permalink
simplify matrix generation
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Jul 19, 2024
1 parent 9a1168e commit 3844f73
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,11 @@ def _generate_error_vector(self, sigma, eta, N, is_ntt=False):
Helper function which generates a element in the
module from the Centered Binomial Distribution.
"""
elements = []
for _ in range(self.k):
elements = [0 for _ in range(self.k)]
for i in range(self.k):
input_bytes = self._prf(sigma, bytes([N]), 64*eta)
poly = self.R.cbd(input_bytes, eta, is_ntt=is_ntt)
elements.append(poly)
N = N + 1
elements[i] = self.R.cbd(input_bytes, eta, is_ntt=is_ntt)
N += 1
v = self.M.vector(elements)
return v, N

Expand All @@ -144,18 +143,13 @@ def _generate_matrix_from_seed(self, rho, transpose=False, is_ntt=False):
When `transpose` is set to True, the matrix A is
built as the transpose.
"""
A = []
A_data = [[0 for _ in range(self.k)] for _ in range(self.k)]
for i in range(self.k):
row = []
for j in range(self.k):
if transpose:
input_bytes = self._xof(rho, bytes([i]), bytes([j]), 3*self.R.n)
else:
input_bytes = self._xof(rho, bytes([j]), bytes([i]), 3*self.R.n)
aij = self.R.parse(input_bytes, is_ntt=is_ntt)
row.append(aij)
A.append(row)
return self.M(A)
input_bytes = self._xof(rho, bytes([j]), bytes([i]), 3*self.R.n)
A_data[i][j] = self.R.parse(input_bytes, is_ntt=is_ntt)
A_hat = self.M(A_data, transpose=transpose)
return A_hat

def _cpapke_keygen(self):
"""
Expand All @@ -171,6 +165,7 @@ def _cpapke_keygen(self):
# Generate random value, hash and split
d = self.random_bytes(32)
rho, sigma = self._g(d)

# Set counter for PRF
N = 0

Expand Down Expand Up @@ -212,6 +207,7 @@ def _cpapke_enc(self, pk, m, coins):
N = 0
rho = pk[-32:]
t = self.M.decode_vector(pk, self.k, l=12, is_ntt=True)

# Encode message as polynomial
m_poly = self.R.decode(m, l=1).decompress(1)

Expand Down

0 comments on commit 3844f73

Please sign in to comment.