diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index fb3b5e55d..06f63c004 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -4,6 +4,7 @@ use crate::field::JoltField; use crate::r1cs::builder::CombinedUniformBuilder; use crate::r1cs::jolt_constraints::{construct_jolt_constraints, JoltIn}; use crate::r1cs::spartan::{self, UniformSpartanProof}; +use crate::utils::profiling; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::log2; use common::constants::RAM_START_ADDRESS; @@ -402,8 +403,10 @@ pub trait Jolt, const C: usize, c &mut transcript, ); - drop_in_background_thread(jolt_polynomials); + // drop_in_background_thread(jolt_polynomials); + drop(jolt_polynomials); + profiling::print_current_memory_usage("pre_spartan"); let spartan_proof = UniformSpartanProof::::prove_precommitted( &preprocessing.generators, r1cs_builder, @@ -412,6 +415,7 @@ pub trait Jolt, const C: usize, c &mut transcript, ) .expect("r1cs proof failed"); + profiling::print_current_memory_usage("post_spartan"); let r1cs_proof = R1CSProof { key: spartan_key, proof: spartan_proof, diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 992b7fe69..6eb701531 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -1,6 +1,6 @@ #![allow(clippy::too_many_arguments)] use crate::poly::eq_poly::EqPolynomial; -use crate::utils::thread::unsafe_allocate_zero_vec; +use crate::utils::thread::{drop_in_background_thread, unsafe_allocate_zero_vec}; use crate::utils::{self, compute_dotproduct, compute_dotproduct_low_optimized}; use crate::field::JoltField; @@ -201,11 +201,17 @@ impl DensePolynomial { } } + #[tracing::instrument(skip_all)] pub fn bound_poly_var_bot(&mut self, r: &F) { let n = self.len() / 2; - for i in 0..n { - self.Z[i] = self.Z[2 * i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]); - } + let mut new_z = unsafe_allocate_zero_vec(n); + new_z.par_iter_mut().enumerate().for_each(|(i, z)| { + *z = self.Z[2*i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]); + }); + + let old_Z = std::mem::replace(&mut self.Z, new_z); + drop_in_background_thread(old_Z); + self.num_vars -= 1; self.len = n; } diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index d0c1c409e..e16bd7f7f 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -2,8 +2,7 @@ use crate::{ field::{JoltField, OptimizedMul}, r1cs::key::{SparseConstraints, UniformR1CS}, utils::{ - mul_0_1_optimized, - thread::{drop_in_background_thread, unsafe_allocate_zero_vec}, + math::Math, mul_0_1_optimized, thread::{drop_in_background_thread, unsafe_allocate_zero_vec} }, }; #[allow(unused_imports)] // clippy thinks these aren't needed lol @@ -13,7 +12,7 @@ use std::{collections::HashMap, fmt::Debug}; use super::{ key::{NonUniformR1CS, SparseEqualityItem}, - ops::{ConstraintInput, Term, Variable, LC}, + ops::{ConstraintInput, Term, Variable, LC}, special_polys::SparsePolynomial, }; pub trait R1CSConstraintBuilder { @@ -848,13 +847,127 @@ impl CombinedUniformBuilder { (Az, Bz, Cz) } + /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] + /// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]] + #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan")] + pub fn compute_spartan_Az_Bz_Cz_sparse( + &self, + inputs: &[Vec], + aux: &[Vec], + ) -> (SparsePolynomial, SparsePolynomial, SparsePolynomial) { + assert_eq!(inputs.len(), I::COUNT); + let num_aux = self.uniform_builder.num_aux(); + assert_eq!(aux.len(), num_aux); + assert!(inputs + .iter() + .chain(aux.iter()) + .all(|inner_input| inner_input.len() == self.uniform_repeat)); + + let uniform_constraint_rows = self.uniform_repeat_constraint_rows(); + // TODO(sragss): Allocation can overshoot by up to a factor of 2, Spartan could handle non-pow-2 Az,Bz,Cz + let constraint_rows = self.constraint_rows(); + let (mut Az, mut Bz, mut Cz) = ( + unsafe_allocate_zero_vec(constraint_rows), + unsafe_allocate_zero_vec(constraint_rows), + unsafe_allocate_zero_vec(constraint_rows), + ); + + let batch_inputs = |lc: &LC| batch_inputs(lc, inputs, aux); + + // uniform_constraints: Xz[0..uniform_constraint_rows] + // TODO(sragss): Attempt moving onto key and computing from materialized rows rather than linear combos + let span = tracing::span!(tracing::Level::DEBUG, "compute_constraints"); + let enter = span.enter(); + let az_chunks = Az.par_chunks_mut(self.uniform_repeat); + let bz_chunks = Bz.par_chunks_mut(self.uniform_repeat); + let cz_chunks = Cz.par_chunks_mut(self.uniform_repeat); + + self.uniform_builder + .constraints + .par_iter() + .zip(az_chunks.zip(bz_chunks.zip(cz_chunks))) + .for_each(|(constraint, (az_chunk, (bz_chunk, cz_chunk)))| { + let a_inputs = batch_inputs(&constraint.a); + let b_inputs = batch_inputs(&constraint.b); + let c_inputs = batch_inputs(&constraint.c); + + constraint.a.evaluate_batch_mut(&a_inputs, az_chunk); + constraint.b.evaluate_batch_mut(&b_inputs, bz_chunk); + constraint.c.evaluate_batch_mut(&c_inputs, cz_chunk); + }); + drop(enter); + + // offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1] + // (a - b) * condition == 0 + // For the final step we will not compute the offset terms, and will assume the condition to be set to 0 + let span = tracing::span!(tracing::Level::DEBUG, "offset_eq"); + let _enter = span.enter(); + + let constr = &self.offset_equality_constraint; + let condition_evals = constr + .cond + .1 + .evaluate_batch(&batch_inputs(&constr.cond.1), self.uniform_repeat); + let eq_a_evals = constr + .a + .1 + .evaluate_batch(&batch_inputs(&constr.a.1), self.uniform_repeat); + let eq_b_evals = constr + .b + .1 + .evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat); + + let Az_off = Az[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat] + .par_iter_mut(); + let Bz_off = Bz[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat] + .par_iter_mut(); + + (0..self.uniform_repeat) + .into_par_iter() + .zip(Az_off.zip(Bz_off)) + .for_each(|(step_index, (az, bz))| { + // Write corresponding values, if outside the step range, only include the constant. + let a_step = step_index + if constr.a.0 { 1 } else { 0 }; + let b_step = step_index + if constr.b.0 { 1 } else { 0 }; + let a = eq_a_evals + .get(a_step) + .cloned() + .unwrap_or(constr.a.1.constant_term_field()); + let b = eq_b_evals + .get(b_step) + .cloned() + .unwrap_or(constr.b.1.constant_term_field()); + *az = a - b; + + let condition_step = step_index + if constr.cond.0 { 1 } else { 0 }; + *bz = condition_evals + .get(condition_step) + .cloned() + .unwrap_or(constr.cond.1.constant_term_field()); + }); + drop(_enter); + + #[cfg(test)] + self.assert_valid(&Az, &Bz, &Cz); + + let num_vars = self.constraint_rows().next_power_of_two().log_2(); + let (az_poly, (bz_poly, cz_poly)) = rayon::join( + || SparsePolynomial::from_dense_evals(num_vars, Az), + || rayon::join( + || SparsePolynomial::from_dense_evals(num_vars, Bz), + || SparsePolynomial::from_dense_evals(num_vars, Cz) + ) + ); + + (az_poly, bz_poly, cz_poly) + } + + #[cfg(test)] pub fn assert_valid(&self, az: &[F], bz: &[F], cz: &[F]) { let rows = az.len(); - let expected_rows = self.constraint_rows().next_power_of_two(); - assert_eq!(az.len(), expected_rows); - assert_eq!(bz.len(), expected_rows); - assert_eq!(cz.len(), expected_rows); + assert_eq!(bz.len(), rows); + assert_eq!(cz.len(), rows); for constraint_index in 0..rows { if az[constraint_index] * bz[constraint_index] != cz[constraint_index] { let uniform_constraint_index = constraint_index / self.uniform_repeat; diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 6a19d6a41..256ab9b14 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -4,8 +4,9 @@ use crate::field::JoltField; use crate::poly::commitment::commitment_scheme::BatchType; use crate::poly::commitment::commitment_scheme::CommitmentScheme; use crate::r1cs::key::UniformSpartanKey; -use crate::utils::compute_dotproduct_low_optimized; +use crate::r1cs::special_polys::SegmentedPaddedWitness; use crate::utils::math::Math; +use crate::utils::profiling; use crate::utils::thread::drop_in_background_thread; use crate::utils::transcript::ProofTranscript; @@ -58,87 +59,6 @@ pub enum SpartanError { InvalidPCSProof, } -// TODO: Rather than use these adhoc virtual indexable polys – create a DensePolynomial which takes any impl Index inner -// and can run all the normal DensePolynomial ops. -#[derive(Clone)] -pub struct SegmentedPaddedWitness { - total_len: usize, - segments: Vec>, - segment_len: usize, - zero: F, -} - -impl SegmentedPaddedWitness { - pub fn new(total_len: usize, segments: Vec>) -> Self { - let segment_len = segments[0].len(); - assert!(segment_len.is_power_of_two()); - for segment in &segments { - assert_eq!( - segment.len(), - segment_len, - "All segments must be the same length" - ); - } - SegmentedPaddedWitness { - total_len, - segments, - segment_len, - zero: F::zero(), - } - } - - pub fn len(&self) -> usize { - self.total_len - } - - #[tracing::instrument(skip_all, name = "SegmentedPaddedWitness::evaluate_all")] - pub fn evaluate_all(&self, point: Vec) -> Vec { - let chi = EqPolynomial::evals(&point); - assert!(chi.len() >= self.segment_len); - - let evals = self - .segments - .par_iter() - .map(|segment| compute_dotproduct_low_optimized(&chi[0..self.segment_len], segment)) - .collect(); - drop_in_background_thread(chi); - evals - } - - pub fn into_dense_polys(self) -> Vec> { - self.segments - .into_iter() - .map(|poly| DensePolynomial::new(poly)) - .collect() - } -} - -impl std::ops::Index for SegmentedPaddedWitness { - type Output = F; - - fn index(&self, index: usize) -> &Self::Output { - if index >= self.segments.len() * self.segment_len { - &self.zero - } else if index >= self.total_len { - panic!("index too high"); - } else { - let segment_index = index / self.segment_len; - let inner_index = index % self.segment_len; - &self.segments[segment_index][inner_index] - } - } -} - -impl IndexablePoly for SegmentedPaddedWitness { - fn len(&self) -> usize { - self.total_len - } -} - -pub trait IndexablePoly: std::ops::Index + Sync { - fn len(&self) -> usize; -} - /// A succinct proof of knowledge of a witness to a relaxed R1CS instance /// The proof is produced using Spartan's combination of the sum-check and /// the commitment to a vector viewed as a polynomial commitment @@ -188,38 +108,15 @@ impl> UniformSpartanProof { let tau = (0..num_rounds_x) .map(|_i| transcript.challenge_scalar(b"t")) .collect::>(); + profiling::print_current_memory_usage("pre_poly_tau"); let mut poly_tau = DensePolynomial::new(EqPolynomial::evals(&tau)); + profiling::print_current_memory_usage("post_poly_tau"); let inputs = &segmented_padded_witness.segments[0..I::COUNT]; let aux = &segmented_padded_witness.segments[I::COUNT..]; - let (az, bz, cz) = constraint_builder.compute_spartan_Az_Bz_Cz(inputs, aux); - // TODO: Do not require these padded, Sumcheck should handle sparsity. - assert!(az.len().is_power_of_two()); - assert!(bz.len().is_power_of_two()); - assert!(cz.len().is_power_of_two()); - - let mut poly_Az = DensePolynomial::new(az); - let mut poly_Bz = DensePolynomial::new(bz); - let mut poly_Cz = DensePolynomial::new(cz); - - #[cfg(test)] - { - // Check that Z is a satisfying assignment - for (i, ((az, bz), cz)) in poly_Az - .evals_ref() - .iter() - .zip(poly_Bz.evals_ref()) - .zip(poly_Cz.evals_ref()) - .enumerate() - { - if *az * *bz != *cz { - let padded_segment_len = segmented_padded_witness.segment_len; - let error_segment_index = i / padded_segment_len; - let error_step_index = i % padded_segment_len; - panic!("witness is not a satisfying assignment. Failed on segment {error_segment_index} at step {error_step_index}"); - } - } - } + profiling::print_current_memory_usage("pre_az_bz_cz"); + let (mut az, mut bz, mut cz) = constraint_builder.compute_spartan_Az_Bz_Cz_sparse(inputs, aux); + profiling::print_current_memory_usage("post_az_bz_cz"); let comb_func_outer = |A: &F, B: &F, C: &F, D: &F| -> F { // Below is an optimized form of: *A * (*B * *C - *D) @@ -230,25 +127,32 @@ impl> UniformSpartanProof { *A * (-(*D)) } } else { - *A * (*B * *C - *D) + let inner = *B * *C - *D; + if inner.is_zero() { + F::zero() + } else { + *A * inner + } } }; + // profiling::start_memory_tracing_span("outersumcheck"); + profiling::print_current_memory_usage("pre_outersumcheck"); let (outer_sumcheck_proof, outer_sumcheck_r, outer_sumcheck_claims) = SumcheckInstanceProof::prove_spartan_cubic::<_>( &F::zero(), // claim is zero num_rounds_x, &mut poly_tau, - &mut poly_Az, - &mut poly_Bz, - &mut poly_Cz, + &mut az, + &mut bz, + &mut cz, comb_func_outer, transcript, ); - drop_in_background_thread(poly_Az); - drop_in_background_thread(poly_Bz); - drop_in_background_thread(poly_Cz); - drop_in_background_thread(poly_tau); + let outer_sumcheck_r: Vec = outer_sumcheck_r.into_iter().rev().collect(); + // drop_in_background_thread((poly_Az, poly_Bz, poly_Cz, poly_tau)); + drop((az, bz, cz, poly_tau)); + profiling::print_current_memory_usage("post_outersumcheck"); // claims from the end of sum-check // claim_Az is the (scalar) value v_A = \sum_y A(r_x, y) * z(r_x) where r_x is the sumcheck randomness @@ -276,8 +180,10 @@ impl> UniformSpartanProof { .ilog2(); let (rx_con, rx_ts) = outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_steps_bits as usize); + profiling::print_current_memory_usage("pre_poly_ABC"); let mut poly_ABC = DensePolynomial::new(key.evaluate_r1cs_mle_rlc(rx_con, rx_ts, r_inner_sumcheck_RLC)); + profiling::print_current_memory_usage("post_poly_ABC"); let (inner_sumcheck_proof, inner_sumcheck_r, _claims_inner) = SumcheckInstanceProof::prove_spartan_quadratic::>( @@ -287,7 +193,8 @@ impl> UniformSpartanProof { &segmented_padded_witness, transcript, ); - drop_in_background_thread(poly_ABC); + // drop_in_background_thread(poly_ABC); + drop(poly_ABC); // Requires 'r_col_segment_bits' to index the (const, segment). Within that segment we index the step using 'r_col_step' let r_col_segment_bits = key.uniform_r1cs.num_vars.next_power_of_two().log_2() + 1; @@ -346,6 +253,9 @@ impl> UniformSpartanProof { .verify(F::zero(), num_rounds_x, 3, transcript) .map_err(|_| SpartanError::InvalidOuterSumcheckProof)?; + // Outer sumcheck is bound from the top, reverse the fiat shamir randomness + let r_x: Vec = r_x.into_iter().rev().collect(); + // verify claim_outer_final let (claim_Az, claim_Bz, claim_Cz) = self.outer_sumcheck_claims; let taus_bound_rx = EqPolynomial::new(tau).evaluate(&r_x); diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index 54dacc491..64371a447 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -1,33 +1,50 @@ -use crate::field::JoltField; +use crate::{field::JoltField, poly::{dense_mlpoly::DensePolynomial, eq_poly::EqPolynomial}, utils::{compute_dotproduct_low_optimized, math::Math, thread::{drop_in_background_thread, unsafe_allocate_zero_vec}}}; +use num_integer::Integer; use rayon::prelude::*; +#[derive(Clone)] pub struct SparsePolynomial { num_vars: usize, + Z: Vec<(usize, F)>, } -impl SparsePolynomial { - pub fn new(num_vars: usize, Z: Vec<(usize, Scalar)>) -> Self { +impl SparsePolynomial { + pub fn new(num_vars: usize, Z: Vec<(usize, F)>) -> Self { SparsePolynomial { num_vars, Z } } + // TODO(sragss): rm + #[tracing::instrument(skip_all)] + pub fn from_dense_evals(num_vars: usize, evals: Vec) -> Self { + assert!(num_vars.pow2() >= evals.len()); + let non_zero_count: usize = evals.par_iter().filter(|f| !f.is_zero()).count(); + let mut sparse: Vec<(usize, F)> = Vec::with_capacity(non_zero_count); + evals.into_iter().enumerate().for_each(|(dense_index, f)| { + if !f.is_zero() { + sparse.push((dense_index, f)); + } + }); + Self::new(num_vars, sparse) + } + /// Computes the $\tilde{eq}$ extension polynomial. /// return 1 when a == r, otherwise return 0. - fn compute_chi(a: &[bool], r: &[Scalar]) -> Scalar { + fn compute_chi(a: &[bool], r: &[F]) -> F { assert_eq!(a.len(), r.len()); - let mut chi_i = Scalar::one(); + let mut chi_i = F::one(); for j in 0..r.len() { if a[j] { chi_i *= r[j]; } else { - chi_i *= Scalar::one() - r[j]; + chi_i *= F::one() - r[j]; } } chi_i } // Takes O(n log n) - pub fn evaluate(&self, r: &[Scalar]) -> Scalar { + pub fn evaluate(&self, r: &[F]) -> F { assert_eq!(self.num_vars, r.len()); (0..self.Z.len()) @@ -38,6 +55,400 @@ impl SparsePolynomial { }) .sum() } + + /// Returns n chunks of roughly even size without orphaning siblings (adjacent dense indices). Additionally returns a vector of (low, high] dense index ranges. + fn chunk_no_orphans(&self, n: usize) -> (Vec<&[(usize, F)]>, Vec<(usize, usize)>) { + if self.Z.len() < n * 2 { + return (vec![(&self.Z)], vec![(0, self.num_vars.pow2())]); + } + + let target_chunk_size = self.Z.len() / n; + let mut chunks: Vec<&[(usize, F)]> = Vec::with_capacity(n); + let mut dense_ranges: Vec<(usize, usize)> = Vec::with_capacity(n); + let mut dense_start_index = 0; + let mut sparse_start_index = 0; + let mut sparse_end_index = target_chunk_size; + for _ in 1..n { + let mut dense_end_index = self.Z[sparse_end_index].0; + if dense_end_index % 2 != 0 { + dense_end_index += 1; + sparse_end_index += 1; + } + chunks.push(&self.Z[sparse_start_index..sparse_end_index]); + dense_ranges.push((dense_start_index, dense_end_index)); + dense_start_index = dense_end_index; + + sparse_start_index = sparse_end_index; + sparse_end_index = std::cmp::min(sparse_end_index + target_chunk_size, self.Z.len() - 1); + } + chunks.push(&self.Z[sparse_start_index..]); + // TODO(sragss): likely makes more sense to return full range then truncate when needed (triple iterator) + let highest_non_zero = self.Z.last().map(|&(index, _)| index).unwrap(); + dense_ranges.push((dense_start_index, highest_non_zero + 1)); + assert_eq!(chunks.len(), n); + assert_eq!(dense_ranges.len(), n); + + // TODO(sragss): To use chunk_no_orphans in the triple iterator, we have to overwrite the top of the dense_ranges. + + (chunks, dense_ranges) + } + + #[tracing::instrument(skip_all)] + pub fn bound_poly_var_bot(&mut self, r: &F) { + // TODO(sragss): Do this with a scan instead. + let n = self.Z.len(); + let span = tracing::span!(tracing::Level::DEBUG, "allocate"); + let _enter = span.enter(); + let mut new_Z: Vec<(usize, F)> = Vec::with_capacity(n); + drop(_enter); + for (sparse_index, (dense_index, value)) in self.Z.iter().enumerate() { + if dense_index.is_even() { + let new_dense_index = dense_index / 2; + // TODO(sragss): Can likely combine these conditions for better speculative execution. + if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { + let upper = self.Z[sparse_index + 1].1; + let eval = *value + *r * (upper - value); + new_Z.push((new_dense_index, eval)); + } else { + new_Z.push((new_dense_index, (F::one() - r) * value)); + } + } else { + if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { + continue; + } else { + let new_dense_index = (dense_index - 1) / 2; + new_Z.push((new_dense_index, *r * value)); + } + } + } + self.Z = new_Z; + self.num_vars -= 1; + } + + #[tracing::instrument(skip_all)] + pub fn bound_poly_var_bot_par(&mut self, r: &F) { + // TODO(sragss): Do this with a scan instead. + let n = self.Z.len(); + // let mut new_Z: Vec<(usize, F)> = Vec::with_capacity(n); + + let (chunks, _range) = self.chunk_no_orphans(rayon::current_num_threads() * 8); + // TODO(sragsss): We can scan up front and collect directly into the thing. + let new_Z: Vec<(usize, F)> = chunks.into_par_iter().map(|chunk| { + // TODO(sragss): Do this with a scan instead; + let mut chunk_Z: Vec<(usize, F)> = Vec::with_capacity(chunk.len()); + for (sparse_index, (dense_index, value)) in chunk.iter().enumerate() { + if dense_index.is_even() { + let new_dense_index = dense_index / 2; + // TODO(sragss): Can likely combine these conditions for better speculative execution. + if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { + let upper = self.Z[sparse_index + 1].1; + let eval = *value + *r * (upper - value); + chunk_Z.push((new_dense_index, eval)); + } else { + chunk_Z.push((new_dense_index, (F::one() - r) * value)); + } + } else { + if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { + continue; + } else { + let new_dense_index = (dense_index - 1) / 2; + chunk_Z.push((new_dense_index, *r * value)); + } + } + } + + chunk_Z + }).flatten().collect(); + + // for (sparse_index, (dense_index, value)) in self.Z.iter().enumerate() { + // if dense_index.is_even() { + // let new_dense_index = dense_index / 2; + // // TODO(sragss): Can likely combine these conditions for better speculative execution. + // if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { + // let upper = self.Z[sparse_index + 1].1; + // let eval = *value + *r * (upper - value); + // new_Z.push((new_dense_index, eval)); + // } else { + // new_Z.push((new_dense_index, (F::one() - r) * value)); + // } + // } else { + // if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { + // continue; + // } else { + // let new_dense_index = (dense_index - 1) / 2; + // new_Z.push((new_dense_index, *r * value)); + // } + // } + // } + self.Z = new_Z; + self.num_vars -= 1; + } + + pub fn final_eval(&self) -> F { + assert_eq!(self.num_vars, 0); + if self.Z.len() == 0 { + F::zero() + } else { + assert_eq!(self.Z.len(), 1); + let item = self.Z[0]; + assert_eq!(item.0, 0); + item.1 + } + } + + #[cfg(test)] + #[tracing::instrument(skip_all)] + pub fn to_dense(self) -> DensePolynomial { + use crate::utils::math::Math; + + let mut evals = unsafe_allocate_zero_vec(self.num_vars.pow2()); + + for (index, value) in self.Z { + evals[index] = value; + } + + DensePolynomial::new(evals) + } +} + +pub struct SparseTripleIterator<'a, F: JoltField> { + dense_index: usize, + end_index: usize, + a: &'a [(usize, F)], + b: &'a [(usize, F)], + c: &'a [(usize, F)], +} + +impl<'a, F: JoltField> SparseTripleIterator<'a, F> { + #[tracing::instrument(skip_all)] + pub fn chunks(a: &'a SparsePolynomial, b: &'a SparsePolynomial, c: &'a SparsePolynomial, n: usize) -> Vec { + // When the instance is small enough, don't worry about parallelism + let total_len = a.num_vars.pow2(); + if n * 2 > b.Z.len() { + return vec![SparseTripleIterator { + dense_index: 0, + end_index: total_len, + a: &a.Z, + b: &b.Z, + c: &c.Z + }]; + } + // Can be made more generic, but this is an optimization / simplification. + assert!(b.Z.len() >= a.Z.len() && b.Z.len() >= c.Z.len(), "b.Z.len() assumed to be longest of a, b, and c"); + + // TODO(sragss): Explain the strategy + + let target_chunk_size = b.Z.len() / n; + let mut b_chunks: Vec<&[(usize, F)]> = Vec::with_capacity(n); + let mut dense_ranges: Vec<(usize, usize)> = Vec::with_capacity(n); + let mut dense_start_index = 0; + let mut sparse_start_index = 0; + let mut sparse_end_index = target_chunk_size; + for _ in 1..n { + let mut dense_end_index = b.Z[sparse_end_index].0; + if dense_end_index % 2 != 0 { + dense_end_index += 1; + sparse_end_index += 1; + } + b_chunks.push(&b.Z[sparse_start_index..sparse_end_index]); + dense_ranges.push((dense_start_index, dense_end_index)); + dense_start_index = dense_end_index; + + sparse_start_index = sparse_end_index; + sparse_end_index = std::cmp::min(sparse_end_index + target_chunk_size, b.Z.len() - 1); + } + b_chunks.push(&b.Z[sparse_start_index..]); + let highest_non_zero = { + let a_last = a.Z.last().map(|&(index, _)| index); + let b_last = b.Z.last().map(|&(index, _)| index); + let c_last = c.Z.last().map(|&(index, _)| index); + *a_last.iter().chain(b_last.iter()).chain(c_last.iter()).max().unwrap() + }; + dense_ranges.push((dense_start_index, highest_non_zero + 1)); + assert_eq!(b_chunks.len(), n); + assert_eq!(dense_ranges.len(), n); + + // Create chunks which overlap with b's sparse indices + let mut a_chunks: Vec<&[(usize, F)]> = vec![&[]; n]; + let mut c_chunks: Vec<&[(usize, F)]> = vec![&[]; n]; + let mut a_i = 0; + let mut c_i = 0; + let span = tracing::span!(tracing::Level::DEBUG, "a, c scanning"); + let _enter = span.enter(); + for (chunk_index, range) in dense_ranges.iter().enumerate().skip(1) { + // Find the corresponding a, c chunks + let prev_chunk_end = range.0; + + if a_i < a.Z.len() && a.Z[a_i].0 < prev_chunk_end { + let a_start = a_i; + while a_i < a.Z.len() && a.Z[a_i].0 < prev_chunk_end { + a_i += 1; + } + + a_chunks[chunk_index - 1] = &a.Z[a_start..a_i]; + } + + if c_i < c.Z.len() && c.Z[c_i].0 < prev_chunk_end { + let c_start = c_i; + while c_i < c.Z.len() && c.Z[c_i].0 < prev_chunk_end { + c_i += 1; + } + + c_chunks[chunk_index - 1] = &c.Z[c_start..c_i]; + } + + } + drop(_enter); + a_chunks[n-1] = &a.Z[a_i..]; + c_chunks[n-1] = &c.Z[c_i..]; + + #[cfg(test)] + { + assert_eq!(a_chunks.concat(), a.Z); + assert_eq!(b_chunks.concat(), b.Z); + assert_eq!(c_chunks.concat(), c.Z); + } + + let mut iterators: Vec> = Vec::with_capacity(n); + for (((a_chunk, b_chunk), c_chunk), range) in a_chunks.iter().zip(b_chunks.iter()).zip(c_chunks.iter()).zip(dense_ranges.iter()) { + #[cfg(test)] + { + assert!(a_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); + assert!(b_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); + assert!(c_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); + } + let iter = SparseTripleIterator { + dense_index: range.0, + end_index: range.1, + a: a_chunk, + b: b_chunk, + c: c_chunk + }; + iterators.push(iter); + } + + iterators + } + + pub fn has_next(&self) -> bool { + self.dense_index < self.end_index + } + + pub fn next_pairs(&mut self) -> (usize, F, F, F, F, F, F) { + // TODO(sragss): We can store a map of big ranges of zeros and skip them rather than hitting each dense index. + let low_index = self.dense_index; + let match_and_advance = |slice: &mut &[(usize, F)], index: usize| -> F { + if let Some(first_item) = slice.first() { + if first_item.0 == index { + let ret = first_item.1; + *slice = &slice[1..]; + ret + } else { + F::zero() + } + } else { + F::zero() + } + }; + + let a_lower_val = match_and_advance(&mut self.a, self.dense_index); + let b_lower_val = match_and_advance(&mut self.b, self.dense_index); + let c_lower_val = match_and_advance(&mut self.c, self.dense_index); + self.dense_index += 1; + let a_upper_val = match_and_advance(&mut self.a, self.dense_index); + let b_upper_val = match_and_advance(&mut self.b, self.dense_index); + let c_upper_val = match_and_advance(&mut self.c, self.dense_index); + self.dense_index += 1; + + (low_index, a_lower_val, a_upper_val, b_lower_val, b_upper_val, c_lower_val, c_upper_val) + } +} + +pub trait IndexablePoly: std::ops::Index + Sync { + fn len(&self) -> usize; +} + +impl IndexablePoly for DensePolynomial { + fn len(&self) -> usize { + self.Z.len() + } +} + +// TODO: Rather than use these adhoc virtual indexable polys – create a DensePolynomial which takes any impl Index inner +// and can run all the normal DensePolynomial ops. +#[derive(Clone)] +pub struct SegmentedPaddedWitness { + total_len: usize, + pub segments: Vec>, + pub segment_len: usize, + zero: F, +} + +impl SegmentedPaddedWitness { + pub fn new(total_len: usize, segments: Vec>) -> Self { + let segment_len = segments[0].len(); + assert!(segment_len.is_power_of_two()); + for segment in &segments { + assert_eq!( + segment.len(), + segment_len, + "All segments must be the same length" + ); + } + SegmentedPaddedWitness { + total_len, + segments, + segment_len, + zero: F::zero(), + } + } + + pub fn len(&self) -> usize { + self.total_len + } + + #[tracing::instrument(skip_all, name = "SegmentedPaddedWitness::evaluate_all")] + pub fn evaluate_all(&self, point: Vec) -> Vec { + let chi = EqPolynomial::evals(&point); + assert!(chi.len() >= self.segment_len); + + let evals = self + .segments + .par_iter() + .map(|segment| compute_dotproduct_low_optimized(&chi[0..self.segment_len], segment)) + .collect(); + drop_in_background_thread(chi); + evals + } + + pub fn into_dense_polys(self) -> Vec> { + self.segments + .into_iter() + .map(|poly| DensePolynomial::new(poly)) + .collect() + } +} + +impl std::ops::Index for SegmentedPaddedWitness { + type Output = F; + + fn index(&self, index: usize) -> &Self::Output { + if index >= self.segments.len() * self.segment_len { + &self.zero + } else if index >= self.total_len { + panic!("index too high"); + } else { + let segment_index = index / self.segment_len; + let inner_index = index % self.segment_len; + &self.segments[segment_index][inner_index] + } + } +} + +impl IndexablePoly for SegmentedPaddedWitness { + fn len(&self) -> usize { + self.total_len + } } /// Returns the `num_bits` from n in a canonical order @@ -73,3 +484,131 @@ pub fn eq_plus_one(x: &[F], y: &[F], l: usize) -> F { }) .sum() } + +#[cfg(test)] +mod tests { + use super::*; + use ark_bn254::Fr; + use ark_std::Zero; + + #[test] + fn sparse_bound_bot_all_left() { + let dense_evals = vec![Fr::from(10), Fr::zero(), Fr::from(20), Fr::zero()]; + let sparse_evals = vec![(0, Fr::from(10)), (2, Fr::from(20))]; + + let mut dense = DensePolynomial::new(dense_evals); + let mut sparse = SparsePolynomial::new(2, sparse_evals); + + assert_eq!(sparse.clone().to_dense(), dense); + + let r = Fr::from(121); + sparse.bound_poly_var_bot(&r); + dense.bound_poly_var_bot(&r); + assert_eq!(sparse.to_dense(), dense); + } + + #[test] + fn sparse_bound_bot_all_right() { + let dense_evals = vec![Fr::zero(), Fr::from(10), Fr::zero(), Fr::from(20)]; + let sparse_evals = vec![(1, Fr::from(10)), (3, Fr::from(20))]; + + let mut dense = DensePolynomial::new(dense_evals); + let mut sparse = SparsePolynomial::new(2, sparse_evals); + + assert_eq!(sparse.clone().to_dense(), dense); + + let r = Fr::from(121); + sparse.bound_poly_var_bot(&r); + dense.bound_poly_var_bot(&r); + assert_eq!(sparse.to_dense(), dense); + } + + #[test] + fn sparse_bound_bot_mixed() { + let dense_evals = vec![Fr::zero(), Fr::from(10), Fr::zero(), Fr::from(20), Fr::from(30), Fr::from(40), Fr::zero(), Fr::from(50)]; + let sparse_evals = vec![(1, Fr::from(10)), (3, Fr::from(20)), (4, Fr::from(30)), (5, Fr::from(40)), (7, Fr::from(50))]; + + let mut dense = DensePolynomial::new(dense_evals); + let mut sparse = SparsePolynomial::new(3, sparse_evals); + + assert_eq!(sparse.clone().to_dense(), dense); + + let r = Fr::from(121); + sparse.bound_poly_var_bot(&r); + dense.bound_poly_var_bot(&r); + assert_eq!(sparse.to_dense(), dense); + } + + #[test] + fn sparse_triple_iterator() { + let a = vec![(9, Fr::from(9)), (10, Fr::from(10)), (12, Fr::from(12))]; + let b = vec![(0, Fr::from(100)), (1, Fr::from(1)), (2, Fr::from(2)), (3, Fr::from(3)), (4, Fr::from(4)), (5, Fr::from(5)), (6, Fr::from(6)), (7, Fr::from(7)), (8, Fr::from(8)), (9, Fr::from(9)), (10, Fr::from(10)), (11, Fr::from(11)), (12, Fr::from(12)), (13, Fr::from(13)), (14, Fr::from(14)), (15, Fr::from(15))]; + let c = vec![(0, Fr::from(12)), (3, Fr::from(3))]; + + let a_poly = SparsePolynomial::new(4, a); + let b_poly = SparsePolynomial::new(4, b); + let c_poly = SparsePolynomial::new(4, c); + + let iterators = SparseTripleIterator::chunks(&a_poly, &b_poly, &c_poly, 4); + assert_eq!(iterators.len(), 4); + } + + #[test] + fn sparse_triple_iterator_random() { + use rand::Rng; + + let mut rng = rand::thread_rng(); + + let prob_exists = 0.32; + let num_vars = 10; + let total_len = 1 << num_vars; + + let mut a = vec![]; + let mut b = vec![]; + let mut c = vec![]; + + for i in 0usize..total_len { + if rng.gen::() < prob_exists { + a.push((i, Fr::from(i as u64))); + } + if rng.gen::() < prob_exists * 2f64 { + b.push((i, Fr::from(i as u64))); + } + if rng.gen::() < prob_exists { + c.push((i, Fr::from(i as u64))); + } + } + + let a_poly = SparsePolynomial::new(num_vars, a); + let b_poly = SparsePolynomial::new(num_vars, b); + let c_poly = SparsePolynomial::new(num_vars, c); + + let mut iterators = SparseTripleIterator::chunks(&a_poly, &b_poly, &c_poly, 8); + + let mut new_a = vec![Fr::zero(); total_len]; + let mut new_b = vec![Fr::zero(); total_len]; + let mut new_c = vec![Fr::zero(); total_len]; + let mut expected_dense_index = 0; + for iterator in iterators.iter_mut() { + while iterator.has_next() { + let (dense_index, a_low, a_high, b_low, b_high, c_low, c_high) = iterator.next_pairs(); + + new_a[dense_index] = a_low; + new_a[dense_index+1] = a_high; + + new_b[dense_index] = b_low; + new_b[dense_index+1] = b_high; + + new_c[dense_index] = c_low; + new_c[dense_index+1] = c_high; + + assert_eq!(dense_index, expected_dense_index); + expected_dense_index += 2; + } + } + + assert_eq!(a_poly.to_dense().Z, new_a); + assert_eq!(b_poly.to_dense().Z, new_b); + assert_eq!(c_poly.to_dense().Z, new_c); + } +} \ No newline at end of file diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index 0b841de14..abf6ea7eb 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -4,7 +4,7 @@ use crate::field::JoltField; use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::unipoly::{CompressedUniPoly, UniPoly}; -use crate::r1cs::spartan::IndexablePoly; +use crate::r1cs::special_polys::{IndexablePoly, SparsePolynomial, SparseTripleIterator}; use crate::utils::errors::ProofVerifyError; use crate::utils::mul_0_optimized; use crate::utils::thread::drop_in_background_thread; @@ -179,51 +179,65 @@ impl SumcheckInstanceProof { skip_all, name = "Spartan2::sumcheck::compute_eval_points_spartan_cubic" )] + /// Binds from the bottom rather than the top. pub fn compute_eval_points_spartan_cubic( - poly_A: &DensePolynomial, - poly_B: &DensePolynomial, - poly_C: &DensePolynomial, - poly_D: &DensePolynomial, + poly_eq: &DensePolynomial, + poly_A: &SparsePolynomial, + poly_B: &SparsePolynomial, + poly_C: &SparsePolynomial, comb_func: &Func, ) -> (F, F, F) where Func: Fn(&F, &F, &F, &F) -> F + Sync, { - let len = poly_A.len() / 2; - (0..len) - .into_par_iter() - .map(|i| { - // eval 0: bound_func is A(low) - let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); + // num_threads * 8 enables better work stealing + let mut iterators = + SparseTripleIterator::chunks(poly_A, poly_B, poly_C, rayon::current_num_threads() * 16); + iterators + .par_iter_mut() + .map(|iterator| { + let span = tracing::span!(tracing::Level::DEBUG, "eval_par_inner"); + let _enter = span.enter(); + let mut eval_point_0 = F::zero(); + let mut eval_point_2 = F::zero(); + let mut eval_point_3 = F::zero(); + while iterator.has_next() { + let (dense_index, a_low, a_high, b_low, b_high, c_low, c_high) = + iterator.next_pairs(); + assert!(dense_index % 2 == 0); - let m_A = poly_A[len + i] - poly_A[i]; - let m_B = poly_B[len + i] - poly_B[i]; - let m_C = poly_C[len + i] - poly_C[i]; - let m_D = poly_D[len + i] - poly_D[i]; + // eval 0: bound_func is A(low) + eval_point_0 += comb_func(&poly_eq[dense_index], &a_low, &b_low, &c_low); + + let m_eq = poly_eq[dense_index + 1] - poly_eq[dense_index]; + let m_A = a_high - a_low; + let m_B = b_high - b_low; + let m_C = c_high - c_low; + + // eval 2 + let poly_A_bound_point = poly_eq[dense_index + 1] + m_eq; + let poly_B_bound_point = a_high + m_A; + let poly_C_bound_point = b_high + m_B; + let poly_D_bound_point = c_high + m_C; + eval_point_2 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + &poly_D_bound_point, + ); - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + m_A; - let poly_B_bound_point = poly_B[len + i] + m_B; - let poly_C_bound_point = poly_C[len + i] + m_C; - let poly_D_bound_point = poly_D[len + i] + m_D; - let eval_point_2 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); - - // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + m_A; - let poly_B_bound_point = poly_B_bound_point + m_B; - let poly_C_bound_point = poly_C_bound_point + m_C; - let poly_D_bound_point = poly_D_bound_point + m_D; - let eval_point_3 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); + // eval 3 + let poly_A_bound_point = poly_A_bound_point + m_eq; + let poly_B_bound_point = poly_B_bound_point + m_A; + let poly_C_bound_point = poly_C_bound_point + m_B; + let poly_D_bound_point = poly_D_bound_point + m_C; + eval_point_3 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + &poly_D_bound_point, + ); + } (eval_point_0, eval_point_2, eval_point_3) }) .reduce( @@ -236,10 +250,10 @@ impl SumcheckInstanceProof { pub fn prove_spartan_cubic( claim: &F, num_rounds: usize, - poly_A: &mut DensePolynomial, - poly_B: &mut DensePolynomial, - poly_C: &mut DensePolynomial, - poly_D: &mut DensePolynomial, + poly_eq: &mut DensePolynomial, + poly_A: &mut SparsePolynomial, + poly_B: &mut SparsePolynomial, + poly_C: &mut SparsePolynomial, comb_func: Func, transcript: &mut ProofTranscript, ) -> (Self, Vec, Vec) @@ -255,7 +269,7 @@ impl SumcheckInstanceProof { // Make an iterator returning the contributions to the evaluations let (eval_point_0, eval_point_2, eval_point_3) = Self::compute_eval_points_spartan_cubic( - poly_A, poly_B, poly_C, poly_D, &comb_func, + poly_eq, poly_A, poly_B, poly_C, &comb_func, ); let evals = [ @@ -264,6 +278,7 @@ impl SumcheckInstanceProof { eval_point_2, eval_point_3, ]; + UniPoly::from_evals(&evals) }; @@ -280,14 +295,14 @@ impl SumcheckInstanceProof { // bound all tables to the verifier's challenege rayon::join( - || poly_A.bound_poly_var_top_par(&r_i), + || poly_eq.bound_poly_var_bot(&r_i), || { rayon::join( - || poly_B.bound_poly_var_top_zero_optimized(&r_i), + || poly_A.bound_poly_var_bot(&r_i), || { rayon::join( - || poly_C.bound_poly_var_top_zero_optimized(&r_i), - || poly_D.bound_poly_var_top_zero_optimized(&r_i), + || poly_B.bound_poly_var_bot(&r_i), + || poly_C.bound_poly_var_bot(&r_i), ) }, ) @@ -298,7 +313,12 @@ impl SumcheckInstanceProof { ( SumcheckInstanceProof::new(polys), r, - vec![poly_A[0], poly_B[0], poly_C[0], poly_D[0]], + vec![ + poly_eq[0], + poly_A.final_eval(), + poly_B.final_eval(), + poly_C.final_eval(), + ], ) }