diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index bc5c0a9c..ece7211f 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -10,7 +10,7 @@ * * Sources: * [FIPS-203]: National Institute of Standards and Technology. Module-Lattice- - * Basead Key-Encapsulation Mechanism Standard. (Department of Commerce, + * Based Key-Encapsulation Mechanism Standard. (Department of Commerce, * Washington, D.C.), Federal Information Processing Standards Publication * (FIPS) NIST FIPS 203. August 2024. * @see https://doi.org/10.6028/NIST.FIPS.203 @@ -63,6 +63,33 @@ type Byte = [8] */ type Z_q_256 = [n](Z q) +/** + * An element in the ring `R_q`. + * + * An element in the ring is a polynomial of degree at most 255 (e.g. with 256 + * terms). The `i`th element in this array represents the coefficient of the + * degree-`i` term. + * + * [FIPS-203] Section 2.3 (definition of the ring). + * [FIPS-203] Section 2.4.4, Equation 2.5 (definition of the representation of + * elements in the ring). + */ +type Rq = [n](Z q) + +/** + * An element in the ring `T_q`. + * + * An element in this ring (sometimes called the "NTT representation") is a + * tuple of 128 polynomials, each of degree at most one (e.g. with two terms). + * The `2i` and `2i+1`th terms in this array represent the degree-0 and + * degree-1 coefficients of the `i`th polynomial, respectively. + * + * [FIPS-203] Section 2.3 (definition of the `T_q`). + * [FIPS-203] Section 2.4.4 Equation 2.7 (definition of the representation of + * an element in `T_q`). + */ +type Tq = [n](Z q) + /** * Pseudorandom function (PRF). * [FIPS-203] Section 4.1, Equations 4.2 and 4.3. @@ -336,7 +363,7 @@ property CorrectnessEncodeDecode' f = Decode'`{12}(Encode'`{12} f) == f * * [FIPS-203] Section 4.2.2, Algorithm 7. */ -SampleNTT : [34]Byte -> Z_q_256 +SampleNTT : [34]Byte -> Tq SampleNTT B = a_hat' where // Steps 1-2, 5. // We (lazily) take an infinite stream from the XOF and remove only as @@ -401,7 +428,7 @@ SampleNTT B = a_hat' where * * [FIPS-203] Section 4.2.2, Algorithm 8. */ -SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Z_q_256 +SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Rq SamplePolyCBD B = f where // Step 1. b = BytesToBits B @@ -450,237 +477,267 @@ property Is256thRootOfq p = (p == 0) || (p >= 256) || (zeta^^p != 1) * Reverse the unsigned 7-bit value corresponding to an input integer in * `[0, ..., 127]`. * [FIPS-203] Section 4.3 "The mathematical structure of the NTT." - */ -BitRev7 : [8] -> [8] -BitRev7 = reverse - - -////////////////////////////////////////////////////////////// -// This section specifies a naive O(N**2) NTT and Inverse NTT -// -// A "fast" O(N log N) version is below, followed by a -// proof of their equivalence -////////////////////////////////////////////////////////////// - -/** - * Compute the NTT representation of the polynomial `f`. - * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. - */ -ParametricNTT : Z_q_256 -> (Z q) -> Z_q_256 -ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] - -/** - * Compute most of the polynomial that corresponds to the NTT representation - * `f`. - * (The last step 14 is in a separate function) * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. - */ -ParametricNTTInv : Z_q_256 -> (Z q) -> Z_q_256 -ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] - -/** - * Number theoretic transform: converts elements in `R_q` to `T_q`. - * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. - */ -NaiveNTT : Z_q_256 -> Z_q_256 -NaiveNTT f = ParametricNTT f zeta - -/** - * Inverse of the number theoretic transform: converts elements in `T_q` to - * `R_q`. + * This diverges from the spec by operating over an 8-bit vector; + * this is to ease prior and subsequent computations that would overflow a + * 7-bit vector, like: + * - `2 * (BitRev7 i) + 1` + * - `2 * i + 1` * - * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. - */ -NaiveNTTInv : Z_q_256 -> Z_q_256 -NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)] - -////////////////////////////////////////////////////////////// -// This section specifies fast O(N log N) NTT and Inverse NTT -// -// A readable explanation of the derivation of this form of -// the NTT is in "A Complete Beginner Guide to the Number -// Theoretic Transform (NTT)" by Ardianto Satriawan, -// Rella Mareta, and Hanho Lee. Available from: -// https://eprint.iacr.org/2024/585 -// -// This section Copyright Amazon.com, Inc. or its affiliates. -////////////////////////////////////////////////////////////// - -// Simple lookup table for Zeta value given K -zeta_expc : [128](Z q) -zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848, - 1062, 1919, 193, 797, 2786, 3260, 569, 1746, - 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, - 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, - 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, - 650, 1977, 2513, 632, 2865, 33, 1320, 1915, - 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, - 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, - 17, 2761, 583, 2649, 1637, 723, 2288, 1100, - 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, - 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, - 939, 2308, 2437, 2388, 733, 2337, 268, 641, - 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, - 1063, 319, 2773, 757, 2099, 561, 2466, 2594, - 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, - 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 ] - -// Fast recursive CT-NTT -ct_butterfly : - {m, hm} - (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => - [2^^m](Z q) -> (Z q) -> [2^^m](Z q) -ct_butterfly v z = new_v - where - halflen = 2^^`hm - lower, upper : [2^^hm](Z q) - lower@x = v@x + z * v@(x + halflen) - upper@x = v@x - z * v@(x + halflen) - new_v = lower # upper - -fast_nttl : - {lv} // Length of v is a member of {256,128,64,32,16,8,4} - (lv >= 2, lv <= 8) => - [2^^lv](Z q) -> [8] -> [2^^lv](Z q) -fast_nttl v k - // Base case. lv==2 so just compute the butterfly and return - | lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k) - - // Recursive case. Butterfly what we have, then recurse on each half, - // concatenate the results and return. - | lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) # - (fast_nttl`{lv-1} s1 (k * 2 + 1)) - where - t = ct_butterfly`{lv,lv-1} v (zeta_expc@k) - // Split t into two halves s0 and s1 - [s0, s1] = split t - -// Top level entry point - start with lv=256, k=1 -fast_ntt : Z_q_256 -> Z_q_256 -fast_ntt v = fast_nttl v 1 - -// Fast recursive GS-Inverse-NTT -gs_butterfly : - {m, hm} - (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => - [2^^m](Z q) -> (Z q) -> [2^^m](Z q) -gs_butterfly v z = new_v - where - halflen = 2^^`hm - lower, upper : [2^^hm](Z q) - lower@x = v@x + v@(x + halflen) - upper@x = z * (v@(x + halflen) - v@x) - new_v = lower # upper - -fast_invnttl : - {lv} // Length of v is a member of {256,128,64,32,16,8,4} - (lv >= 2, lv <= 8) => - [2^^lv](Z q) -> [8] -> [2^^lv](Z q) - -fast_invnttl v k - // Base case. lv==2 so just compute the butterfly and return - | lv == 2 => gs_butterfly`{lv,lv-1} v (zeta_expc@k) - - // Recursive case. Recurse on each half, - // concatenate the results, butterfly that, and return. - | lv > 2 => gs_butterfly`{lv,lv-1} t (zeta_expc@k) - where - // Split t into two halves s0 and s1 - [s0, s1] = split v - t = (fast_invnttl`{lv-1} s0 (k * 2 + 1)) # - (fast_invnttl`{lv-1} s1 (k * 2)) - -// Multiply all elements of v by the reciprocal of 128 (modulo q) -recip_128_modq = (recip 128) : (Z q) -mul_recip128 : Z_q_256 -> Z_q_256 -mul_recip128 v = [ v@x * recip_128_modq | x <- [0 .. Z_q_256 -fast_invntt v = mul_recip128 (fast_invnttl v 1) - -////////////////////////////////////////////////////////////// -// Properties and proofs of Naive and Fast NTT -////////////////////////////////////////////////////////////// - -/** - * This property demonstrates that NaiveNTT is self-inverting. - * ``` - * :prove NaiveNTT_Inverts - * ``` - */ -NaiveNTT_Inverts : Z_q_256 -> Bit -property NaiveNTT_Inverts f = NaiveNTTInv (NaiveNTT f) == f - -/** - * This property demonstrates that NaiveNTTInv is self-inverting. - * ``` - * :prove NaiveNTTInv_Inverts - * ``` - */ -NaiveNTTInv_Inverts : Z_q_256 -> Bit -property NaiveNTTInv_Inverts f = NaiveNTT (NaiveNTTInv f) == f - -/** - * This property demonstrates that `fast_ntt` is the inverse of `fast_invntt`. - * ``` - * :prove fast_ntt_inverts + * A "pure" implementation of `BitRev7` in Cryptol is the `reverse` function + * on 7-bit vectors. This mini-property shows equivalence: + * ```repl + * :prove \(x:[7]) -> ([0] # reverse x) == BitRev7 ([0] # x) * ``` */ -fast_ntt_inverts : Z_q_256 -> Bit -property fast_ntt_inverts f = fast_invntt (fast_ntt f) == f +BitRev7 : [8] -> [8] +BitRev7 i = if i > 255 then error "BitRev7 called with invalid input" + else (reverse i) >> 1 -/** - * This property demonstrates that `fast_invntt` is the inverse of `fast_ntt`. - * ``` - * :prove fast_invntt_inverts - * ``` - */ -fast_invntt_inverts : Z_q_256 -> Bit -property fast_invntt_inverts f = fast_ntt (fast_invntt f) == f /** - * This property demonstrates that `naive_ntt` is equivalent to `fast_ntt`. - * ``` - * :prove naive_fast_ntt_equiv - * ``` - */ -naive_fast_ntt_equiv : Z_q_256 -> Bit -property naive_fast_ntt_equiv f = NaiveNTT f == fast_ntt f + * This section specifies the number-theoretic transform (NTT). + * + * It includes the version from [FIPS-203] Section 4.3 as well + * as a faster O(N log N) version, and a proof of their equivalence. + * + * This is explicitly allowed by the spec: "For every computational procedure + * [...] a conforming implementation may replace the given set of steps with + * any mathematically equivalent set of steps". The equivalence properties + * prove mathematical equivalence. + * [FIPS-203] Introduction, "7. Implementations". + */ +import submodule NTT +submodule NTT where + private + /** + * Number theoretic transform: compute the "NTT representation" in + * `T_q` of a polynomial in `R_q`. + * + * [FIPS-203] Section 4.3, Algorithm 9. + */ + NaiveNTT : Rq -> Tq + NaiveNTT f = join [[f2i i, f2iPlus1 i] | i <- [0 .. 127]] where + f2i i = sum [f @(2*j) * zeta_term i j | j <- [0 .. 127]] + f2iPlus1 i = sum [f @(2*j+1) * zeta_term i j | j <- [0 .. 127]] + zeta_term i j = zeta ^^ ((2 * BitRev7 i + 1) * j) + + /** + * Inverse of the number theoretic transform: converts from the "NTT + * representation" in `T_q` to a polynomial in `R_q`. + * + * [FIPS-203] Section 4.3, Algorithm 10. + */ + NaiveNTTInv : Tq -> Rq + NaiveNTTInv f_hat = [f_i * 3303 | f_i <- f] where + f = join [[f2i i, f2iPlus1 i] | i <- [0 .. 127]] + f2i i = sum [f_hat @(2*j) * zeta_term i j | j <- [0 .. 127]] + f2iPlus1 i = sum [f_hat @(2*j+1) * zeta_term i j | j <- [0 .. 127]] + zeta_term i j = (recip zeta) ^^ ((2 * BitRev7 j + 1) * i) + + ////////////////////////////////////////////////////////////// + // This section specifies fast O(N log N) NTT and Inverse NTT + // + // A readable explanation of the derivation of this form of + // the NTT is in "A Complete Beginner Guide to the Number + // Theoretic Transform (NTT)" by Ardianto Satriawan, + // Rella Mareta, and Hanho Lee. Available from: + // https://eprint.iacr.org/2024/585 + // + // This section Copyright Amazon.com, Inc. or its affiliates. + ////////////////////////////////////////////////////////////// + + /** + * Lookup table for `zeta` values computed in `NTT` and `NTTInv`. + * + * [FIPS-203] Appendix A. + */ + zeta_expc : [128](Z q) + zeta_expc = [ + 1, 1729, 2580, 3289, 2642, 630, 1897, 848, + 1062, 1919, 193, 797, 2786, 3260, 569, 1746, + 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, + 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, + 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, + 650, 1977, 2513, 632, 2865, 33, 1320, 1915, + 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, + 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, + 17, 2761, 583, 2649, 1637, 723, 2288, 1100, + 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, + 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, + 939, 2308, 2437, 2388, 733, 2337, 268, 641, + 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, + 1063, 319, 2773, 757, 2099, 561, 2466, 2594, + 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, + 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 + ] + + + // Top level entry point - start with lv=256, k=1 + fast_ntt : Rq -> Tq + fast_ntt v = fast_nttl v 1 + + // Recursive NTT function + fast_nttl : + {lv} // Length of v is a member of {256,128,64,32,16,8,4} + (lv >= 2, lv <= 8) => + [2^^lv](Z q) -> [8] -> [2^^lv](Z q) + fast_nttl v k + // Base case. lv==2 so just compute the butterfly and return + | lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k) + + // Recursive case. Butterfly what we have, then recurse on each half, + // concatenate the results and return. + | lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) # + (fast_nttl`{lv-1} s1 (k * 2 + 1)) + where + t = ct_butterfly`{lv,lv-1} v (zeta_expc@k) + // Split t into two halves s0 and s1 + [s0, s1] = split t + + // Fast recursive CT-NTT + ct_butterfly : + {m, hm} + (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => + [2^^m](Z q) -> (Z q) -> [2^^m](Z q) + ct_butterfly v z = new_v where + halflen = 2^^`hm + lower, upper : [2^^hm](Z q) + lower@x = v@x + z * v@(x + halflen) + upper@x = v@x - z * v@(x + halflen) + new_v = lower # upper + + + // Recursive inverse-NTT function + fast_invnttl : + {lv} // Length of v is a member of {256,128,64,32,16,8,4} + (lv >= 2, lv <= 8) => + [2^^lv](Z q) -> [8] -> [2^^lv](Z q) + fast_invnttl v k + // Base case. lv==2 so just compute the butterfly and return + | lv == 2 => gs_butterfly`{lv,lv-1} v (zeta_expc@k) + + // Recursive case. Recurse on each half, + // concatenate the results, butterfly that, and return. + | lv > 2 => gs_butterfly`{lv,lv-1} t (zeta_expc@k) where + // Split t into two halves s0 and s1 + [s0, s1] = split v + t = (fast_invnttl`{lv-1} s0 (k * 2 + 1)) # + (fast_invnttl`{lv-1} s1 (k * 2)) + + // Fast recursive GS-Inverse-NTT + gs_butterfly : + {m, hm} + (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => + [2^^m](Z q) -> (Z q) -> [2^^m](Z q) + gs_butterfly v z = new_v where + halflen = 2^^`hm + lower, upper : [2^^hm](Z q) + lower@x = v@x + v@(x + halflen) + upper@x = z * (v@(x + halflen) - v@x) + new_v = lower # upper + + // Top level entry point - start with lv=256, k=1 + fast_invntt : Tq -> Rq + fast_invntt v = mul_recip128 (fast_invnttl v 1) where + + // Multiplicative inverse of 128, mod `q`. + recip_128_modq = (recip 128) : (Z q) + + // Multiply all elements of v' by the reciprocal of 128 (modulo q) + mul_recip128 : Tq -> Tq + mul_recip128 v' = [ v'@x * recip_128_modq | x <- [0 .. Bit + property NaiveNTT_Inverts f = NaiveNTTInv (NaiveNTT f) == f + + /** + * This property demonstrates that NaiveNTTInv is self-inverting. + * ```repl + * :prove NaiveNTTInv_Inverts + * ``` + */ + NaiveNTTInv_Inverts : Tq -> Bit + property NaiveNTTInv_Inverts f = NaiveNTT (NaiveNTTInv f) == f + + /** + * This property demonstrates that `fast_ntt` is the inverse of `fast_invntt`. + * ```repl + * :prove fast_ntt_inverts + * ``` + */ + fast_ntt_inverts : Rq -> Bit + property fast_ntt_inverts f = fast_invntt (fast_ntt f) == f + + /** + * This property demonstrates that `fast_invntt` is the inverse of `fast_ntt`. + * ```repl + * :prove fast_invntt_inverts + * ``` + */ + fast_invntt_inverts : Tq -> Bit + property fast_invntt_inverts f = fast_ntt (fast_invntt f) == f + + /** + * This property demonstrates that `naive_ntt` is equivalent to `fast_ntt`. + * ```repl + * :prove naive_fast_ntt_equiv + * ``` + */ + naive_fast_ntt_equiv : Rq -> Bit + property naive_fast_ntt_equiv f = NaiveNTT f == fast_ntt f + + /** + * This property demonstrates that `naive_invntt` is equivalent to `fast_invntt`. + * ```repl + * :prove naive_fast_invntt_equiv + * ``` + */ + naive_fast_invntt_equiv : Tq -> Bit + property naive_fast_invntt_equiv f = NaiveNTTInv f == fast_invntt f -/** - * This property demonstrates that `naive_invntt` is equivalent to `fast_invntt`. - * ``` - * :prove naive_fast_invntt_equiv - * ``` - */ -naive_fast_invntt_equiv : Z_q_256 -> Bit -property naive_fast_invntt_equiv f = NaiveNTTInv f == fast_invntt f + /** + * The Number-Theoretic Transform (NTT) is used to improve the efficiency + * of multiplication in the ring `R_q`. We choose to use the fast version + * of NTT, which is equivalent to the version described in the spec. + */ + NTT : Rq -> Tq + NTT = fast_ntt -////////////////////////////////////////////////////////////// -// NTT "dispatcher" -// -// Here, we can choose to call either the naive or fast NTT -////////////////////////////////////////////////////////////// + /** + * The inverse of the Number-Theoretic Transform (NTT) is used to improve + * the efficiency of multiplication in the ring `R_q`. We choose to use + * the fast version of inverse function, which is equivalent to the version + * described in the spec. + */ + NTTInv : Tq -> Rq + NTTInv = fast_invntt -NTT' : Z_q_256 -> Z_q_256 -// fast -NTT' f = fast_ntt f -// slow -//NTT' f = NaiveNTT f + /** + * The notation `NTT` is overloaded to mean both a single application of `NTT` + * to an element of `R_q` and also `k` applications of `NTT` to every element + * of a `k`-length vector. + * [FIPS-203] Section 2.4.6 Equation 2.9. + */ + NTT_Vec v = map NTT v -NTTInv' : Z_q_256 -> Z_q_256 -// fast -NTTInv' f = fast_invntt f -// slow -//NTTInv' f = NaiveNTTInv f + /** + * The notation `NTTInv` is overloaded to mean both a single application of + * `NTTInv` to an element of `R_q` and also `k` applications of `NTTInv` to + * every element of a `k`-length vector. + * [FIPS-203] Section 2.4.6. + */ + NTTInv_Vec v = map NTTInv v ////////////////////////////////////////////////////////////// // Polynomial multiplication in the NTT domain @@ -691,25 +748,28 @@ NTTInv' f = fast_invntt f * quadratic modulus. * [FIPS-203] Section 4.3.1 Algorithm 12. */ -BaseCaseMultiply : [2] (Z q) -> [2] (Z q) -> (Z q) -> [2] (Z q) -BaseCaseMultiply a b root = [c0, c1] +BaseCaseMultiply : (Z q) -> (Z q) -> (Z q) -> (Z q) -> (Z q) -> [2](Z q) +BaseCaseMultiply a0 a1 b0 b1 γ = [c0, c1] where - c0 = a@1 * b@1 * root + a@0 * b@0 - c1 = a@0 * b@1 + a@1 * b@0 + c0 = a0 * b0 + a1 * b1 * γ + c1 = a0 * b1 + a1 * b0 /** * Compute the product (in the ring `T_q`) of two NTT representations. * [FIPS-203] Section 4.3.1 Algorithm 11. */ -MultiplyNTTs : Z_q_256 -> Z_q_256 -> Z_q_256 -MultiplyNTTs a b = join [BaseCaseMultiply (f_hat_i i) (g_hat_i i) (root i) | i : Byte <- [0 .. 127]] - where - f_hat_i i = [a@(2*i),a@(2*i+1)] - g_hat_i i = [b@(2*i),b@(2*i+1)] - root i = (zeta^^(reverse (64 + (i >> 1)) >> 1) * ((-1 : (Z q)) ^^ (i))) +MultiplyNTTs : Tq -> Tq -> Tq +MultiplyNTTs f_hat g_hat = join h_hat where + h_hat = [ BaseCaseMultiply + (f_hat @(2*i)) + (f_hat @(2*i+1)) + (g_hat @(2*i)) + (g_hat @(2*i+1)) + (zeta ^^(2 * BitRev7 i + 1)) + | i <- [0 .. 127] ] /** - * Testing that (1+x)^2 = 1+2x+x^2 + * Testing that (1+x)^2 = 1+2x+x^2. * ```repl * :prove TestMult * ``` @@ -719,39 +779,23 @@ property TestMult = prod f f == fsq where f = [1, 1] # [0 | i <- [3 .. 256]] fsq = [1,2,1] # [0 | i <- [4 .. 256]] - prod : Z_q_256 -> Z_q_256 -> Z_q_256 - prod a b = NTTInv' (MultiplyNTTs (NTT' a) (NTT' b)) + prod : Rq -> Rq -> Rq + prod a b = NTTInv (MultiplyNTTs (NTT a) (NTT b)) /** * The cross product notation ×𝑇𝑞 is defined as the `MultiplyNTTs` function * (also referred to as `T_q` multiplication). * [FIPS-203] Section 2.4.5 Equation 2.8. */ -dot : Z_q_256 -> Z_q_256 -> Z_q_256 +dot : Tq -> Tq -> Tq dot f g = MultiplyNTTs f g -/** - * The notation `NTT` is overloaded to mean both a single application of `NTT` - * to an element of `R_q` and also `k` applications of `NTT` to every element - * of a `k`-length vector. - * [FIPS-203] Section 2.4.6 Equation 2.9. - */ -NTT v = map NTT' v - -/** - * The notation `NTTInv` is overloaded to mean both a single application of - * `NTTInv` to an element of `R_q` and also `k` applications of `NTTInv` to - * every element of a `k`-length vector. - * [FIPS-203] Section 2.4.6. - */ -NTTInv v = map NTTInv' v - /** * Overloaded `dot` function between two vectors is a standard dot-product * functionality with `T_q` multiplication as the base operation. * [FIPS-203] Section 2.4.7 Equation 2.14. */ -dotVecVec : {k1} (fin k1) => [k1]Z_q_256 -> [k1]Z_q_256 -> Z_q_256 +dotVecVec : {k1} (fin k1) => [k1]Tq -> [k1]Tq -> Tq dotVecVec v1 v2 = sum (zipWith dot v1 v2) /** @@ -759,7 +803,7 @@ dotVecVec v1 v2 = sum (zipWith dot v1 v2) * vector multiplication with `T_q` multiplication as the base operation. * [FIPS-203] Section 2.4.7 Equation 2.12 and 2.13. */ -dotMatVec : {k1,k2} (fin k1, fin k2) => [k1][k2]Z_q_256 -> [k2]Z_q_256 -> [k1]Z_q_256 +dotMatVec : {k1,k2} (fin k1, fin k2) => [k1][k2]Tq -> [k2]Tq -> [k1]Tq dotMatVec matrix vector = [dotVecVec v1 vector | v1 <- matrix] /** @@ -768,7 +812,7 @@ dotMatVec matrix vector = [dotVecVec v1 vector | v1 <- matrix] * [FIPS-203] Section 2.4.7. */ dotMatMat :{k1,k2,k3} (fin k1, fin k2, fin k3) => - [k1][k2]Z_q_256 -> [k2][k3]Z_q_256 -> [k1][k3]Z_q_256 + [k1][k2]Tq -> [k2][k3]Tq -> [k1][k3]Tq dotMatMat matrix1 matrix2 = transpose [dotMatVec matrix1 vector | vector <- m'] where m' = transpose matrix2 @@ -821,9 +865,9 @@ private submodule K_PKE where e = [SamplePolyCBD`{eta_1} (PRF σ N) | N <- [k .. 2 * k - 1]] // Step 16. - s_hat = NTT s + s_hat = NTT_Vec s // Step 17. - e_hat = NTT e + e_hat = NTT_Vec e // Step 18. t_hat = (dotMatVec A_hat s_hat) + e_hat // Step 19. @@ -861,13 +905,13 @@ private submodule K_PKE where // value instead. e2 = SamplePolyCBD`{eta_2} (PRF r (2 * `k)) // Step 18. - y_hat = NTT y + y_hat = NTT_Vec y // Step 19. - u = NTTInv (dotMatVec (transpose A_hat) y_hat) + e1 + u = NTTInv_Vec (dotMatVec (transpose A_hat) y_hat) + e1 // Step 20. mu = Decompress'`{1} (DecodeBytes'`{1} m) // Step 21. - v = (NTTInv' (dotVecVec t_hat y_hat)) + e2 + mu + v = (NTTInv (dotVecVec t_hat y_hat)) + e2 + mu // Step 22. c1 = EncodeBytes`{d_u} (Compress`{d_u} u) // Step 23. @@ -897,7 +941,7 @@ private submodule K_PKE where // Step 5. s_hat = Decode`{12} dkPKE // Step 6. - w = v' - NTTInv' (dotVecVec s_hat (NTT u')) + w = v' - NTTInv (dotVecVec s_hat (NTT_Vec u')) // Step 7. m = EncodeBytes'`{1} (Compress'`{1} w)