Skip to content

Commit

Permalink
added support for shamir key split secret size greater than 16 bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
miketery committed Jan 20, 2022
1 parent b1cac0e commit 76e194b
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
42 changes: 42 additions & 0 deletions lib/Crypto/Protocol/SecretSharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from Crypto.Util.number import long_to_bytes, bytes_to_long
from Crypto.Random import get_random_bytes as rng

SHAMIR_BLOCK_SIZE = 16

def _mult_gf2(f1, f2):
"""Multiply two polynomials in GF(2)"""
Expand Down Expand Up @@ -276,3 +277,44 @@ def combine(shares, ssss=False):
denominator *= x_j + x_m
result += y_j * numerator * denominator.inverse()
return result.encode()

@staticmethod
def split_large(k, n, secret, ssss=False):
"""
Wrapper for Shamir.split()
when len(key) > SHAMIR_BLOCK_SIZE (16)
"""
if isinstance(secret, bytes) is not True:
raise TypeError("Secret must be bytes")
if len(secret) % 16 != 0:
raise ValueError("Secret size must be in 16 byte increments")

blocks = len(secret) // SHAMIR_BLOCK_SIZE
shares = [b'' for _ in range(n)]
for i in range(blocks):
block_shares = Shamir.split(k, n,
secret[i*SHAMIR_BLOCK_SIZE:(i+1)*SHAMIR_BLOCK_SIZE], ssss)
for j in range(n):
shares[j] += block_shares[j][1]
return [(i+1,shares[i]) for i in range(n)]

@staticmethod
def combine_large(shares, ssss=False):
"""
Wrapper for Shamir.combine()
when len(key) > SHAMIR_BLOCK_SIZE (16)
"""
share_len = len(shares[0][1])
for share in shares:
if len(share[1]) % 16 != 0:
raise ValueError(f"Share #{share[0]} is not in 16 byte increments")
if len(share[1]) != share_len:
raise ValueError("Share sizes are inconsistant")
blocks = share_len // SHAMIR_BLOCK_SIZE
result = b''
for i in range(blocks):
block_shares = [
(int(idx), share[i*SHAMIR_BLOCK_SIZE:(i+1)*SHAMIR_BLOCK_SIZE])
for idx, share in shares]
result += Shamir.combine(block_shares, ssss)
return result
4 changes: 4 additions & 0 deletions lib/Crypto/Protocol/SecretSharing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ class Shamir(object):
def split(k: int, n: int, secret: bytes, ssss: Optional[bool]) -> List[Tuple[int, bytes]]: ...
@staticmethod
def combine(shares: List[Tuple[int, bytes]], ssss: Optional[bool]) -> bytes: ...
@staticmethod
def split_large(k: int, n: int, secret: bytes, ssss: Optional[bool]) -> List[Tuple[int, bytes]]: ...
@staticmethod
def combine_large(shares: List[Tuple[int, bytes]], ssss: Optional[bool]) -> bytes: ...

35 changes: 35 additions & 0 deletions lib/Crypto/SelfTest/Protocol/test_SecretSharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,15 @@ def test2(self):
"11-065d0082c80b1aea18f4abe0c49df72e",
"12-84a09430c1d20ea9f388f3123c3733a3"),
)
test_vectors_large = (
# combine key and shares to test larger than 16 byte key
(2, test_vectors[0][1] + test_vectors[1][1],
'1-'+test_vectors[0][2][2:] + test_vectors[1][2][2:],
'2-'+test_vectors[0][3][2:] + test_vectors[1][3][2:],
'3-'+test_vectors[0][4][2:] + test_vectors[1][4][2:],
'4-'+test_vectors[0][5][2:] + test_vectors[1][5][2:],
),
)

def get_share(p):
pos = p.find('-')
Expand All @@ -225,6 +234,17 @@ def get_share(p):
result = Shamir.combine(shares, True)
self.assertEqual(secret, result)

for tv in test_vectors_large:
k = tv[0]
secret = unhexlify(tv[1])
max_perms = 10
for perm, shares_idx in enumerate(permutations(range(2, len(tv)), k)):
if perm > max_perms:
break
shares = [ get_share(tv[x]) for x in shares_idx ]
result = Shamir.combine_large(shares, True)
self.assertEqual(secret, result)

def test3(self):
# Loopback split/recombine
secret = unhexlify(b("000102030405060708090a0b0c0d0e0f"))
Expand Down Expand Up @@ -253,6 +273,21 @@ def test5(self):
shares = Shamir.split(2, 3, secret)
self.assertRaises(ValueError, Shamir.combine, (shares[0], shares[0]))

def test6(self):
# Test key size greater than 16 bytes
secret = unhexlify(b("000102030405060708090a0b0c0d0e0f000102030405060708090a0b0c0d0e0f"))
shares = Shamir.split_large(2, 3, secret)

secret2 = Shamir.combine_large(shares[:2])
self.assertEqual(secret, secret2)

def test7(self):
secret = unhexlify(b("000102030405060708090a0b0c0d0e0f0001020304050607"))
self.assertRaises(ValueError, Shamir.split_large, 2, 3, secret)

def test8(self):
secret = 123456
self.assertRaises(TypeError, Shamir.split_large, 2, 3, secret)

def get_tests(config={}):
tests = []
Expand Down

0 comments on commit 76e194b

Please sign in to comment.