Skip to content

Commit

Permalink
refactor(kyberlib): 🎨 Eliminate division ops to prevent timing variab…
Browse files Browse the repository at this point in the history
…ility

Replace division with multiply and shift to avoid DIV instruction and timing issues.
  • Loading branch information
sebastienrousseau committed May 12, 2024
1 parent 2e68e3f commit 39652f6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 27 deletions.
41 changes: 31 additions & 10 deletions src/reference/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ pub(crate) fn poly_compress(r: &mut [u8], a: Poly) {
let mut k = 0usize;
let mut u: i16;

// Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
// = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
// = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
// = DIV((x << d) + q/2, q) & ((1<<d) - 1)
//
// We approximate DIV(x, q) by computing (x*a)>>e, where a/(2^e) ≈ 1/q.
// For d in {10,11} we use 20,642,678/2^36, which computes division by x/q
// correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably.
// For d in {4,5} we use 315/2^20, which doesn't compute division by x/q
// correctly for all inputs, but it's close enough that the end result
// of the compression is correct. The advantage is that we do not need
// to use a 64-bit intermediate value.
match KYBER_POLY_COMPRESSED_BYTES {
128 => {
#[allow(clippy::needless_range_loop)]
Expand All @@ -41,9 +53,11 @@ pub(crate) fn poly_compress(r: &mut [u8], a: Poly) {
// map to positive standard representatives
u = a.coeffs[8 * i + j];
u += (u >> 15) & KYBER_Q as i16;
t[j] = (((((u as u16) << 4) + KYBER_Q as u16 / 2)
/ KYBER_Q as u16)
& 15) as u8;
let mut tmp: u32 =
(((u as u16) << 4) + KYBER_Q as u16 / 2) as u32;
tmp *= 315;
tmp >>= 20;
t[j] = ((tmp as u16) & 15) as u8;
}
r[k] = t[0] | (t[1] << 4);
r[k + 1] = t[2] | (t[3] << 4);
Expand All @@ -59,9 +73,11 @@ pub(crate) fn poly_compress(r: &mut [u8], a: Poly) {
// map to positive standard representatives
u = a.coeffs[8 * i + j];
u += (u >> 15) & KYBER_Q as i16;
t[j] = (((((u as u32) << 5) + KYBER_Q as u32 / 2)
/ KYBER_Q as u32)
& 31) as u8;
let mut tmp: u32 =
((u as u32) << 5) + KYBER_Q as u32 / 2;
tmp *= 315;
tmp >>= 20;
t[j] = ((tmp as u16) & 31) as u8;
}
r[k] = t[0] | (t[1] << 5);
r[k + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7);
Expand Down Expand Up @@ -324,14 +340,19 @@ pub(crate) fn poly_frommsg(r: &mut Poly, msg: &[u8]) {
/// Arguments: - [u8] msg: output message
/// - const poly *a: input polynomial
pub(crate) fn poly_tomsg(msg: &mut [u8], a: Poly) {
let mut t;
let mut t: u32;
#[allow(clippy::needless_range_loop)]
for i in 0..KYBER_N / 8 {
msg[i] = 0;
for j in 0..8 {
t = a.coeffs[8 * i + j];
t += (t >> 15) & KYBER_Q as i16;
t = (((t << 1) + KYBER_Q as i16 / 2) / KYBER_Q as i16) & 1;
t = a.coeffs[8 * i + j] as u32;

t <<= 1;
t = t.wrapping_add(1665);
t = t.wrapping_mul(80635);
t >>= 28;
t &= 1;

msg[i] |= (t << j) as u8;
}
}
Expand Down
36 changes: 19 additions & 17 deletions src/reference/polyvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ pub(crate) fn polyvec_compress(r: &mut [u8], a: Polyvec) {
let mut idx = 0usize;
for i in 0..KYBER_SECURITY_PARAMETER {
for j in 0..KYBER_N / 8 {
for k in 0..8 {
t[k] = a.vec[i].coeffs[8 * j + k] as u16;
t[k] = t[k].wrapping_add(
(((t[k] as i16) >> 15) & KYBER_Q as i16) as u16,
for (k, t_k) in t.iter_mut().enumerate() {
*t_k = a.vec[i].coeffs[8 * j + k] as u16;
*t_k = t_k.wrapping_add(
((((*t_k as i16) >> 15) & KYBER_Q as i16)
as u16),
);
t[k] = (((((t[k] as u32) << 11)
+ KYBER_Q as u32 / 2)
/ KYBER_Q as u32)
& 0x7ff) as u16;
let mut tmp: u64 =
((*t_k as u64) << 11) + (KYBER_Q as u64 / 2);
tmp *= 20642679;
tmp >>= 36;
*t_k = (tmp as u16) & 0x7ff;
}
r[idx] = (t[0]) as u8;
r[idx + 1] = ((t[0] >> 8) | (t[1] << 3)) as u8;
Expand All @@ -61,16 +63,16 @@ pub(crate) fn polyvec_compress(r: &mut [u8], a: Polyvec) {
let mut idx = 0usize;
for i in 0..KYBER_SECURITY_PARAMETER {
for j in 0..KYBER_N / 4 {
for (k, item) in t.iter_mut().enumerate() {
*item = a.vec[i].coeffs[4 * j + k] as u16;
*item = item.wrapping_add(
(((*item as i16) >> 15) & KYBER_Q as i16)
as u16,
for (k, t_k) in t.iter_mut().enumerate() {
*t_k = a.vec[i].coeffs[4 * j + k] as u16;
*t_k = t_k.wrapping_add(
(((*t_k as i16) >> 15) & KYBER_Q as i16) as u16,
);
*item = (((((*item as u32) << 10)
+ KYBER_Q as u32 / 2)
/ KYBER_Q as u32)
& 0x3ff) as u16;
let mut tmp: u64 =
((*t_k as u64) << 10) + (KYBER_Q as u64 / 2);
tmp *= 20642679;
tmp >>= 36;
*t_k = (tmp as u16) & 0x3ff;
}
r[idx] = (t[0]) as u8;
r[idx + 1] = ((t[0] >> 8) | (t[1] << 2)) as u8;
Expand Down

0 comments on commit 39652f6

Please sign in to comment.