Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-KEM: Improve compression and byte conversion functions #161

Merged
merged 8 commits into from
Nov 1, 2024
177 changes: 140 additions & 37 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -115,84 +115,187 @@ XOF : ([34]Byte) -> [inf]Byte
XOF(d) = groupBy (SHAKE128::xof (join d))

/**
* Conversion from bit arrays to byte arrays.
* Conversion from little-endian bit arrays to byte arrays.
* [FIPS-203] Section 4.2.1, Algorithm 3.
*/
BitsToBytes : {ell} (fin ell, ell > 0) => [ell*8]Bit -> [ell]Byte
BitsToBytes input = map reverse (groupBy input)
BitsToBytes : {ell} (fin ell) => [8 * ell]Bit -> [ell]Byte
BitsToBytes b
| ell == 0 => zero
| ell > 0 => B where
// Group the bits into the B[⌊i / 8⌋] sets; pad them to support
// subsequent operations, and correlate each bit with its index `i`.
b' = groupBy`{8} [(zext [bi], i)
| bi <- b
| i <- [0..8 * ell - 1]]

// Steps 2-4.
B = [sum [bi * (2 ^^ (i % 8))
| (bi, i) <- bi8]
| bi8 <- b']

/**
* Conversion from byte arrays to bit arrays.
* [FIPS-203] Section 4.2.1, Algorithm 4.
*/
BytesToBits : {ell} (fin ell, ell > 0) => [ell]Byte -> [ell*8]Bit
BytesToBits input = join (map reverse input)
BytesToBits : {ell} (fin ell) => [ell]Byte -> [ell*8]Bit
BytesToBits C
| ell == 0 => []
| ell > 0 => join [[ b8ij where
// Step 4. Taking the last bit is the same as modding by 2. (See
// `mod2IsFinalBit`).
b8ij = Ci' ! 0
marsella marked this conversation as resolved.
Show resolved Hide resolved
// Step 5. Shifting right is the same as the iterative
// division (see `div2IsShiftR`). This accounts for all the
// divisions "up to this point" (e.g. none when `j = 0`), which
// is why we use `Ci'` to evaluate `b8ij` above.
Ci' = Ci >> j
// Step 3.
| j <- [0..7]]
// Step 2. We iterate over `C` directly instead of indexing into it.
| Ci <- C ]

// In Cryptol, rounding is computed via the built-in function roundAway
property rounding = ((roundAway(1.5) == 2) && (roundAway(1.4) == 1))
/**
* The iterative division by 2 in `BytesToBits` is the same as shifting right.
* ```repl
* :prove div2IsShiftR
* ```
*/
div2IsShiftR : Byte -> Bit
div2IsShiftR C = take (d2 C) == shl where
// Note: division here is floor'd by default.
d2 c = [c] # d2 (c / 2)
shl = [C >> j | j <- [0..7]]

/**
* Compression from an integer mod `q` to an integer mod `2^d`.
* The conversions between bits and bytes are each others' inverses.
* [FIPS-203] Section 4.2.1 (see description on Algorithm 4).
* The sample `ell` values here are a subset of the possible values in the spec.
* ```repl
* :prove B2B2BInverts`{32}
* :prove B2B2BInverts`{192}
* :prove B2B2BInverts`{384}
* ```
*/
B2B2BInverts : {ell} (fin ell, ell > 0) => [ell * 8] -> Bit
B2B2BInverts bits = bitsWorks && bytesWorks where
bitsWorks = BytesToBits (BitsToBytes bits) == bits
bytesWorks = BitsToBytes (BytesToBits (split bits)) == split bits

/**
* This currently fails due to endianness issues!
* Check the example given in the spec for converting between bits and bytes.
* [FIPS-203] Section 4.2.1 "Converting between bits and bytes."
* ```repl
* :prove B2BExampleWorks
* ```
*/
B2BExampleWorks = BitsToBytes 0b11010001 == [139]

/**
* In Cryptol, rounding is computed via the built-in function `roundAway`.
* [FIPS-203] Section 2.3.
*/
property roundingWorks y = y >= 0 ==> roundUpWorks && roundDownWorks where
y' = fromInteger y
roundUpWorks = roundAway (y' + 0.5) == (y + 1)
roundDownWorks = roundAway (y' + 0.4) == y

/**
* Compress an integer mod `q` into an integer mod `2^d`.
* [FIPS-203] Section 4.2.1, Equation 4.7.
*/
Compress'' : {d} (d < lg2 q) => Z q -> [d]
Compress'' x = fromInteger(roundAway(((2^^`d)/.`q) * fromInteger(fromZ(x))) % 2^^`d)
Compress : {d} (d < width q) => Z q -> [d]
Compress x = y where
// Convert from an integer mod `q` to a rational number.
x' = fromInteger (fromZ x) : Rational
// Compress. Note that `/.` denotes division of rationals.
y' = roundAway (((2^^`d) /. `q) * x')
// mod 2^^d (by converting from an integer to a d-bit vector).
y = (fromInteger y') : [d]

/**
* Decompression from an integer mod `2^d` to an integer mod `q`.
* Decompress an integer mod `2^d` into an integer mod `q`.
* [FIPS-203] Section 4.2.1, Equation 4.8.
*/
Decompress'' : {d} (d < lg2 q) => [d] -> Z q
Decompress'' x = fromInteger(roundAway(((`q)/.(2^^`d))*fromInteger(toInteger(x))))
Decompress : {d} (d < width q) => [d] -> Z q
Decompress y = x where
// Convert from a d-length bit vector to a rational number.
y' = fromInteger (toInteger y) : Rational
// Decompress! As before, `/.` is division of rationals.
x' = roundAway((`q /. (2^^`d)) * y')
// Convert from an integer to an integer mod `q`.
x = (fromInteger x') : Z q

/**
* Compression inverts decompression for all inputs and bit lengths.
* We'll prove it for the bit lengths found in the
* ```repl
* :prove CompressInvertsDecompress`{1}
* :exhaust CompressInvertsDecompress`{d_u}
* :exhaust CompressInvertsDecompress`{d_v}
* ```
*/
CompressInvertsDecompress : {d} (d < width q) => [d] -> Bit
property CompressInvertsDecompress y = Compress (Decompress y) == y

/**
* When `d` is large, compression followed by decompression must not
* significantly alter the value.
* This sets `d = d_u`, which is the largest value for `d` used in the
* spec.
* [FIPS-203] Section 4.2.1, "Compression and Decompression".
* ```repl
* :exhaust DecompressMostlyInvertsCompress
* ```
*/
CorrectnessCompress : Z q -> Bit
property CorrectnessCompress x = err <= B_q`{d_u} where
x' = Decompress''`{d_u}(Compress''`{d_u}(x))
err = abs(modpm(x'-x))

DecompressMostlyInvertsCompress : Z q -> Bit
property DecompressMostlyInvertsCompress x = errIsSmallEnough where
x' = Decompress`{d_u} (Compress`{d_u} x)
err = abs (modpm (x' - x))
errIsSmallEnough = err <= B_q`{d_u}

// The spec doesn't describe formally what "not significantly altered"
// means; we use this equation.
B_q : {d} (d < lg2 q) => Integer
B_q = roundAway((`q/.(2^^(`d+1))))

modpm : {alpha} (fin alpha, alpha > 0) => Z alpha -> Integer
modpm r = if r' > (`alpha / 2) then r' - `alpha else r'
where r' = fromZ(r)
// Convert an integer mod `q` to a representation centered around 0
// (and represented as an `Integer`).
modpm : Z q -> Integer
modpm r = if r' > (`q / 2) then r' - `q else r'
where r' = fromZ r

/**
* Compression applied to a vector is equivalent to applying compression to
* each individual element.
* [FIPS-203] Section 2.4.8, Equation 2.15.
*/
Compress' : {d} (d < lg2 q) => Z_q_256 -> [n][d]
Compress' x = map Compress''`{d} x
Compress_Vec : {d} (d < lg2 q) => Z_q_256 -> [n][d]
Compress_Vec x = map Compress`{d} x

/**
* Decompression applied to a vector is equivalent to applying decompression to
* each individual element.
* [FIPS-203] Section 2.4.8.
*/
Decompress' : {d} (d < lg2 q) => [n][d] -> Z_q_256
Decompress' x = map Decompress''`{d} x
Decompress_Vec : {d} (d < lg2 q) => [n][d] -> Z_q_256
Decompress_Vec x = map Decompress`{d} x

/**
* Compression applied to an array is equivalent to applying compression to
* Compression applied to a matrix is equivalent to applying compression to
* each individual element.
* [FIPS-203] Section 2.4.8.
*/
Compress : {d, k1} (d < lg2 q, fin k1) => [k1]Z_q_256 -> [k1][n][d]
Compress x = map Compress'`{d} x
Compress_Mat : {d, k1} (d < lg2 q, fin k1) => [k1]Z_q_256 -> [k1][n][d]
Compress_Mat x = map Compress_Vec`{d} x

/**
* Decompression applied to an array is equivalent to applying decompression to
* Decompression applied to a matrix is equivalent to applying decompression to
* each individual element.
* [FIPS-203] Section 2.4.8.
*/
Decompress : {d, k1} (d < lg2 q, fin k1) => [k1][n][d] -> [k1]Z_q_256
Decompress x = map Decompress'`{d} x
Decompress_Mat : {d, k1} (d < lg2 q, fin k1) => [k1][n][d] -> [k1]Z_q_256
Decompress_Mat x = map Decompress_Vec`{d} x

private
/**
Expand Down Expand Up @@ -964,13 +1067,13 @@ private submodule K_PKE where
// Step 19.
u = NTTInv (dotMatVec (transpose A_hat) y_hat) + e1
// Step 20.
mu = Decompress'`{1} (ByteDecode`{1} m)
mu = Decompress_Vec`{1} (ByteDecode`{1} m)
// Step 21.
v = (NTTInv' (dotVecVec t_hat y_hat)) + e2 + mu
// Step 22.
c1 = ByteEncode_Vec`{d_u} (Compress`{d_u} u)
c1 = ByteEncode_Vec`{d_u} (Compress_Mat`{d_u} u)
// Step 23.
c2 = ByteEncode`{d_v} (Compress'`{d_v} v)
c2 = ByteEncode`{d_v} (Compress_Vec`{d_v} v)
// Step 24.
c = c1 # c2

Expand All @@ -990,15 +1093,15 @@ private submodule K_PKE where
// Step 2.
c2 = c @@[32 * d_u * k .. 32 * (d_u * k + d_v) - 1]
// Step 3.j
u' = Decompress`{d_u} (ByteDecode_Vec`{d_u} c1)
u' = Decompress_Mat`{d_u} (ByteDecode_Vec`{d_u} c1)
// Step 4.
v' = Decompress'`{d_v} (ByteDecode`{d_v} c2)
v' = Decompress_Vec`{d_v} (ByteDecode`{d_v} c2)
// Step 5.
s_hat = ByteDecode12_Vec dkPKE
// Step 6.
w = v' - NTTInv' (dotVecVec s_hat (NTT u'))
// Step 7.
m = ByteEncode`{1} (Compress'`{1} w)
m = ByteEncode`{1} (Compress_Vec`{1} w)

/**
* The K-PKE scheme must satisfy the basic properties of an encryption
Expand Down