From 94db00cd34f55eece45f874cbe432845ddf252c4 Mon Sep 17 00:00:00 2001 From: einar-taiko <126954546+einar-taiko@users.noreply.github.com> Date: Tue, 30 May 2023 04:21:07 +0900 Subject: [PATCH] Resolve Compare FFT implementations #62 (#4) * FFT opt * Implement all suggestions from 2nd review * Default to `parallel` * Fix missed autoformatings --------- Co-authored-by: Brechtpd --- halo2_proofs/Cargo.toml | 1 + halo2_proofs/benches/fft.rs | 17 +- halo2_proofs/src/arithmetic.rs | 166 +++---- halo2_proofs/src/fft.rs | 235 ++++++++++ halo2_proofs/src/fft/baseline.rs | 128 +++++ halo2_proofs/src/fft/parallel.rs | 287 ++++++++++++ halo2_proofs/src/fft/recursive.rs | 467 +++++++++++++++++++ halo2_proofs/src/lib.rs | 1 + halo2_proofs/src/plonk.rs | 97 ++++ halo2_proofs/src/plonk/circuit.rs | 10 + halo2_proofs/src/plonk/evaluation.rs | 10 +- halo2_proofs/src/plonk/permutation/prover.rs | 2 +- halo2_proofs/src/plonk/prover.rs | 51 +- halo2_proofs/src/poly/domain.rs | 68 ++- 14 files changed, 1411 insertions(+), 129 deletions(-) create mode 100644 halo2_proofs/src/fft.rs create mode 100644 halo2_proofs/src/fft/baseline.rs create mode 100644 halo2_proofs/src/fft/parallel.rs create mode 100644 halo2_proofs/src/fft/recursive.rs diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index 113e3e8858..7f8d8652d8 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -58,6 +58,7 @@ tracing = "0.1" blake2b_simd = "1" sha3 = "0.9.1" rand_chacha = "0.3" +ark-std = { version = "0.3", features = ["print-trace"] } # Developer tooling dependencies plotters = { version = "0.3.0", optional = true } diff --git a/halo2_proofs/benches/fft.rs b/halo2_proofs/benches/fft.rs index 0de72a0380..89217e8d8b 100644 --- a/halo2_proofs/benches/fft.rs +++ b/halo2_proofs/benches/fft.rs @@ -1,22 +1,27 @@ #[macro_use] extern crate criterion; -use crate::arithmetic::best_fft; +use halo2_proofs::{arithmetic::best_fft, poly::EvaluationDomain}; use group::ff::Field; -use halo2_proofs::*; -use halo2curves::pasta::Fp; +use halo2curves::bn256::Fr as Scalar; use criterion::{BenchmarkId, Criterion}; use rand_core::OsRng; fn criterion_benchmark(c: &mut Criterion) { + let j = 5; let mut group = c.benchmark_group("fft"); for k in 3..19 { + let domain = EvaluationDomain::new(j,k); + let omega = domain.get_omega(); + let l = 1<>(); - let omega = Fp::random(OsRng); // would be weird if this mattered + let mut a = (0..(1 << k)).map(|_| Scalar::random(OsRng)).collect::>(); + b.iter(|| { - best_fft(&mut a, omega, k as u32); + best_fft(&mut a, omega, k as u32, data, false); }); }); } diff --git a/halo2_proofs/src/arithmetic.rs b/halo2_proofs/src/arithmetic.rs index a53b541b62..04abc20fb0 100644 --- a/halo2_proofs/src/arithmetic.rs +++ b/halo2_proofs/src/arithmetic.rs @@ -10,6 +10,16 @@ use group::{ pub use halo2curves::{CurveAffine, CurveExt}; +use crate::{ + fft::{ + self, parallel, + recursive::{self, FFTData}, + }, + plonk::{get_duration, get_time, log_info}, + poly::EvaluationDomain, +}; +use std::{env, mem}; + /// This represents an element of a group with basic operations that can be /// performed. This allows an FFT implementation (for example) to operate /// generically over either a field or elliptic curve group. @@ -25,6 +35,9 @@ where { } +/// TEMP +pub static mut MULTIEXP_TOTAL_TIME: usize = 0; + fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); @@ -147,8 +160,11 @@ pub fn small_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::C pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); + log_info(format!("msm: {}", coeffs.len())); + + let start = get_time(); let num_threads = multicore::current_num_threads(); - if coeffs.len() > num_threads { + let res = if coeffs.len() > num_threads { let chunk = coeffs.len() / num_threads; let num_chunks = coeffs.chunks(chunk).len(); let mut results = vec![C::Curve::identity(); num_chunks]; @@ -170,134 +186,48 @@ pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu let mut acc = C::Curve::identity(); multiexp_serial(coeffs, bases, &mut acc); acc - } -} - -/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size -/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative -/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when -/// interpreted as the coefficients of a polynomial of degree $n - 1$, is -/// transformed into the evaluations of this polynomial at each of the $n$ -/// distinct powers of $\omega$. This transformation is invertible by providing -/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element -/// by $n$. -/// -/// This will use multithreading if beneficial. -pub fn best_fft>(a: &mut [G], omega: Scalar, log_n: u32) { - fn bitreverse(mut n: usize, l: usize) -> usize { - let mut r = 0; - for _ in 0..l { - r = (r << 1) | (n & 1); - n >>= 1; - } - r - } - - let threads = multicore::current_num_threads(); - let log_threads = log2_floor(threads); - let n = a.len() as usize; - assert_eq!(n, 1 << log_n); + }; - for k in 0..n { - let rk = bitreverse(k, log_n as usize); - if k < rk { - a.swap(rk, k); - } + let duration = get_duration(start); + #[allow(unsafe_code)] + unsafe { + crate::arithmetic::MULTIEXP_TOTAL_TIME += duration; } - // precompute twiddle factors - let twiddles: Vec<_> = (0..(n / 2) as usize) - .scan(Scalar::ONE, |w, _| { - let tw = *w; - *w *= ω - Some(tw) - }) - .collect(); - - if log_n <= log_threads { - let mut chunk = 2_usize; - let mut twiddle_chunk = (n / 2) as usize; - for _ in 0..log_n { - a.chunks_mut(chunk).for_each(|coeffs| { - let (left, right) = coeffs.split_at_mut(chunk / 2); - - // case when twiddle factor is one - let (a, left) = left.split_at_mut(1); - let (b, right) = right.split_at_mut(1); - let t = b[0]; - b[0] = a[0]; - a[0] += &t; - b[0] -= &t; - - left.iter_mut() - .zip(right.iter_mut()) - .enumerate() - .for_each(|(i, (a, b))| { - let mut t = *b; - t *= &twiddles[(i + 1) * twiddle_chunk]; - *b = *a; - *a += &t; - *b -= &t; - }); - }); - chunk *= 2; - twiddle_chunk /= 2; - } - } else { - recursive_butterfly_arithmetic(a, n, 1, &twiddles) - } + res } -/// This perform recursive butterfly arithmetic -pub fn recursive_butterfly_arithmetic>( +/// Dispatcher +pub fn best_fft>( a: &mut [G], - n: usize, - twiddle_chunk: usize, - twiddles: &[Scalar], + omega: Scalar, + log_n: u32, + data: &FFTData, + inverse: bool, ) { - if n == 2 { - let t = a[1]; - a[1] = a[0]; - a[0] += &t; - a[1] -= &t; - } else { - let (left, right) = a.split_at_mut(n / 2); - rayon::join( - || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles), - || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles), - ); - - // case when twiddle factor is one - let (a, left) = left.split_at_mut(1); - let (b, right) = right.split_at_mut(1); - let t = b[0]; - b[0] = a[0]; - a[0] += &t; - b[0] -= &t; - - left.iter_mut() - .zip(right.iter_mut()) - .enumerate() - .for_each(|(i, (a, b))| { - let mut t = *b; - t *= &twiddles[(i + 1) * twiddle_chunk]; - *b = *a; - *a += &t; - *b -= &t; - }); - } + fft::fft(a, omega, log_n, data, inverse); } /// Convert coefficient bases group elements to lagrange basis by inverse FFT. pub fn g_to_lagrange(g_projective: Vec, k: u32) -> Vec { let n_inv = C::Scalar::TWO_INV.pow_vartime(&[k as u64, 0, 0, 0]); + let omega = C::Scalar::ROOT_OF_UNITY; let mut omega_inv = C::Scalar::ROOT_OF_UNITY_INV; for _ in k..C::Scalar::S { omega_inv = omega_inv.square(); } let mut g_lagrange_projective = g_projective; - best_fft(&mut g_lagrange_projective, omega_inv, k); + let n = g_lagrange_projective.len(); + let fft_data = FFTData::new(n, omega, omega_inv); + + best_fft( + &mut g_lagrange_projective, + omega_inv, + k, + &fft_data, + false, + ); parallelize(&mut g_lagrange_projective, |g, _| { for g in g.iter_mut() { *g *= n_inv; @@ -402,7 +332,8 @@ pub fn parallelize(v: &mu }); } -fn log2_floor(num: usize) -> u32 { +/// Compute the binary logarithm floored. +pub fn log2_floor(num: usize) -> u32 { assert!(num > 0); let mut pow = 0; @@ -496,7 +427,18 @@ pub(crate) fn powers(base: F) -> impl Iterator { std::iter::successors(Some(F::ONE), move |power| Some(base * power)) } +/// Reverse `l` LSBs of bitvector `n` +pub fn bitreverse(mut n: usize, l: usize) -> usize { + let mut r = 0; + for _ in 0..l { + r = (r << 1) | (n & 1); + n >>= 1; + } + r +} + #[cfg(test)] +use crate::plonk::{start_measure, stop_measure}; use rand_core::OsRng; #[cfg(test)] diff --git a/halo2_proofs/src/fft.rs b/halo2_proofs/src/fft.rs new file mode 100644 index 0000000000..d4e2ef3ca9 --- /dev/null +++ b/halo2_proofs/src/fft.rs @@ -0,0 +1,235 @@ +//! This is a module for dispatching between different FFT implementations at runtime based on environment variable `FFT`. + +use std::env::var; + +use ff::Field; + +use self::recursive::FFTData; +use crate::{arithmetic::FftGroup, plonk::log_info}; + +pub mod baseline; +pub mod parallel; +pub mod recursive; + +/// Runtime dispatcher to concrete FFT implementation +pub fn fft>( + a: &mut [G], + omega: Scalar, + log_n: u32, + data: &FFTData, + inverse: bool, +) { + match var("FFT") { + Err(_) => { + // No `FFT=` environment variable specified. + log_info("=== Parallel FFT ===".to_string()); + parallel::fft(a, omega, log_n, data, inverse) + } + Ok(fft_impl) if fft_impl == "baseline"=> { + log_info("=== Baseline FFT ===".to_string()); + baseline::fft(a, omega, log_n, data, inverse) + } + Ok(fft_impl) if fft_impl == "recursive" => { + log_info("=== Recusive FFT ===".to_string()); + recursive::fft(a, omega, log_n, data, inverse) + } + Ok(fft_impl) if fft_impl == "parallel" => { + log_info("=== Parallel FFT ===".to_string()); + parallel::fft(a, omega, log_n, data, inverse) + } + _ => { + panic!("Please either specify environment variable `FFT={{baseline,recursive,parallel}}` or remove it all together.") + } + } +} + +#[cfg(test)] +mod tests { + use std::{time::Instant, env::var}; + + use ff::Field; + use halo2curves::bn256::Fr as Scalar; + use rand_core::OsRng; + + use crate::{ + fft::{self, recursive::FFTData}, + multicore, + arithmetic::{eval_polynomial, lagrange_interpolate, best_fft}, + plonk::{start_measure, log_info, stop_measure}, + poly::EvaluationDomain, + }; + + /// Read Environment Variable `DEGREE` + fn get_degree() -> usize { + var("DEGREE") + .unwrap_or_else(|_| "22".to_string()) + .parse() + .expect("Cannot parse DEGREE env var as usize") + } + + + #[test] + fn test_fft_parallel() { + let max_log_n = 22; + let min_log_n = 8; + let a = (0..(1 << max_log_n)) + .into_iter() + .map(|i| Scalar::from(i as u64)) + .collect::>(); + + log_info("\n---------- test_fft_parallel ---------".to_owned()); + for log_n in min_log_n..=max_log_n { + let domain = EvaluationDomain::::new(1, log_n); + let mut a0 = a[0..(1 << log_n)].to_vec(); + let mut a1 = a0.clone(); + + // FFTData is not used in `baseline` and `parallel` so default values suffices. + let d = FFTData::default(); + let f = false; + + // warm up & correct test + fft::baseline::fft(&mut a0, domain.get_omega(), log_n, &d, f); + fft::parallel::fft(&mut a1, domain.get_omega(), log_n, &d, f); + assert_eq!(a0, a1); + + let ori_time = Instant::now(); + fft::baseline::fft(&mut a0, domain.get_omega(), log_n, &d, f); + let ori_time = ori_time.elapsed(); + let ori_micros = f64::from(ori_time.as_micros() as u32); + + let opt_time = Instant::now(); + fft::parallel::fft(&mut a1, domain.get_omega(), log_n, &d, f); + let opt_time = opt_time.elapsed(); + let opt_micros = f64::from(opt_time.as_micros() as u32); + + log_info(format!( + " [log_n = {}] orig::fft time: {:?}, scroll::fft time: {:?}, speedup: {}", + log_n, + ori_time, + opt_time, + ori_micros / opt_micros + )); + } + } + + #[test] + fn test_fft_recursive() { + log_info("\n---------- test_fft_recursive ---------".to_owned()); + + let k = get_degree() as u32; + + let domain = EvaluationDomain::::new(1, k); + let n = domain.get_n() as usize; + + let input = vec![Scalar::random(OsRng); n]; + + let num_threads = multicore::current_num_threads(); + + let mut a = input.clone(); + let l_a= a.len(); + let start = start_measure(format!("best fft {} ({})", a.len(), num_threads), false); + fft::baseline::fft(&mut a, domain.get_omega(), k, domain.get_fft_data(l_a), false); + stop_measure(start); + + let mut b = input; + let l_b= b.len(); + let start = start_measure( + format!("recursive fft {} ({})", a.len(), num_threads), + false, + ); + fft::recursive::fft(&mut b, domain.get_omega(), k, domain.get_fft_data(l_b), false); + stop_measure(start); + + for i in 0..n { + //log_info(format!("{}: {} {}", i, a[i], b[i])); + assert_eq!(a[i], b[i]); + } + } + + #[test] + fn test_fft_all() { + log_info("\n---------- test_fft_all ---------".to_owned()); + + let k = get_degree() as u32; + + let domain = EvaluationDomain::::new(1, k); + let n = domain.get_n() as usize; + + let input = vec![Scalar::random(OsRng); n]; + + let num_threads = multicore::current_num_threads(); + + let mut data_baseline = input.clone(); + let l_baseline = data_baseline.len(); + let start = start_measure( + format!("baseline fft {} ({})", data_baseline.len(), num_threads), + false, + ); + fft::baseline::fft(&mut data_baseline, domain.get_omega(), k, domain.get_fft_data(l_baseline), false); + stop_measure(start); + + let mut data_parallel = input.clone(); + let l_parallel = data_parallel.len(); + let start = start_measure( + format!("parallel fft {} ({})", data_parallel.len(), num_threads), + false, + ); + fft::parallel::fft(&mut data_parallel, domain.get_omega(), k, domain.get_fft_data(l_parallel), false); + stop_measure(start); + + let mut data_recursive = input; + let l_recursive = data_recursive.len(); + let start = start_measure( + format!("recursive fft {} ({})", data_recursive.len(), num_threads), + false, + ); + fft::recursive::fft( + &mut data_recursive, + domain.get_omega(), + k, + domain.get_fft_data(l_recursive), + false, + ); + stop_measure(start); + + for i in 0..n { + // log_info(format!("{}: {} {}", i, data_baseline[i], data_recursive[i])); + assert_eq!(data_baseline[i], data_recursive[i]); + // log_info(format!("{}: {} {}", i, data_baseline[i], data_parallel[i])); + assert_eq!(data_baseline[i], data_parallel[i]); + } + } + + #[test] + fn test_fft_single() { + log_info("\n---------- test_fft_single ---------".to_owned()); + + let k = get_degree() as u32; + + let domain = EvaluationDomain::new(1, k); + let n = domain.get_n() as usize; + + let mut input = vec![Scalar::random(OsRng); n]; + let l = input.len(); + + let num_threads = multicore::current_num_threads(); + + let start = start_measure(format!("fft {} ({})", input.len(), num_threads), false); + fft::fft(&mut input, domain.get_omega(), k, domain.get_fft_data(l), false); + stop_measure(start); + } + + #[test] + fn test_mem_leak() { + let j = 1; + let k = 3; + let domain = EvaluationDomain::new(j,k); + let omega = domain.get_omega(); + let l = 1<>(); + + best_fft(&mut a, omega, k as u32, data, false); +} + +} diff --git a/halo2_proofs/src/fft/baseline.rs b/halo2_proofs/src/fft/baseline.rs new file mode 100644 index 0000000000..4962a3d326 --- /dev/null +++ b/halo2_proofs/src/fft/baseline.rs @@ -0,0 +1,128 @@ +//! This contains the baseline FFT implementation + +use ff::Field; + +use super::recursive::FFTData; +use crate::{ + arithmetic::{self, log2_floor, FftGroup}, + multicore, +}; + +/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size +/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative +/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when +/// interpreted as the coefficients of a polynomial of degree $n - 1$, is +/// transformed into the evaluations of this polynomial at each of the $n$ +/// distinct powers of $\omega$. This transformation is invertible by providing +/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element +/// by $n$. +/// +/// This will use multithreading if beneficial. +fn best_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + let threads = multicore::current_num_threads(); + let log_threads = log2_floor(threads); + let n = a.len() as usize; + assert_eq!(n, 1 << log_n); + + for k in 0..n { + let rk = arithmetic::bitreverse(k, log_n as usize); + if k < rk { + a.swap(rk, k); + } + } + + //let start = start_measure(format!("twiddles {} ({})", a.len(), threads), false); + // precompute twiddle factors + let twiddles: Vec<_> = (0..(n / 2) as usize) + .scan(Scalar::ONE, |w, _| { + let tw = *w; + *w *= ω + Some(tw) + }) + .collect(); + //stop_measure(start); + + if log_n <= log_threads { + let mut chunk = 2_usize; + let mut twiddle_chunk = (n / 2) as usize; + for _ in 0..log_n { + a.chunks_mut(chunk).for_each(|coeffs| { + let (left, right) = coeffs.split_at_mut(chunk / 2); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + }); + chunk *= 2; + twiddle_chunk /= 2; + } + } else { + recursive_butterfly_arithmetic(a, n, 1, &twiddles) + } +} + +/// This perform recursive butterfly arithmetic +fn recursive_butterfly_arithmetic>( + a: &mut [G], + n: usize, + twiddle_chunk: usize, + twiddles: &[Scalar], +) { + if n == 2 { + let t = a[1]; + a[1] = a[0]; + a[0] += &t; + a[1] -= &t; + } else { + let (left, right) = a.split_at_mut(n / 2); + rayon::join( + || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles), + || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles), + ); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + } +} + +/// Generic adaptor +pub fn fft>( + a: &mut [G], + omega: Scalar, + log_n: u32, + _data: &FFTData, + _inverse: bool, +) { + best_fft(a, omega, log_n) +} diff --git a/halo2_proofs/src/fft/parallel.rs b/halo2_proofs/src/fft/parallel.rs new file mode 100644 index 0000000000..02b42cabad --- /dev/null +++ b/halo2_proofs/src/fft/parallel.rs @@ -0,0 +1,287 @@ +//! This module provides common utilities, traits and structures for group, +//! field and polynomial arithmetic. + +use crate::arithmetic::{self, bitreverse, log2_floor, FftGroup}; + +use crate::multicore; +pub use ff::Field; +use group::{ + ff::{BatchInvert, PrimeField}, + Curve, Group as _, GroupOpsOwned, ScalarMulOwned, +}; +pub use halo2curves::{CurveAffine, CurveExt}; +use std::time::Instant; + +use super::recursive::FFTData; + +/// A constant +pub const SPARSE_TWIDDLE_DEGREE: u32 = 10; + +/// Dispatcher +fn best_fft_opt>(a: &mut [G], omega: Scalar, log_n: u32) { + let threads = multicore::current_num_threads(); + let log_split = log2_floor(threads) as usize; + let n = a.len() as usize; + let sub_n = n >> log_split; + let split_m = 1 << log_split; + + if sub_n >= split_m { + parallel_fft(a, omega, log_n); + } else { + serial_fft(a, omega, log_n); + } +} + +fn serial_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + let n = a.len() as u32; + assert_eq!(n, 1 << log_n); + + for k in 0..n as usize { + let rk = arithmetic::bitreverse(k, log_n as usize); + if k < rk { + a.swap(rk as usize, k as usize); + } + } + + let mut m = 1; + for _ in 0..log_n { + let w_m: Scalar = omega.pow_vartime(&[u64::from(n / (2 * m)), 0, 0, 0]); + + let mut k = 0; + while k < n { + let mut w = Scalar::ONE; + for j in 0..m { + let mut t = a[(k + j + m) as usize]; + t *= &w; + a[(k + j + m) as usize] = a[(k + j) as usize]; + a[(k + j + m) as usize] -= &t; + a[(k + j) as usize] += &t; + w *= &w_m; + } + + k += 2 * m; + } + + m *= 2; + } +} + +fn serial_split_fft>( + a: &mut [G], + twiddle_lut: &[Scalar], + twiddle_scale: usize, + log_n: u32, +) { + let n = a.len() as u32; + assert_eq!(n, 1 << log_n); + + let mut m = 1; + for _ in 0..log_n { + let omega_idx = twiddle_scale * n as usize / (2 * m as usize); // 1/2, 1/4, 1/8, ... + let low_idx = omega_idx % (1 << SPARSE_TWIDDLE_DEGREE); + let high_idx = omega_idx >> SPARSE_TWIDDLE_DEGREE; + let mut w_m = twiddle_lut[low_idx]; + if high_idx > 0 { + w_m = w_m * twiddle_lut[(1 << SPARSE_TWIDDLE_DEGREE) + high_idx]; + } + + let mut k = 0; + while k < n { + let mut w = Scalar::ONE; + for j in 0..m { + let mut t = a[(k + j + m) as usize]; + t *= &w; + a[(k + j + m) as usize] = a[(k + j) as usize]; + a[(k + j + m) as usize] -= &t; + a[(k + j) as usize] += &t; + w *= &w_m; + } + + k += 2 * m; + } + + m *= 2; + } +} + +fn split_radix_fft>( + tmp: &mut [G], + a: &[G], + twiddle_lut: &[Scalar], + n: usize, + sub_fft_offset: usize, + log_split: usize, +) { + let split_m = 1 << log_split; + let sub_n = n >> log_split; + + // we use out-place bitreverse here, split_m <= num_threads, so the buffer spase is small + // and it's is good for data locality + let tmp_filler_val = tmp[0]; + let mut t1 = vec![tmp_filler_val; split_m]; + for i in 0..split_m { + t1[arithmetic::bitreverse(i, log_split)] = a[(i * sub_n + sub_fft_offset)]; + } + serial_split_fft(&mut t1, twiddle_lut, sub_n, log_split as u32); + + let sparse_degree = SPARSE_TWIDDLE_DEGREE; + let omega_idx = sub_fft_offset as usize; + let low_idx = omega_idx % (1 << sparse_degree); + let high_idx = omega_idx >> sparse_degree; + let mut omega = twiddle_lut[low_idx]; + if high_idx > 0 { + omega = omega * twiddle_lut[(1 << sparse_degree) + high_idx]; + } + let mut w_m = Scalar::ONE; + for i in 0..split_m { + t1[i] *= &w_m; + tmp[i] = t1[i]; + w_m = w_m * omega; + } +} + +/// Precalculate twiddles factors +fn generate_twiddle_lookup_table( + omega: F, + log_n: u32, + sparse_degree: u32, + with_last_level: bool, +) -> Vec { + let without_last_level = !with_last_level; + let is_lut_len_large = sparse_degree > log_n; + + // dense + if is_lut_len_large { + let mut twiddle_lut = vec![F::ZERO; (1 << log_n) as usize]; + parallelize(&mut twiddle_lut, |twiddle_lut, start| { + let mut w_n = omega.pow_vartime(&[start as u64, 0, 0, 0]); + for twiddle_lut in twiddle_lut.iter_mut() { + *twiddle_lut = w_n; + w_n = w_n * omega; + } + }); + return twiddle_lut; + } + + // sparse + let low_degree_lut_len = 1 << sparse_degree; + let high_degree_lut_len = 1 << (log_n - sparse_degree - without_last_level as u32); + let mut twiddle_lut = vec![F::ZERO; (low_degree_lut_len + high_degree_lut_len) as usize]; + parallelize( + &mut twiddle_lut[..low_degree_lut_len], + |twiddle_lut, start| { + let mut w_n = omega.pow_vartime(&[start as u64, 0, 0, 0]); + for twiddle_lut in twiddle_lut.iter_mut() { + *twiddle_lut = w_n; + w_n = w_n * omega; + } + }, + ); + let high_degree_omega = omega.pow_vartime(&[(1 << sparse_degree) as u64, 0, 0, 0]); + parallelize( + &mut twiddle_lut[low_degree_lut_len..], + |twiddle_lut, start| { + let mut w_n = high_degree_omega.pow_vartime(&[start as u64, 0, 0, 0]); + for twiddle_lut in twiddle_lut.iter_mut() { + *twiddle_lut = w_n; + w_n = w_n * high_degree_omega; + } + }, + ); + twiddle_lut +} + +/// The parallel implementation +fn parallel_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + let n = a.len() as usize; + assert_eq!(n, 1 << log_n); + + let log_split = log2_floor(multicore::current_num_threads()) as usize; + let split_m = 1 << log_split; + let sub_n = n >> log_split as usize; + let twiddle_lut = generate_twiddle_lookup_table(omega, log_n, SPARSE_TWIDDLE_DEGREE, true); + + // split fft + let tmp_filler_val = a[0]; + let mut tmp = vec![tmp_filler_val; n]; + multicore::scope(|scope| { + let a = &*a; + let twiddle_lut = &*twiddle_lut; + for (chunk_idx, tmp) in tmp.chunks_mut(sub_n).enumerate() { + scope.spawn(move |_| { + let split_fft_offset = (chunk_idx * sub_n) >> log_split; + for (i, tmp) in tmp.chunks_mut(split_m).enumerate() { + let split_fft_offset = split_fft_offset + i; + split_radix_fft(tmp, a, twiddle_lut, n, split_fft_offset, log_split); + } + }); + } + }); + + // shuffle + parallelize(a, |a, start| { + for (idx, a) in a.iter_mut().enumerate() { + let idx = start + idx; + let i = idx / sub_n; + let j = idx % sub_n; + *a = tmp[j * split_m + i]; + } + }); + + // sub fft + let new_omega = omega.pow_vartime(&[split_m as u64, 0, 0, 0]); + multicore::scope(|scope| { + for a in a.chunks_mut(sub_n) { + scope.spawn(move |_| { + serial_fft(a, new_omega, log_n - log_split as u32); + }); + } + }); + + // copy & unshuffle + let mask = (1 << log_split) - 1; + parallelize(&mut tmp, |tmp, start| { + for (idx, tmp) in tmp.iter_mut().enumerate() { + let idx = start + idx; + *tmp = a[idx]; + } + }); + parallelize(a, |a, start| { + for (idx, a) in a.iter_mut().enumerate() { + let idx = start + idx; + *a = tmp[sub_n * (idx & mask) + (idx >> log_split)]; + } + }); +} + +/// This simple utility function will parallelize an operation that is to be +/// performed over a mutable slice. +fn parallelize(v: &mut [T], f: F) { + let n = v.len(); + let num_threads = multicore::current_num_threads(); + let mut chunk = (n as usize) / num_threads; + if chunk < num_threads { + chunk = n as usize; + } + + multicore::scope(|scope| { + for (chunk_num, v) in v.chunks_mut(chunk).enumerate() { + let f = f.clone(); + scope.spawn(move |_| { + let start = chunk_num * chunk; + f(v, start); + }); + } + }); +} + +/// Generic adaptor +pub fn fft>( + data_in: &mut [G], + omega: Scalar, + log_n: u32, + _data: &FFTData, + _inverse: bool, +) { + best_fft_opt(data_in, omega, log_n) +} diff --git a/halo2_proofs/src/fft/recursive.rs b/halo2_proofs/src/fft/recursive.rs new file mode 100644 index 0000000000..77fb645dee --- /dev/null +++ b/halo2_proofs/src/fft/recursive.rs @@ -0,0 +1,467 @@ +//! This contains the recursive FFT. + +use std::env; + +use crate::{ + arithmetic::{self, log2_floor, parallelize, FftGroup}, + multicore, + plonk::{get_duration, get_time, log_info}, +}; + +pub use ff::Field; +use ff::WithSmallOrderMulGroup; +use group::{ + ff::{BatchInvert, PrimeField}, + Curve, Group, GroupOpsOwned, ScalarMulOwned, +}; + +pub use halo2curves::{CurveAffine, CurveExt}; + +/// FFTStage +#[derive(Clone, Debug)] +pub struct FFTStage { + radix: usize, + length: usize, +} + +/// FFT stages +fn get_stages(size: usize, radixes: Vec) -> Vec { + let mut stages: Vec = vec![]; + + let mut n = size; + + // Use the specified radices + for &radix in &radixes { + n /= radix; + stages.push(FFTStage { radix, length: n }); + } + + // Fill in the rest of the tree if needed + let mut p = 2; + while n > 1 { + while n % p != 0 { + if p == 4 { + p = 2; + } + } + n /= p; + stages.push(FFTStage { + radix: p, + length: n, + }); + } + + /*for i in 0..stages.len() { + log_info(format!("Stage {}: {}, {}", i, stages[i].radix, stages[i].length)); + }*/ + + stages +} + +/// FFTData +#[derive(Clone, Debug)] +pub struct FFTData { + n: usize, + + stages: Vec, + + f_twiddles: Vec>, + inv_twiddles: Vec>, + //scratch: Vec, +} + +impl Default for FFTData { + fn default() -> Self { + Self { + n: Default::default(), + stages: Default::default(), + f_twiddles: Default::default(), + inv_twiddles: Default::default(), + } + } +} + +impl FFTData { + /// Create FFT data + pub fn new(n: usize, omega: F, omega_inv: F) -> Self { + let stages = get_stages(n as usize, vec![]); + let mut f_twiddles = vec![]; + let mut inv_twiddles = vec![]; + let mut scratch = vec![F::ZERO; n]; + + // Generate stage twiddles + for inv in 0..2 { + let inverse = inv == 0; + let o = if inverse { omega_inv } else { omega }; + let stage_twiddles = if inverse { + &mut inv_twiddles + } else { + &mut f_twiddles + }; + + let twiddles = &mut scratch; + + // Twiddles + parallelize(twiddles, |twiddles, start| { + let w_m = o; + let mut w = o.pow_vartime(&[start as u64, 0, 0, 0]); + for value in twiddles.iter_mut() { + *value = w; + w *= w_m; + } + }); + + // Re-order twiddles for cache friendliness + let num_stages = stages.len(); + stage_twiddles.resize(num_stages, vec![]); + for l in 0..num_stages { + let radix = stages[l].radix; + let stage_length = stages[l].length; + + let num_twiddles = stage_length * (radix - 1); + stage_twiddles[l].resize(num_twiddles + 1, F::ZERO); + + // Set j + stage_twiddles[l][num_twiddles] = twiddles[(twiddles.len() * 3) / 4]; + + let stride = n / (stage_length * radix); + let mut tws = vec![0usize; radix - 1]; + for i in 0..stage_length { + for j in 0..radix - 1 { + stage_twiddles[l][i * (radix - 1) + j] = twiddles[tws[j]]; + tws[j] += (j + 1) * stride; + } + } + } + } + + Self { + n, + stages, + f_twiddles, + inv_twiddles, + //scratch, + } + } + + /// Return private field `n` + pub fn get_n(&self) -> usize { self.n } +} + +/// Radix 2 butterfly +fn butterfly_2>( + out: &mut [G], + twiddles: &[Scalar], + stage_length: usize, +) { + let mut out_offset = 0; + let mut out_offset2 = stage_length; + + let t = out[out_offset2]; + out[out_offset2] = out[out_offset] - &t; + out[out_offset] += &t; + out_offset2 += 1; + out_offset += 1; + + for twiddle in twiddles[1..stage_length].iter() { + let t = out[out_offset2] * twiddle; + out[out_offset2] = out[out_offset] - &t; + out[out_offset] += &t; + out_offset2 += 1; + out_offset += 1; + } +} + +/// Radix 2 butterfly +fn butterfly_2_parallel>( + out: &mut [G], + twiddles: &[Scalar], + _stage_length: usize, + num_threads: usize, +) { + let n = out.len(); + let mut chunk = (n as usize) / num_threads; + if chunk < num_threads { + chunk = n as usize; + } + + multicore::scope(|scope| { + let (part_a, part_b) = out.split_at_mut(n / 2); + for (i, (part0, part1)) in part_a + .chunks_mut(chunk) + .zip(part_b.chunks_mut(chunk)) + .enumerate() + { + scope.spawn(move |_| { + let offset = i * chunk; + for k in 0..part0.len() { + let t = part1[k] * &twiddles[offset + k]; + part1[k] = part0[k] - &t; + part0[k] += &t; + } + }); + } + }); +} + +/// Radix 4 butterfly +fn butterfly_4>( + out: &mut [G], + twiddles: &[Scalar], + stage_length: usize, +) { + let j = twiddles[twiddles.len() - 1]; + let mut tw = 0; + + /* Case twiddle == one */ + { + let i0 = 0; + let i1 = stage_length; + let i2 = stage_length * 2; + let i3 = stage_length * 3; + + let z0 = out[i0]; + let z1 = out[i1]; + let z2 = out[i2]; + let z3 = out[i3]; + + let t1 = z0 + &z2; + let t2 = z1 + &z3; + let t3 = z0 - &z2; + let t4j = (z1 - &z3) * &j; + + out[i0] = t1 + &t2; + out[i1] = t3 - &t4j; + out[i2] = t1 - &t2; + out[i3] = t3 + &t4j; + + tw += 3; + } + + for k in 1..stage_length { + let i0 = k; + let i1 = k + stage_length; + let i2 = k + stage_length * 2; + let i3 = k + stage_length * 3; + + let z0 = out[i0]; + let z1 = out[i1] * &twiddles[tw]; + let z2 = out[i2] * &twiddles[tw + 1]; + let z3 = out[i3] * &twiddles[tw + 2]; + + let t1 = z0 + &z2; + let t2 = z1 + &z3; + let t3 = z0 - &z2; + let t4j = (z1 - &z3) * &j; + + out[i0] = t1 + &t2; + out[i1] = t3 - &t4j; + out[i2] = t1 - &t2; + out[i3] = t3 + &t4j; + + tw += 3; + } +} + +/// Radix 4 butterfly +fn butterfly_4_parallel>( + out: &mut [G], + twiddles: &[Scalar], + _stage_length: usize, + num_threads: usize, +) { + let j = twiddles[twiddles.len() - 1]; + + let n = out.len(); + let mut chunk = (n as usize) / num_threads; + if chunk < num_threads { + chunk = n as usize; + } + multicore::scope(|scope| { + //let mut parts: Vec<&mut [F]> = out.chunks_mut(4).collect(); + //out.chunks_mut(4).map(|c| c.chunks_mut(chunk)).fold(predicate) + let (part_a, part_b) = out.split_at_mut(n / 2); + let (part_aa, part_ab) = part_a.split_at_mut(n / 4); + let (part_ba, part_bb) = part_b.split_at_mut(n / 4); + for (i, (((part0, part1), part2), part3)) in part_aa + .chunks_mut(chunk) + .zip(part_ab.chunks_mut(chunk)) + .zip(part_ba.chunks_mut(chunk)) + .zip(part_bb.chunks_mut(chunk)) + .enumerate() + { + scope.spawn(move |_| { + let offset = i * chunk; + let mut tw = offset * 3; + for k in 0..part1.len() { + let z0 = part0[k]; + let z1 = part1[k] * &twiddles[tw]; + let z2 = part2[k] * &twiddles[tw + 1]; + let z3 = part3[k] * &twiddles[tw + 2]; + + let t1 = z0 + &z2; + let t2 = z1 + &z3; + let t3 = z0 - &z2; + let t4j = (z1 - &z3) * &j; + + part0[k] = t1 + &t2; + part1[k] = t3 - &t4j; + part2[k] = t1 - &t2; + part3[k] = t3 + &t4j; + + tw += 3; + } + }); + } + }); +} + +/// Inner recursion +fn recursive_fft_inner>( + data_in: &[G], + data_out: &mut [G], + twiddles: &Vec>, + stages: &Vec, + in_offset: usize, + stride: usize, + level: usize, + num_threads: usize, +) { + let radix = stages[level].radix; + let stage_length = stages[level].length; + + if num_threads > 1 { + if stage_length == 1 { + for i in 0..radix { + data_out[i] = data_in[in_offset + i * stride]; + } + } else { + let num_threads_recursive = if num_threads >= radix { + radix + } else { + num_threads + }; + parallelize_count(data_out, num_threads_recursive, |data_out, i| { + let num_threads_in_recursion = if num_threads < radix { + 1 + } else { + (num_threads + i) / radix + }; + recursive_fft_inner( + data_in, + data_out, + twiddles, + stages, + in_offset + i * stride, + stride * radix, + level + 1, + num_threads_in_recursion, + ) + }); + } + match radix { + 2 => butterfly_2_parallel(data_out, &twiddles[level], stage_length, num_threads), + 4 => butterfly_4_parallel(data_out, &twiddles[level], stage_length, num_threads), + _ => unimplemented!("radix unsupported"), + } + } else { + if stage_length == 1 { + for i in 0..radix { + data_out[i] = data_in[in_offset + i * stride]; + } + } else { + for i in 0..radix { + recursive_fft_inner( + data_in, + &mut data_out[i * stage_length..(i + 1) * stage_length], + twiddles, + stages, + in_offset + i * stride, + stride * radix, + level + 1, + num_threads, + ); + } + } + match radix { + 2 => butterfly_2(data_out, &twiddles[level], stage_length), + 4 => butterfly_4(data_out, &twiddles[level], stage_length), + _ => unimplemented!("radix unsupported"), + } + } +} + +/// Todo: Brechts impl starts here +fn recursive_fft>( + data: &FFTData, + data_in: &mut Vec, + inverse: bool, +) { + let num_threads = multicore::current_num_threads(); + //let start = start_measure(format!("recursive fft {} ({})", data_in.len(), num_threads), false); + + // TODO: reuse scratch buffer between FFTs + //let start_mem = start_measure(format!("alloc"), false); + let filler = data_in[0]; + let mut scratch = vec![filler; data_in.len()]; + //stop_measure(start_mem); + + recursive_fft_inner( + data_in, + &mut /*data.*/scratch, + if inverse { + &data.inv_twiddles + } else { + &data.f_twiddles + }, + &data.stages, + 0, + 1, + 0, + num_threads, + ); + //let duration = stop_measure(start); + + //let start = start_measure(format!("copy"), false); + // Will simply swap the vector's buffer, no data is actually copied + std::mem::swap(data_in, &mut /*data.*/scratch); + //stop_measure(start); +} + +/// This simple utility function will parallelize an operation that is to be +/// performed over a mutable slice. +fn parallelize_count( + v: &mut [T], + num_threads: usize, + f: F, +) { + let n = v.len(); + let mut chunk = (n as usize) / num_threads; + if chunk < num_threads { + chunk = n as usize; + } + + multicore::scope(|scope| { + for (chunk_num, v) in v.chunks_mut(chunk).enumerate() { + let f = f.clone(); + scope.spawn(move |_| { + f(v, chunk_num); + }); + } + }); +} + +/// Generic adaptor +pub fn fft>( + data_in: &mut [G], + _omega: Scalar, + _log_n: u32, + data: &FFTData, + inverse: bool, +) { + let orig_len = data_in.len(); + let mut data_in_vec = data_in.to_vec(); + recursive_fft(data, &mut data_in_vec, inverse); + data_in.copy_from_slice(&data_in_vec); + assert_eq!(orig_len, data_in.len()); +} diff --git a/halo2_proofs/src/lib.rs b/halo2_proofs/src/lib.rs index 52676ddb19..8a686e5687 100644 --- a/halo2_proofs/src/lib.rs +++ b/halo2_proofs/src/lib.rs @@ -27,6 +27,7 @@ pub mod arithmetic; pub mod circuit; +pub mod fft; pub use halo2curves; mod multicore; pub mod plonk; diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index 5485d39fb4..a32f2342f2 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -19,6 +19,7 @@ use crate::poly::{ }; use crate::transcript::{ChallengeScalar, EncodedChallenge, Transcript}; use crate::SerdeFormat; +use ark_std::perf_trace::{AtomicUsize, Ordering}; mod assigned; mod circuit; @@ -40,7 +41,103 @@ pub use prover::*; pub use verifier::*; use evaluation::Evaluator; +use std::env::var; use std::io; +use std::time::Instant; + +/// Temp +#[allow(missing_debug_implementations)] +pub struct MeasurementInfo { + /// Temp + pub measure: bool, + /// Temp + pub time: Instant, + /// Message + pub message: String, + /// Indent + pub indent: usize, +} + +/// TEMP +pub static NUM_INDENT: AtomicUsize = AtomicUsize::new(0); + +/// Temp +pub fn get_time() -> Instant { + Instant::now() +} + +/// Temp +pub fn get_duration(start: Instant) -> usize { + let final_time = Instant::now() - start; + let secs = final_time.as_secs() as usize; + let millis = final_time.subsec_millis() as usize; + let micros = (final_time.subsec_micros() % 1000) as usize; + secs * 1000000 + millis * 1000 + micros +} + +/// Temp +pub fn log_measurement(indent: Option, msg: String, duration: usize) { + let indent = indent.unwrap_or(0); + println!( + "{}{} ........ {}s", + "*".repeat(indent), + msg, + (duration as f32) / 1000000.0 + ); +} + +/// Temp +pub fn start_measure>(msg: S, always: bool) -> MeasurementInfo { + let measure: u32 = var("MEASURE") + .unwrap_or_else(|_| "0".to_string()) + .parse() + .expect("Cannot parse MEASURE env var as u32"); + + let indent = NUM_INDENT.fetch_add(1, Ordering::Relaxed); + + if always || measure == 1 + /* || msg.starts_with("compressed_cosets")*/ + { + MeasurementInfo { + measure: true, + time: get_time(), + message: msg.as_ref().to_string(), + indent, + } + } else { + MeasurementInfo { + measure: false, + time: get_time(), + message: "".to_string(), + indent, + } + } +} + +/// Temp +pub fn stop_measure(info: MeasurementInfo) -> usize { + NUM_INDENT.fetch_sub(1, Ordering::Relaxed); + let duration = get_duration(info.time); + if info.measure { + log_measurement(Some(info.indent), info.message, duration); + } + duration +} + +/// Get env variable +pub fn env_value(key: &str, default: usize) -> usize { + match var(key) { + Ok(val) => val.parse().unwrap(), + Err(_) => default, + } +} + +/// Temp +pub fn log_info(msg: String) { + if env_value("INFO", 0) != 0 { + println!("{}", msg); + } +} /// This is a verifying key which allows for the verification of proofs for a /// particular circuit. diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index 97e28fc8a1..ae3f1eb7ae 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -10,6 +10,7 @@ use ff::Field; use sealed::SealedPhase; use std::cmp::Ordering; use std::collections::HashMap; +use std::env::var; use std::fmt::{Debug, Formatter}; use std::{ convert::TryFrom, @@ -2188,6 +2189,15 @@ impl ConstraintSystem { .unwrap_or(0), ); + fn get_max_degree() -> usize { + var("MAX_DEGREE").map_or(usize::MAX, |max_degree_str| { + max_degree_str + .parse() + .expect("Cannot parse MAX_DEGREE env var as usize") + }) + } + degree = std::cmp::min(degree, get_max_degree()); + std::cmp::max(degree, self.minimum_degree.unwrap_or(1)) } diff --git a/halo2_proofs/src/plonk/evaluation.rs b/halo2_proofs/src/plonk/evaluation.rs index d5fb984024..1c05a261a0 100644 --- a/halo2_proofs/src/plonk/evaluation.rs +++ b/halo2_proofs/src/plonk/evaluation.rs @@ -26,7 +26,7 @@ use std::{ ops::{Index, Mul, MulAssign}, }; -use super::{ConstraintSystem, Expression}; +use super::{start_measure, stop_measure, ConstraintSystem, Expression}; /// Return the index in the polynomial of size `isize` after rotation `rot`. fn get_rotation_idx(idx: usize, rot: i32, rot_scale: i32, isize: i32) -> usize { @@ -303,6 +303,7 @@ impl Evaluator { let p = &pk.vk.cs.permutation; // Calculate the advice and instance cosets + let start = start_measure("cosets", false); let advice: Vec>> = advice_polys .iter() .map(|advice_polys| { @@ -321,6 +322,7 @@ impl Evaluator { .collect() }) .collect(); + stop_measure(start); let mut values = domain.empty_extended(); @@ -333,6 +335,7 @@ impl Evaluator { .zip(permutations.iter()) { // Custom gates + let start = start_measure("custom gates", false); multicore::scope(|scope| { let chunk_size = (size + num_threads - 1) / num_threads; for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() { @@ -360,8 +363,10 @@ impl Evaluator { }); } }); + stop_measure(start); // Permutations + let start = start_measure("permutations", false); let sets = &permutation.sets; if !sets.is_empty() { let blinding_factors = pk.vk.cs.blinding_factors(); @@ -442,8 +447,10 @@ impl Evaluator { } }); } + stop_measure(start); // Lookups + let start = start_measure("lookups", false); for (n, lookup) in lookups.iter().enumerate() { // Polynomials required for this lookup. // Calculated here so these only have to be kept in memory for the short time @@ -517,6 +524,7 @@ impl Evaluator { } }); } + stop_measure(start); } values } diff --git a/halo2_proofs/src/plonk/permutation/prover.rs b/halo2_proofs/src/plonk/permutation/prover.rs index 2dbc2deadc..f837a3686e 100644 --- a/halo2_proofs/src/plonk/permutation/prover.rs +++ b/halo2_proofs/src/plonk/permutation/prover.rs @@ -173,7 +173,7 @@ impl Argument { let z = domain.lagrange_to_coeff(z); let permutation_product_poly = z.clone(); - let permutation_product_coset = domain.coeff_to_extended(z.clone()); + let permutation_product_coset = domain.coeff_to_extended(z); let permutation_product_commitment = permutation_product_commitment_projective.to_affine(); diff --git a/halo2_proofs/src/plonk/prover.rs b/halo2_proofs/src/plonk/prover.rs index 1258f09fd6..626a9180eb 100644 --- a/halo2_proofs/src/plonk/prover.rs +++ b/halo2_proofs/src/plonk/prover.rs @@ -18,6 +18,9 @@ use super::{ lookup, permutation, vanishing, ChallengeBeta, ChallengeGamma, ChallengeTheta, ChallengeX, ChallengeY, Error, Expression, ProvingKey, }; +use crate::arithmetic::MULTIEXP_TOTAL_TIME; +use crate::plonk::{start_measure, stop_measure, log_info}; +use crate::poly::FFT_TOTAL_TIME; use crate::{ arithmetic::{eval_polynomial, CurveAffine}, circuit::Value, @@ -57,6 +60,12 @@ pub fn create_proof< where Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>, { + #[allow(unsafe_code)] + unsafe { + FFT_TOTAL_TIME = 0; + MULTIEXP_TOTAL_TIME = 0; + } + for instance in instances.iter() { if instance.len() != pk.vk.cs.num_instance_columns { return Err(Error::InvalidInstances); @@ -82,6 +91,7 @@ where pub instance_polys: Vec>, } + let start = start_measure("instances", false); let instance: Vec> = instances .iter() .map(|instance| -> Result, Error> { @@ -136,6 +146,7 @@ where }) }) .collect::, _>>()?; + stop_measure(start); #[derive(Clone)] struct AdviceSingle { @@ -287,6 +298,7 @@ where } } + let start = start_measure("advice_values", false); let (advice, challenges) = { let mut advice = vec![ AdviceSingle:: { @@ -315,6 +327,7 @@ where for ((circuit, advice), instances) in circuits.iter().zip(advice.iter_mut()).zip(instances) { + let start = start_measure("witness collection", false); let mut witness = WitnessCollection { k: params.k(), current_phase, @@ -336,7 +349,9 @@ where config.clone(), meta.constants.clone(), )?; + stop_measure(start); + let start = start_measure("batch invert", false); let mut advice_values = batch_invert_assigned::( witness .advice @@ -351,6 +366,7 @@ where }) .collect(), ); + stop_measure(start); // Add blinding factors to advice columns for advice_values in &mut advice_values { @@ -360,6 +376,7 @@ where } // Compute commitments to advice column polynomials + let start = start_measure("commit_lagrange", false); let blinds: Vec<_> = advice_values .iter() .map(|_| Blind(Scheme::Scalar::random(&mut rng))) @@ -387,6 +404,7 @@ where advice.advice_polys[*column_index] = advice_values; advice.advice_blinds[*column_index] = blind; } + stop_measure(start); } for (index, phase) in meta.challenge_phase.iter().enumerate() { @@ -405,10 +423,12 @@ where (advice, challenges) }; + stop_measure(start); // Sample theta challenge for keeping lookup columns linearly independent let theta: ChallengeTheta<_> = transcript.squeeze_challenge_scalar(); + let start = start_measure("lookups", false); let lookups: Vec>> = instance .iter() .zip(advice.iter()) @@ -435,6 +455,7 @@ where .collect() }) .collect::, _>>()?; + stop_measure(start); // Sample beta challenge let beta: ChallengeBeta<_> = transcript.squeeze_challenge_scalar(); @@ -443,6 +464,7 @@ where let gamma: ChallengeGamma<_> = transcript.squeeze_challenge_scalar(); // Commit to permutations. + let start = start_measure("permutation.commit", false); let permutations: Vec> = instance .iter() .zip(advice.iter()) @@ -461,7 +483,9 @@ where ) }) .collect::, _>>()?; + stop_measure(start); + let start = start_measure("lookups.commit_product", false); let lookups: Vec>> = lookups .into_iter() .map(|lookups| -> Result, _> { @@ -472,6 +496,7 @@ where .collect::, _>>() }) .collect::, _>>()?; + stop_measure(start); // Commit to the vanishing argument's random polynomial for blinding h(x_3) let vanishing = vanishing::Argument::commit(params, domain, &mut rng, transcript)?; @@ -480,6 +505,7 @@ where let y: ChallengeY<_> = transcript.squeeze_challenge_scalar(); // Calculate the advice polys + let start = start_measure("advice_polys", false); let advice: Vec> = advice .into_iter() .map( @@ -497,8 +523,10 @@ where }, ) .collect(); + stop_measure(start); // Evaluate the h(X) polynomial + let start = start_measure("evaluate_h", false); let h_poly = pk.ev.evaluate_h( pk, &advice @@ -517,6 +545,7 @@ where &lookups, &permutations, ); + stop_measure(start); // Construct the vanishing argument's h(X) commitments let vanishing = vanishing.construct(params, domain, h_poly, &mut rng, transcript)?; @@ -524,6 +553,7 @@ where let x: ChallengeX<_> = transcript.squeeze_challenge_scalar(); let xn = x.pow(&[params.n() as u64, 0, 0, 0]); + let start = start_measure("instance eval_polynomial", false); if P::QUERY_INSTANCE { // Compute and hash instance evals for each circuit instance for instance in instance.iter() { @@ -545,8 +575,10 @@ where } } } + stop_measure(start); // Compute and hash advice evals for each circuit instance + let start = start_measure("advice eval_polynomial", false); for advice in advice.iter() { // Evaluate polynomials at omega^i x let advice_evals: Vec<_> = meta @@ -565,8 +597,10 @@ where transcript.write_scalar(*eval)?; } } + stop_measure(start); // Compute and hash fixed evals (shared across all circuit instances) + let start = start_measure("fixed eval_polynomial", false); let fixed_evals: Vec<_> = meta .fixed_queries .iter() @@ -574,6 +608,7 @@ where eval_polynomial(&pk.fixed_polys[column.index()], domain.rotate_omega(*x, at)) }) .collect(); + stop_measure(start); // Hash each fixed column evaluation for eval in fixed_evals.iter() { @@ -592,6 +627,7 @@ where .collect::, _>>()?; // Evaluate the lookups, if any, at omega^i x. + let start = start_measure("lookup evaluate", false); let lookups: Vec>> = lookups .into_iter() .map(|lookups| -> Result, _> { @@ -601,6 +637,7 @@ where .collect::, _>>() }) .collect::, _>>()?; + stop_measure(start); let instances = instance .iter() @@ -650,8 +687,18 @@ where // We query the h(X) polynomial at x .chain(vanishing.open(x)); + let start = start_measure("create_proof", false); let prover = P::new(params); - prover + let proof = prover .create_proof(rng, transcript, instances) - .map_err(|_| Error::ConstraintSystemFailure) + .map_err(|_| Error::ConstraintSystemFailure); + stop_measure(start); + + #[allow(unsafe_code)] + unsafe { + log_info(format!("·FFT: {}s", (FFT_TOTAL_TIME as f32) / 1000000.0)); + log_info(format!("·MultiExps: {}s", (MULTIEXP_TOTAL_TIME as f32) / 1000000.0)); + } + + proof } diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index b442dbc8b1..fd72f48473 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -3,7 +3,9 @@ use crate::{ arithmetic::{best_fft, parallelize}, - plonk::Assigned, + fft::recursive::FFTData, + multicore, + plonk::{get_duration, get_time, log_info, start_measure, stop_measure, Assigned}, }; use super::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation}; @@ -13,7 +15,10 @@ use group::{ Group, }; -use std::marker::PhantomData; +use std::{env::var, marker::PhantomData}; + +/// TEMP +pub static mut FFT_TOTAL_TIME: usize = 0; /// This structure contains precomputed constants and other details needed for /// performing operations on an evaluation domain of size $2^k$ and an extended @@ -34,6 +39,11 @@ pub struct EvaluationDomain { extended_ifft_divisor: F, t_evaluations: Vec, barycentric_weight: F, + + /// Recursive stuff + fft_data: FFTData, + /// Recursive stuff for the extension field + pub extended_fft_data: FFTData, } impl> EvaluationDomain { @@ -53,6 +63,7 @@ impl> EvaluationDomain { while (1 << extended_k) < (n * quotient_poly_degree) { extended_k += 1; } + log_info(format!("k: {}, extended_k: {}", k, extended_k)); let mut extended_omega = F::ROOT_OF_UNITY; @@ -141,6 +152,12 @@ impl> EvaluationDomain { extended_ifft_divisor, t_evaluations, barycentric_weight, + fft_data: FFTData::::new(n as usize, omega, omega_inv), + extended_fft_data: FFTData::::new( + (1 << extended_k) as usize, + extended_omega, + extended_omega_inv, + ), } } @@ -227,7 +244,7 @@ impl> EvaluationDomain { assert_eq!(a.values.len(), 1 << self.k); // Perform inverse FFT to obtain the polynomial in coefficient form - Self::ifft(&mut a.values, self.omega_inv, self.k, self.ifft_divisor); + self.ifft(&mut a.values, self.omega_inv, self.k, self.ifft_divisor); Polynomial { values: a.values, @@ -245,7 +262,14 @@ impl> EvaluationDomain { self.distribute_powers_zeta(&mut a.values, true); a.values.resize(self.extended_len(), F::ZERO); - best_fft(&mut a.values, self.extended_omega, self.extended_k); + + best_fft( + &mut a.values, + self.extended_omega, + self.extended_k, + &self.extended_fft_data, + false, + ); Polynomial { values: a.values, @@ -282,7 +306,7 @@ impl> EvaluationDomain { assert_eq!(a.values.len(), self.extended_len()); // Inverse FFT - Self::ifft( + self.ifft( &mut a.values, self.extended_omega_inv, self.extended_k, @@ -350,8 +374,8 @@ impl> EvaluationDomain { }); } - fn ifft(a: &mut [F], omega_inv: F, log_n: u32, divisor: F) { - best_fft(a, omega_inv, log_n); + fn ifft(&self, a: &mut Vec, omega_inv: F, log_n: u32, divisor: F) { + self.fft_inner(a, omega_inv, log_n, true); parallelize(a, |a, _| { for a in a { // Finish iFFT @@ -360,6 +384,18 @@ impl> EvaluationDomain { }); } + fn fft_inner(&self, a: &mut Vec, omega: F, log_n: u32, inverse: bool) { + let start = get_time(); + let fft_data = self.get_fft_data(a.len()); + best_fft(a, omega, log_n, fft_data, inverse); + let duration = get_duration(start); + + #[allow(unsafe_code)] + unsafe { + FFT_TOTAL_TIME += duration; + } + } + /// Get the size of the domain pub fn k(&self) -> u32 { self.k @@ -474,6 +510,18 @@ impl> EvaluationDomain { omega: &self.omega, } } + + /// Get the private field `n` + pub fn get_n(&self) -> u64 { self.n } + + /// Get the private `fft_data` + pub fn get_fft_data(&self, l: usize) -> &FFTData { + if l == self.fft_data.get_n() { + &self.fft_data + } else { + &self.extended_fft_data + } + } } /// Represents the minimal parameters that determine an `EvaluationDomain`. @@ -485,6 +533,12 @@ pub struct PinnedEvaluationDomain<'a, F: Field> { omega: &'a F, } +#[cfg(test)] +use std::{ + env, + time::Instant, +}; + #[test] fn test_rotate() { use rand_core::OsRng;