Skip to content

Commit

Permalink
utils: remove coerceSize and fix uses #101
Browse files Browse the repository at this point in the history
  • Loading branch information
marsella committed Aug 29, 2024
1 parent c40efc8 commit 4769f56
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 210 deletions.
33 changes: 14 additions & 19 deletions Common/ntt.cry
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,17 @@ roots = iterate ((*) (r * r)) 1
/**
* An O(n log n) number theortic transform for Dilithium.
*/
import Common::utils (coerceSize)

ntt : [nn]Fld -> [nn]Fld
ntt a = ntt_r`{lg2 nn} 0 a

ntt_r : {n} (fin n) => Integer -> [2 ^^ n]Fld -> [2 ^^ n]Fld
ntt_r depth a =
if `n == 0 then
a
else
coerceSize (butterfly depth even odd)
ntt_r depth a
| n == 0 => a
| n > 0 => butterfly depth even odd
where
(lft, rht) = shuffle (coerceSize a)
even = ntt_r`{max 1 n - 1} (depth + 1) lft
odd = ntt_r`{max 1 n - 1} (depth + 1) rht
(lft, rht) = shuffle a
even = ntt_r`{n - 1} (depth + 1) lft
odd = ntt_r`{n - 1} (depth + 1) rht

/**
* Group even indices in first half and odd indices in second half.
Expand All @@ -115,13 +111,14 @@ shuffle a =
/**
* Perform the butterfly operation.
*/
butterfly : {n} (fin n, n > 0) => Integer -> [n]Fld -> [n]Fld -> [2 * n]Fld
butterfly : {n} (fin n) => Integer -> [2^^n]Fld -> [2^^n]Fld -> [2 ^^ (n+1)]Fld
butterfly depth even odd =
lft # rht
where
j = 2 ^^ depth
lft = [ even @ i + roots @ (i * j) * odd @ i | i <- [0 .. <n] ]
rht = [ even @ i - roots @ (i * j) * odd @ i | i <- [0 .. <n] ]
lft = [ even @ i + roots @ (i * j) * odd @ i | i <- [0 .. <len] ]
rht = [ even @ i - roots @ (i * j) * odd @ i | i <- [0 .. <len] ]
type len = 2^^n

/* INVERSE NTT */

Expand All @@ -139,13 +136,11 @@ ivntt a =
map ((*) ivn) (ivntt_r`{lg2 nn} 0 a)

ivntt_r : {n} (fin n) => Integer -> [2 ^^ n]Fld -> [2 ^^ n]Fld
ivntt_r depth a =
if `n == 0 then
a
else
coerceSize (ivbutterfly depth even odd)
ivntt_r depth a
| n == 0 => a
| n > 0 => ivbutterfly depth even odd
where
(lft, rht) = shuffle (coerceSize a)
(lft, rht) = shuffle a
even = ivntt_r`{max 1 n - 1} (depth + 1) lft
odd = ivntt_r`{max 1 n - 1} (depth + 1) rht

Expand Down
3 changes: 0 additions & 3 deletions Common/utils.cry
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,3 @@ mp_mod_inv c = if c == 0 then error "Zero does not have a multiplicative inverse
*/
mp_mod_inv_correct : {a} (fin a, prime a, a >=2) => Z a -> Bit
property mp_mod_inv_correct x = x != 0 ==> x * mp_mod_inv x == 1

coerceSize : {m, n, a} [m]a -> [n]a
coerceSize xs = [ xs @ i | i <- [0 .. <n]]
42 changes: 9 additions & 33 deletions Primitive/Asymmetric/Cipher/ML-KEM/specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)]
// This section Copyright Amazon.com, Inc. or its affiliates.
//////////////////////////////////////////////////////////////

import Common::utils (coerceSize)

// Simple lookup table for Zeta value given K
zeta_expc : [128](Z q)
zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848,
Expand All @@ -175,16 +173,6 @@ zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848,
1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 ]

// Fast recursive CT-NTT
//
// The "coerceSize" calls in this code are required to satisfy
// Cryptol's type constraint solver that this code really
// is type-correct by effectively changing a static type-check
// into a dynamic one.
//
// As the static type constraint prover improves, this
// might become unncessesary.
//
// See https://github.com/GaloisInc/cryptol/issues/1489 for more details.
ct_butterfly :
{m, hm}
(m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) =>
Expand All @@ -195,7 +183,7 @@ ct_butterfly v z = new_v
lower, upper : [2^^hm](Z q)
lower@x = v@x + z * v@(x + halflen)
upper@x = v@x - z * v@(x + halflen)
new_v = coerceSize (lower # upper)
new_v = lower # upper

fast_nttl :
{lv} // Length of v is a member of {256,128,64,32,16,8,4}
Expand All @@ -206,30 +194,19 @@ fast_nttl v k
| lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k)

// Recursive case. Butterfly what we have, then recurse on each half,
// concatenate the results and return. As above, we need coerceSize
// here (twice) to satisfy the type checker.
| lv > 2 => coerceSize ((fast_nttl`{lv-1} s0 (k * 2)) #
(fast_nttl`{lv-1} s1 (k * 2 + 1)))
// concatenate the results and return.
| lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) #
(fast_nttl`{lv-1} s1 (k * 2 + 1))
where
t = ct_butterfly`{lv,lv-1} v (zeta_expc@k)
// Split t into two halves s0 and s1
[s0, s1] = split (coerceSize t)
[s0, s1] = split t

// Top level entry point - start with lv=256, k=1
fast_ntt : Z_q_256 -> Z_q_256
fast_ntt v = fast_nttl v 1

// Fast recursive GS-Inverse-NTT
//
// The "coerceSize" calls in this code are required to satisfy
// Cryptol's type constraint solver that this code really
// is type-correct by effectively changing a static type-check
// into a dynamic one.
//
// As the static type constraint prover improves, this
// might become unncessesary.
//
// See https://github.com/GaloisInc/cryptol/issues/1489 for more details.
gs_butterfly :
{m, hm}
(m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) =>
Expand All @@ -240,7 +217,7 @@ gs_butterfly v z = new_v
lower, upper : [2^^hm](Z q)
lower@x = v@x + v@(x + halflen)
upper@x = z * (v@(x + halflen) - v@x)
new_v = coerceSize (lower # upper)
new_v = lower # upper

fast_invnttl :
{lv} // Length of v is a member of {256,128,64,32,16,8,4}
Expand All @@ -253,13 +230,12 @@ fast_invnttl v k

// Recursive case. Recurse on each half,
// concatenate the results, butterfly that, and return.
// As above, we need coerceSize here (twice) to satisfy the type checker.
| lv > 2 => gs_butterfly`{lv,lv-1} t (zeta_expc@k)
where
// Split t into two halves s0 and s1
[s0, s1] = split (coerceSize v)
t = coerceSize ((fast_invnttl`{lv-1} s0 (k * 2 + 1)) #
(fast_invnttl`{lv-1} s1 (k * 2)))
[s0, s1] = split v
t = (fast_invnttl`{lv-1} s0 (k * 2 + 1)) #
(fast_invnttl`{lv-1} s1 (k * 2))

// Multiply all elements of v by the reciprocal of 128 (modulo q)
recip_128_modq = (recip 128) : (Z q)
Expand Down
Loading

0 comments on commit 4769f56

Please sign in to comment.