Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FFT-based specification for Poseidon MDS layer on x86 targets #886

Merged
merged 2 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion field/src/field_testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,22 @@ macro_rules! test_field_arithmetic {

use num::bigint::BigUint;
use rand::rngs::OsRng;
use rand::Rng;
use rand::{Rng, RngCore};
use $crate::types::{Field, Sample};

#[test]
fn modular_reduction() {
let mut rng = OsRng;
for _ in 0..10 {
let x_lo = rng.next_u64();
let x_hi = rng.next_u32();
let x = (x_lo as u128) + ((x_hi as u128) << 64);
let a = <$field>::from_noncanonical_u128(x);
let b = <$field>::from_noncanonical_u96((x_lo, x_hi));
assert_eq!(a, b);
}
}

#[test]
fn batch_inversion() {
for n in 0..20 {
Expand Down
13 changes: 13 additions & 0 deletions field/src/goldilocks_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl Field for GoldilocksField {
Self(n)
}

fn from_noncanonical_u96((n_lo, n_hi): (u64, u32)) -> Self {
reduce96((n_lo, n_hi))
}

fn from_noncanonical_u128(n: u128) -> Self {
reduce128(n)
}
Expand Down Expand Up @@ -337,6 +341,15 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
res_wrapped + EPSILON * (carry as u64)
}

/// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the
/// field order and `2^64`.
#[inline]
fn reduce96((x_lo, x_hi): (u64, u32)) -> GoldilocksField {
let t1 = x_hi as u64 * EPSILON;
let t2 = unsafe { add_no_canonicalize_trashing_input(x_lo, t1) };
GoldilocksField(t2)
}

/// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the
/// field order and `2^64`.
#[inline]
Expand Down
172 changes: 172 additions & 0 deletions plonky2/src/hash/poseidon_goldilocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
//! `poseidon_constants.sage` script in the `mir-protocol/hash-constants`
//! repository.

use plonky2_field::types::Field;
use unroll::unroll_for_loops;

use crate::field::goldilocks_field::GoldilocksField;
use crate::hash::poseidon::{Poseidon, N_PARTIAL_ROUNDS};

Expand Down Expand Up @@ -211,6 +214,39 @@ impl Poseidon for GoldilocksField {
0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b, ],
];

#[cfg(target_arch="x86_64")]
#[inline(always)]
#[unroll_for_loops]
fn mds_layer(state: &[Self; 12]) -> [Self; 12] {
let mut result = [GoldilocksField::ZERO; 12];

// Using the linearity of the operations we can split the state into a low||high decomposition
// and operate on each with no overflow and then combine/reduce the result to a field element.
let mut state_l = [0u64; 12];
let mut state_h = [0u64; 12];

for r in 0..12 {
let s = state[r].0;
state_h[r] = s >> 32;
state_l[r] = (s as u32) as u64;
}

let state_h = mds_multiply_freq(state_h);
let state_l = mds_multiply_freq(state_l);

for r in 0..12 {
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);

result[r] = GoldilocksField::from_noncanonical_u96((s as u64, (s >> 64) as u32));
}

// Add first element with the only non-zero diagonal matrix coefficient.
let s = Self::MDS_MATRIX_DIAG[0] as u128 * (state[0].0 as u128);
result[0] += GoldilocksField::from_noncanonical_u96((s as u64, (s >> 64) as u32));

result
}

// #[cfg(all(target_arch="x86_64", target_feature="avx2", target_feature="bmi2"))]
// #[inline]
// fn poseidon(input: [Self; 12]) -> [Self; 12] {
Expand Down Expand Up @@ -268,6 +304,142 @@ impl Poseidon for GoldilocksField {
}
}

// MDS layer helper methods
// The following code has been adapted from winterfell/crypto/src/hash/mds/mds_f64_12x12.rs
// located at https://github.com/facebook/winterfell.

const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 32, 16];
const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(2, -1), (-4, 1), (16, 1)];
const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-1, -8, 2];

/// Split 3 x 4 FFT-based MDS vector-multiplication with the Poseidon circulant MDS matrix.
#[inline(always)]
fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;

let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
let (u4, u5, u6) = fft4_real([s1, s4, s7, s10]);
let (u8, u9, u10) = fft4_real([s2, s5, s8, s11]);

// This where the multiplication in frequency domain is done. More precisely, and with
// the appropriate permuations in between, the sequence of
// 3-point FFTs --> multiplication by twiddle factors --> Hadamard multiplication -->
// 3 point iFFTs --> multiplication by (inverse) twiddle factors
// is "squashed" into one step composed of the functions "block1", "block2" and "block3".
// The expressions in the aforementioned functions are the result of explicit computations
// combined with the Karatsuba trick for the multiplication of complex numbers.

let [v0, v4, v8] = block1([u0, u4, u8], MDS_FREQ_BLOCK_ONE);
let [v1, v5, v9] = block2([u1, u5, u9], MDS_FREQ_BLOCK_TWO);
let [v2, v6, v10] = block3([u2, u6, u10], MDS_FREQ_BLOCK_THREE);
// The 4th block is not computed as it is similar to the 2nd one, up to complex conjugation.

let [s0, s3, s6, s9] = ifft4_real_unreduced((v0, v1, v2));
let [s1, s4, s7, s10] = ifft4_real_unreduced((v4, v5, v6));
let [s2, s5, s8, s11] = ifft4_real_unreduced((v8, v9, v10));

[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
}

/// Real 2-FFT over u64 integers.
#[inline(always)]
fn fft2_real(x: [u64; 2]) -> [i64; 2] {
[(x[0] as i64 + x[1] as i64), (x[0] as i64 - x[1] as i64)]
}

/// Real 2-iFFT over u64 integers.
/// Division by two to complete the inverse FFT is not performed here.
#[inline(always)]
fn ifft2_real_unreduced(y: [i64; 2]) -> [u64; 2] {
[(y[0] + y[1]) as u64, (y[0] - y[1]) as u64]
}

/// Real 4-FFT over u64 integers.
#[inline(always)]
fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) {
let [z0, z2] = fft2_real([x[0], x[2]]);
let [z1, z3] = fft2_real([x[1], x[3]]);
let y0 = z0 + z1;
let y1 = (z2, -z3);
let y2 = z0 - z1;
(y0, y1, y2)
}

/// Real 4-iFFT over u64 integers.
/// Division by four to complete the inverse FFT is not performed here.
#[inline(always)]
fn ifft4_real_unreduced(y: (i64, (i64, i64), i64)) -> [u64; 4] {
let z0 = y.0 + y.2;
let z1 = y.0 - y.2;
let z2 = y.1 .0;
let z3 = -y.1 .1;

let [x0, x2] = ifft2_real_unreduced([z0, z2]);
let [x1, x3] = ifft2_real_unreduced([z1, z3]);

[x0, x1, x2, x3]
}

#[inline(always)]
fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
let [x0, x1, x2] = x;
let [y0, y1, y2] = y;
let z0 = x0 * y0 + x1 * y2 + x2 * y1;
let z1 = x0 * y1 + x1 * y0 + x2 * y2;
let z2 = x0 * y2 + x1 * y1 + x2 * y0;

[z0, z1, z2]
}

#[inline(always)]
fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] {
let [(x0r, x0i), (x1r, x1i), (x2r, x2i)] = x;
let [(y0r, y0i), (y1r, y1i), (y2r, y2i)] = y;
let x0s = x0r + x0i;
let x1s = x1r + x1i;
let x2s = x2r + x2i;
let y0s = y0r + y0i;
let y1s = y1r + y1i;
let y2s = y2r + y2i;

// Compute x0​y0 ​− ix1​y2​ − ix2​y1​ using Karatsuba for complex numbers multiplication
let m0 = (x0r * y0r, x0i * y0i);
let m1 = (x1r * y2r, x1i * y2i);
let m2 = (x2r * y1r, x2i * y1i);
let z0r = (m0.0 - m0.1) + (x1s * y2s - m1.0 - m1.1) + (x2s * y1s - m2.0 - m2.1);
let z0i = (x0s * y0s - m0.0 - m0.1) + (-m1.0 + m1.1) + (-m2.0 + m2.1);
let z0 = (z0r, z0i);

// Compute x0​y1​ + x1​y0​ − ix2​y2 using Karatsuba for complex numbers multiplication
let m0 = (x0r * y1r, x0i * y1i);
let m1 = (x1r * y0r, x1i * y0i);
let m2 = (x2r * y2r, x2i * y2i);
let z1r = (m0.0 - m0.1) + (m1.0 - m1.1) + (x2s * y2s - m2.0 - m2.1);
let z1i = (x0s * y1s - m0.0 - m0.1) + (x1s * y0s - m1.0 - m1.1) + (-m2.0 + m2.1);
let z1 = (z1r, z1i);

// Compute x0​y2​ + x1​y1 ​+ x2​y0​ using Karatsuba for complex numbers multiplication
let m0 = (x0r * y2r, x0i * y2i);
let m1 = (x1r * y1r, x1i * y1i);
let m2 = (x2r * y0r, x2i * y0i);
let z2r = (m0.0 - m0.1) + (m1.0 - m1.1) + (m2.0 - m2.1);
let z2i = (x0s * y2s - m0.0 - m0.1) + (x1s * y1s - m1.0 - m1.1) + (x2s * y0s - m2.0 - m2.1);
let z2 = (z2r, z2i);

[z0, z1, z2]
}

#[inline(always)]
fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
let [x0, x1, x2] = x;
let [y0, y1, y2] = y;
let z0 = x0 * y0 - x1 * y2 - x2 * y1;
let z1 = x0 * y1 + x1 * y0 - x2 * y2;
let z2 = x0 * y2 + x1 * y1 + x2 * y0;

[z0, z1, z2]
}

#[cfg(test)]
mod tests {
use crate::field::goldilocks_field::GoldilocksField as F;
Expand Down