From c480bcd0adea77ed4d9a1cdec89f2f5151fd4107 Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Tue, 15 Oct 2024 14:07:54 -0400 Subject: [PATCH 1/8] mlkem: put NTT into a submodule #147 --- .../Cipher/ML_KEM/Specification.cry | 478 +++++++++--------- 1 file changed, 240 insertions(+), 238 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index bc5c0a9c..5a077c2d 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -455,232 +455,250 @@ BitRev7 : [8] -> [8] BitRev7 = reverse -////////////////////////////////////////////////////////////// -// This section specifies a naive O(N**2) NTT and Inverse NTT -// -// A "fast" O(N log N) version is below, followed by a -// proof of their equivalence -////////////////////////////////////////////////////////////// - /** - * Compute the NTT representation of the polynomial `f`. + * This section specifies the number-theoretic transform (NTT). * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. - */ -ParametricNTT : Z_q_256 -> (Z q) -> Z_q_256 -ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] - -/** - * Compute most of the polynomial that corresponds to the NTT representation - * `f`. - * (The last step 14 is in a separate function) - * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. - */ -ParametricNTTInv : Z_q_256 -> (Z q) -> Z_q_256 -ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] - -/** - * Number theoretic transform: converts elements in `R_q` to `T_q`. - * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. - */ -NaiveNTT : Z_q_256 -> Z_q_256 -NaiveNTT f = ParametricNTT f zeta - -/** - * Inverse of the number theoretic transform: converts elements in `T_q` to - * `R_q`. - * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. - */ -NaiveNTTInv : Z_q_256 -> Z_q_256 -NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)] - -////////////////////////////////////////////////////////////// -// This section specifies fast O(N log N) NTT and Inverse NTT -// -// A readable explanation of the derivation of this form of -// the NTT is in "A Complete Beginner Guide to the Number -// Theoretic Transform (NTT)" by Ardianto Satriawan, -// Rella Mareta, and Hanho Lee. Available from: -// https://eprint.iacr.org/2024/585 -// -// This section Copyright Amazon.com, Inc. or its affiliates. -////////////////////////////////////////////////////////////// - -// Simple lookup table for Zeta value given K -zeta_expc : [128](Z q) -zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848, - 1062, 1919, 193, 797, 2786, 3260, 569, 1746, - 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, - 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, - 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, - 650, 1977, 2513, 632, 2865, 33, 1320, 1915, - 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, - 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, - 17, 2761, 583, 2649, 1637, 723, 2288, 1100, - 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, - 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, - 939, 2308, 2437, 2388, 733, 2337, 268, 641, - 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, - 1063, 319, 2773, 757, 2099, 561, 2466, 2594, - 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, - 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 ] - -// Fast recursive CT-NTT -ct_butterfly : - {m, hm} - (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => - [2^^m](Z q) -> (Z q) -> [2^^m](Z q) -ct_butterfly v z = new_v - where - halflen = 2^^`hm - lower, upper : [2^^hm](Z q) - lower@x = v@x + z * v@(x + halflen) - upper@x = v@x - z * v@(x + halflen) - new_v = lower # upper - -fast_nttl : - {lv} // Length of v is a member of {256,128,64,32,16,8,4} - (lv >= 2, lv <= 8) => - [2^^lv](Z q) -> [8] -> [2^^lv](Z q) -fast_nttl v k - // Base case. lv==2 so just compute the butterfly and return - | lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k) - - // Recursive case. Butterfly what we have, then recurse on each half, - // concatenate the results and return. - | lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) # - (fast_nttl`{lv-1} s1 (k * 2 + 1)) - where - t = ct_butterfly`{lv,lv-1} v (zeta_expc@k) - // Split t into two halves s0 and s1 - [s0, s1] = split t - -// Top level entry point - start with lv=256, k=1 -fast_ntt : Z_q_256 -> Z_q_256 -fast_ntt v = fast_nttl v 1 - -// Fast recursive GS-Inverse-NTT -gs_butterfly : - {m, hm} - (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => - [2^^m](Z q) -> (Z q) -> [2^^m](Z q) -gs_butterfly v z = new_v - where - halflen = 2^^`hm - lower, upper : [2^^hm](Z q) - lower@x = v@x + v@(x + halflen) - upper@x = z * (v@(x + halflen) - v@x) - new_v = lower # upper - -fast_invnttl : - {lv} // Length of v is a member of {256,128,64,32,16,8,4} - (lv >= 2, lv <= 8) => - [2^^lv](Z q) -> [8] -> [2^^lv](Z q) - -fast_invnttl v k - // Base case. lv==2 so just compute the butterfly and return - | lv == 2 => gs_butterfly`{lv,lv-1} v (zeta_expc@k) - - // Recursive case. Recurse on each half, - // concatenate the results, butterfly that, and return. - | lv > 2 => gs_butterfly`{lv,lv-1} t (zeta_expc@k) - where - // Split t into two halves s0 and s1 - [s0, s1] = split v - t = (fast_invnttl`{lv-1} s0 (k * 2 + 1)) # - (fast_invnttl`{lv-1} s1 (k * 2)) - -// Multiply all elements of v by the reciprocal of 128 (modulo q) -recip_128_modq = (recip 128) : (Z q) -mul_recip128 : Z_q_256 -> Z_q_256 -mul_recip128 v = [ v@x * recip_128_modq | x <- [0 .. Z_q_256 -fast_invntt v = mul_recip128 (fast_invnttl v 1) - -////////////////////////////////////////////////////////////// -// Properties and proofs of Naive and Fast NTT -////////////////////////////////////////////////////////////// - -/** - * This property demonstrates that NaiveNTT is self-inverting. - * ``` - * :prove NaiveNTT_Inverts - * ``` - */ -NaiveNTT_Inverts : Z_q_256 -> Bit -property NaiveNTT_Inverts f = NaiveNTTInv (NaiveNTT f) == f - -/** - * This property demonstrates that NaiveNTTInv is self-inverting. - * ``` - * :prove NaiveNTTInv_Inverts - * ``` - */ -NaiveNTTInv_Inverts : Z_q_256 -> Bit -property NaiveNTTInv_Inverts f = NaiveNTT (NaiveNTTInv f) == f - -/** - * This property demonstrates that `fast_ntt` is the inverse of `fast_invntt`. - * ``` - * :prove fast_ntt_inverts - * ``` - */ -fast_ntt_inverts : Z_q_256 -> Bit -property fast_ntt_inverts f = fast_invntt (fast_ntt f) == f - -/** - * This property demonstrates that `fast_invntt` is the inverse of `fast_ntt`. - * ``` - * :prove fast_invntt_inverts - * ``` - */ -fast_invntt_inverts : Z_q_256 -> Bit -property fast_invntt_inverts f = fast_ntt (fast_invntt f) == f - -/** - * This property demonstrates that `naive_ntt` is equivalent to `fast_ntt`. - * ``` - * :prove naive_fast_ntt_equiv - * ``` - */ -naive_fast_ntt_equiv : Z_q_256 -> Bit -property naive_fast_ntt_equiv f = NaiveNTT f == fast_ntt f - -/** - * This property demonstrates that `naive_invntt` is equivalent to `fast_invntt`. - * ``` - * :prove naive_fast_invntt_equiv - * ``` - */ -naive_fast_invntt_equiv : Z_q_256 -> Bit -property naive_fast_invntt_equiv f = NaiveNTTInv f == fast_invntt f + * It includes the version from [FIPS-203] Section 4.3 as well + * as a faster O(N log N) version, and a proof of their equivalence. + */ +import submodule NTT +submodule NTT where + private + /** + * Compute the NTT representation of the polynomial `f`. + * + * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. + */ + ParametricNTT : Z_q_256 -> (Z q) -> Z_q_256 + ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] + where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] + f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] + + /** + * Compute most of the polynomial that corresponds to the NTT representation + * `f`. + * (The last step 14 is in a separate function) + * + * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. + */ + ParametricNTTInv : Z_q_256 -> (Z q) -> Z_q_256 + ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] + where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] + f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] + + /** + * Number theoretic transform: converts elements in `R_q` to `T_q`. + * + * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. + */ + NaiveNTT : Z_q_256 -> Z_q_256 + NaiveNTT f = ParametricNTT f zeta + + /** + * Inverse of the number theoretic transform: converts elements in `T_q` to + * `R_q`. + * + * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. + */ + NaiveNTTInv : Z_q_256 -> Z_q_256 + NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)] + + ////////////////////////////////////////////////////////////// + // This section specifies fast O(N log N) NTT and Inverse NTT + // + // A readable explanation of the derivation of this form of + // the NTT is in "A Complete Beginner Guide to the Number + // Theoretic Transform (NTT)" by Ardianto Satriawan, + // Rella Mareta, and Hanho Lee. Available from: + // https://eprint.iacr.org/2024/585 + // + // This section Copyright Amazon.com, Inc. or its affiliates. + ////////////////////////////////////////////////////////////// + + // Simple lookup table for Zeta value given K + zeta_expc : [128](Z q) + zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848, + 1062, 1919, 193, 797, 2786, 3260, 569, 1746, + 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, + 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, + 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, + 650, 1977, 2513, 632, 2865, 33, 1320, 1915, + 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, + 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, + 17, 2761, 583, 2649, 1637, 723, 2288, 1100, + 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, + 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, + 939, 2308, 2437, 2388, 733, 2337, 268, 641, + 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, + 1063, 319, 2773, 757, 2099, 561, 2466, 2594, + 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, + 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 ] + + // Fast recursive CT-NTT + ct_butterfly : + {m, hm} + (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => + [2^^m](Z q) -> (Z q) -> [2^^m](Z q) + ct_butterfly v z = new_v + where + halflen = 2^^`hm + lower, upper : [2^^hm](Z q) + lower@x = v@x + z * v@(x + halflen) + upper@x = v@x - z * v@(x + halflen) + new_v = lower # upper + + fast_nttl : + {lv} // Length of v is a member of {256,128,64,32,16,8,4} + (lv >= 2, lv <= 8) => + [2^^lv](Z q) -> [8] -> [2^^lv](Z q) + fast_nttl v k + // Base case. lv==2 so just compute the butterfly and return + | lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k) + + // Recursive case. Butterfly what we have, then recurse on each half, + // concatenate the results and return. + | lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) # + (fast_nttl`{lv-1} s1 (k * 2 + 1)) + where + t = ct_butterfly`{lv,lv-1} v (zeta_expc@k) + // Split t into two halves s0 and s1 + [s0, s1] = split t + + // Top level entry point - start with lv=256, k=1 + fast_ntt : Z_q_256 -> Z_q_256 + fast_ntt v = fast_nttl v 1 + + // Fast recursive GS-Inverse-NTT + gs_butterfly : + {m, hm} + (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => + [2^^m](Z q) -> (Z q) -> [2^^m](Z q) + gs_butterfly v z = new_v + where + halflen = 2^^`hm + lower, upper : [2^^hm](Z q) + lower@x = v@x + v@(x + halflen) + upper@x = z * (v@(x + halflen) - v@x) + new_v = lower # upper + + fast_invnttl : + {lv} // Length of v is a member of {256,128,64,32,16,8,4} + (lv >= 2, lv <= 8) => + [2^^lv](Z q) -> [8] -> [2^^lv](Z q) + + fast_invnttl v k + // Base case. lv==2 so just compute the butterfly and return + | lv == 2 => gs_butterfly`{lv,lv-1} v (zeta_expc@k) + + // Recursive case. Recurse on each half, + // concatenate the results, butterfly that, and return. + | lv > 2 => gs_butterfly`{lv,lv-1} t (zeta_expc@k) + where + // Split t into two halves s0 and s1 + [s0, s1] = split v + t = (fast_invnttl`{lv-1} s0 (k * 2 + 1)) # + (fast_invnttl`{lv-1} s1 (k * 2)) + + // Multiply all elements of v by the reciprocal of 128 (modulo q) + recip_128_modq = (recip 128) : (Z q) + mul_recip128 : Z_q_256 -> Z_q_256 + mul_recip128 v = [ v@x * recip_128_modq | x <- [0 .. Z_q_256 + fast_invntt v = mul_recip128 (fast_invnttl v 1) + + ////////////////////////////////////////////////////////////// + // Properties and proofs of Naive and Fast NTT + ////////////////////////////////////////////////////////////// + + /** + * This property demonstrates that NaiveNTT is self-inverting. + * ``` + * :prove NaiveNTT_Inverts + * ``` + */ + NaiveNTT_Inverts : Z_q_256 -> Bit + property NaiveNTT_Inverts f = NaiveNTTInv (NaiveNTT f) == f + + /** + * This property demonstrates that NaiveNTTInv is self-inverting. + * ``` + * :prove NaiveNTTInv_Inverts + * ``` + */ + NaiveNTTInv_Inverts : Z_q_256 -> Bit + property NaiveNTTInv_Inverts f = NaiveNTT (NaiveNTTInv f) == f + + /** + * This property demonstrates that `fast_ntt` is the inverse of `fast_invntt`. + * ``` + * :prove fast_ntt_inverts + * ``` + */ + fast_ntt_inverts : Z_q_256 -> Bit + property fast_ntt_inverts f = fast_invntt (fast_ntt f) == f + + /** + * This property demonstrates that `fast_invntt` is the inverse of `fast_ntt`. + * ``` + * :prove fast_invntt_inverts + * ``` + */ + fast_invntt_inverts : Z_q_256 -> Bit + property fast_invntt_inverts f = fast_ntt (fast_invntt f) == f + + /** + * This property demonstrates that `naive_ntt` is equivalent to `fast_ntt`. + * ``` + * :prove naive_fast_ntt_equiv + * ``` + */ + naive_fast_ntt_equiv : Z_q_256 -> Bit + property naive_fast_ntt_equiv f = NaiveNTT f == fast_ntt f + + /** + * This property demonstrates that `naive_invntt` is equivalent to `fast_invntt`. + * ``` + * :prove naive_fast_invntt_equiv + * ``` + */ + naive_fast_invntt_equiv : Z_q_256 -> Bit + property naive_fast_invntt_equiv f = NaiveNTTInv f == fast_invntt f + + ////////////////////////////////////////////////////////////// + // NTT "dispatcher" + // + // Here, we can choose to call either the naive or fast NTT + ////////////////////////////////////////////////////////////// + + NTT' : Z_q_256 -> Z_q_256 + // fast + NTT' f = fast_ntt f + // slow + //NTT' f = NaiveNTT f + + NTTInv' : Z_q_256 -> Z_q_256 + // fast + NTTInv' f = fast_invntt f + // slow + //NTTInv' f = NaiveNTTInv f -////////////////////////////////////////////////////////////// -// NTT "dispatcher" -// -// Here, we can choose to call either the naive or fast NTT -////////////////////////////////////////////////////////////// - -NTT' : Z_q_256 -> Z_q_256 -// fast -NTT' f = fast_ntt f -// slow -//NTT' f = NaiveNTT f + /** + * The notation `NTT` is overloaded to mean both a single application of `NTT` + * to an element of `R_q` and also `k` applications of `NTT` to every element + * of a `k`-length vector. + * [FIPS-203] Section 2.4.6 Equation 2.9. + */ + NTT v = map NTT' v -NTTInv' : Z_q_256 -> Z_q_256 -// fast -NTTInv' f = fast_invntt f -// slow -//NTTInv' f = NaiveNTTInv f + /** + * The notation `NTTInv` is overloaded to mean both a single application of + * `NTTInv` to an element of `R_q` and also `k` applications of `NTTInv` to + * every element of a `k`-length vector. + * [FIPS-203] Section 2.4.6. + */ + NTTInv v = map NTTInv' v ////////////////////////////////////////////////////////////// // Polynomial multiplication in the NTT domain @@ -730,22 +748,6 @@ property TestMult = prod f f == fsq where dot : Z_q_256 -> Z_q_256 -> Z_q_256 dot f g = MultiplyNTTs f g -/** - * The notation `NTT` is overloaded to mean both a single application of `NTT` - * to an element of `R_q` and also `k` applications of `NTT` to every element - * of a `k`-length vector. - * [FIPS-203] Section 2.4.6 Equation 2.9. - */ -NTT v = map NTT' v - -/** - * The notation `NTTInv` is overloaded to mean both a single application of - * `NTTInv` to an element of `R_q` and also `k` applications of `NTTInv` to - * every element of a `k`-length vector. - * [FIPS-203] Section 2.4.6. - */ -NTTInv v = map NTTInv' v - /** * Overloaded `dot` function between two vectors is a standard dot-product * functionality with `T_q` multiplication as the base operation. From 2474ff578b0f7c1e2134c4cb613a62b4e837444e Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Tue, 15 Oct 2024 14:33:00 -0400 Subject: [PATCH 2/8] mlkem: add Rq / Tq types and use them #147 This doesn't replace all uses of `Z_q_256`, but it gets all the easy ones. --- .../Cipher/ML_KEM/Specification.cry | 83 ++++++++++++------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index 5a077c2d..a67a5c25 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -10,7 +10,7 @@ * * Sources: * [FIPS-203]: National Institute of Standards and Technology. Module-Lattice- - * Basead Key-Encapsulation Mechanism Standard. (Department of Commerce, + * Based Key-Encapsulation Mechanism Standard. (Department of Commerce, * Washington, D.C.), Federal Information Processing Standards Publication * (FIPS) NIST FIPS 203. August 2024. * @see https://doi.org/10.6028/NIST.FIPS.203 @@ -63,6 +63,33 @@ type Byte = [8] */ type Z_q_256 = [n](Z q) +/** + * An element in the ring `R_q`. + * + * An element in the ring is a polynomial of degree at most 255 (e.g. with 256 + * terms). The `i`th element in this array represents the coefficient of the + * degree-`i` term. + * + * [FIPS-203] Section 2.3 (definition of the ring). + * [FIPS-203] Section 2.4.4, Equation 2.5 (definition of the representation of + * elements in the ring). + */ +type Rq = [n](Z q) + +/** + * An element in the ring `T_q`. + * + * An element in this ring (sometimes called the "NTT representation") is a + * tuple of 128 polynomials, each of degree at most one (e.g. with two terms). + * The `2i` and `2i+1`th terms in this array represent the degree-0 and + * degree-1 coefficients of the `i`th polynomial, respectively. + * + * [FIPS-203] Section 2.3 (definition of the `T_q`). + * [FIPS-203] Section 2.4.4 Equation 2.7 (definition of the representation of + * an element in `T_q`). + */ +type Tq = [n](Z q) + /** * Pseudorandom function (PRF). * [FIPS-203] Section 4.1, Equations 4.2 and 4.3. @@ -336,7 +363,7 @@ property CorrectnessEncodeDecode' f = Decode'`{12}(Encode'`{12} f) == f * * [FIPS-203] Section 4.2.2, Algorithm 7. */ -SampleNTT : [34]Byte -> Z_q_256 +SampleNTT : [34]Byte -> Tq SampleNTT B = a_hat' where // Steps 1-2, 5. // We (lazily) take an infinite stream from the XOF and remove only as @@ -401,7 +428,7 @@ SampleNTT B = a_hat' where * * [FIPS-203] Section 4.2.2, Algorithm 8. */ -SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Z_q_256 +SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Rq SamplePolyCBD B = f where // Step 1. b = BytesToBits B @@ -469,7 +496,7 @@ submodule NTT where * * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. */ - ParametricNTT : Z_q_256 -> (Z q) -> Z_q_256 + ParametricNTT : Rq -> (Z q) -> Tq ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] @@ -481,7 +508,7 @@ submodule NTT where * * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. */ - ParametricNTTInv : Z_q_256 -> (Z q) -> Z_q_256 + ParametricNTTInv : Tq -> (Z q) -> Rq ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] @@ -491,7 +518,7 @@ submodule NTT where * * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. */ - NaiveNTT : Z_q_256 -> Z_q_256 + NaiveNTT : Rq -> Tq NaiveNTT f = ParametricNTT f zeta /** @@ -500,7 +527,7 @@ submodule NTT where * * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. */ - NaiveNTTInv : Z_q_256 -> Z_q_256 + NaiveNTTInv : Tq -> Rq NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)] ////////////////////////////////////////////////////////////// @@ -565,7 +592,7 @@ submodule NTT where [s0, s1] = split t // Top level entry point - start with lv=256, k=1 - fast_ntt : Z_q_256 -> Z_q_256 + fast_ntt : Rq -> Tq fast_ntt v = fast_nttl v 1 // Fast recursive GS-Inverse-NTT @@ -605,7 +632,7 @@ submodule NTT where mul_recip128 v = [ v@x * recip_128_modq | x <- [0 .. Z_q_256 + fast_invntt : Tq -> Rq fast_invntt v = mul_recip128 (fast_invnttl v 1) ////////////////////////////////////////////////////////////// @@ -618,7 +645,7 @@ submodule NTT where * :prove NaiveNTT_Inverts * ``` */ - NaiveNTT_Inverts : Z_q_256 -> Bit + NaiveNTT_Inverts : Rq -> Bit property NaiveNTT_Inverts f = NaiveNTTInv (NaiveNTT f) == f /** @@ -627,7 +654,7 @@ submodule NTT where * :prove NaiveNTTInv_Inverts * ``` */ - NaiveNTTInv_Inverts : Z_q_256 -> Bit + NaiveNTTInv_Inverts : Tq -> Bit property NaiveNTTInv_Inverts f = NaiveNTT (NaiveNTTInv f) == f /** @@ -636,7 +663,7 @@ submodule NTT where * :prove fast_ntt_inverts * ``` */ - fast_ntt_inverts : Z_q_256 -> Bit + fast_ntt_inverts : Rq -> Bit property fast_ntt_inverts f = fast_invntt (fast_ntt f) == f /** @@ -645,7 +672,7 @@ submodule NTT where * :prove fast_invntt_inverts * ``` */ - fast_invntt_inverts : Z_q_256 -> Bit + fast_invntt_inverts : Tq -> Bit property fast_invntt_inverts f = fast_ntt (fast_invntt f) == f /** @@ -654,7 +681,7 @@ submodule NTT where * :prove naive_fast_ntt_equiv * ``` */ - naive_fast_ntt_equiv : Z_q_256 -> Bit + naive_fast_ntt_equiv : Rq -> Bit property naive_fast_ntt_equiv f = NaiveNTT f == fast_ntt f /** @@ -663,7 +690,7 @@ submodule NTT where * :prove naive_fast_invntt_equiv * ``` */ - naive_fast_invntt_equiv : Z_q_256 -> Bit + naive_fast_invntt_equiv : Tq -> Bit property naive_fast_invntt_equiv f = NaiveNTTInv f == fast_invntt f ////////////////////////////////////////////////////////////// @@ -672,17 +699,11 @@ submodule NTT where // Here, we can choose to call either the naive or fast NTT ////////////////////////////////////////////////////////////// - NTT' : Z_q_256 -> Z_q_256 - // fast - NTT' f = fast_ntt f - // slow - //NTT' f = NaiveNTT f + NTT' : Rq -> Tq + NTT' = fast_ntt - NTTInv' : Z_q_256 -> Z_q_256 - // fast - NTTInv' f = fast_invntt f - // slow - //NTTInv' f = NaiveNTTInv f + NTTInv' : Tq -> Rq + NTTInv' = fast_invntt /** * The notation `NTT` is overloaded to mean both a single application of `NTT` @@ -719,7 +740,7 @@ BaseCaseMultiply a b root = [c0, c1] * Compute the product (in the ring `T_q`) of two NTT representations. * [FIPS-203] Section 4.3.1 Algorithm 11. */ -MultiplyNTTs : Z_q_256 -> Z_q_256 -> Z_q_256 +MultiplyNTTs : Tq -> Tq -> Tq MultiplyNTTs a b = join [BaseCaseMultiply (f_hat_i i) (g_hat_i i) (root i) | i : Byte <- [0 .. 127]] where f_hat_i i = [a@(2*i),a@(2*i+1)] @@ -737,7 +758,7 @@ property TestMult = prod f f == fsq where f = [1, 1] # [0 | i <- [3 .. 256]] fsq = [1,2,1] # [0 | i <- [4 .. 256]] - prod : Z_q_256 -> Z_q_256 -> Z_q_256 + prod : Rq -> Rq -> Rq prod a b = NTTInv' (MultiplyNTTs (NTT' a) (NTT' b)) /** @@ -745,7 +766,7 @@ property TestMult = prod f f == fsq where * (also referred to as `T_q` multiplication). * [FIPS-203] Section 2.4.5 Equation 2.8. */ -dot : Z_q_256 -> Z_q_256 -> Z_q_256 +dot : Tq -> Tq -> Tq dot f g = MultiplyNTTs f g /** @@ -753,7 +774,7 @@ dot f g = MultiplyNTTs f g * functionality with `T_q` multiplication as the base operation. * [FIPS-203] Section 2.4.7 Equation 2.14. */ -dotVecVec : {k1} (fin k1) => [k1]Z_q_256 -> [k1]Z_q_256 -> Z_q_256 +dotVecVec : {k1} (fin k1) => [k1]Tq -> [k1]Tq -> Tq dotVecVec v1 v2 = sum (zipWith dot v1 v2) /** @@ -761,7 +782,7 @@ dotVecVec v1 v2 = sum (zipWith dot v1 v2) * vector multiplication with `T_q` multiplication as the base operation. * [FIPS-203] Section 2.4.7 Equation 2.12 and 2.13. */ -dotMatVec : {k1,k2} (fin k1, fin k2) => [k1][k2]Z_q_256 -> [k2]Z_q_256 -> [k1]Z_q_256 +dotMatVec : {k1,k2} (fin k1, fin k2) => [k1][k2]Tq -> [k2]Tq -> [k1]Tq dotMatVec matrix vector = [dotVecVec v1 vector | v1 <- matrix] /** @@ -770,7 +791,7 @@ dotMatVec matrix vector = [dotVecVec v1 vector | v1 <- matrix] * [FIPS-203] Section 2.4.7. */ dotMatMat :{k1,k2,k3} (fin k1, fin k2, fin k3) => - [k1][k2]Z_q_256 -> [k2][k3]Z_q_256 -> [k1][k3]Z_q_256 + [k1][k2]Tq -> [k2][k3]Tq -> [k1][k3]Tq dotMatMat matrix1 matrix2 = transpose [dotMatVec matrix1 vector | vector <- m'] where m' = transpose matrix2 From f156af4ae7d9d8f6c01fae7d4d4360b37c0d9479 Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Tue, 15 Oct 2024 14:40:36 -0400 Subject: [PATCH 3/8] mlkem: bring ntt names into alignment #147 This replaces `'`s with suffixes explictly describing what type of data each NTT function operates over. --- .../Cipher/ML_KEM/Specification.cry | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index a67a5c25..1700d4ed 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -699,11 +699,11 @@ submodule NTT where // Here, we can choose to call either the naive or fast NTT ////////////////////////////////////////////////////////////// - NTT' : Rq -> Tq - NTT' = fast_ntt + NTT : Rq -> Tq + NTT = fast_ntt - NTTInv' : Tq -> Rq - NTTInv' = fast_invntt + NTTInv : Tq -> Rq + NTTInv = fast_invntt /** * The notation `NTT` is overloaded to mean both a single application of `NTT` @@ -711,7 +711,7 @@ submodule NTT where * of a `k`-length vector. * [FIPS-203] Section 2.4.6 Equation 2.9. */ - NTT v = map NTT' v + NTT_Vec v = map NTT v /** * The notation `NTTInv` is overloaded to mean both a single application of @@ -719,7 +719,7 @@ submodule NTT where * every element of a `k`-length vector. * [FIPS-203] Section 2.4.6. */ - NTTInv v = map NTTInv' v + NTTInv_Vec v = map NTTInv v ////////////////////////////////////////////////////////////// // Polynomial multiplication in the NTT domain @@ -759,7 +759,7 @@ property TestMult = prod f f == fsq where fsq = [1,2,1] # [0 | i <- [4 .. 256]] prod : Rq -> Rq -> Rq - prod a b = NTTInv' (MultiplyNTTs (NTT' a) (NTT' b)) + prod a b = NTTInv (MultiplyNTTs (NTT a) (NTT b)) /** * The cross product notation ×𝑇𝑞 is defined as the `MultiplyNTTs` function @@ -844,9 +844,9 @@ private submodule K_PKE where e = [SamplePolyCBD`{eta_1} (PRF σ N) | N <- [k .. 2 * k - 1]] // Step 16. - s_hat = NTT s + s_hat = NTT_Vec s // Step 17. - e_hat = NTT e + e_hat = NTT_Vec e // Step 18. t_hat = (dotMatVec A_hat s_hat) + e_hat // Step 19. @@ -884,13 +884,13 @@ private submodule K_PKE where // value instead. e2 = SamplePolyCBD`{eta_2} (PRF r (2 * `k)) // Step 18. - y_hat = NTT y + y_hat = NTT_Vec y // Step 19. - u = NTTInv (dotMatVec (transpose A_hat) y_hat) + e1 + u = NTTInv_Vec (dotMatVec (transpose A_hat) y_hat) + e1 // Step 20. mu = Decompress'`{1} (DecodeBytes'`{1} m) // Step 21. - v = (NTTInv' (dotVecVec t_hat y_hat)) + e2 + mu + v = (NTTInv (dotVecVec t_hat y_hat)) + e2 + mu // Step 22. c1 = EncodeBytes`{d_u} (Compress`{d_u} u) // Step 23. @@ -920,7 +920,7 @@ private submodule K_PKE where // Step 5. s_hat = Decode`{12} dkPKE // Step 6. - w = v' - NTTInv' (dotVecVec s_hat (NTT u')) + w = v' - NTTInv (dotVecVec s_hat (NTT_Vec u')) // Step 7. m = EncodeBytes'`{1} (Compress'`{1} w) From 311df719f0a51dd9be160407fbc72da0121df724 Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Tue, 15 Oct 2024 14:48:47 -0400 Subject: [PATCH 4/8] mlkem: add docs about allowing equivalence #147 This adds some documentation around the NTT module explaining where the spec says it's allowed to choose any version of their algorithms that are the same. --- .../Cipher/ML_KEM/Specification.cry | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index 1700d4ed..003f6d09 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -486,7 +486,13 @@ BitRev7 = reverse * This section specifies the number-theoretic transform (NTT). * * It includes the version from [FIPS-203] Section 4.3 as well - * as a faster O(N log N) version, and a proof of their equivalence. + * as a faster O(N log N) version, and a proof of their equivalence. + * + * This is explicitly allowed by the spec: "For every computational procedure + * [...] a conforming implementation may replace the given set of steps with + * any mathematically equivalent set of steps". The equivalence properties + * prove mathematical equivalence. + * [FIPS-203] Introduction, "7. Implementations". */ import submodule NTT submodule NTT where @@ -693,15 +699,20 @@ submodule NTT where naive_fast_invntt_equiv : Tq -> Bit property naive_fast_invntt_equiv f = NaiveNTTInv f == fast_invntt f - ////////////////////////////////////////////////////////////// - // NTT "dispatcher" - // - // Here, we can choose to call either the naive or fast NTT - ////////////////////////////////////////////////////////////// - + /** + * The Number-Theoretic Transform (NTT) is used to improve the efficiency + * of multiplication in the ring `R_q`. We choose to use the fast version + * of NTT, which is equivalent to the version described in the spec. + */ NTT : Rq -> Tq NTT = fast_ntt + /** + * The inverse of the Number-Theoretic Transform (NTT) is used to improve + * the efficiency of multiplication in the ring `R_q`. We choose to use + * the fast version of inverse function, which is equivalent to the version + * described in the spec. + */ NTTInv : Tq -> Rq NTTInv = fast_invntt From 76c018d44454934f386bf85fcdec5e5fb94ddde2 Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Tue, 15 Oct 2024 14:52:30 -0400 Subject: [PATCH 5/8] mlkem: format properties correctly #147 Several properties didn't have correct `repl` commands in the docstrings. --- .../Asymmetric/Cipher/ML_KEM/Specification.cry | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index 003f6d09..aee7fe3a 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -647,7 +647,7 @@ submodule NTT where /** * This property demonstrates that NaiveNTT is self-inverting. - * ``` + * ```repl * :prove NaiveNTT_Inverts * ``` */ @@ -656,7 +656,7 @@ submodule NTT where /** * This property demonstrates that NaiveNTTInv is self-inverting. - * ``` + * ```repl * :prove NaiveNTTInv_Inverts * ``` */ @@ -665,7 +665,7 @@ submodule NTT where /** * This property demonstrates that `fast_ntt` is the inverse of `fast_invntt`. - * ``` + * ```repl * :prove fast_ntt_inverts * ``` */ @@ -674,7 +674,7 @@ submodule NTT where /** * This property demonstrates that `fast_invntt` is the inverse of `fast_ntt`. - * ``` + * ```repl * :prove fast_invntt_inverts * ``` */ @@ -683,7 +683,7 @@ submodule NTT where /** * This property demonstrates that `naive_ntt` is equivalent to `fast_ntt`. - * ``` + * ```repl * :prove naive_fast_ntt_equiv * ``` */ @@ -692,7 +692,7 @@ submodule NTT where /** * This property demonstrates that `naive_invntt` is equivalent to `fast_invntt`. - * ``` + * ```repl * :prove naive_fast_invntt_equiv * ``` */ @@ -759,7 +759,7 @@ MultiplyNTTs a b = join [BaseCaseMultiply (f_hat_i i) (g_hat_i i) (root i) | i : root i = (zeta^^(reverse (64 + (i >> 1)) >> 1) * ((-1 : (Z q)) ^^ (i))) /** - * Testing that (1+x)^2 = 1+2x+x^2 + * Testing that (1+x)^2 = 1+2x+x^2. * ```repl * :prove TestMult * ``` From 5a6a4a50b6c3e14844d029c5433b6a04f043e7ba Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Wed, 16 Oct 2024 11:34:39 -0400 Subject: [PATCH 6/8] mlkem: reorganize fast-NTT a little bit #147 --- .../Cipher/ML_KEM/Specification.cry | 130 ++++++++++-------- 1 file changed, 69 insertions(+), 61 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index aee7fe3a..3ba8dc75 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -548,98 +548,106 @@ submodule NTT where // This section Copyright Amazon.com, Inc. or its affiliates. ////////////////////////////////////////////////////////////// - // Simple lookup table for Zeta value given K + /** + * Lookup table for `zeta` values computed in `NTT` and `NTTInv`. + * + * [FIPS-203] Appendix A. + */ zeta_expc : [128](Z q) - zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848, - 1062, 1919, 193, 797, 2786, 3260, 569, 1746, - 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, - 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, - 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, - 650, 1977, 2513, 632, 2865, 33, 1320, 1915, - 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, - 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, - 17, 2761, 583, 2649, 1637, 723, 2288, 1100, - 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, - 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, - 939, 2308, 2437, 2388, 733, 2337, 268, 641, - 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, - 1063, 319, 2773, 757, 2099, 561, 2466, 2594, - 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, - 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 ] + zeta_expc = [ + 1, 1729, 2580, 3289, 2642, 630, 1897, 848, + 1062, 1919, 193, 797, 2786, 3260, 569, 1746, + 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, + 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, + 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, + 650, 1977, 2513, 632, 2865, 33, 1320, 1915, + 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, + 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, + 17, 2761, 583, 2649, 1637, 723, 2288, 1100, + 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, + 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, + 939, 2308, 2437, 2388, 733, 2337, 268, 641, + 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, + 1063, 319, 2773, 757, 2099, 561, 2466, 2594, + 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, + 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 + ] + + + // Top level entry point - start with lv=256, k=1 + fast_ntt : Rq -> Tq + fast_ntt v = fast_nttl v 1 + + // Recursive NTT function + fast_nttl : + {lv} // Length of v is a member of {256,128,64,32,16,8,4} + (lv >= 2, lv <= 8) => + [2^^lv](Z q) -> [8] -> [2^^lv](Z q) + fast_nttl v k + // Base case. lv==2 so just compute the butterfly and return + | lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k) + + // Recursive case. Butterfly what we have, then recurse on each half, + // concatenate the results and return. + | lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) # + (fast_nttl`{lv-1} s1 (k * 2 + 1)) + where + t = ct_butterfly`{lv,lv-1} v (zeta_expc@k) + // Split t into two halves s0 and s1 + [s0, s1] = split t // Fast recursive CT-NTT ct_butterfly : {m, hm} (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => [2^^m](Z q) -> (Z q) -> [2^^m](Z q) - ct_butterfly v z = new_v - where + ct_butterfly v z = new_v where halflen = 2^^`hm lower, upper : [2^^hm](Z q) lower@x = v@x + z * v@(x + halflen) upper@x = v@x - z * v@(x + halflen) new_v = lower # upper - fast_nttl : + + // Recursive inverse-NTT function + fast_invnttl : {lv} // Length of v is a member of {256,128,64,32,16,8,4} (lv >= 2, lv <= 8) => [2^^lv](Z q) -> [8] -> [2^^lv](Z q) - fast_nttl v k - // Base case. lv==2 so just compute the butterfly and return - | lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k) - - // Recursive case. Butterfly what we have, then recurse on each half, - // concatenate the results and return. - | lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) # - (fast_nttl`{lv-1} s1 (k * 2 + 1)) - where - t = ct_butterfly`{lv,lv-1} v (zeta_expc@k) - // Split t into two halves s0 and s1 - [s0, s1] = split t + fast_invnttl v k + // Base case. lv==2 so just compute the butterfly and return + | lv == 2 => gs_butterfly`{lv,lv-1} v (zeta_expc@k) - // Top level entry point - start with lv=256, k=1 - fast_ntt : Rq -> Tq - fast_ntt v = fast_nttl v 1 + // Recursive case. Recurse on each half, + // concatenate the results, butterfly that, and return. + | lv > 2 => gs_butterfly`{lv,lv-1} t (zeta_expc@k) where + // Split t into two halves s0 and s1 + [s0, s1] = split v + t = (fast_invnttl`{lv-1} s0 (k * 2 + 1)) # + (fast_invnttl`{lv-1} s1 (k * 2)) // Fast recursive GS-Inverse-NTT gs_butterfly : {m, hm} (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => [2^^m](Z q) -> (Z q) -> [2^^m](Z q) - gs_butterfly v z = new_v - where + gs_butterfly v z = new_v where halflen = 2^^`hm lower, upper : [2^^hm](Z q) lower@x = v@x + v@(x + halflen) upper@x = z * (v@(x + halflen) - v@x) new_v = lower # upper - fast_invnttl : - {lv} // Length of v is a member of {256,128,64,32,16,8,4} - (lv >= 2, lv <= 8) => - [2^^lv](Z q) -> [8] -> [2^^lv](Z q) - - fast_invnttl v k - // Base case. lv==2 so just compute the butterfly and return - | lv == 2 => gs_butterfly`{lv,lv-1} v (zeta_expc@k) - - // Recursive case. Recurse on each half, - // concatenate the results, butterfly that, and return. - | lv > 2 => gs_butterfly`{lv,lv-1} t (zeta_expc@k) - where - // Split t into two halves s0 and s1 - [s0, s1] = split v - t = (fast_invnttl`{lv-1} s0 (k * 2 + 1)) # - (fast_invnttl`{lv-1} s1 (k * 2)) - - // Multiply all elements of v by the reciprocal of 128 (modulo q) - recip_128_modq = (recip 128) : (Z q) - mul_recip128 : Z_q_256 -> Z_q_256 - mul_recip128 v = [ v@x * recip_128_modq | x <- [0 .. Rq - fast_invntt v = mul_recip128 (fast_invnttl v 1) + fast_invntt v = mul_recip128 (fast_invnttl v 1) where + + // Multiplicative inverse of 128, mod `q`. + recip_128_modq = (recip 128) : (Z q) + + // Multiply all elements of v' by the reciprocal of 128 (modulo q) + mul_recip128 : Tq -> Tq + mul_recip128 v' = [ v'@x * recip_128_modq | x <- [0 .. Date: Wed, 16 Oct 2024 13:27:03 -0400 Subject: [PATCH 7/8] mlkem: clean up NTT multiplication functions #147 - Adds docs to BitRev and contains its behavior a bit better - Adjust spacing, naming, etc in MultiplyNTTs and BaseCaseMultiply --- .../Cipher/ML_KEM/Specification.cry | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index 3ba8dc75..cf8c7eed 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -477,9 +477,22 @@ property Is256thRootOfq p = (p == 0) || (p >= 256) || (zeta^^p != 1) * Reverse the unsigned 7-bit value corresponding to an input integer in * `[0, ..., 127]`. * [FIPS-203] Section 4.3 "The mathematical structure of the NTT." + * + * This diverges from the spec by operating over an 8-bit vector; + * this is to ease prior and subsequent computations that would overflow a + * 7-bit vector, like: + * - `2 * (BitRev7 i) + 1` + * - `2 * i + 1` + * + * A "pure" implementation of `BitRev7` in Cryptol is the `reverse` function + * on 7-bit vectors. This mini-property shows equivalence: + * ```repl + * :prove \(x:[7]) -> ([0] # reverse x) == BitRev7 ([0] # x) + * ``` */ BitRev7 : [8] -> [8] -BitRev7 = reverse +BitRev7 i = if i > 255 then error "BitRev7 called with invalid input" + else (reverse i) >> 1 /** @@ -504,8 +517,8 @@ submodule NTT where */ ParametricNTT : Rq -> (Z q) -> Tq ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] + where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i)+1)*j) | j <- [0 .. 127]] + f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i)+1)*j) | j <- [0 .. 127]] /** * Compute most of the polynomial that corresponds to the NTT representation @@ -516,8 +529,8 @@ submodule NTT where */ ParametricNTTInv : Tq -> (Z q) -> Rq ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] + where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j)+1)*i) | j <- [0 .. 127]] + f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j)+1)*i) | j <- [0 .. 127]] /** * Number theoretic transform: converts elements in `R_q` to `T_q`. @@ -749,22 +762,25 @@ submodule NTT where * quadratic modulus. * [FIPS-203] Section 4.3.1 Algorithm 12. */ -BaseCaseMultiply : [2] (Z q) -> [2] (Z q) -> (Z q) -> [2] (Z q) -BaseCaseMultiply a b root = [c0, c1] +BaseCaseMultiply : (Z q) -> (Z q) -> (Z q) -> (Z q) -> (Z q) -> [2](Z q) +BaseCaseMultiply a0 a1 b0 b1 γ = [c0, c1] where - c0 = a@1 * b@1 * root + a@0 * b@0 - c1 = a@0 * b@1 + a@1 * b@0 + c0 = a0 * b0 + a1 * b1 * γ + c1 = a0 * b1 + a1 * b0 /** * Compute the product (in the ring `T_q`) of two NTT representations. * [FIPS-203] Section 4.3.1 Algorithm 11. */ MultiplyNTTs : Tq -> Tq -> Tq -MultiplyNTTs a b = join [BaseCaseMultiply (f_hat_i i) (g_hat_i i) (root i) | i : Byte <- [0 .. 127]] - where - f_hat_i i = [a@(2*i),a@(2*i+1)] - g_hat_i i = [b@(2*i),b@(2*i+1)] - root i = (zeta^^(reverse (64 + (i >> 1)) >> 1) * ((-1 : (Z q)) ^^ (i))) +MultiplyNTTs f_hat g_hat = join h_hat where + h_hat = [ BaseCaseMultiply + (f_hat @(2*i)) + (f_hat @(2*i+1)) + (g_hat @(2*i)) + (g_hat @(2*i+1)) + (zeta ^^(2 * BitRev7 i + 1)) + | i <- [0 .. 127] ] /** * Testing that (1+x)^2 = 1+2x+x^2. From 71a5776ee1ee20d4b3c5a2dd593020bd04681015 Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Wed, 16 Oct 2024 14:32:09 -0400 Subject: [PATCH 8/8] mlkem: condense naive NTT implementations #147 This doesn't make them spec adherent but it simplifies the section a bit. --- .../Cipher/ML_KEM/Specification.cry | 44 +++++++------------ 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index cf8c7eed..ece7211f 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -511,43 +511,29 @@ import submodule NTT submodule NTT where private /** - * Compute the NTT representation of the polynomial `f`. + * Number theoretic transform: compute the "NTT representation" in + * `T_q` of a polynomial in `R_q`. * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. - */ - ParametricNTT : Rq -> (Z q) -> Tq - ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i)+1)*j) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i)+1)*j) | j <- [0 .. 127]] - - /** - * Compute most of the polynomial that corresponds to the NTT representation - * `f`. - * (The last step 14 is in a separate function) - * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. - */ - ParametricNTTInv : Tq -> (Z q) -> Rq - ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j)+1)*i) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j)+1)*i) | j <- [0 .. 127]] - - /** - * Number theoretic transform: converts elements in `R_q` to `T_q`. - * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. + * [FIPS-203] Section 4.3, Algorithm 9. */ NaiveNTT : Rq -> Tq - NaiveNTT f = ParametricNTT f zeta + NaiveNTT f = join [[f2i i, f2iPlus1 i] | i <- [0 .. 127]] where + f2i i = sum [f @(2*j) * zeta_term i j | j <- [0 .. 127]] + f2iPlus1 i = sum [f @(2*j+1) * zeta_term i j | j <- [0 .. 127]] + zeta_term i j = zeta ^^ ((2 * BitRev7 i + 1) * j) /** - * Inverse of the number theoretic transform: converts elements in `T_q` to - * `R_q`. + * Inverse of the number theoretic transform: converts from the "NTT + * representation" in `T_q` to a polynomial in `R_q`. * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. + * [FIPS-203] Section 4.3, Algorithm 10. */ NaiveNTTInv : Tq -> Rq - NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)] + NaiveNTTInv f_hat = [f_i * 3303 | f_i <- f] where + f = join [[f2i i, f2iPlus1 i] | i <- [0 .. 127]] + f2i i = sum [f_hat @(2*j) * zeta_term i j | j <- [0 .. 127]] + f2iPlus1 i = sum [f_hat @(2*j+1) * zeta_term i j | j <- [0 .. 127]] + zeta_term i j = (recip zeta) ^^ ((2 * BitRev7 j + 1) * i) ////////////////////////////////////////////////////////////// // This section specifies fast O(N log N) NTT and Inverse NTT