Skip to content
This repository has been archived by the owner on Jan 16, 2024. It is now read-only.

Commit

Permalink
Merge pull request #21 from kudelskisecurity/bugfix/kyberslash2
Browse files Browse the repository at this point in the history
Fix for KyberSlash2
  • Loading branch information
tgkudelski authored Jan 16, 2024
2 parents 56534a7 + 5d875a1 commit 2a6ca2d
Showing 1 changed file with 50 additions and 14 deletions.
64 changes: 50 additions & 14 deletions crystals-kyber/poly.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,24 @@ func polyToMsg(p Poly) []byte {
return msg
}

//compress packs a polynomial into a byte array using d bits per coefficient
//compress packs a polynomial into a byte array using d bits per coefficient - fixed against https://kyberslash.cr.yp.to/faq.html (cases d=4,5 only for now)
func (p *Poly) compress(d int) []byte {
c := make([]byte, n*d/8)
switch d {

case 3:
var t [8]uint16
var d0 uint32 /* accumulation value for fixing KyberSlash2 */
id := 0
for i := 0; i < n/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<3)+uint32(q)/2)/
uint32(q)) & ((1 << 3) - 1)
/* t[j] = uint16(((uint32(p[8*i+j])<<3)+uint32(q)/2)/
uint32(q)) & ((1 << 3) - 1) */
d0 = uint32(p[8*i+j]) << 3
d0 += 1664
d0 *= 161271
d0 >>= 29
t[j] = uint16(d0 & 0x7)
}
c[id] = byte(t[0]) | byte(t[1]<<3) | byte(t[2]<<6)
c[id+1] = byte(t[2]>>2) | byte(t[3]<<1) | byte(t[4]<<4) | byte(t[5]<<7)
Expand All @@ -207,11 +213,17 @@ func (p *Poly) compress(d int) []byte {

case 4:
var t [8]uint16
var d0 uint32 /* accumulation value for fixing KyberSlash2 */
id := 0
for i := 0; i < n/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(q)/2)/
uint32(q)) & ((1 << 4) - 1)
/* t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(q)/2)/
uint32(q)) & ((1 << 4) - 1)*/
d0 = uint32(p[8*i+j]) << 4
d0 += 1665
d0 *= 80635
d0 >>= 28
t[j] = uint16(d0 & 0xf)
}
c[id] = byte(t[0]) | byte(t[1]<<4)
c[id+1] = byte(t[2]) | byte(t[3]<<4)
Expand All @@ -222,11 +234,17 @@ func (p *Poly) compress(d int) []byte {

case 5:
var t [8]uint16
var d0 uint32 /* accumulation value for fixing KyberSlash2 */
id := 0
for i := 0; i < n/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(q)/2)/
uint32(q)) & ((1 << 5) - 1)
/* t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(q)/2)/
uint32(q)) & ((1 << 5) - 1) */
d0 = uint32(p[8*i+j]) << 5
d0 += 1664
d0 *= 40318
d0 >>= 27
t[j] = uint16(d0 & 0x1f)
}
c[id] = byte(t[0]) | byte(t[1]<<5)
c[id+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
Expand All @@ -238,11 +256,17 @@ func (p *Poly) compress(d int) []byte {

case 6:
var t [4]uint16
var d0 uint32 /* accumulation value for fixing KyberSlash2 */
id := 0
for i := 0; i < n/4; i++ {
for j := 0; j < 4; j++ {
t[j] = uint16(((uint32(p[4*i+j])<<6)+uint32(q)/2)/
uint32(q)) & ((1 << 6) - 1)
for j := 0; j < 4; j++ {
/* t[j] = uint16(((uint32(p[4*i+j])<<6)+uint32(q)/2)/
uint32(q)) & ((1 << 6) - 1) */
d0 = uint32(p[4*i+j]) << 6
d0 += 1664
d0 *= 20159
d0 >>= 26
t[j] = uint16(d0 & 0x3f)
}
c[id] = byte(t[0]) | byte(t[1]<<6)
c[id+1] = byte(t[1]>>2) | byte(t[2]<<4)
Expand All @@ -252,11 +276,17 @@ func (p *Poly) compress(d int) []byte {

case 10:
var t [4]uint16
var d0 uint64 /* accumulation value for fixing KyberSlash2 */
id := 0
for i := 0; i < n/4; i++ {
for j := 0; j < 4; j++ {
t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(q)/2)/
uint32(q)) & ((1 << 10) - 1)
/* t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(q)/2)/
uint32(q)) & ((1 << 10) - 1) */
d0 = uint64(p[4*i+j]) << 10
d0 += 1665
d0 *= 1290167
d0 >>= 32
t[j] = uint16(d0 & 0x3ff)
}
c[id] = byte(t[0])
c[id+1] = byte(t[0]>>8) | byte(t[1]<<2)
Expand All @@ -267,11 +297,17 @@ func (p *Poly) compress(d int) []byte {
}
case 11:
var t [8]uint16
var d0 uint64 /* accumulation value for fixing KyberSlash2 */
id := 0
for i := 0; i < n/8; i++ {
for j := 0; j < 8; j++ {
t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(q)/2)/
uint32(q)) & ((1 << 11) - 1)
/* t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(q)/2)/
uint32(q)) & ((1 << 11) - 1) */
d0 = uint64(p[8*i+j]) << 11
d0 += 1664
d0 *= 645084
d0 >>= 31
t[j] = uint16(d0 & 0x7ff)
}
c[id] = byte(t[0])
c[id+1] = byte(t[0]>>8) | byte(t[1]<<3)
Expand Down

0 comments on commit 2a6ca2d

Please sign in to comment.