diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index e772ca1..72c6298 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -115,84 +115,187 @@ XOF : ([34]Byte) -> [inf]Byte XOF(d) = groupBy (SHAKE128::xof (join d)) /** - * Conversion from bit arrays to byte arrays. + * Conversion from little-endian bit arrays to byte arrays. * [FIPS-203] Section 4.2.1, Algorithm 3. */ -BitsToBytes : {ell} (fin ell, ell > 0) => [ell*8]Bit -> [ell]Byte -BitsToBytes input = map reverse (groupBy input) +BitsToBytes : {ell} (fin ell) => [8 * ell]Bit -> [ell]Byte +BitsToBytes b + | ell == 0 => zero + | ell > 0 => B where + // Group the bits into the B[⌊i / 8⌋] sets; pad them to support + // subsequent operations, and correlate each bit with its index `i`. + b' = groupBy`{8} [(zext [bi], i) + | bi <- b + | i <- [0..8 * ell - 1]] + + // Steps 2-4. + B = [sum [bi * (2 ^^ (i % 8)) + | (bi, i) <- bi8] + | bi8 <- b'] /** * Conversion from byte arrays to bit arrays. * [FIPS-203] Section 4.2.1, Algorithm 4. */ -BytesToBits : {ell} (fin ell, ell > 0) => [ell]Byte -> [ell*8]Bit -BytesToBits input = join (map reverse input) +BytesToBits : {ell} (fin ell) => [ell]Byte -> [ell*8]Bit +BytesToBits C + | ell == 0 => [] + | ell > 0 => join [[ b8ij where + // Step 4. Taking the last bit is the same as modding by 2. (See + // `mod2IsFinalBit`). + b8ij = Ci' ! 0 + // Step 5. Shifting right is the same as the iterative + // division (see `div2IsShiftR`). This accounts for all the + // divisions "up to this point" (e.g. none when `j = 0`), which + // is why we use `Ci'` to evaluate `b8ij` above. + Ci' = Ci >> j + // Step 3. + | j <- [0..7]] + // Step 2. We iterate over `C` directly instead of indexing into it. + | Ci <- C ] -// In Cryptol, rounding is computed via the built-in function roundAway -property rounding = ((roundAway(1.5) == 2) && (roundAway(1.4) == 1)) +/** + * The iterative division by 2 in `BytesToBits` is the same as shifting right. + * ```repl + * :prove div2IsShiftR + * ``` + */ +div2IsShiftR : Byte -> Bit +div2IsShiftR C = take (d2 C) == shl where + // Note: division here is floor'd by default. + d2 c = [c] # d2 (c / 2) + shl = [C >> j | j <- [0..7]] /** - * Compression from an integer mod `q` to an integer mod `2^d`. + * The conversions between bits and bytes are each others' inverses. + * [FIPS-203] Section 4.2.1 (see description on Algorithm 4). + * The sample `ell` values here are a subset of the possible values in the spec. + * ```repl + * :prove B2B2BInverts`{32} + * :prove B2B2BInverts`{192} + * :prove B2B2BInverts`{384} + * ``` + */ +B2B2BInverts : {ell} (fin ell, ell > 0) => [ell * 8] -> Bit +B2B2BInverts bits = bitsWorks && bytesWorks where + bitsWorks = BytesToBits (BitsToBytes bits) == bits + bytesWorks = BitsToBytes (BytesToBits (split bits)) == split bits + +/** + * This currently fails due to endianness issues! + * Check the example given in the spec for converting between bits and bytes. + * [FIPS-203] Section 4.2.1 "Converting between bits and bytes." + * ```repl + * :prove B2BExampleWorks + * ``` + */ +B2BExampleWorks = BitsToBytes 0b11010001 == [139] + +/** + * In Cryptol, rounding is computed via the built-in function `roundAway`. + * [FIPS-203] Section 2.3. + */ +property roundingWorks y = y >= 0 ==> roundUpWorks && roundDownWorks where + y' = fromInteger y + roundUpWorks = roundAway (y' + 0.5) == (y + 1) + roundDownWorks = roundAway (y' + 0.4) == y + +/** + * Compress an integer mod `q` into an integer mod `2^d`. * [FIPS-203] Section 4.2.1, Equation 4.7. */ -Compress'' : {d} (d < lg2 q) => Z q -> [d] -Compress'' x = fromInteger(roundAway(((2^^`d)/.`q) * fromInteger(fromZ(x))) % 2^^`d) +Compress : {d} (d < width q) => Z q -> [d] +Compress x = y where + // Convert from an integer mod `q` to a rational number. + x' = fromInteger (fromZ x) : Rational + // Compress. Note that `/.` denotes division of rationals. + y' = roundAway (((2^^`d) /. `q) * x') + // mod 2^^d (by converting from an integer to a d-bit vector). + y = (fromInteger y') : [d] /** - * Decompression from an integer mod `2^d` to an integer mod `q`. + * Decompress an integer mod `2^d` into an integer mod `q`. * [FIPS-203] Section 4.2.1, Equation 4.8. */ -Decompress'' : {d} (d < lg2 q) => [d] -> Z q -Decompress'' x = fromInteger(roundAway(((`q)/.(2^^`d))*fromInteger(toInteger(x)))) +Decompress : {d} (d < width q) => [d] -> Z q +Decompress y = x where + // Convert from a d-length bit vector to a rational number. + y' = fromInteger (toInteger y) : Rational + // Decompress! As before, `/.` is division of rationals. + x' = roundAway((`q /. (2^^`d)) * y') + // Convert from an integer to an integer mod `q`. + x = (fromInteger x') : Z q + +/** + * Compression inverts decompression for all inputs and bit lengths. + * We'll prove it for the bit lengths found in the + * ```repl + * :prove CompressInvertsDecompress`{1} + * :exhaust CompressInvertsDecompress`{d_u} + * :exhaust CompressInvertsDecompress`{d_v} + * ``` + */ +CompressInvertsDecompress : {d} (d < width q) => [d] -> Bit +property CompressInvertsDecompress y = Compress (Decompress y) == y /** * When `d` is large, compression followed by decompression must not * significantly alter the value. + * This sets `d = d_u`, which is the largest value for `d` used in the + * spec. * [FIPS-203] Section 4.2.1, "Compression and Decompression". + * ```repl + * :exhaust DecompressMostlyInvertsCompress + * ``` */ -CorrectnessCompress : Z q -> Bit -property CorrectnessCompress x = err <= B_q`{d_u} where - x' = Decompress''`{d_u}(Compress''`{d_u}(x)) - err = abs(modpm(x'-x)) - +DecompressMostlyInvertsCompress : Z q -> Bit +property DecompressMostlyInvertsCompress x = errIsSmallEnough where + x' = Decompress`{d_u} (Compress`{d_u} x) + err = abs (modpm (x' - x)) + errIsSmallEnough = err <= B_q`{d_u} + + // The spec doesn't describe formally what "not significantly altered" + // means; we use this equation. B_q : {d} (d < lg2 q) => Integer B_q = roundAway((`q/.(2^^(`d+1)))) - modpm : {alpha} (fin alpha, alpha > 0) => Z alpha -> Integer - modpm r = if r' > (`alpha / 2) then r' - `alpha else r' - where r' = fromZ(r) + // Convert an integer mod `q` to a representation centered around 0 + // (and represented as an `Integer`). + modpm : Z q -> Integer + modpm r = if r' > (`q / 2) then r' - `q else r' + where r' = fromZ r /** * Compression applied to a vector is equivalent to applying compression to * each individual element. * [FIPS-203] Section 2.4.8, Equation 2.15. */ -Compress' : {d} (d < lg2 q) => Z_q_256 -> [n][d] -Compress' x = map Compress''`{d} x +Compress_Vec : {d} (d < lg2 q) => Z_q_256 -> [n][d] +Compress_Vec x = map Compress`{d} x /** * Decompression applied to a vector is equivalent to applying decompression to * each individual element. * [FIPS-203] Section 2.4.8. */ -Decompress' : {d} (d < lg2 q) => [n][d] -> Z_q_256 -Decompress' x = map Decompress''`{d} x +Decompress_Vec : {d} (d < lg2 q) => [n][d] -> Z_q_256 +Decompress_Vec x = map Decompress`{d} x /** - * Compression applied to an array is equivalent to applying compression to + * Compression applied to a matrix is equivalent to applying compression to * each individual element. * [FIPS-203] Section 2.4.8. */ -Compress : {d, k1} (d < lg2 q, fin k1) => [k1]Z_q_256 -> [k1][n][d] -Compress x = map Compress'`{d} x +Compress_Mat : {d, k1} (d < lg2 q, fin k1) => [k1]Z_q_256 -> [k1][n][d] +Compress_Mat x = map Compress_Vec`{d} x /** - * Decompression applied to an array is equivalent to applying decompression to + * Decompression applied to a matrix is equivalent to applying decompression to * each individual element. * [FIPS-203] Section 2.4.8. */ -Decompress : {d, k1} (d < lg2 q, fin k1) => [k1][n][d] -> [k1]Z_q_256 -Decompress x = map Decompress'`{d} x +Decompress_Mat : {d, k1} (d < lg2 q, fin k1) => [k1][n][d] -> [k1]Z_q_256 +Decompress_Mat x = map Decompress_Vec`{d} x private /** @@ -964,13 +1067,13 @@ private submodule K_PKE where // Step 19. u = NTTInv (dotMatVec (transpose A_hat) y_hat) + e1 // Step 20. - mu = Decompress'`{1} (ByteDecode`{1} m) + mu = Decompress_Vec`{1} (ByteDecode`{1} m) // Step 21. v = (NTTInv' (dotVecVec t_hat y_hat)) + e2 + mu // Step 22. - c1 = ByteEncode_Vec`{d_u} (Compress`{d_u} u) + c1 = ByteEncode_Vec`{d_u} (Compress_Mat`{d_u} u) // Step 23. - c2 = ByteEncode`{d_v} (Compress'`{d_v} v) + c2 = ByteEncode`{d_v} (Compress_Vec`{d_v} v) // Step 24. c = c1 # c2 @@ -990,15 +1093,15 @@ private submodule K_PKE where // Step 2. c2 = c @@[32 * d_u * k .. 32 * (d_u * k + d_v) - 1] // Step 3.j - u' = Decompress`{d_u} (ByteDecode_Vec`{d_u} c1) + u' = Decompress_Mat`{d_u} (ByteDecode_Vec`{d_u} c1) // Step 4. - v' = Decompress'`{d_v} (ByteDecode`{d_v} c2) + v' = Decompress_Vec`{d_v} (ByteDecode`{d_v} c2) // Step 5. s_hat = ByteDecode12_Vec dkPKE // Step 6. w = v' - NTTInv' (dotVecVec s_hat (NTT u')) // Step 7. - m = ByteEncode`{1} (Compress'`{1} w) + m = ByteEncode`{1} (Compress_Vec`{1} w) /** * The K-PKE scheme must satisfy the basic properties of an encryption