From 2e13acdb1b1e6a5ccdf7a87749447a6283da57a3 Mon Sep 17 00:00:00 2001 From: sragss Date: Fri, 21 Jun 2024 19:30:35 -0700 Subject: [PATCH 01/17] init working sparsity --- jolt-core/src/jolt/vm/mod.rs | 6 +- jolt-core/src/poly/dense_mlpoly.rs | 14 +- jolt-core/src/r1cs/builder.rs | 127 +++++- jolt-core/src/r1cs/spartan.rs | 148 ++----- jolt-core/src/r1cs/special_polys.rs | 553 ++++++++++++++++++++++++- jolt-core/src/subprotocols/sumcheck.rs | 116 +++--- 6 files changed, 778 insertions(+), 186 deletions(-) diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index fb3b5e55d..06f63c004 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -4,6 +4,7 @@ use crate::field::JoltField; use crate::r1cs::builder::CombinedUniformBuilder; use crate::r1cs::jolt_constraints::{construct_jolt_constraints, JoltIn}; use crate::r1cs::spartan::{self, UniformSpartanProof}; +use crate::utils::profiling; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::log2; use common::constants::RAM_START_ADDRESS; @@ -402,8 +403,10 @@ pub trait Jolt, const C: usize, c &mut transcript, ); - drop_in_background_thread(jolt_polynomials); + // drop_in_background_thread(jolt_polynomials); + drop(jolt_polynomials); + profiling::print_current_memory_usage("pre_spartan"); let spartan_proof = UniformSpartanProof::::prove_precommitted( &preprocessing.generators, r1cs_builder, @@ -412,6 +415,7 @@ pub trait Jolt, const C: usize, c &mut transcript, ) .expect("r1cs proof failed"); + profiling::print_current_memory_usage("post_spartan"); let r1cs_proof = R1CSProof { key: spartan_key, proof: spartan_proof, diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 992b7fe69..6eb701531 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -1,6 +1,6 @@ #![allow(clippy::too_many_arguments)] use crate::poly::eq_poly::EqPolynomial; -use crate::utils::thread::unsafe_allocate_zero_vec; +use crate::utils::thread::{drop_in_background_thread, unsafe_allocate_zero_vec}; use crate::utils::{self, compute_dotproduct, compute_dotproduct_low_optimized}; use crate::field::JoltField; @@ -201,11 +201,17 @@ impl DensePolynomial { } } + #[tracing::instrument(skip_all)] pub fn bound_poly_var_bot(&mut self, r: &F) { let n = self.len() / 2; - for i in 0..n { - self.Z[i] = self.Z[2 * i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]); - } + let mut new_z = unsafe_allocate_zero_vec(n); + new_z.par_iter_mut().enumerate().for_each(|(i, z)| { + *z = self.Z[2*i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]); + }); + + let old_Z = std::mem::replace(&mut self.Z, new_z); + drop_in_background_thread(old_Z); + self.num_vars -= 1; self.len = n; } diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index d0c1c409e..e16bd7f7f 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -2,8 +2,7 @@ use crate::{ field::{JoltField, OptimizedMul}, r1cs::key::{SparseConstraints, UniformR1CS}, utils::{ - mul_0_1_optimized, - thread::{drop_in_background_thread, unsafe_allocate_zero_vec}, + math::Math, mul_0_1_optimized, thread::{drop_in_background_thread, unsafe_allocate_zero_vec} }, }; #[allow(unused_imports)] // clippy thinks these aren't needed lol @@ -13,7 +12,7 @@ use std::{collections::HashMap, fmt::Debug}; use super::{ key::{NonUniformR1CS, SparseEqualityItem}, - ops::{ConstraintInput, Term, Variable, LC}, + ops::{ConstraintInput, Term, Variable, LC}, special_polys::SparsePolynomial, }; pub trait R1CSConstraintBuilder { @@ -848,13 +847,127 @@ impl CombinedUniformBuilder { (Az, Bz, Cz) } + /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] + /// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]] + #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan")] + pub fn compute_spartan_Az_Bz_Cz_sparse( + &self, + inputs: &[Vec], + aux: &[Vec], + ) -> (SparsePolynomial, SparsePolynomial, SparsePolynomial) { + assert_eq!(inputs.len(), I::COUNT); + let num_aux = self.uniform_builder.num_aux(); + assert_eq!(aux.len(), num_aux); + assert!(inputs + .iter() + .chain(aux.iter()) + .all(|inner_input| inner_input.len() == self.uniform_repeat)); + + let uniform_constraint_rows = self.uniform_repeat_constraint_rows(); + // TODO(sragss): Allocation can overshoot by up to a factor of 2, Spartan could handle non-pow-2 Az,Bz,Cz + let constraint_rows = self.constraint_rows(); + let (mut Az, mut Bz, mut Cz) = ( + unsafe_allocate_zero_vec(constraint_rows), + unsafe_allocate_zero_vec(constraint_rows), + unsafe_allocate_zero_vec(constraint_rows), + ); + + let batch_inputs = |lc: &LC| batch_inputs(lc, inputs, aux); + + // uniform_constraints: Xz[0..uniform_constraint_rows] + // TODO(sragss): Attempt moving onto key and computing from materialized rows rather than linear combos + let span = tracing::span!(tracing::Level::DEBUG, "compute_constraints"); + let enter = span.enter(); + let az_chunks = Az.par_chunks_mut(self.uniform_repeat); + let bz_chunks = Bz.par_chunks_mut(self.uniform_repeat); + let cz_chunks = Cz.par_chunks_mut(self.uniform_repeat); + + self.uniform_builder + .constraints + .par_iter() + .zip(az_chunks.zip(bz_chunks.zip(cz_chunks))) + .for_each(|(constraint, (az_chunk, (bz_chunk, cz_chunk)))| { + let a_inputs = batch_inputs(&constraint.a); + let b_inputs = batch_inputs(&constraint.b); + let c_inputs = batch_inputs(&constraint.c); + + constraint.a.evaluate_batch_mut(&a_inputs, az_chunk); + constraint.b.evaluate_batch_mut(&b_inputs, bz_chunk); + constraint.c.evaluate_batch_mut(&c_inputs, cz_chunk); + }); + drop(enter); + + // offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1] + // (a - b) * condition == 0 + // For the final step we will not compute the offset terms, and will assume the condition to be set to 0 + let span = tracing::span!(tracing::Level::DEBUG, "offset_eq"); + let _enter = span.enter(); + + let constr = &self.offset_equality_constraint; + let condition_evals = constr + .cond + .1 + .evaluate_batch(&batch_inputs(&constr.cond.1), self.uniform_repeat); + let eq_a_evals = constr + .a + .1 + .evaluate_batch(&batch_inputs(&constr.a.1), self.uniform_repeat); + let eq_b_evals = constr + .b + .1 + .evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat); + + let Az_off = Az[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat] + .par_iter_mut(); + let Bz_off = Bz[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat] + .par_iter_mut(); + + (0..self.uniform_repeat) + .into_par_iter() + .zip(Az_off.zip(Bz_off)) + .for_each(|(step_index, (az, bz))| { + // Write corresponding values, if outside the step range, only include the constant. + let a_step = step_index + if constr.a.0 { 1 } else { 0 }; + let b_step = step_index + if constr.b.0 { 1 } else { 0 }; + let a = eq_a_evals + .get(a_step) + .cloned() + .unwrap_or(constr.a.1.constant_term_field()); + let b = eq_b_evals + .get(b_step) + .cloned() + .unwrap_or(constr.b.1.constant_term_field()); + *az = a - b; + + let condition_step = step_index + if constr.cond.0 { 1 } else { 0 }; + *bz = condition_evals + .get(condition_step) + .cloned() + .unwrap_or(constr.cond.1.constant_term_field()); + }); + drop(_enter); + + #[cfg(test)] + self.assert_valid(&Az, &Bz, &Cz); + + let num_vars = self.constraint_rows().next_power_of_two().log_2(); + let (az_poly, (bz_poly, cz_poly)) = rayon::join( + || SparsePolynomial::from_dense_evals(num_vars, Az), + || rayon::join( + || SparsePolynomial::from_dense_evals(num_vars, Bz), + || SparsePolynomial::from_dense_evals(num_vars, Cz) + ) + ); + + (az_poly, bz_poly, cz_poly) + } + + #[cfg(test)] pub fn assert_valid(&self, az: &[F], bz: &[F], cz: &[F]) { let rows = az.len(); - let expected_rows = self.constraint_rows().next_power_of_two(); - assert_eq!(az.len(), expected_rows); - assert_eq!(bz.len(), expected_rows); - assert_eq!(cz.len(), expected_rows); + assert_eq!(bz.len(), rows); + assert_eq!(cz.len(), rows); for constraint_index in 0..rows { if az[constraint_index] * bz[constraint_index] != cz[constraint_index] { let uniform_constraint_index = constraint_index / self.uniform_repeat; diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 6a19d6a41..256ab9b14 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -4,8 +4,9 @@ use crate::field::JoltField; use crate::poly::commitment::commitment_scheme::BatchType; use crate::poly::commitment::commitment_scheme::CommitmentScheme; use crate::r1cs::key::UniformSpartanKey; -use crate::utils::compute_dotproduct_low_optimized; +use crate::r1cs::special_polys::SegmentedPaddedWitness; use crate::utils::math::Math; +use crate::utils::profiling; use crate::utils::thread::drop_in_background_thread; use crate::utils::transcript::ProofTranscript; @@ -58,87 +59,6 @@ pub enum SpartanError { InvalidPCSProof, } -// TODO: Rather than use these adhoc virtual indexable polys – create a DensePolynomial which takes any impl Index inner -// and can run all the normal DensePolynomial ops. -#[derive(Clone)] -pub struct SegmentedPaddedWitness { - total_len: usize, - segments: Vec>, - segment_len: usize, - zero: F, -} - -impl SegmentedPaddedWitness { - pub fn new(total_len: usize, segments: Vec>) -> Self { - let segment_len = segments[0].len(); - assert!(segment_len.is_power_of_two()); - for segment in &segments { - assert_eq!( - segment.len(), - segment_len, - "All segments must be the same length" - ); - } - SegmentedPaddedWitness { - total_len, - segments, - segment_len, - zero: F::zero(), - } - } - - pub fn len(&self) -> usize { - self.total_len - } - - #[tracing::instrument(skip_all, name = "SegmentedPaddedWitness::evaluate_all")] - pub fn evaluate_all(&self, point: Vec) -> Vec { - let chi = EqPolynomial::evals(&point); - assert!(chi.len() >= self.segment_len); - - let evals = self - .segments - .par_iter() - .map(|segment| compute_dotproduct_low_optimized(&chi[0..self.segment_len], segment)) - .collect(); - drop_in_background_thread(chi); - evals - } - - pub fn into_dense_polys(self) -> Vec> { - self.segments - .into_iter() - .map(|poly| DensePolynomial::new(poly)) - .collect() - } -} - -impl std::ops::Index for SegmentedPaddedWitness { - type Output = F; - - fn index(&self, index: usize) -> &Self::Output { - if index >= self.segments.len() * self.segment_len { - &self.zero - } else if index >= self.total_len { - panic!("index too high"); - } else { - let segment_index = index / self.segment_len; - let inner_index = index % self.segment_len; - &self.segments[segment_index][inner_index] - } - } -} - -impl IndexablePoly for SegmentedPaddedWitness { - fn len(&self) -> usize { - self.total_len - } -} - -pub trait IndexablePoly: std::ops::Index + Sync { - fn len(&self) -> usize; -} - /// A succinct proof of knowledge of a witness to a relaxed R1CS instance /// The proof is produced using Spartan's combination of the sum-check and /// the commitment to a vector viewed as a polynomial commitment @@ -188,38 +108,15 @@ impl> UniformSpartanProof { let tau = (0..num_rounds_x) .map(|_i| transcript.challenge_scalar(b"t")) .collect::>(); + profiling::print_current_memory_usage("pre_poly_tau"); let mut poly_tau = DensePolynomial::new(EqPolynomial::evals(&tau)); + profiling::print_current_memory_usage("post_poly_tau"); let inputs = &segmented_padded_witness.segments[0..I::COUNT]; let aux = &segmented_padded_witness.segments[I::COUNT..]; - let (az, bz, cz) = constraint_builder.compute_spartan_Az_Bz_Cz(inputs, aux); - // TODO: Do not require these padded, Sumcheck should handle sparsity. - assert!(az.len().is_power_of_two()); - assert!(bz.len().is_power_of_two()); - assert!(cz.len().is_power_of_two()); - - let mut poly_Az = DensePolynomial::new(az); - let mut poly_Bz = DensePolynomial::new(bz); - let mut poly_Cz = DensePolynomial::new(cz); - - #[cfg(test)] - { - // Check that Z is a satisfying assignment - for (i, ((az, bz), cz)) in poly_Az - .evals_ref() - .iter() - .zip(poly_Bz.evals_ref()) - .zip(poly_Cz.evals_ref()) - .enumerate() - { - if *az * *bz != *cz { - let padded_segment_len = segmented_padded_witness.segment_len; - let error_segment_index = i / padded_segment_len; - let error_step_index = i % padded_segment_len; - panic!("witness is not a satisfying assignment. Failed on segment {error_segment_index} at step {error_step_index}"); - } - } - } + profiling::print_current_memory_usage("pre_az_bz_cz"); + let (mut az, mut bz, mut cz) = constraint_builder.compute_spartan_Az_Bz_Cz_sparse(inputs, aux); + profiling::print_current_memory_usage("post_az_bz_cz"); let comb_func_outer = |A: &F, B: &F, C: &F, D: &F| -> F { // Below is an optimized form of: *A * (*B * *C - *D) @@ -230,25 +127,32 @@ impl> UniformSpartanProof { *A * (-(*D)) } } else { - *A * (*B * *C - *D) + let inner = *B * *C - *D; + if inner.is_zero() { + F::zero() + } else { + *A * inner + } } }; + // profiling::start_memory_tracing_span("outersumcheck"); + profiling::print_current_memory_usage("pre_outersumcheck"); let (outer_sumcheck_proof, outer_sumcheck_r, outer_sumcheck_claims) = SumcheckInstanceProof::prove_spartan_cubic::<_>( &F::zero(), // claim is zero num_rounds_x, &mut poly_tau, - &mut poly_Az, - &mut poly_Bz, - &mut poly_Cz, + &mut az, + &mut bz, + &mut cz, comb_func_outer, transcript, ); - drop_in_background_thread(poly_Az); - drop_in_background_thread(poly_Bz); - drop_in_background_thread(poly_Cz); - drop_in_background_thread(poly_tau); + let outer_sumcheck_r: Vec = outer_sumcheck_r.into_iter().rev().collect(); + // drop_in_background_thread((poly_Az, poly_Bz, poly_Cz, poly_tau)); + drop((az, bz, cz, poly_tau)); + profiling::print_current_memory_usage("post_outersumcheck"); // claims from the end of sum-check // claim_Az is the (scalar) value v_A = \sum_y A(r_x, y) * z(r_x) where r_x is the sumcheck randomness @@ -276,8 +180,10 @@ impl> UniformSpartanProof { .ilog2(); let (rx_con, rx_ts) = outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_steps_bits as usize); + profiling::print_current_memory_usage("pre_poly_ABC"); let mut poly_ABC = DensePolynomial::new(key.evaluate_r1cs_mle_rlc(rx_con, rx_ts, r_inner_sumcheck_RLC)); + profiling::print_current_memory_usage("post_poly_ABC"); let (inner_sumcheck_proof, inner_sumcheck_r, _claims_inner) = SumcheckInstanceProof::prove_spartan_quadratic::>( @@ -287,7 +193,8 @@ impl> UniformSpartanProof { &segmented_padded_witness, transcript, ); - drop_in_background_thread(poly_ABC); + // drop_in_background_thread(poly_ABC); + drop(poly_ABC); // Requires 'r_col_segment_bits' to index the (const, segment). Within that segment we index the step using 'r_col_step' let r_col_segment_bits = key.uniform_r1cs.num_vars.next_power_of_two().log_2() + 1; @@ -346,6 +253,9 @@ impl> UniformSpartanProof { .verify(F::zero(), num_rounds_x, 3, transcript) .map_err(|_| SpartanError::InvalidOuterSumcheckProof)?; + // Outer sumcheck is bound from the top, reverse the fiat shamir randomness + let r_x: Vec = r_x.into_iter().rev().collect(); + // verify claim_outer_final let (claim_Az, claim_Bz, claim_Cz) = self.outer_sumcheck_claims; let taus_bound_rx = EqPolynomial::new(tau).evaluate(&r_x); diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index 54dacc491..64371a447 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -1,33 +1,50 @@ -use crate::field::JoltField; +use crate::{field::JoltField, poly::{dense_mlpoly::DensePolynomial, eq_poly::EqPolynomial}, utils::{compute_dotproduct_low_optimized, math::Math, thread::{drop_in_background_thread, unsafe_allocate_zero_vec}}}; +use num_integer::Integer; use rayon::prelude::*; +#[derive(Clone)] pub struct SparsePolynomial { num_vars: usize, + Z: Vec<(usize, F)>, } -impl SparsePolynomial { - pub fn new(num_vars: usize, Z: Vec<(usize, Scalar)>) -> Self { +impl SparsePolynomial { + pub fn new(num_vars: usize, Z: Vec<(usize, F)>) -> Self { SparsePolynomial { num_vars, Z } } + // TODO(sragss): rm + #[tracing::instrument(skip_all)] + pub fn from_dense_evals(num_vars: usize, evals: Vec) -> Self { + assert!(num_vars.pow2() >= evals.len()); + let non_zero_count: usize = evals.par_iter().filter(|f| !f.is_zero()).count(); + let mut sparse: Vec<(usize, F)> = Vec::with_capacity(non_zero_count); + evals.into_iter().enumerate().for_each(|(dense_index, f)| { + if !f.is_zero() { + sparse.push((dense_index, f)); + } + }); + Self::new(num_vars, sparse) + } + /// Computes the $\tilde{eq}$ extension polynomial. /// return 1 when a == r, otherwise return 0. - fn compute_chi(a: &[bool], r: &[Scalar]) -> Scalar { + fn compute_chi(a: &[bool], r: &[F]) -> F { assert_eq!(a.len(), r.len()); - let mut chi_i = Scalar::one(); + let mut chi_i = F::one(); for j in 0..r.len() { if a[j] { chi_i *= r[j]; } else { - chi_i *= Scalar::one() - r[j]; + chi_i *= F::one() - r[j]; } } chi_i } // Takes O(n log n) - pub fn evaluate(&self, r: &[Scalar]) -> Scalar { + pub fn evaluate(&self, r: &[F]) -> F { assert_eq!(self.num_vars, r.len()); (0..self.Z.len()) @@ -38,6 +55,400 @@ impl SparsePolynomial { }) .sum() } + + /// Returns n chunks of roughly even size without orphaning siblings (adjacent dense indices). Additionally returns a vector of (low, high] dense index ranges. + fn chunk_no_orphans(&self, n: usize) -> (Vec<&[(usize, F)]>, Vec<(usize, usize)>) { + if self.Z.len() < n * 2 { + return (vec![(&self.Z)], vec![(0, self.num_vars.pow2())]); + } + + let target_chunk_size = self.Z.len() / n; + let mut chunks: Vec<&[(usize, F)]> = Vec::with_capacity(n); + let mut dense_ranges: Vec<(usize, usize)> = Vec::with_capacity(n); + let mut dense_start_index = 0; + let mut sparse_start_index = 0; + let mut sparse_end_index = target_chunk_size; + for _ in 1..n { + let mut dense_end_index = self.Z[sparse_end_index].0; + if dense_end_index % 2 != 0 { + dense_end_index += 1; + sparse_end_index += 1; + } + chunks.push(&self.Z[sparse_start_index..sparse_end_index]); + dense_ranges.push((dense_start_index, dense_end_index)); + dense_start_index = dense_end_index; + + sparse_start_index = sparse_end_index; + sparse_end_index = std::cmp::min(sparse_end_index + target_chunk_size, self.Z.len() - 1); + } + chunks.push(&self.Z[sparse_start_index..]); + // TODO(sragss): likely makes more sense to return full range then truncate when needed (triple iterator) + let highest_non_zero = self.Z.last().map(|&(index, _)| index).unwrap(); + dense_ranges.push((dense_start_index, highest_non_zero + 1)); + assert_eq!(chunks.len(), n); + assert_eq!(dense_ranges.len(), n); + + // TODO(sragss): To use chunk_no_orphans in the triple iterator, we have to overwrite the top of the dense_ranges. + + (chunks, dense_ranges) + } + + #[tracing::instrument(skip_all)] + pub fn bound_poly_var_bot(&mut self, r: &F) { + // TODO(sragss): Do this with a scan instead. + let n = self.Z.len(); + let span = tracing::span!(tracing::Level::DEBUG, "allocate"); + let _enter = span.enter(); + let mut new_Z: Vec<(usize, F)> = Vec::with_capacity(n); + drop(_enter); + for (sparse_index, (dense_index, value)) in self.Z.iter().enumerate() { + if dense_index.is_even() { + let new_dense_index = dense_index / 2; + // TODO(sragss): Can likely combine these conditions for better speculative execution. + if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { + let upper = self.Z[sparse_index + 1].1; + let eval = *value + *r * (upper - value); + new_Z.push((new_dense_index, eval)); + } else { + new_Z.push((new_dense_index, (F::one() - r) * value)); + } + } else { + if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { + continue; + } else { + let new_dense_index = (dense_index - 1) / 2; + new_Z.push((new_dense_index, *r * value)); + } + } + } + self.Z = new_Z; + self.num_vars -= 1; + } + + #[tracing::instrument(skip_all)] + pub fn bound_poly_var_bot_par(&mut self, r: &F) { + // TODO(sragss): Do this with a scan instead. + let n = self.Z.len(); + // let mut new_Z: Vec<(usize, F)> = Vec::with_capacity(n); + + let (chunks, _range) = self.chunk_no_orphans(rayon::current_num_threads() * 8); + // TODO(sragsss): We can scan up front and collect directly into the thing. + let new_Z: Vec<(usize, F)> = chunks.into_par_iter().map(|chunk| { + // TODO(sragss): Do this with a scan instead; + let mut chunk_Z: Vec<(usize, F)> = Vec::with_capacity(chunk.len()); + for (sparse_index, (dense_index, value)) in chunk.iter().enumerate() { + if dense_index.is_even() { + let new_dense_index = dense_index / 2; + // TODO(sragss): Can likely combine these conditions for better speculative execution. + if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { + let upper = self.Z[sparse_index + 1].1; + let eval = *value + *r * (upper - value); + chunk_Z.push((new_dense_index, eval)); + } else { + chunk_Z.push((new_dense_index, (F::one() - r) * value)); + } + } else { + if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { + continue; + } else { + let new_dense_index = (dense_index - 1) / 2; + chunk_Z.push((new_dense_index, *r * value)); + } + } + } + + chunk_Z + }).flatten().collect(); + + // for (sparse_index, (dense_index, value)) in self.Z.iter().enumerate() { + // if dense_index.is_even() { + // let new_dense_index = dense_index / 2; + // // TODO(sragss): Can likely combine these conditions for better speculative execution. + // if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { + // let upper = self.Z[sparse_index + 1].1; + // let eval = *value + *r * (upper - value); + // new_Z.push((new_dense_index, eval)); + // } else { + // new_Z.push((new_dense_index, (F::one() - r) * value)); + // } + // } else { + // if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { + // continue; + // } else { + // let new_dense_index = (dense_index - 1) / 2; + // new_Z.push((new_dense_index, *r * value)); + // } + // } + // } + self.Z = new_Z; + self.num_vars -= 1; + } + + pub fn final_eval(&self) -> F { + assert_eq!(self.num_vars, 0); + if self.Z.len() == 0 { + F::zero() + } else { + assert_eq!(self.Z.len(), 1); + let item = self.Z[0]; + assert_eq!(item.0, 0); + item.1 + } + } + + #[cfg(test)] + #[tracing::instrument(skip_all)] + pub fn to_dense(self) -> DensePolynomial { + use crate::utils::math::Math; + + let mut evals = unsafe_allocate_zero_vec(self.num_vars.pow2()); + + for (index, value) in self.Z { + evals[index] = value; + } + + DensePolynomial::new(evals) + } +} + +pub struct SparseTripleIterator<'a, F: JoltField> { + dense_index: usize, + end_index: usize, + a: &'a [(usize, F)], + b: &'a [(usize, F)], + c: &'a [(usize, F)], +} + +impl<'a, F: JoltField> SparseTripleIterator<'a, F> { + #[tracing::instrument(skip_all)] + pub fn chunks(a: &'a SparsePolynomial, b: &'a SparsePolynomial, c: &'a SparsePolynomial, n: usize) -> Vec { + // When the instance is small enough, don't worry about parallelism + let total_len = a.num_vars.pow2(); + if n * 2 > b.Z.len() { + return vec![SparseTripleIterator { + dense_index: 0, + end_index: total_len, + a: &a.Z, + b: &b.Z, + c: &c.Z + }]; + } + // Can be made more generic, but this is an optimization / simplification. + assert!(b.Z.len() >= a.Z.len() && b.Z.len() >= c.Z.len(), "b.Z.len() assumed to be longest of a, b, and c"); + + // TODO(sragss): Explain the strategy + + let target_chunk_size = b.Z.len() / n; + let mut b_chunks: Vec<&[(usize, F)]> = Vec::with_capacity(n); + let mut dense_ranges: Vec<(usize, usize)> = Vec::with_capacity(n); + let mut dense_start_index = 0; + let mut sparse_start_index = 0; + let mut sparse_end_index = target_chunk_size; + for _ in 1..n { + let mut dense_end_index = b.Z[sparse_end_index].0; + if dense_end_index % 2 != 0 { + dense_end_index += 1; + sparse_end_index += 1; + } + b_chunks.push(&b.Z[sparse_start_index..sparse_end_index]); + dense_ranges.push((dense_start_index, dense_end_index)); + dense_start_index = dense_end_index; + + sparse_start_index = sparse_end_index; + sparse_end_index = std::cmp::min(sparse_end_index + target_chunk_size, b.Z.len() - 1); + } + b_chunks.push(&b.Z[sparse_start_index..]); + let highest_non_zero = { + let a_last = a.Z.last().map(|&(index, _)| index); + let b_last = b.Z.last().map(|&(index, _)| index); + let c_last = c.Z.last().map(|&(index, _)| index); + *a_last.iter().chain(b_last.iter()).chain(c_last.iter()).max().unwrap() + }; + dense_ranges.push((dense_start_index, highest_non_zero + 1)); + assert_eq!(b_chunks.len(), n); + assert_eq!(dense_ranges.len(), n); + + // Create chunks which overlap with b's sparse indices + let mut a_chunks: Vec<&[(usize, F)]> = vec![&[]; n]; + let mut c_chunks: Vec<&[(usize, F)]> = vec![&[]; n]; + let mut a_i = 0; + let mut c_i = 0; + let span = tracing::span!(tracing::Level::DEBUG, "a, c scanning"); + let _enter = span.enter(); + for (chunk_index, range) in dense_ranges.iter().enumerate().skip(1) { + // Find the corresponding a, c chunks + let prev_chunk_end = range.0; + + if a_i < a.Z.len() && a.Z[a_i].0 < prev_chunk_end { + let a_start = a_i; + while a_i < a.Z.len() && a.Z[a_i].0 < prev_chunk_end { + a_i += 1; + } + + a_chunks[chunk_index - 1] = &a.Z[a_start..a_i]; + } + + if c_i < c.Z.len() && c.Z[c_i].0 < prev_chunk_end { + let c_start = c_i; + while c_i < c.Z.len() && c.Z[c_i].0 < prev_chunk_end { + c_i += 1; + } + + c_chunks[chunk_index - 1] = &c.Z[c_start..c_i]; + } + + } + drop(_enter); + a_chunks[n-1] = &a.Z[a_i..]; + c_chunks[n-1] = &c.Z[c_i..]; + + #[cfg(test)] + { + assert_eq!(a_chunks.concat(), a.Z); + assert_eq!(b_chunks.concat(), b.Z); + assert_eq!(c_chunks.concat(), c.Z); + } + + let mut iterators: Vec> = Vec::with_capacity(n); + for (((a_chunk, b_chunk), c_chunk), range) in a_chunks.iter().zip(b_chunks.iter()).zip(c_chunks.iter()).zip(dense_ranges.iter()) { + #[cfg(test)] + { + assert!(a_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); + assert!(b_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); + assert!(c_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); + } + let iter = SparseTripleIterator { + dense_index: range.0, + end_index: range.1, + a: a_chunk, + b: b_chunk, + c: c_chunk + }; + iterators.push(iter); + } + + iterators + } + + pub fn has_next(&self) -> bool { + self.dense_index < self.end_index + } + + pub fn next_pairs(&mut self) -> (usize, F, F, F, F, F, F) { + // TODO(sragss): We can store a map of big ranges of zeros and skip them rather than hitting each dense index. + let low_index = self.dense_index; + let match_and_advance = |slice: &mut &[(usize, F)], index: usize| -> F { + if let Some(first_item) = slice.first() { + if first_item.0 == index { + let ret = first_item.1; + *slice = &slice[1..]; + ret + } else { + F::zero() + } + } else { + F::zero() + } + }; + + let a_lower_val = match_and_advance(&mut self.a, self.dense_index); + let b_lower_val = match_and_advance(&mut self.b, self.dense_index); + let c_lower_val = match_and_advance(&mut self.c, self.dense_index); + self.dense_index += 1; + let a_upper_val = match_and_advance(&mut self.a, self.dense_index); + let b_upper_val = match_and_advance(&mut self.b, self.dense_index); + let c_upper_val = match_and_advance(&mut self.c, self.dense_index); + self.dense_index += 1; + + (low_index, a_lower_val, a_upper_val, b_lower_val, b_upper_val, c_lower_val, c_upper_val) + } +} + +pub trait IndexablePoly: std::ops::Index + Sync { + fn len(&self) -> usize; +} + +impl IndexablePoly for DensePolynomial { + fn len(&self) -> usize { + self.Z.len() + } +} + +// TODO: Rather than use these adhoc virtual indexable polys – create a DensePolynomial which takes any impl Index inner +// and can run all the normal DensePolynomial ops. +#[derive(Clone)] +pub struct SegmentedPaddedWitness { + total_len: usize, + pub segments: Vec>, + pub segment_len: usize, + zero: F, +} + +impl SegmentedPaddedWitness { + pub fn new(total_len: usize, segments: Vec>) -> Self { + let segment_len = segments[0].len(); + assert!(segment_len.is_power_of_two()); + for segment in &segments { + assert_eq!( + segment.len(), + segment_len, + "All segments must be the same length" + ); + } + SegmentedPaddedWitness { + total_len, + segments, + segment_len, + zero: F::zero(), + } + } + + pub fn len(&self) -> usize { + self.total_len + } + + #[tracing::instrument(skip_all, name = "SegmentedPaddedWitness::evaluate_all")] + pub fn evaluate_all(&self, point: Vec) -> Vec { + let chi = EqPolynomial::evals(&point); + assert!(chi.len() >= self.segment_len); + + let evals = self + .segments + .par_iter() + .map(|segment| compute_dotproduct_low_optimized(&chi[0..self.segment_len], segment)) + .collect(); + drop_in_background_thread(chi); + evals + } + + pub fn into_dense_polys(self) -> Vec> { + self.segments + .into_iter() + .map(|poly| DensePolynomial::new(poly)) + .collect() + } +} + +impl std::ops::Index for SegmentedPaddedWitness { + type Output = F; + + fn index(&self, index: usize) -> &Self::Output { + if index >= self.segments.len() * self.segment_len { + &self.zero + } else if index >= self.total_len { + panic!("index too high"); + } else { + let segment_index = index / self.segment_len; + let inner_index = index % self.segment_len; + &self.segments[segment_index][inner_index] + } + } +} + +impl IndexablePoly for SegmentedPaddedWitness { + fn len(&self) -> usize { + self.total_len + } } /// Returns the `num_bits` from n in a canonical order @@ -73,3 +484,131 @@ pub fn eq_plus_one(x: &[F], y: &[F], l: usize) -> F { }) .sum() } + +#[cfg(test)] +mod tests { + use super::*; + use ark_bn254::Fr; + use ark_std::Zero; + + #[test] + fn sparse_bound_bot_all_left() { + let dense_evals = vec![Fr::from(10), Fr::zero(), Fr::from(20), Fr::zero()]; + let sparse_evals = vec![(0, Fr::from(10)), (2, Fr::from(20))]; + + let mut dense = DensePolynomial::new(dense_evals); + let mut sparse = SparsePolynomial::new(2, sparse_evals); + + assert_eq!(sparse.clone().to_dense(), dense); + + let r = Fr::from(121); + sparse.bound_poly_var_bot(&r); + dense.bound_poly_var_bot(&r); + assert_eq!(sparse.to_dense(), dense); + } + + #[test] + fn sparse_bound_bot_all_right() { + let dense_evals = vec![Fr::zero(), Fr::from(10), Fr::zero(), Fr::from(20)]; + let sparse_evals = vec![(1, Fr::from(10)), (3, Fr::from(20))]; + + let mut dense = DensePolynomial::new(dense_evals); + let mut sparse = SparsePolynomial::new(2, sparse_evals); + + assert_eq!(sparse.clone().to_dense(), dense); + + let r = Fr::from(121); + sparse.bound_poly_var_bot(&r); + dense.bound_poly_var_bot(&r); + assert_eq!(sparse.to_dense(), dense); + } + + #[test] + fn sparse_bound_bot_mixed() { + let dense_evals = vec![Fr::zero(), Fr::from(10), Fr::zero(), Fr::from(20), Fr::from(30), Fr::from(40), Fr::zero(), Fr::from(50)]; + let sparse_evals = vec![(1, Fr::from(10)), (3, Fr::from(20)), (4, Fr::from(30)), (5, Fr::from(40)), (7, Fr::from(50))]; + + let mut dense = DensePolynomial::new(dense_evals); + let mut sparse = SparsePolynomial::new(3, sparse_evals); + + assert_eq!(sparse.clone().to_dense(), dense); + + let r = Fr::from(121); + sparse.bound_poly_var_bot(&r); + dense.bound_poly_var_bot(&r); + assert_eq!(sparse.to_dense(), dense); + } + + #[test] + fn sparse_triple_iterator() { + let a = vec![(9, Fr::from(9)), (10, Fr::from(10)), (12, Fr::from(12))]; + let b = vec![(0, Fr::from(100)), (1, Fr::from(1)), (2, Fr::from(2)), (3, Fr::from(3)), (4, Fr::from(4)), (5, Fr::from(5)), (6, Fr::from(6)), (7, Fr::from(7)), (8, Fr::from(8)), (9, Fr::from(9)), (10, Fr::from(10)), (11, Fr::from(11)), (12, Fr::from(12)), (13, Fr::from(13)), (14, Fr::from(14)), (15, Fr::from(15))]; + let c = vec![(0, Fr::from(12)), (3, Fr::from(3))]; + + let a_poly = SparsePolynomial::new(4, a); + let b_poly = SparsePolynomial::new(4, b); + let c_poly = SparsePolynomial::new(4, c); + + let iterators = SparseTripleIterator::chunks(&a_poly, &b_poly, &c_poly, 4); + assert_eq!(iterators.len(), 4); + } + + #[test] + fn sparse_triple_iterator_random() { + use rand::Rng; + + let mut rng = rand::thread_rng(); + + let prob_exists = 0.32; + let num_vars = 10; + let total_len = 1 << num_vars; + + let mut a = vec![]; + let mut b = vec![]; + let mut c = vec![]; + + for i in 0usize..total_len { + if rng.gen::() < prob_exists { + a.push((i, Fr::from(i as u64))); + } + if rng.gen::() < prob_exists * 2f64 { + b.push((i, Fr::from(i as u64))); + } + if rng.gen::() < prob_exists { + c.push((i, Fr::from(i as u64))); + } + } + + let a_poly = SparsePolynomial::new(num_vars, a); + let b_poly = SparsePolynomial::new(num_vars, b); + let c_poly = SparsePolynomial::new(num_vars, c); + + let mut iterators = SparseTripleIterator::chunks(&a_poly, &b_poly, &c_poly, 8); + + let mut new_a = vec![Fr::zero(); total_len]; + let mut new_b = vec![Fr::zero(); total_len]; + let mut new_c = vec![Fr::zero(); total_len]; + let mut expected_dense_index = 0; + for iterator in iterators.iter_mut() { + while iterator.has_next() { + let (dense_index, a_low, a_high, b_low, b_high, c_low, c_high) = iterator.next_pairs(); + + new_a[dense_index] = a_low; + new_a[dense_index+1] = a_high; + + new_b[dense_index] = b_low; + new_b[dense_index+1] = b_high; + + new_c[dense_index] = c_low; + new_c[dense_index+1] = c_high; + + assert_eq!(dense_index, expected_dense_index); + expected_dense_index += 2; + } + } + + assert_eq!(a_poly.to_dense().Z, new_a); + assert_eq!(b_poly.to_dense().Z, new_b); + assert_eq!(c_poly.to_dense().Z, new_c); + } +} \ No newline at end of file diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index 0b841de14..abf6ea7eb 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -4,7 +4,7 @@ use crate::field::JoltField; use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::unipoly::{CompressedUniPoly, UniPoly}; -use crate::r1cs::spartan::IndexablePoly; +use crate::r1cs::special_polys::{IndexablePoly, SparsePolynomial, SparseTripleIterator}; use crate::utils::errors::ProofVerifyError; use crate::utils::mul_0_optimized; use crate::utils::thread::drop_in_background_thread; @@ -179,51 +179,65 @@ impl SumcheckInstanceProof { skip_all, name = "Spartan2::sumcheck::compute_eval_points_spartan_cubic" )] + /// Binds from the bottom rather than the top. pub fn compute_eval_points_spartan_cubic( - poly_A: &DensePolynomial, - poly_B: &DensePolynomial, - poly_C: &DensePolynomial, - poly_D: &DensePolynomial, + poly_eq: &DensePolynomial, + poly_A: &SparsePolynomial, + poly_B: &SparsePolynomial, + poly_C: &SparsePolynomial, comb_func: &Func, ) -> (F, F, F) where Func: Fn(&F, &F, &F, &F) -> F + Sync, { - let len = poly_A.len() / 2; - (0..len) - .into_par_iter() - .map(|i| { - // eval 0: bound_func is A(low) - let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); + // num_threads * 8 enables better work stealing + let mut iterators = + SparseTripleIterator::chunks(poly_A, poly_B, poly_C, rayon::current_num_threads() * 16); + iterators + .par_iter_mut() + .map(|iterator| { + let span = tracing::span!(tracing::Level::DEBUG, "eval_par_inner"); + let _enter = span.enter(); + let mut eval_point_0 = F::zero(); + let mut eval_point_2 = F::zero(); + let mut eval_point_3 = F::zero(); + while iterator.has_next() { + let (dense_index, a_low, a_high, b_low, b_high, c_low, c_high) = + iterator.next_pairs(); + assert!(dense_index % 2 == 0); - let m_A = poly_A[len + i] - poly_A[i]; - let m_B = poly_B[len + i] - poly_B[i]; - let m_C = poly_C[len + i] - poly_C[i]; - let m_D = poly_D[len + i] - poly_D[i]; + // eval 0: bound_func is A(low) + eval_point_0 += comb_func(&poly_eq[dense_index], &a_low, &b_low, &c_low); + + let m_eq = poly_eq[dense_index + 1] - poly_eq[dense_index]; + let m_A = a_high - a_low; + let m_B = b_high - b_low; + let m_C = c_high - c_low; + + // eval 2 + let poly_A_bound_point = poly_eq[dense_index + 1] + m_eq; + let poly_B_bound_point = a_high + m_A; + let poly_C_bound_point = b_high + m_B; + let poly_D_bound_point = c_high + m_C; + eval_point_2 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + &poly_D_bound_point, + ); - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + m_A; - let poly_B_bound_point = poly_B[len + i] + m_B; - let poly_C_bound_point = poly_C[len + i] + m_C; - let poly_D_bound_point = poly_D[len + i] + m_D; - let eval_point_2 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); - - // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + m_A; - let poly_B_bound_point = poly_B_bound_point + m_B; - let poly_C_bound_point = poly_C_bound_point + m_C; - let poly_D_bound_point = poly_D_bound_point + m_D; - let eval_point_3 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, - ); + // eval 3 + let poly_A_bound_point = poly_A_bound_point + m_eq; + let poly_B_bound_point = poly_B_bound_point + m_A; + let poly_C_bound_point = poly_C_bound_point + m_B; + let poly_D_bound_point = poly_D_bound_point + m_C; + eval_point_3 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + &poly_D_bound_point, + ); + } (eval_point_0, eval_point_2, eval_point_3) }) .reduce( @@ -236,10 +250,10 @@ impl SumcheckInstanceProof { pub fn prove_spartan_cubic( claim: &F, num_rounds: usize, - poly_A: &mut DensePolynomial, - poly_B: &mut DensePolynomial, - poly_C: &mut DensePolynomial, - poly_D: &mut DensePolynomial, + poly_eq: &mut DensePolynomial, + poly_A: &mut SparsePolynomial, + poly_B: &mut SparsePolynomial, + poly_C: &mut SparsePolynomial, comb_func: Func, transcript: &mut ProofTranscript, ) -> (Self, Vec, Vec) @@ -255,7 +269,7 @@ impl SumcheckInstanceProof { // Make an iterator returning the contributions to the evaluations let (eval_point_0, eval_point_2, eval_point_3) = Self::compute_eval_points_spartan_cubic( - poly_A, poly_B, poly_C, poly_D, &comb_func, + poly_eq, poly_A, poly_B, poly_C, &comb_func, ); let evals = [ @@ -264,6 +278,7 @@ impl SumcheckInstanceProof { eval_point_2, eval_point_3, ]; + UniPoly::from_evals(&evals) }; @@ -280,14 +295,14 @@ impl SumcheckInstanceProof { // bound all tables to the verifier's challenege rayon::join( - || poly_A.bound_poly_var_top_par(&r_i), + || poly_eq.bound_poly_var_bot(&r_i), || { rayon::join( - || poly_B.bound_poly_var_top_zero_optimized(&r_i), + || poly_A.bound_poly_var_bot(&r_i), || { rayon::join( - || poly_C.bound_poly_var_top_zero_optimized(&r_i), - || poly_D.bound_poly_var_top_zero_optimized(&r_i), + || poly_B.bound_poly_var_bot(&r_i), + || poly_C.bound_poly_var_bot(&r_i), ) }, ) @@ -298,7 +313,12 @@ impl SumcheckInstanceProof { ( SumcheckInstanceProof::new(polys), r, - vec![poly_A[0], poly_B[0], poly_C[0], poly_D[0]], + vec![ + poly_eq[0], + poly_A.final_eval(), + poly_B.final_eval(), + poly_C.final_eval(), + ], ) } From 17e6d3f2be3bf0b4407fa6470b389614e18989d0 Mon Sep 17 00:00:00 2001 From: sragss Date: Sat, 22 Jun 2024 08:49:41 -0700 Subject: [PATCH 02/17] par binding algos --- jolt-core/src/poly/dense_mlpoly.rs | 9 ++- jolt-core/src/r1cs/special_polys.rs | 103 ++++++++++++++++++++----- jolt-core/src/subprotocols/sumcheck.rs | 32 ++++---- jolt-core/src/utils/thread.rs | 26 +++++++ 4 files changed, 134 insertions(+), 36 deletions(-) diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 6eb701531..928b63915 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -206,7 +206,14 @@ impl DensePolynomial { let n = self.len() / 2; let mut new_z = unsafe_allocate_zero_vec(n); new_z.par_iter_mut().enumerate().for_each(|(i, z)| { - *z = self.Z[2*i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]); + let m = self.Z[2*i + 1] - self.Z[2*i]; + *z = if m.is_zero() { + self.Z[2*i] + } else if m.is_one() { + self.Z[2*i] + r + }else { + self.Z[2*i] + *r * m + } }); let old_Z = std::mem::replace(&mut self.Z, new_z); diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index 64371a447..76b1aab3f 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -1,4 +1,4 @@ -use crate::{field::JoltField, poly::{dense_mlpoly::DensePolynomial, eq_poly::EqPolynomial}, utils::{compute_dotproduct_low_optimized, math::Math, thread::{drop_in_background_thread, unsafe_allocate_zero_vec}}}; +use crate::{field::JoltField, poly::{dense_mlpoly::DensePolynomial, eq_poly::EqPolynomial}, utils::{compute_dotproduct_low_optimized, math::Math, mul_0_1_optimized, thread::{drop_in_background_thread, unsafe_allocate_sparse_zero_vec, unsafe_allocate_zero_vec}}}; use num_integer::Integer; use rayon::prelude::*; @@ -57,6 +57,7 @@ impl SparsePolynomial { } /// Returns n chunks of roughly even size without orphaning siblings (adjacent dense indices). Additionally returns a vector of (low, high] dense index ranges. + #[tracing::instrument(skip_all)] fn chunk_no_orphans(&self, n: usize) -> (Vec<&[(usize, F)]>, Vec<(usize, usize)>) { if self.Z.len() < n * 2 { return (vec![(&self.Z)], vec![(0, self.num_vars.pow2())]); @@ -83,12 +84,12 @@ impl SparsePolynomial { } chunks.push(&self.Z[sparse_start_index..]); // TODO(sragss): likely makes more sense to return full range then truncate when needed (triple iterator) + // TODO(sragss): To use chunk_no_orphans in the triple iterator, we have to overwrite the top of the dense_ranges. let highest_non_zero = self.Z.last().map(|&(index, _)| index).unwrap(); dense_ranges.push((dense_start_index, highest_non_zero + 1)); assert_eq!(chunks.len(), n); assert_eq!(dense_ranges.len(), n); - // TODO(sragss): To use chunk_no_orphans in the triple iterator, we have to overwrite the top of the dense_ranges. (chunks, dense_ranges) } @@ -127,38 +128,75 @@ impl SparsePolynomial { #[tracing::instrument(skip_all)] pub fn bound_poly_var_bot_par(&mut self, r: &F) { - // TODO(sragss): Do this with a scan instead. - let n = self.Z.len(); - // let mut new_Z: Vec<(usize, F)> = Vec::with_capacity(n); - + // TODO(sragss): better parallelism. + let count_span = tracing::span!(tracing::Level::DEBUG, "counting"); + let count_enter = count_span.enter(); let (chunks, _range) = self.chunk_no_orphans(rayon::current_num_threads() * 8); + let chunk_sizes: Vec = chunks.par_iter().map(|chunk| { + let mut chunk_size = 0; + let mut i = 0; + while i < chunk.len() { + chunk_size += 1; + + // If they're siblings, avoid double counting + if chunk[i].0.is_even() && i + 1 < chunk.len() && chunk[i].0 + 1 == chunk[i + 1].0 { + i += 1; + } + i += 1; + } + chunk_size + }).collect(); + drop(count_enter); + + let alloc_span = tracing::span!(tracing::Level::DEBUG, "alloc_new_Z"); + let alloc_enter = alloc_span.enter(); + let total_len: usize = chunk_sizes.iter().sum(); + let mut new_Z: Vec<(usize, F)> = unsafe_allocate_sparse_zero_vec(total_len); + drop(alloc_enter); + + let mut mutable_chunks: Vec<&mut [(usize, F)]> = vec![]; + let mut remainder = new_Z.as_mut_slice(); + for chunk_size in chunk_sizes { + let (first, second) = remainder.split_at_mut(chunk_size); + mutable_chunks.push(first); + remainder = second; + } + assert_eq!(mutable_chunks.len(), chunks.len()); + // TODO(sragsss): We can scan up front and collect directly into the thing. - let new_Z: Vec<(usize, F)> = chunks.into_par_iter().map(|chunk| { + chunks.into_par_iter().zip(mutable_chunks.par_iter_mut()).for_each(|(chunk, mutable)| { + let span = tracing::span!(tracing::Level::DEBUG, "chunk"); + let _enter = span.enter(); // TODO(sragss): Do this with a scan instead; - let mut chunk_Z: Vec<(usize, F)> = Vec::with_capacity(chunk.len()); + // let mut chunk_Z: Vec<(usize, F)> = Vec::with_capacity(chunk.len()); + let mut write_index = 0; for (sparse_index, (dense_index, value)) in chunk.iter().enumerate() { if dense_index.is_even() { let new_dense_index = dense_index / 2; // TODO(sragss): Can likely combine these conditions for better speculative execution. - if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { - let upper = self.Z[sparse_index + 1].1; - let eval = *value + *r * (upper - value); - chunk_Z.push((new_dense_index, eval)); - } else { - chunk_Z.push((new_dense_index, (F::one() - r) * value)); + + // All exist + if chunk.len() >= 2 && sparse_index <= chunk.len() - 2 && chunk[sparse_index + 1].0 == dense_index + 1 { + let upper = chunk[sparse_index + 1].1; + let eval = *value + mul_0_1_optimized(r, &(upper - value)); + mutable[write_index] = (new_dense_index, eval); + write_index += 1; + } else { // low exists + mutable[write_index] = (new_dense_index, mul_0_1_optimized(&(F::one() - r), value)); + write_index += 1; } } else { - if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { + // low and high exist + if sparse_index > 0 && chunk[sparse_index - 1].0 == dense_index - 1 { continue; - } else { + } else { // high only exists let new_dense_index = (dense_index - 1) / 2; - chunk_Z.push((new_dense_index, *r * value)); + mutable[write_index] = (new_dense_index, mul_0_1_optimized(r, value)); + write_index += 1; } } } - - chunk_Z - }).flatten().collect(); + }); // for (sparse_index, (dense_index, value)) in self.Z.iter().enumerate() { // if dense_index.is_even() { @@ -180,7 +218,8 @@ impl SparsePolynomial { // } // } // } - self.Z = new_Z; + let old_Z = std::mem::replace(&mut self.Z, new_Z); + drop_in_background_thread(old_Z); self.num_vars -= 1; } @@ -611,4 +650,26 @@ mod tests { assert_eq!(b_poly.to_dense().Z, new_b); assert_eq!(c_poly.to_dense().Z, new_c); } + + #[test] + fn binding() { + use rand::Rng; + + let mut rng = rand::thread_rng(); + let prob_exists = 0.32; + let num_vars = 6; + let total_len = 1 << num_vars; + let mut a = vec![]; + for i in 0usize..total_len { + if rng.gen::() < prob_exists { + a.push((i, Fr::from(i as u64))); + } + } + + let mut a_poly = SparsePolynomial::new(num_vars, a); + + let r = Fr::from(100); + assert_eq!(a_poly.clone().bound_poly_var_bot(&r), a_poly.bound_poly_var_bot_par(&r)); + + } } \ No newline at end of file diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index abf6ea7eb..ab5d919c9 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -294,20 +294,24 @@ impl SumcheckInstanceProof { claim_per_round = poly.evaluate(&r_i); // bound all tables to the verifier's challenege - rayon::join( - || poly_eq.bound_poly_var_bot(&r_i), - || { - rayon::join( - || poly_A.bound_poly_var_bot(&r_i), - || { - rayon::join( - || poly_B.bound_poly_var_bot(&r_i), - || poly_C.bound_poly_var_bot(&r_i), - ) - }, - ) - }, - ); + poly_eq.bound_poly_var_bot(&r_i); + poly_A.bound_poly_var_bot_par(&r_i); + poly_B.bound_poly_var_bot_par(&r_i); + poly_C.bound_poly_var_bot_par(&r_i); + // rayon::join( + // || poly_eq.bound_poly_var_bot(&r_i), + // || { + // rayon::join( + // || poly_A.bound_poly_var_bot_par(&r_i), + // || { + // rayon::join( + // || poly_B.bound_poly_var_bot_par(&r_i), + // || poly_C.bound_poly_var_bot_par(&r_i), + // ) + // }, + // ) + // }, + // ); } ( diff --git a/jolt-core/src/utils/thread.rs b/jolt-core/src/utils/thread.rs index 196d9f8a5..eda50a466 100644 --- a/jolt-core/src/utils/thread.rs +++ b/jolt-core/src/utils/thread.rs @@ -44,6 +44,32 @@ pub fn unsafe_allocate_zero_vec(size: usize) -> Vec { result } +#[tracing::instrument(skip_all, name = "unsafe_allocate_sparse_zero_vec")] +pub fn unsafe_allocate_sparse_zero_vec(size: usize) -> Vec<(usize, F)> { + // Check for safety of 0 allocation + unsafe { + let value = &F::zero(); + let ptr = value as *const F as *const u8; + let bytes = std::slice::from_raw_parts(ptr, std::mem::size_of::()); + assert!(bytes.iter().all(|&byte| byte == 0)); + } + + // Bulk allocate zeros, unsafely + let result: Vec<(usize, F)>; + unsafe { + let layout = std::alloc::Layout::array::<(usize, F)>(size).unwrap(); + let ptr = std::alloc::alloc_zeroed(layout) as *mut (usize, F); + + if ptr.is_null() { + panic!("Zero vec allocation failed"); + } + + result = Vec::from_raw_parts(ptr, size, size); + } + result +} + + pub fn join_triple(oper_a: A, oper_b: B, oper_c: C) -> (RA, RB, RC) where A: FnOnce() -> RA + Send, From 3647c37cf7caf7105da0f4c18011c09fda0f953e Mon Sep 17 00:00:00 2001 From: sragss Date: Tue, 25 Jun 2024 17:21:51 -0700 Subject: [PATCH 03/17] switch sparsity ordering --- jolt-core/src/r1cs/builder.rs | 27 ++++- jolt-core/src/r1cs/key.rs | 2 +- jolt-core/src/r1cs/ops.rs | 5 + jolt-core/src/r1cs/special_polys.rs | 174 ++++++++++++---------------- jolt-core/src/utils/thread.rs | 8 +- 5 files changed, 109 insertions(+), 107 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index e16bd7f7f..1a350547e 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -878,15 +878,16 @@ impl CombinedUniformBuilder { // TODO(sragss): Attempt moving onto key and computing from materialized rows rather than linear combos let span = tracing::span!(tracing::Level::DEBUG, "compute_constraints"); let enter = span.enter(); - let az_chunks = Az.par_chunks_mut(self.uniform_repeat); - let bz_chunks = Bz.par_chunks_mut(self.uniform_repeat); - let cz_chunks = Cz.par_chunks_mut(self.uniform_repeat); + let az_chunks = Az.chunks_mut(self.uniform_repeat); + let bz_chunks = Bz.chunks_mut(self.uniform_repeat); + let cz_chunks = Cz.chunks_mut(self.uniform_repeat); self.uniform_builder .constraints - .par_iter() + .iter() + .enumerate() .zip(az_chunks.zip(bz_chunks.zip(cz_chunks))) - .for_each(|(constraint, (az_chunk, (bz_chunk, cz_chunk)))| { + .for_each(|((constraint_index, constraint), (az_chunk, (bz_chunk, cz_chunk)))| { let a_inputs = batch_inputs(&constraint.a); let b_inputs = batch_inputs(&constraint.b); let c_inputs = batch_inputs(&constraint.c); @@ -894,9 +895,25 @@ impl CombinedUniformBuilder { constraint.a.evaluate_batch_mut(&a_inputs, az_chunk); constraint.b.evaluate_batch_mut(&b_inputs, bz_chunk); constraint.c.evaluate_batch_mut(&c_inputs, cz_chunk); + + let az_zero = az_chunk.iter().filter(|item| item.is_zero()).count(); + let bz_zero = bz_chunk.iter().filter(|item| item.is_zero()).count(); + let cz_zero = cz_chunk.iter().filter(|item| item.is_zero()).count(); + + println!("[{constraint_index}] empty map az: {} bz: {} cz: {}", az_zero == az_chunk.len(), bz_zero == bz_chunk.len(), cz_zero == cz_chunk.len()); }); drop(enter); + let az_non_zero = Az.iter().filter(|item| !item.is_zero()).count(); + let bz_non_zero = Bz.iter().filter(|item| !item.is_zero()).count(); + let cz_non_zero = Cz.iter().filter(|item| !item.is_zero()).count(); + + println!("Uniform repeat: {}", self.uniform_repeat); + println!("Num constraints: {constraint_rows}"); + println!("Az sparsity: {az_non_zero}/{}", Az.len()); + println!("Bz sparsity: {bz_non_zero}/{}", Bz.len()); + println!("Cz sparsity: {cz_non_zero}/{}", Cz.len()); + // offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1] // (a - b) * condition == 0 // For the final step we will not compute the offset terms, and will assume the condition to be set to 0 diff --git a/jolt-core/src/r1cs/key.rs b/jolt-core/src/r1cs/key.rs index f361c03dd..f4dcfb166 100644 --- a/jolt-core/src/r1cs/key.rs +++ b/jolt-core/src/r1cs/key.rs @@ -269,7 +269,7 @@ impl UniformSpartanKey { let eval_variables: F = (0..self.uniform_r1cs.num_vars) .map(|var_index| r_var_eq[var_index] * segment_evals[var_index]) .sum(); - let const_poly = SparsePolynomial::new(self.num_vars_total().log_2(), vec![(0, F::one())]); + let const_poly = SparsePolynomial::new(self.num_vars_total().log_2(), vec![(F::one(), 0)]); let eval_const = const_poly.evaluate(r_rest); (F::one() - r_const) * eval_variables + r_const * eval_const diff --git a/jolt-core/src/r1cs/ops.rs b/jolt-core/src/r1cs/ops.rs index 48d6c06bf..6cb179490 100644 --- a/jolt-core/src/r1cs/ops.rs +++ b/jolt-core/src/r1cs/ops.rs @@ -129,6 +129,11 @@ impl LC { let terms: Vec = self.to_field_elements(); + if terms.len() == 0 { + println!("no terms"); + return; + } + output .par_iter_mut() .enumerate() diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index 76b1aab3f..dba878c9c 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -6,11 +6,11 @@ use rayon::prelude::*; pub struct SparsePolynomial { num_vars: usize, - Z: Vec<(usize, F)>, + Z: Vec<(F, usize)>, } impl SparsePolynomial { - pub fn new(num_vars: usize, Z: Vec<(usize, F)>) -> Self { + pub fn new(num_vars: usize, Z: Vec<(F, usize)>) -> Self { SparsePolynomial { num_vars, Z } } @@ -18,13 +18,23 @@ impl SparsePolynomial { #[tracing::instrument(skip_all)] pub fn from_dense_evals(num_vars: usize, evals: Vec) -> Self { assert!(num_vars.pow2() >= evals.len()); - let non_zero_count: usize = evals.par_iter().filter(|f| !f.is_zero()).count(); - let mut sparse: Vec<(usize, F)> = Vec::with_capacity(non_zero_count); - evals.into_iter().enumerate().for_each(|(dense_index, f)| { - if !f.is_zero() { - sparse.push((dense_index, f)); + let non_zero_count: usize = evals.par_chunks(10_000).map(|chunk| chunk.iter().filter(|f| !f.is_zero()).count()).sum(); + + let span_allocate = tracing::span!(tracing::Level::DEBUG, "allocate"); + let _enter_allocate = span_allocate.enter(); + let mut sparse: Vec<(F, usize)> = unsafe_allocate_sparse_zero_vec(non_zero_count); + drop(_enter_allocate); + + let span_copy = tracing::span!(tracing::Level::DEBUG, "copy"); + let _enter_copy = span_copy.enter(); + let mut sparse_index = 0; + for (dense_index, dense) in evals.iter().enumerate() { + if !dense.is_zero() { + sparse[sparse_index] = (*dense, dense_index); + sparse_index += 1; } - }); + } + drop(_enter_copy); Self::new(num_vars, sparse) } @@ -50,27 +60,27 @@ impl SparsePolynomial { (0..self.Z.len()) .into_par_iter() .map(|i| { - let bits = get_bits(self.Z[0].0, r.len()); - SparsePolynomial::compute_chi(&bits, r) * self.Z[i].1 + let bits = get_bits(self.Z[0].1, r.len()); + SparsePolynomial::compute_chi(&bits, r) * self.Z[i].0 }) .sum() } /// Returns n chunks of roughly even size without orphaning siblings (adjacent dense indices). Additionally returns a vector of (low, high] dense index ranges. #[tracing::instrument(skip_all)] - fn chunk_no_orphans(&self, n: usize) -> (Vec<&[(usize, F)]>, Vec<(usize, usize)>) { + fn chunk_no_orphans(&self, n: usize) -> (Vec<&[(F, usize)]>, Vec<(usize, usize)>) { if self.Z.len() < n * 2 { return (vec![(&self.Z)], vec![(0, self.num_vars.pow2())]); } let target_chunk_size = self.Z.len() / n; - let mut chunks: Vec<&[(usize, F)]> = Vec::with_capacity(n); + let mut chunks: Vec<&[(F, usize)]> = Vec::with_capacity(n); let mut dense_ranges: Vec<(usize, usize)> = Vec::with_capacity(n); let mut dense_start_index = 0; let mut sparse_start_index = 0; let mut sparse_end_index = target_chunk_size; for _ in 1..n { - let mut dense_end_index = self.Z[sparse_end_index].0; + let mut dense_end_index = self.Z[sparse_end_index].1; if dense_end_index % 2 != 0 { dense_end_index += 1; sparse_end_index += 1; @@ -83,9 +93,7 @@ impl SparsePolynomial { sparse_end_index = std::cmp::min(sparse_end_index + target_chunk_size, self.Z.len() - 1); } chunks.push(&self.Z[sparse_start_index..]); - // TODO(sragss): likely makes more sense to return full range then truncate when needed (triple iterator) - // TODO(sragss): To use chunk_no_orphans in the triple iterator, we have to overwrite the top of the dense_ranges. - let highest_non_zero = self.Z.last().map(|&(index, _)| index).unwrap(); + let highest_non_zero = self.Z.last().map(|&(_, index)| index).unwrap(); dense_ranges.push((dense_start_index, highest_non_zero + 1)); assert_eq!(chunks.len(), n); assert_eq!(dense_ranges.len(), n); @@ -100,25 +108,24 @@ impl SparsePolynomial { let n = self.Z.len(); let span = tracing::span!(tracing::Level::DEBUG, "allocate"); let _enter = span.enter(); - let mut new_Z: Vec<(usize, F)> = Vec::with_capacity(n); + let mut new_Z: Vec<(F, usize)> = Vec::with_capacity(n); drop(_enter); - for (sparse_index, (dense_index, value)) in self.Z.iter().enumerate() { + for (sparse_index, (value, dense_index)) in self.Z.iter().enumerate() { if dense_index.is_even() { let new_dense_index = dense_index / 2; - // TODO(sragss): Can likely combine these conditions for better speculative execution. - if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { - let upper = self.Z[sparse_index + 1].1; + if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].1 == dense_index + 1 { + let upper = self.Z[sparse_index + 1].0; let eval = *value + *r * (upper - value); - new_Z.push((new_dense_index, eval)); + new_Z.push((eval, new_dense_index)); } else { - new_Z.push((new_dense_index, (F::one() - r) * value)); + new_Z.push(((F::one() - r) * value, new_dense_index)); } } else { - if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { + if sparse_index > 0 && self.Z[sparse_index - 1].1 == dense_index - 1 { continue; } else { let new_dense_index = (dense_index - 1) / 2; - new_Z.push((new_dense_index, *r * value)); + new_Z.push((*r * value, new_dense_index)); } } } @@ -139,7 +146,7 @@ impl SparsePolynomial { chunk_size += 1; // If they're siblings, avoid double counting - if chunk[i].0.is_even() && i + 1 < chunk.len() && chunk[i].0 + 1 == chunk[i + 1].0 { + if chunk[i].1.is_even() && i + 1 < chunk.len() && chunk[i].1 + 1 == chunk[i + 1].1 { i += 1; } i += 1; @@ -151,10 +158,10 @@ impl SparsePolynomial { let alloc_span = tracing::span!(tracing::Level::DEBUG, "alloc_new_Z"); let alloc_enter = alloc_span.enter(); let total_len: usize = chunk_sizes.iter().sum(); - let mut new_Z: Vec<(usize, F)> = unsafe_allocate_sparse_zero_vec(total_len); + let mut new_Z: Vec<(F, usize)> = unsafe_allocate_sparse_zero_vec(total_len); drop(alloc_enter); - let mut mutable_chunks: Vec<&mut [(usize, F)]> = vec![]; + let mut mutable_chunks: Vec<&mut [(F, usize)]> = vec![]; let mut remainder = new_Z.as_mut_slice(); for chunk_size in chunk_sizes { let (first, second) = remainder.split_at_mut(chunk_size); @@ -163,61 +170,34 @@ impl SparsePolynomial { } assert_eq!(mutable_chunks.len(), chunks.len()); - // TODO(sragsss): We can scan up front and collect directly into the thing. chunks.into_par_iter().zip(mutable_chunks.par_iter_mut()).for_each(|(chunk, mutable)| { let span = tracing::span!(tracing::Level::DEBUG, "chunk"); let _enter = span.enter(); - // TODO(sragss): Do this with a scan instead; - // let mut chunk_Z: Vec<(usize, F)> = Vec::with_capacity(chunk.len()); let mut write_index = 0; - for (sparse_index, (dense_index, value)) in chunk.iter().enumerate() { + for (sparse_index, (value, dense_index)) in chunk.iter().enumerate() { if dense_index.is_even() { let new_dense_index = dense_index / 2; - // TODO(sragss): Can likely combine these conditions for better speculative execution. - - // All exist - if chunk.len() >= 2 && sparse_index <= chunk.len() - 2 && chunk[sparse_index + 1].0 == dense_index + 1 { - let upper = chunk[sparse_index + 1].1; + if chunk.len() >= 2 && sparse_index <= chunk.len() - 2 && chunk[sparse_index + 1].1 == dense_index + 1 { + let upper = chunk[sparse_index + 1].0; let eval = *value + mul_0_1_optimized(r, &(upper - value)); - mutable[write_index] = (new_dense_index, eval); + mutable[write_index] = (eval, new_dense_index); write_index += 1; - } else { // low exists - mutable[write_index] = (new_dense_index, mul_0_1_optimized(&(F::one() - r), value)); + } else { + mutable[write_index] = (mul_0_1_optimized(&(F::one() - r), value), new_dense_index); write_index += 1; } } else { - // low and high exist - if sparse_index > 0 && chunk[sparse_index - 1].0 == dense_index - 1 { + if sparse_index > 0 && chunk[sparse_index - 1].1 == dense_index - 1 { continue; - } else { // high only exists + } else { let new_dense_index = (dense_index - 1) / 2; - mutable[write_index] = (new_dense_index, mul_0_1_optimized(r, value)); + mutable[write_index] = (mul_0_1_optimized(r, value), new_dense_index); write_index += 1; } } } }); - // for (sparse_index, (dense_index, value)) in self.Z.iter().enumerate() { - // if dense_index.is_even() { - // let new_dense_index = dense_index / 2; - // // TODO(sragss): Can likely combine these conditions for better speculative execution. - // if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].0 == dense_index + 1 { - // let upper = self.Z[sparse_index + 1].1; - // let eval = *value + *r * (upper - value); - // new_Z.push((new_dense_index, eval)); - // } else { - // new_Z.push((new_dense_index, (F::one() - r) * value)); - // } - // } else { - // if sparse_index > 0 && self.Z[sparse_index - 1].0 == dense_index - 1 { - // continue; - // } else { - // let new_dense_index = (dense_index - 1) / 2; - // new_Z.push((new_dense_index, *r * value)); - // } - // } - // } let old_Z = std::mem::replace(&mut self.Z, new_Z); drop_in_background_thread(old_Z); self.num_vars -= 1; @@ -230,8 +210,8 @@ impl SparsePolynomial { } else { assert_eq!(self.Z.len(), 1); let item = self.Z[0]; - assert_eq!(item.0, 0); - item.1 + assert_eq!(item.1, 0); + item.0 } } @@ -242,7 +222,7 @@ impl SparsePolynomial { let mut evals = unsafe_allocate_zero_vec(self.num_vars.pow2()); - for (index, value) in self.Z { + for (value, index) in self.Z { evals[index] = value; } @@ -253,9 +233,9 @@ impl SparsePolynomial { pub struct SparseTripleIterator<'a, F: JoltField> { dense_index: usize, end_index: usize, - a: &'a [(usize, F)], - b: &'a [(usize, F)], - c: &'a [(usize, F)], + a: &'a [(F, usize)], + b: &'a [(F, usize)], + c: &'a [(F, usize)], } impl<'a, F: JoltField> SparseTripleIterator<'a, F> { @@ -278,13 +258,13 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { // TODO(sragss): Explain the strategy let target_chunk_size = b.Z.len() / n; - let mut b_chunks: Vec<&[(usize, F)]> = Vec::with_capacity(n); + let mut b_chunks: Vec<&[(F, usize)]> = Vec::with_capacity(n); let mut dense_ranges: Vec<(usize, usize)> = Vec::with_capacity(n); let mut dense_start_index = 0; let mut sparse_start_index = 0; let mut sparse_end_index = target_chunk_size; for _ in 1..n { - let mut dense_end_index = b.Z[sparse_end_index].0; + let mut dense_end_index = b.Z[sparse_end_index].1; if dense_end_index % 2 != 0 { dense_end_index += 1; sparse_end_index += 1; @@ -298,9 +278,9 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { } b_chunks.push(&b.Z[sparse_start_index..]); let highest_non_zero = { - let a_last = a.Z.last().map(|&(index, _)| index); - let b_last = b.Z.last().map(|&(index, _)| index); - let c_last = c.Z.last().map(|&(index, _)| index); + let a_last = a.Z.last().map(|&(_, index)| index); + let b_last = b.Z.last().map(|&(_, index)| index); + let c_last = c.Z.last().map(|&(_, index)| index); *a_last.iter().chain(b_last.iter()).chain(c_last.iter()).max().unwrap() }; dense_ranges.push((dense_start_index, highest_non_zero + 1)); @@ -308,8 +288,8 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { assert_eq!(dense_ranges.len(), n); // Create chunks which overlap with b's sparse indices - let mut a_chunks: Vec<&[(usize, F)]> = vec![&[]; n]; - let mut c_chunks: Vec<&[(usize, F)]> = vec![&[]; n]; + let mut a_chunks: Vec<&[(F, usize)]> = vec![&[]; n]; + let mut c_chunks: Vec<&[(F, usize)]> = vec![&[]; n]; let mut a_i = 0; let mut c_i = 0; let span = tracing::span!(tracing::Level::DEBUG, "a, c scanning"); @@ -318,18 +298,18 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { // Find the corresponding a, c chunks let prev_chunk_end = range.0; - if a_i < a.Z.len() && a.Z[a_i].0 < prev_chunk_end { + if a_i < a.Z.len() && a.Z[a_i].1 < prev_chunk_end { let a_start = a_i; - while a_i < a.Z.len() && a.Z[a_i].0 < prev_chunk_end { + while a_i < a.Z.len() && a.Z[a_i].1 < prev_chunk_end { a_i += 1; } a_chunks[chunk_index - 1] = &a.Z[a_start..a_i]; } - if c_i < c.Z.len() && c.Z[c_i].0 < prev_chunk_end { + if c_i < c.Z.len() && c.Z[c_i].1 < prev_chunk_end { let c_start = c_i; - while c_i < c.Z.len() && c.Z[c_i].0 < prev_chunk_end { + while c_i < c.Z.len() && c.Z[c_i].1 < prev_chunk_end { c_i += 1; } @@ -352,9 +332,9 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { for (((a_chunk, b_chunk), c_chunk), range) in a_chunks.iter().zip(b_chunks.iter()).zip(c_chunks.iter()).zip(dense_ranges.iter()) { #[cfg(test)] { - assert!(a_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); - assert!(b_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); - assert!(c_chunk.iter().all(|(index, _)| *index >= range.0 && *index <= range.1)); + assert!(a_chunk.iter().all(|(_, index)| *index >= range.0 && *index <= range.1)); + assert!(b_chunk.iter().all(|(_, index)| *index >= range.0 && *index <= range.1)); + assert!(c_chunk.iter().all(|(_, index)| *index >= range.0 && *index <= range.1)); } let iter = SparseTripleIterator { dense_index: range.0, @@ -376,10 +356,10 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { pub fn next_pairs(&mut self) -> (usize, F, F, F, F, F, F) { // TODO(sragss): We can store a map of big ranges of zeros and skip them rather than hitting each dense index. let low_index = self.dense_index; - let match_and_advance = |slice: &mut &[(usize, F)], index: usize| -> F { + let match_and_advance = |slice: &mut &[(F, usize)], index: usize| -> F { if let Some(first_item) = slice.first() { - if first_item.0 == index { - let ret = first_item.1; + if first_item.1 == index { + let ret = first_item.0; *slice = &slice[1..]; ret } else { @@ -533,7 +513,7 @@ mod tests { #[test] fn sparse_bound_bot_all_left() { let dense_evals = vec![Fr::from(10), Fr::zero(), Fr::from(20), Fr::zero()]; - let sparse_evals = vec![(0, Fr::from(10)), (2, Fr::from(20))]; + let sparse_evals = vec![(Fr::from(10), 0), (Fr::from(20), 2)]; let mut dense = DensePolynomial::new(dense_evals); let mut sparse = SparsePolynomial::new(2, sparse_evals); @@ -549,7 +529,7 @@ mod tests { #[test] fn sparse_bound_bot_all_right() { let dense_evals = vec![Fr::zero(), Fr::from(10), Fr::zero(), Fr::from(20)]; - let sparse_evals = vec![(1, Fr::from(10)), (3, Fr::from(20))]; + let sparse_evals = vec![(Fr::from(10), 1), (Fr::from(20), 3)]; let mut dense = DensePolynomial::new(dense_evals); let mut sparse = SparsePolynomial::new(2, sparse_evals); @@ -565,7 +545,7 @@ mod tests { #[test] fn sparse_bound_bot_mixed() { let dense_evals = vec![Fr::zero(), Fr::from(10), Fr::zero(), Fr::from(20), Fr::from(30), Fr::from(40), Fr::zero(), Fr::from(50)]; - let sparse_evals = vec![(1, Fr::from(10)), (3, Fr::from(20)), (4, Fr::from(30)), (5, Fr::from(40)), (7, Fr::from(50))]; + let sparse_evals = vec![(Fr::from(10), 1), (Fr::from(20), 3), (Fr::from(30), 4), (Fr::from(40), 5), (Fr::from(50), 7)]; let mut dense = DensePolynomial::new(dense_evals); let mut sparse = SparsePolynomial::new(3, sparse_evals); @@ -580,9 +560,9 @@ mod tests { #[test] fn sparse_triple_iterator() { - let a = vec![(9, Fr::from(9)), (10, Fr::from(10)), (12, Fr::from(12))]; - let b = vec![(0, Fr::from(100)), (1, Fr::from(1)), (2, Fr::from(2)), (3, Fr::from(3)), (4, Fr::from(4)), (5, Fr::from(5)), (6, Fr::from(6)), (7, Fr::from(7)), (8, Fr::from(8)), (9, Fr::from(9)), (10, Fr::from(10)), (11, Fr::from(11)), (12, Fr::from(12)), (13, Fr::from(13)), (14, Fr::from(14)), (15, Fr::from(15))]; - let c = vec![(0, Fr::from(12)), (3, Fr::from(3))]; + let a = vec![(Fr::from(9), 9), (Fr::from(10), 10), (Fr::from(12), 12)]; + let b = vec![(Fr::from(100), 0), (Fr::from(1), 1), (Fr::from(2), 2), (Fr::from(3), 3), (Fr::from(4), 4), (Fr::from(5), 5), (Fr::from(6), 6), (Fr::from(7), 7), (Fr::from(8), 8), (Fr::from(9), 9), (Fr::from(10), 10), (Fr::from(11), 11), (Fr::from(12), 12), (Fr::from(13), 13), (Fr::from(14), 14), (Fr::from(15), 15)]; + let c = vec![(Fr::from(12), 0), (Fr::from(3), 3)]; let a_poly = SparsePolynomial::new(4, a); let b_poly = SparsePolynomial::new(4, b); @@ -608,13 +588,13 @@ mod tests { for i in 0usize..total_len { if rng.gen::() < prob_exists { - a.push((i, Fr::from(i as u64))); + a.push((Fr::from(i as u64), i)); } if rng.gen::() < prob_exists * 2f64 { - b.push((i, Fr::from(i as u64))); + b.push((Fr::from(i as u64), i)); } if rng.gen::() < prob_exists { - c.push((i, Fr::from(i as u64))); + c.push((Fr::from(i as u64), i)); } } @@ -662,7 +642,7 @@ mod tests { let mut a = vec![]; for i in 0usize..total_len { if rng.gen::() < prob_exists { - a.push((i, Fr::from(i as u64))); + a.push((Fr::from(i as u64), i)); } } diff --git a/jolt-core/src/utils/thread.rs b/jolt-core/src/utils/thread.rs index eda50a466..7deaa28e6 100644 --- a/jolt-core/src/utils/thread.rs +++ b/jolt-core/src/utils/thread.rs @@ -45,7 +45,7 @@ pub fn unsafe_allocate_zero_vec(size: usize) -> Vec { } #[tracing::instrument(skip_all, name = "unsafe_allocate_sparse_zero_vec")] -pub fn unsafe_allocate_sparse_zero_vec(size: usize) -> Vec<(usize, F)> { +pub fn unsafe_allocate_sparse_zero_vec(size: usize) -> Vec<(F, usize)> { // Check for safety of 0 allocation unsafe { let value = &F::zero(); @@ -55,10 +55,10 @@ pub fn unsafe_allocate_sparse_zero_vec(size: usize) -> Vec } // Bulk allocate zeros, unsafely - let result: Vec<(usize, F)>; + let result: Vec<(F, usize)>; unsafe { - let layout = std::alloc::Layout::array::<(usize, F)>(size).unwrap(); - let ptr = std::alloc::alloc_zeroed(layout) as *mut (usize, F); + let layout = std::alloc::Layout::array::<(F, usize)>(size).unwrap(); + let ptr = std::alloc::alloc_zeroed(layout) as *mut (F, usize); if ptr.is_null() { panic!("Zero vec allocation failed"); From 7817ca785193e9a2fa4f1d95087f8b1f22248494 Mon Sep 17 00:00:00 2001 From: sragss Date: Tue, 25 Jun 2024 21:24:44 -0700 Subject: [PATCH 04/17] e2e sparsity working --- jolt-core/src/r1cs/builder.rs | 154 ++++++++++++++++------------ jolt-core/src/r1cs/ops.rs | 5 - jolt-core/src/r1cs/special_polys.rs | 2 +- jolt-core/src/utils/thread.rs | 48 ++++++++- 4 files changed, 137 insertions(+), 72 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 1a350547e..340f2edde 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -2,7 +2,7 @@ use crate::{ field::{JoltField, OptimizedMul}, r1cs::key::{SparseConstraints, UniformR1CS}, utils::{ - math::Math, mul_0_1_optimized, thread::{drop_in_background_thread, unsafe_allocate_zero_vec} + math::Math, mul_0_1_optimized, thread::{drop_in_background_thread, par_flatten_triple, unsafe_allocate_sparse_zero_vec, unsafe_allocate_zero_vec} }, }; #[allow(unused_imports)] // clippy thinks these aren't needed lol @@ -21,12 +21,19 @@ pub trait R1CSConstraintBuilder { fn build_constraints(&self, builder: &mut R1CSBuilder); } +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum EvaluationHint { + Zero = 0, + Other = 2, +} + /// Constraints over a single row. Each variable points to a single item in Z and the corresponding coefficient. #[derive(Clone, Debug)] struct Constraint { a: LC, b: LC, c: LC, + evaluation_hint: (EvaluationHint, EvaluationHint, EvaluationHint) } impl Constraint { @@ -224,6 +231,7 @@ impl R1CSBuilder { a, b, c: LC::zero(), + evaluation_hint: (EvaluationHint::Zero, EvaluationHint::Other, EvaluationHint::Zero) }; self.constraints.push(constraint); } @@ -242,7 +250,7 @@ impl R1CSBuilder { let a = condition; let b = left - right; let c = LC::zero(); - let constraint = Constraint { a, b, c }; + let constraint = Constraint { a, b, c, evaluation_hint: (EvaluationHint::Other, EvaluationHint::Other, EvaluationHint::Zero) }; // TODO(sragss): Can do better on middle term. self.constraints.push(constraint); } @@ -250,11 +258,12 @@ impl R1CSBuilder { let one: LC = Variable::Constant.into(); let a: LC = value.into(); let b = one - a.clone(); - // value * (1 - value) + // value * (1 - value) == 0 let constraint = Constraint { a, b, c: LC::zero(), + evaluation_hint: (EvaluationHint::Other, EvaluationHint::Other, EvaluationHint::Zero) }; self.constraints.push(constraint); } @@ -278,6 +287,7 @@ impl R1CSBuilder { a: condition.clone(), b: (result_true - result_false.clone()), c: (alleged_result - result_false), + evaluation_hint: (EvaluationHint::Other, EvaluationHint::Other, EvaluationHint::Other) // TODO(sragss): Is this the best we can do? }; self.constraints.push(constraint); } @@ -423,6 +433,7 @@ impl R1CSBuilder { a: x.into(), b: y.into(), c: z.into(), + evaluation_hint: (EvaluationHint::Other, EvaluationHint::Other, EvaluationHint::Other) }; self.constraints.push(constraint); } @@ -849,7 +860,7 @@ impl CombinedUniformBuilder { /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] /// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]] - #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan")] + #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan_sparse")] pub fn compute_spartan_Az_Bz_Cz_sparse( &self, inputs: &[Vec], @@ -864,55 +875,69 @@ impl CombinedUniformBuilder { .all(|inner_input| inner_input.len() == self.uniform_repeat)); let uniform_constraint_rows = self.uniform_repeat_constraint_rows(); - // TODO(sragss): Allocation can overshoot by up to a factor of 2, Spartan could handle non-pow-2 Az,Bz,Cz - let constraint_rows = self.constraint_rows(); - let (mut Az, mut Bz, mut Cz) = ( - unsafe_allocate_zero_vec(constraint_rows), - unsafe_allocate_zero_vec(constraint_rows), - unsafe_allocate_zero_vec(constraint_rows), - ); let batch_inputs = |lc: &LC| batch_inputs(lc, inputs, aux); - // uniform_constraints: Xz[0..uniform_constraint_rows] - // TODO(sragss): Attempt moving onto key and computing from materialized rows rather than linear combos - let span = tracing::span!(tracing::Level::DEBUG, "compute_constraints"); - let enter = span.enter(); - let az_chunks = Az.chunks_mut(self.uniform_repeat); - let bz_chunks = Bz.chunks_mut(self.uniform_repeat); - let cz_chunks = Cz.chunks_mut(self.uniform_repeat); - - self.uniform_builder - .constraints - .iter() - .enumerate() - .zip(az_chunks.zip(bz_chunks.zip(cz_chunks))) - .for_each(|((constraint_index, constraint), (az_chunk, (bz_chunk, cz_chunk)))| { + // Enforce correctness of hints. + // TODO(sragss): Can be moved into assert_valid. + #[cfg(test)] + self.uniform_builder.constraints.iter().enumerate().for_each(|(constraint_index, constraint)| { + let assert_hint = |constraint: &Constraint| { let a_inputs = batch_inputs(&constraint.a); let b_inputs = batch_inputs(&constraint.b); let c_inputs = batch_inputs(&constraint.c); - constraint.a.evaluate_batch_mut(&a_inputs, az_chunk); - constraint.b.evaluate_batch_mut(&b_inputs, bz_chunk); - constraint.c.evaluate_batch_mut(&c_inputs, cz_chunk); + let a = constraint.a.evaluate_batch(&a_inputs, self.uniform_repeat); + let b = constraint.b.evaluate_batch(&b_inputs, self.uniform_repeat); + let c = constraint.c.evaluate_batch(&c_inputs, self.uniform_repeat); - let az_zero = az_chunk.iter().filter(|item| item.is_zero()).count(); - let bz_zero = bz_chunk.iter().filter(|item| item.is_zero()).count(); - let cz_zero = cz_chunk.iter().filter(|item| item.is_zero()).count(); + if constraint.evaluation_hint.0 == EvaluationHint::Zero { + a.iter().for_each(|item| assert_eq!(*item, F::zero(), "Wrong hint: {constraint_index} {constraint:?}")); + } + if constraint.evaluation_hint.1 == EvaluationHint::Zero { + b.iter().for_each(|item| assert_eq!(*item, F::zero(), "Wrong hint: {constraint_index} {constraint:?}")); + } + if constraint.evaluation_hint.2 == EvaluationHint::Zero { + c.iter().for_each(|item| assert_eq!(*item, F::zero(), "Wrong hint: {constraint_index} {constraint:?}")); + } + }; - println!("[{constraint_index}] empty map az: {} bz: {} cz: {}", az_zero == az_chunk.len(), bz_zero == bz_chunk.len(), cz_zero == cz_chunk.len()); - }); - drop(enter); + assert_hint(constraint); + }); + + // uniform_constraints: Xz[0..uniform_constraint_rows] + let span = tracing::span!(tracing::Level::DEBUG, "uniform_evals"); + let _enter = span.enter(); + let uni_constraint_evals: Vec<(Vec<(F, usize)>, Vec<(F, usize)>, Vec<(F, usize)>)> = self.uniform_builder.constraints.par_iter().enumerate().map(|(constraint_index, constraint)| { + let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat); + + let mut evaluate_lc_chunk = |hint, lc: &LC| { + if hint != EvaluationHint::Zero { + let inputs = batch_inputs(lc); + lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer); + + // Take only the non-zero elements and represent them as sparse tuples (eval, dense_index) + let mut sparse = Vec::with_capacity(self.uniform_repeat); // overshoot + dense_output_buffer.iter().enumerate().for_each(|(local_index, item)| { + if !item.is_zero() { + let global_index = constraint_index * self.uniform_repeat + local_index; + sparse.push((*item, global_index)); + } + }); + sparse + } else { + vec![] + } + }; - let az_non_zero = Az.iter().filter(|item| !item.is_zero()).count(); - let bz_non_zero = Bz.iter().filter(|item| !item.is_zero()).count(); - let cz_non_zero = Cz.iter().filter(|item| !item.is_zero()).count(); + let a_chunk: Vec<(F, usize)> = evaluate_lc_chunk(constraint.evaluation_hint.0, &constraint.a); + let b_chunk: Vec<(F, usize)> = evaluate_lc_chunk(constraint.evaluation_hint.1, &constraint.b); + let c_chunk: Vec<(F, usize)> = evaluate_lc_chunk(constraint.evaluation_hint.2, &constraint.c); - println!("Uniform repeat: {}", self.uniform_repeat); - println!("Num constraints: {constraint_rows}"); - println!("Az sparsity: {az_non_zero}/{}", Az.len()); - println!("Bz sparsity: {bz_non_zero}/{}", Bz.len()); - println!("Cz sparsity: {cz_non_zero}/{}", Cz.len()); + (a_chunk, b_chunk, c_chunk) + }).collect(); + + let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple(uni_constraint_evals, unsafe_allocate_sparse_zero_vec, self.uniform_repeat); // offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1] // (a - b) * condition == 0 @@ -934,15 +959,9 @@ impl CombinedUniformBuilder { .1 .evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat); - let Az_off = Az[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat] - .par_iter_mut(); - let Bz_off = Bz[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat] - .par_iter_mut(); - - (0..self.uniform_repeat) + let dense_az_bz: Vec<(F, F)> = (0..self.uniform_repeat) .into_par_iter() - .zip(Az_off.zip(Bz_off)) - .for_each(|(step_index, (az, bz))| { + .map(|step_index| { // Write corresponding values, if outside the step range, only include the constant. let a_step = step_index + if constr.a.0 { 1 } else { 0 }; let b_step = step_index + if constr.b.0 { 1 } else { 0 }; @@ -954,27 +973,34 @@ impl CombinedUniformBuilder { .get(b_step) .cloned() .unwrap_or(constr.b.1.constant_term_field()); - *az = a - b; + let az = a - b; let condition_step = step_index + if constr.cond.0 { 1 } else { 0 }; - *bz = condition_evals + let bz = condition_evals .get(condition_step) .cloned() .unwrap_or(constr.cond.1.constant_term_field()); - }); - drop(_enter); - - #[cfg(test)] - self.assert_valid(&Az, &Bz, &Cz); + (az, bz) + }).collect(); + + // Sparsify: take only the non-zero elements + for (local_index, (az, bz)) in dense_az_bz.iter().enumerate() { + let global_index = uniform_constraint_rows + local_index; + if !az.is_zero() { + az_sparse.push((*az, global_index)); + } + if !bz.is_zero() { + bz_sparse.push((*bz, global_index)); + } + } let num_vars = self.constraint_rows().next_power_of_two().log_2(); - let (az_poly, (bz_poly, cz_poly)) = rayon::join( - || SparsePolynomial::from_dense_evals(num_vars, Az), - || rayon::join( - || SparsePolynomial::from_dense_evals(num_vars, Bz), - || SparsePolynomial::from_dense_evals(num_vars, Cz) - ) - ); + let az_poly = SparsePolynomial::new(num_vars, az_sparse); + let bz_poly = SparsePolynomial::new(num_vars, bz_sparse); + let cz_poly = SparsePolynomial::new(num_vars, cz_sparse); + + #[cfg(test)] + self.assert_valid(&az_poly.clone().to_dense().evals_ref(), &bz_poly.clone().to_dense().evals_ref(), &cz_poly.clone().to_dense().evals_ref()); (az_poly, bz_poly, cz_poly) } diff --git a/jolt-core/src/r1cs/ops.rs b/jolt-core/src/r1cs/ops.rs index 6cb179490..48d6c06bf 100644 --- a/jolt-core/src/r1cs/ops.rs +++ b/jolt-core/src/r1cs/ops.rs @@ -129,11 +129,6 @@ impl LC { let terms: Vec = self.to_field_elements(); - if terms.len() == 0 { - println!("no terms"); - return; - } - output .par_iter_mut() .enumerate() diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index dba878c9c..cf53420fc 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -2,7 +2,7 @@ use crate::{field::JoltField, poly::{dense_mlpoly::DensePolynomial, eq_poly::EqP use num_integer::Integer; use rayon::prelude::*; -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq)] pub struct SparsePolynomial { num_vars: usize, diff --git a/jolt-core/src/utils/thread.rs b/jolt-core/src/utils/thread.rs index 7deaa28e6..4678022a7 100644 --- a/jolt-core/src/utils/thread.rs +++ b/jolt-core/src/utils/thread.rs @@ -1,4 +1,5 @@ use std::thread::{self, JoinHandle}; +use rayon::prelude::*; use crate::field::JoltField; @@ -17,7 +18,7 @@ pub fn allocate_vec_in_background( thread::spawn(move || vec![value; size]) } -#[tracing::instrument(skip_all, name = "unsafe_allocate_zero_vec")] +#[tracing::instrument(skip_all)] pub fn unsafe_allocate_zero_vec(size: usize) -> Vec { // https://stackoverflow.com/questions/59314686/how-to-efficiently-create-a-large-vector-of-items-initialized-to-the-same-value @@ -44,7 +45,7 @@ pub fn unsafe_allocate_zero_vec(size: usize) -> Vec { result } -#[tracing::instrument(skip_all, name = "unsafe_allocate_sparse_zero_vec")] +#[tracing::instrument(skip_all)] pub fn unsafe_allocate_sparse_zero_vec(size: usize) -> Vec<(F, usize)> { // Check for safety of 0 allocation unsafe { @@ -69,6 +70,49 @@ pub fn unsafe_allocate_sparse_zero_vec(size: usize) -> Vec result } +#[tracing::instrument(skip_all)] +pub fn par_flatten_triple Vec>( + triple: Vec<(Vec, Vec, Vec)>, + allocate: F, + excess_alloc: usize) -> (Vec, Vec, Vec) { + let az_len: usize = triple.iter().map(|item| item.0.len()).sum(); + let bz_len: usize = triple.iter().map(|item| item.1.len()).sum(); + let cz_len: usize = triple.iter().map(|item| item.2.len()).sum(); + + let (mut a_sparse, mut b_sparse, mut c_sparse): (Vec, Vec, Vec) = ( + allocate(az_len), + allocate(bz_len), + allocate(cz_len), + ); + + let mut a_slices = Vec::with_capacity(triple.len() + excess_alloc); + let mut b_slices = Vec::with_capacity(triple.len() + excess_alloc); + let mut c_slices = Vec::with_capacity(triple.len() + excess_alloc); + + let mut a_rest: &mut [T] = a_sparse.as_mut_slice(); + let mut b_rest: &mut [T] = b_sparse.as_mut_slice(); + let mut c_rest: &mut [T] = c_sparse.as_mut_slice(); + + for item in &triple { + let (a_chunk, a_new_rest) = a_rest.split_at_mut(item.0.len()); + a_slices.push(a_chunk); + a_rest = a_new_rest; + + let (b_chunk, b_new_rest) = b_rest.split_at_mut(item.1.len()); + b_slices.push(b_chunk); + b_rest = b_new_rest; + + let (c_chunk, c_new_rest) = c_rest.split_at_mut(item.2.len()); + c_slices.push(c_chunk); + c_rest = c_new_rest; + } + + triple.into_par_iter().zip(a_slices.par_iter_mut().zip(b_slices.par_iter_mut().zip(c_slices.par_iter_mut()))).for_each(|(chunk, (a, (b, c)))| { + join_triple(|| a.copy_from_slice(&chunk.0), || b.copy_from_slice(&chunk.1), || c.copy_from_slice(&chunk.2)); + }); + + (a_sparse, b_sparse, c_sparse) +} pub fn join_triple(oper_a: A, oper_b: B, oper_c: C) -> (RA, RB, RC) where From 968b050d45a400d0584d7a70dc2b32070c665153 Mon Sep 17 00:00:00 2001 From: sragss Date: Tue, 25 Jun 2024 21:25:42 -0700 Subject: [PATCH 05/17] cargo fmt --- jolt-core/src/poly/dense_mlpoly.rs | 10 +- jolt-core/src/r1cs/builder.rs | 180 ++++++++++++++-------- jolt-core/src/r1cs/spartan.rs | 3 +- jolt-core/src/r1cs/special_polys.rs | 228 ++++++++++++++++++++-------- jolt-core/src/utils/thread.rs | 35 +++-- 5 files changed, 310 insertions(+), 146 deletions(-) diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 928b63915..6bfd3aab2 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -206,13 +206,13 @@ impl DensePolynomial { let n = self.len() / 2; let mut new_z = unsafe_allocate_zero_vec(n); new_z.par_iter_mut().enumerate().for_each(|(i, z)| { - let m = self.Z[2*i + 1] - self.Z[2*i]; + let m = self.Z[2 * i + 1] - self.Z[2 * i]; *z = if m.is_zero() { - self.Z[2*i] + self.Z[2 * i] } else if m.is_one() { - self.Z[2*i] + r - }else { - self.Z[2*i] + *r * m + self.Z[2 * i] + r + } else { + self.Z[2 * i] + *r * m } }); diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 340f2edde..0a5bc5859 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -2,7 +2,12 @@ use crate::{ field::{JoltField, OptimizedMul}, r1cs::key::{SparseConstraints, UniformR1CS}, utils::{ - math::Math, mul_0_1_optimized, thread::{drop_in_background_thread, par_flatten_triple, unsafe_allocate_sparse_zero_vec, unsafe_allocate_zero_vec} + math::Math, + mul_0_1_optimized, + thread::{ + drop_in_background_thread, par_flatten_triple, unsafe_allocate_sparse_zero_vec, + unsafe_allocate_zero_vec, + }, }, }; #[allow(unused_imports)] // clippy thinks these aren't needed lol @@ -12,7 +17,8 @@ use std::{collections::HashMap, fmt::Debug}; use super::{ key::{NonUniformR1CS, SparseEqualityItem}, - ops::{ConstraintInput, Term, Variable, LC}, special_polys::SparsePolynomial, + ops::{ConstraintInput, Term, Variable, LC}, + special_polys::SparsePolynomial, }; pub trait R1CSConstraintBuilder { @@ -33,7 +39,7 @@ struct Constraint { a: LC, b: LC, c: LC, - evaluation_hint: (EvaluationHint, EvaluationHint, EvaluationHint) + evaluation_hint: (EvaluationHint, EvaluationHint, EvaluationHint), } impl Constraint { @@ -231,7 +237,11 @@ impl R1CSBuilder { a, b, c: LC::zero(), - evaluation_hint: (EvaluationHint::Zero, EvaluationHint::Other, EvaluationHint::Zero) + evaluation_hint: ( + EvaluationHint::Zero, + EvaluationHint::Other, + EvaluationHint::Zero, + ), }; self.constraints.push(constraint); } @@ -250,7 +260,16 @@ impl R1CSBuilder { let a = condition; let b = left - right; let c = LC::zero(); - let constraint = Constraint { a, b, c, evaluation_hint: (EvaluationHint::Other, EvaluationHint::Other, EvaluationHint::Zero) }; // TODO(sragss): Can do better on middle term. + let constraint = Constraint { + a, + b, + c, + evaluation_hint: ( + EvaluationHint::Other, + EvaluationHint::Other, + EvaluationHint::Zero, + ), + }; // TODO(sragss): Can do better on middle term. self.constraints.push(constraint); } @@ -263,7 +282,11 @@ impl R1CSBuilder { a, b, c: LC::zero(), - evaluation_hint: (EvaluationHint::Other, EvaluationHint::Other, EvaluationHint::Zero) + evaluation_hint: ( + EvaluationHint::Other, + EvaluationHint::Other, + EvaluationHint::Zero, + ), }; self.constraints.push(constraint); } @@ -287,7 +310,11 @@ impl R1CSBuilder { a: condition.clone(), b: (result_true - result_false.clone()), c: (alleged_result - result_false), - evaluation_hint: (EvaluationHint::Other, EvaluationHint::Other, EvaluationHint::Other) // TODO(sragss): Is this the best we can do? + evaluation_hint: ( + EvaluationHint::Other, + EvaluationHint::Other, + EvaluationHint::Other, + ), // TODO(sragss): Is this the best we can do? }; self.constraints.push(constraint); } @@ -433,7 +460,11 @@ impl R1CSBuilder { a: x.into(), b: y.into(), c: z.into(), - evaluation_hint: (EvaluationHint::Other, EvaluationHint::Other, EvaluationHint::Other) + evaluation_hint: ( + EvaluationHint::Other, + EvaluationHint::Other, + EvaluationHint::Other, + ), }; self.constraints.push(constraint); } @@ -858,14 +889,18 @@ impl CombinedUniformBuilder { (Az, Bz, Cz) } - /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] + /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] /// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]] #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan_sparse")] pub fn compute_spartan_Az_Bz_Cz_sparse( &self, inputs: &[Vec], aux: &[Vec], - ) -> (SparsePolynomial, SparsePolynomial, SparsePolynomial) { + ) -> ( + SparsePolynomial, + SparsePolynomial, + SparsePolynomial, + ) { assert_eq!(inputs.len(), I::COUNT); let num_aux = self.uniform_builder.num_aux(); assert_eq!(aux.len(), num_aux); @@ -881,63 +916,82 @@ impl CombinedUniformBuilder { // Enforce correctness of hints. // TODO(sragss): Can be moved into assert_valid. #[cfg(test)] - self.uniform_builder.constraints.iter().enumerate().for_each(|(constraint_index, constraint)| { - let assert_hint = |constraint: &Constraint| { - let a_inputs = batch_inputs(&constraint.a); - let b_inputs = batch_inputs(&constraint.b); - let c_inputs = batch_inputs(&constraint.c); - - let a = constraint.a.evaluate_batch(&a_inputs, self.uniform_repeat); - let b = constraint.b.evaluate_batch(&b_inputs, self.uniform_repeat); - let c = constraint.c.evaluate_batch(&c_inputs, self.uniform_repeat); - - if constraint.evaluation_hint.0 == EvaluationHint::Zero { - a.iter().for_each(|item| assert_eq!(*item, F::zero(), "Wrong hint: {constraint_index} {constraint:?}")); - } - if constraint.evaluation_hint.1 == EvaluationHint::Zero { - b.iter().for_each(|item| assert_eq!(*item, F::zero(), "Wrong hint: {constraint_index} {constraint:?}")); - } - if constraint.evaluation_hint.2 == EvaluationHint::Zero { - c.iter().for_each(|item| assert_eq!(*item, F::zero(), "Wrong hint: {constraint_index} {constraint:?}")); - } - }; + self.uniform_builder + .constraints + .iter() + .for_each(|constraint| { + let assert_hint = |constraint: &Constraint| { + let a_inputs = batch_inputs(&constraint.a); + let b_inputs = batch_inputs(&constraint.b); + let c_inputs = batch_inputs(&constraint.c); + + let a = constraint.a.evaluate_batch(&a_inputs, self.uniform_repeat); + let b = constraint.b.evaluate_batch(&b_inputs, self.uniform_repeat); + let c = constraint.c.evaluate_batch(&c_inputs, self.uniform_repeat); + + if constraint.evaluation_hint.0 == EvaluationHint::Zero { + a.iter().for_each(|item| assert_eq!(*item, F::zero(),)); + } + if constraint.evaluation_hint.1 == EvaluationHint::Zero { + b.iter().for_each(|item| assert_eq!(*item, F::zero(),)); + } + if constraint.evaluation_hint.2 == EvaluationHint::Zero { + c.iter().for_each(|item| assert_eq!(*item, F::zero(),)); + } + }; - assert_hint(constraint); - }); + assert_hint(constraint); + }); // uniform_constraints: Xz[0..uniform_constraint_rows] let span = tracing::span!(tracing::Level::DEBUG, "uniform_evals"); let _enter = span.enter(); - let uni_constraint_evals: Vec<(Vec<(F, usize)>, Vec<(F, usize)>, Vec<(F, usize)>)> = self.uniform_builder.constraints.par_iter().enumerate().map(|(constraint_index, constraint)| { - let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat); - - let mut evaluate_lc_chunk = |hint, lc: &LC| { - if hint != EvaluationHint::Zero { - let inputs = batch_inputs(lc); - lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer); - - // Take only the non-zero elements and represent them as sparse tuples (eval, dense_index) - let mut sparse = Vec::with_capacity(self.uniform_repeat); // overshoot - dense_output_buffer.iter().enumerate().for_each(|(local_index, item)| { - if !item.is_zero() { - let global_index = constraint_index * self.uniform_repeat + local_index; - sparse.push((*item, global_index)); + let uni_constraint_evals: Vec<(Vec<(F, usize)>, Vec<(F, usize)>, Vec<(F, usize)>)> = + self.uniform_builder + .constraints + .par_iter() + .enumerate() + .map(|(constraint_index, constraint)| { + let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat); + + let mut evaluate_lc_chunk = |hint, lc: &LC| { + if hint != EvaluationHint::Zero { + let inputs = batch_inputs(lc); + lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer); + + // Take only the non-zero elements and represent them as sparse tuples (eval, dense_index) + let mut sparse = Vec::with_capacity(self.uniform_repeat); // overshoot + dense_output_buffer.iter().enumerate().for_each( + |(local_index, item)| { + if !item.is_zero() { + let global_index = + constraint_index * self.uniform_repeat + local_index; + sparse.push((*item, global_index)); + } + }, + ); + sparse + } else { + vec![] } - }); - sparse - } else { - vec![] - } - }; + }; - let a_chunk: Vec<(F, usize)> = evaluate_lc_chunk(constraint.evaluation_hint.0, &constraint.a); - let b_chunk: Vec<(F, usize)> = evaluate_lc_chunk(constraint.evaluation_hint.1, &constraint.b); - let c_chunk: Vec<(F, usize)> = evaluate_lc_chunk(constraint.evaluation_hint.2, &constraint.c); + let a_chunk: Vec<(F, usize)> = + evaluate_lc_chunk(constraint.evaluation_hint.0, &constraint.a); + let b_chunk: Vec<(F, usize)> = + evaluate_lc_chunk(constraint.evaluation_hint.1, &constraint.b); + let c_chunk: Vec<(F, usize)> = + evaluate_lc_chunk(constraint.evaluation_hint.2, &constraint.c); - (a_chunk, b_chunk, c_chunk) - }).collect(); + (a_chunk, b_chunk, c_chunk) + }) + .collect(); - let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple(uni_constraint_evals, unsafe_allocate_sparse_zero_vec, self.uniform_repeat); + let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple( + uni_constraint_evals, + unsafe_allocate_sparse_zero_vec, + self.uniform_repeat, + ); // offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1] // (a - b) * condition == 0 @@ -981,7 +1035,8 @@ impl CombinedUniformBuilder { .cloned() .unwrap_or(constr.cond.1.constant_term_field()); (az, bz) - }).collect(); + }) + .collect(); // Sparsify: take only the non-zero elements for (local_index, (az, bz)) in dense_az_bz.iter().enumerate() { @@ -1000,12 +1055,15 @@ impl CombinedUniformBuilder { let cz_poly = SparsePolynomial::new(num_vars, cz_sparse); #[cfg(test)] - self.assert_valid(&az_poly.clone().to_dense().evals_ref(), &bz_poly.clone().to_dense().evals_ref(), &cz_poly.clone().to_dense().evals_ref()); + self.assert_valid( + &az_poly.clone().to_dense().evals_ref(), + &bz_poly.clone().to_dense().evals_ref(), + &cz_poly.clone().to_dense().evals_ref(), + ); (az_poly, bz_poly, cz_poly) } - #[cfg(test)] pub fn assert_valid(&self, az: &[F], bz: &[F], cz: &[F]) { let rows = az.len(); diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 256ab9b14..b436769e3 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -115,7 +115,8 @@ impl> UniformSpartanProof { let inputs = &segmented_padded_witness.segments[0..I::COUNT]; let aux = &segmented_padded_witness.segments[I::COUNT..]; profiling::print_current_memory_usage("pre_az_bz_cz"); - let (mut az, mut bz, mut cz) = constraint_builder.compute_spartan_Az_Bz_Cz_sparse(inputs, aux); + let (mut az, mut bz, mut cz) = + constraint_builder.compute_spartan_Az_Bz_Cz_sparse(inputs, aux); profiling::print_current_memory_usage("post_az_bz_cz"); let comb_func_outer = |A: &F, B: &F, C: &F, D: &F| -> F { diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index cf53420fc..d757a24c7 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -1,11 +1,22 @@ -use crate::{field::JoltField, poly::{dense_mlpoly::DensePolynomial, eq_poly::EqPolynomial}, utils::{compute_dotproduct_low_optimized, math::Math, mul_0_1_optimized, thread::{drop_in_background_thread, unsafe_allocate_sparse_zero_vec, unsafe_allocate_zero_vec}}}; +use crate::{ + field::JoltField, + poly::{dense_mlpoly::DensePolynomial, eq_poly::EqPolynomial}, + utils::{ + compute_dotproduct_low_optimized, + math::Math, + mul_0_1_optimized, + thread::{ + drop_in_background_thread, unsafe_allocate_sparse_zero_vec, unsafe_allocate_zero_vec, + }, + }, +}; use num_integer::Integer; use rayon::prelude::*; #[derive(Clone, Debug, PartialEq)] pub struct SparsePolynomial { num_vars: usize, - + Z: Vec<(F, usize)>, } @@ -18,7 +29,10 @@ impl SparsePolynomial { #[tracing::instrument(skip_all)] pub fn from_dense_evals(num_vars: usize, evals: Vec) -> Self { assert!(num_vars.pow2() >= evals.len()); - let non_zero_count: usize = evals.par_chunks(10_000).map(|chunk| chunk.iter().filter(|f| !f.is_zero()).count()).sum(); + let non_zero_count: usize = evals + .par_chunks(10_000) + .map(|chunk| chunk.iter().filter(|f| !f.is_zero()).count()) + .sum(); let span_allocate = tracing::span!(tracing::Level::DEBUG, "allocate"); let _enter_allocate = span_allocate.enter(); @@ -90,7 +104,8 @@ impl SparsePolynomial { dense_start_index = dense_end_index; sparse_start_index = sparse_end_index; - sparse_end_index = std::cmp::min(sparse_end_index + target_chunk_size, self.Z.len() - 1); + sparse_end_index = + std::cmp::min(sparse_end_index + target_chunk_size, self.Z.len() - 1); } chunks.push(&self.Z[sparse_start_index..]); let highest_non_zero = self.Z.last().map(|&(_, index)| index).unwrap(); @@ -98,7 +113,6 @@ impl SparsePolynomial { assert_eq!(chunks.len(), n); assert_eq!(dense_ranges.len(), n); - (chunks, dense_ranges) } @@ -113,9 +127,12 @@ impl SparsePolynomial { for (sparse_index, (value, dense_index)) in self.Z.iter().enumerate() { if dense_index.is_even() { let new_dense_index = dense_index / 2; - if self.Z.len() >= 2 && sparse_index <= self.Z.len() - 2 && self.Z[sparse_index + 1].1 == dense_index + 1 { + if self.Z.len() >= 2 + && sparse_index <= self.Z.len() - 2 + && self.Z[sparse_index + 1].1 == dense_index + 1 + { let upper = self.Z[sparse_index + 1].0; - let eval = *value + *r * (upper - value); + let eval = *value + *r * (upper - value); new_Z.push((eval, new_dense_index)); } else { new_Z.push(((F::one() - r) * value, new_dense_index)); @@ -139,20 +156,26 @@ impl SparsePolynomial { let count_span = tracing::span!(tracing::Level::DEBUG, "counting"); let count_enter = count_span.enter(); let (chunks, _range) = self.chunk_no_orphans(rayon::current_num_threads() * 8); - let chunk_sizes: Vec = chunks.par_iter().map(|chunk| { - let mut chunk_size = 0; - let mut i = 0; - while i < chunk.len() { - chunk_size += 1; - - // If they're siblings, avoid double counting - if chunk[i].1.is_even() && i + 1 < chunk.len() && chunk[i].1 + 1 == chunk[i + 1].1 { + let chunk_sizes: Vec = chunks + .par_iter() + .map(|chunk| { + let mut chunk_size = 0; + let mut i = 0; + while i < chunk.len() { + chunk_size += 1; + + // If they're siblings, avoid double counting + if chunk[i].1.is_even() + && i + 1 < chunk.len() + && chunk[i].1 + 1 == chunk[i + 1].1 + { + i += 1; + } i += 1; } - i += 1; - } - chunk_size - }).collect(); + chunk_size + }) + .collect(); drop(count_enter); let alloc_span = tracing::span!(tracing::Level::DEBUG, "alloc_new_Z"); @@ -170,33 +193,40 @@ impl SparsePolynomial { } assert_eq!(mutable_chunks.len(), chunks.len()); - chunks.into_par_iter().zip(mutable_chunks.par_iter_mut()).for_each(|(chunk, mutable)| { - let span = tracing::span!(tracing::Level::DEBUG, "chunk"); - let _enter = span.enter(); - let mut write_index = 0; - for (sparse_index, (value, dense_index)) in chunk.iter().enumerate() { - if dense_index.is_even() { - let new_dense_index = dense_index / 2; - if chunk.len() >= 2 && sparse_index <= chunk.len() - 2 && chunk[sparse_index + 1].1 == dense_index + 1 { - let upper = chunk[sparse_index + 1].0; - let eval = *value + mul_0_1_optimized(r, &(upper - value)); - mutable[write_index] = (eval, new_dense_index); - write_index += 1; - } else { - mutable[write_index] = (mul_0_1_optimized(&(F::one() - r), value), new_dense_index); - write_index += 1; - } - } else { - if sparse_index > 0 && chunk[sparse_index - 1].1 == dense_index - 1 { - continue; + chunks + .into_par_iter() + .zip(mutable_chunks.par_iter_mut()) + .for_each(|(chunk, mutable)| { + let span = tracing::span!(tracing::Level::DEBUG, "chunk"); + let _enter = span.enter(); + let mut write_index = 0; + for (sparse_index, (value, dense_index)) in chunk.iter().enumerate() { + if dense_index.is_even() { + let new_dense_index = dense_index / 2; + if chunk.len() >= 2 + && sparse_index <= chunk.len() - 2 + && chunk[sparse_index + 1].1 == dense_index + 1 + { + let upper = chunk[sparse_index + 1].0; + let eval = *value + mul_0_1_optimized(r, &(upper - value)); + mutable[write_index] = (eval, new_dense_index); + write_index += 1; + } else { + mutable[write_index] = + (mul_0_1_optimized(&(F::one() - r), value), new_dense_index); + write_index += 1; + } } else { - let new_dense_index = (dense_index - 1) / 2; - mutable[write_index] = (mul_0_1_optimized(r, value), new_dense_index); - write_index += 1; + if sparse_index > 0 && chunk[sparse_index - 1].1 == dense_index - 1 { + continue; + } else { + let new_dense_index = (dense_index - 1) / 2; + mutable[write_index] = (mul_0_1_optimized(r, value), new_dense_index); + write_index += 1; + } } } - } - }); + }); let old_Z = std::mem::replace(&mut self.Z, new_Z); drop_in_background_thread(old_Z); @@ -240,7 +270,12 @@ pub struct SparseTripleIterator<'a, F: JoltField> { impl<'a, F: JoltField> SparseTripleIterator<'a, F> { #[tracing::instrument(skip_all)] - pub fn chunks(a: &'a SparsePolynomial, b: &'a SparsePolynomial, c: &'a SparsePolynomial, n: usize) -> Vec { + pub fn chunks( + a: &'a SparsePolynomial, + b: &'a SparsePolynomial, + c: &'a SparsePolynomial, + n: usize, + ) -> Vec { // When the instance is small enough, don't worry about parallelism let total_len = a.num_vars.pow2(); if n * 2 > b.Z.len() { @@ -249,11 +284,14 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { end_index: total_len, a: &a.Z, b: &b.Z, - c: &c.Z + c: &c.Z, }]; } // Can be made more generic, but this is an optimization / simplification. - assert!(b.Z.len() >= a.Z.len() && b.Z.len() >= c.Z.len(), "b.Z.len() assumed to be longest of a, b, and c"); + assert!( + b.Z.len() >= a.Z.len() && b.Z.len() >= c.Z.len(), + "b.Z.len() assumed to be longest of a, b, and c" + ); // TODO(sragss): Explain the strategy @@ -281,7 +319,12 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { let a_last = a.Z.last().map(|&(_, index)| index); let b_last = b.Z.last().map(|&(_, index)| index); let c_last = c.Z.last().map(|&(_, index)| index); - *a_last.iter().chain(b_last.iter()).chain(c_last.iter()).max().unwrap() + *a_last + .iter() + .chain(b_last.iter()) + .chain(c_last.iter()) + .max() + .unwrap() }; dense_ranges.push((dense_start_index, highest_non_zero + 1)); assert_eq!(b_chunks.len(), n); @@ -315,11 +358,10 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { c_chunks[chunk_index - 1] = &c.Z[c_start..c_i]; } - } drop(_enter); - a_chunks[n-1] = &a.Z[a_i..]; - c_chunks[n-1] = &c.Z[c_i..]; + a_chunks[n - 1] = &a.Z[a_i..]; + c_chunks[n - 1] = &c.Z[c_i..]; #[cfg(test)] { @@ -329,19 +371,30 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { } let mut iterators: Vec> = Vec::with_capacity(n); - for (((a_chunk, b_chunk), c_chunk), range) in a_chunks.iter().zip(b_chunks.iter()).zip(c_chunks.iter()).zip(dense_ranges.iter()) { + for (((a_chunk, b_chunk), c_chunk), range) in a_chunks + .iter() + .zip(b_chunks.iter()) + .zip(c_chunks.iter()) + .zip(dense_ranges.iter()) + { #[cfg(test)] { - assert!(a_chunk.iter().all(|(_, index)| *index >= range.0 && *index <= range.1)); - assert!(b_chunk.iter().all(|(_, index)| *index >= range.0 && *index <= range.1)); - assert!(c_chunk.iter().all(|(_, index)| *index >= range.0 && *index <= range.1)); + assert!(a_chunk + .iter() + .all(|(_, index)| *index >= range.0 && *index <= range.1)); + assert!(b_chunk + .iter() + .all(|(_, index)| *index >= range.0 && *index <= range.1)); + assert!(c_chunk + .iter() + .all(|(_, index)| *index >= range.0 && *index <= range.1)); } let iter = SparseTripleIterator { dense_index: range.0, end_index: range.1, a: a_chunk, b: b_chunk, - c: c_chunk + c: c_chunk, }; iterators.push(iter); } @@ -379,7 +432,15 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { let c_upper_val = match_and_advance(&mut self.c, self.dense_index); self.dense_index += 1; - (low_index, a_lower_val, a_upper_val, b_lower_val, b_upper_val, c_lower_val, c_upper_val) + ( + low_index, + a_lower_val, + a_upper_val, + b_lower_val, + b_upper_val, + c_lower_val, + c_upper_val, + ) } } @@ -544,8 +605,23 @@ mod tests { #[test] fn sparse_bound_bot_mixed() { - let dense_evals = vec![Fr::zero(), Fr::from(10), Fr::zero(), Fr::from(20), Fr::from(30), Fr::from(40), Fr::zero(), Fr::from(50)]; - let sparse_evals = vec![(Fr::from(10), 1), (Fr::from(20), 3), (Fr::from(30), 4), (Fr::from(40), 5), (Fr::from(50), 7)]; + let dense_evals = vec![ + Fr::zero(), + Fr::from(10), + Fr::zero(), + Fr::from(20), + Fr::from(30), + Fr::from(40), + Fr::zero(), + Fr::from(50), + ]; + let sparse_evals = vec![ + (Fr::from(10), 1), + (Fr::from(20), 3), + (Fr::from(30), 4), + (Fr::from(40), 5), + (Fr::from(50), 7), + ]; let mut dense = DensePolynomial::new(dense_evals); let mut sparse = SparsePolynomial::new(3, sparse_evals); @@ -561,7 +637,24 @@ mod tests { #[test] fn sparse_triple_iterator() { let a = vec![(Fr::from(9), 9), (Fr::from(10), 10), (Fr::from(12), 12)]; - let b = vec![(Fr::from(100), 0), (Fr::from(1), 1), (Fr::from(2), 2), (Fr::from(3), 3), (Fr::from(4), 4), (Fr::from(5), 5), (Fr::from(6), 6), (Fr::from(7), 7), (Fr::from(8), 8), (Fr::from(9), 9), (Fr::from(10), 10), (Fr::from(11), 11), (Fr::from(12), 12), (Fr::from(13), 13), (Fr::from(14), 14), (Fr::from(15), 15)]; + let b = vec![ + (Fr::from(100), 0), + (Fr::from(1), 1), + (Fr::from(2), 2), + (Fr::from(3), 3), + (Fr::from(4), 4), + (Fr::from(5), 5), + (Fr::from(6), 6), + (Fr::from(7), 7), + (Fr::from(8), 8), + (Fr::from(9), 9), + (Fr::from(10), 10), + (Fr::from(11), 11), + (Fr::from(12), 12), + (Fr::from(13), 13), + (Fr::from(14), 14), + (Fr::from(15), 15), + ]; let c = vec![(Fr::from(12), 0), (Fr::from(3), 3)]; let a_poly = SparsePolynomial::new(4, a); @@ -610,16 +703,17 @@ mod tests { let mut expected_dense_index = 0; for iterator in iterators.iter_mut() { while iterator.has_next() { - let (dense_index, a_low, a_high, b_low, b_high, c_low, c_high) = iterator.next_pairs(); + let (dense_index, a_low, a_high, b_low, b_high, c_low, c_high) = + iterator.next_pairs(); new_a[dense_index] = a_low; - new_a[dense_index+1] = a_high; + new_a[dense_index + 1] = a_high; new_b[dense_index] = b_low; - new_b[dense_index+1] = b_high; + new_b[dense_index + 1] = b_high; new_c[dense_index] = c_low; - new_c[dense_index+1] = c_high; + new_c[dense_index + 1] = c_high; assert_eq!(dense_index, expected_dense_index); expected_dense_index += 2; @@ -649,7 +743,9 @@ mod tests { let mut a_poly = SparsePolynomial::new(num_vars, a); let r = Fr::from(100); - assert_eq!(a_poly.clone().bound_poly_var_bot(&r), a_poly.bound_poly_var_bot_par(&r)); - + assert_eq!( + a_poly.clone().bound_poly_var_bot(&r), + a_poly.bound_poly_var_bot_par(&r) + ); } -} \ No newline at end of file +} diff --git a/jolt-core/src/utils/thread.rs b/jolt-core/src/utils/thread.rs index 4678022a7..92754a095 100644 --- a/jolt-core/src/utils/thread.rs +++ b/jolt-core/src/utils/thread.rs @@ -1,5 +1,5 @@ -use std::thread::{self, JoinHandle}; use rayon::prelude::*; +use std::thread::{self, JoinHandle}; use crate::field::JoltField; @@ -45,7 +45,7 @@ pub fn unsafe_allocate_zero_vec(size: usize) -> Vec { result } -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip_all)] pub fn unsafe_allocate_sparse_zero_vec(size: usize) -> Vec<(F, usize)> { // Check for safety of 0 allocation unsafe { @@ -72,18 +72,16 @@ pub fn unsafe_allocate_sparse_zero_vec(size: usize) -> Vec #[tracing::instrument(skip_all)] pub fn par_flatten_triple Vec>( - triple: Vec<(Vec, Vec, Vec)>, - allocate: F, - excess_alloc: usize) -> (Vec, Vec, Vec) { + triple: Vec<(Vec, Vec, Vec)>, + allocate: F, + excess_alloc: usize, +) -> (Vec, Vec, Vec) { let az_len: usize = triple.iter().map(|item| item.0.len()).sum(); let bz_len: usize = triple.iter().map(|item| item.1.len()).sum(); let cz_len: usize = triple.iter().map(|item| item.2.len()).sum(); - let (mut a_sparse, mut b_sparse, mut c_sparse): (Vec, Vec, Vec) = ( - allocate(az_len), - allocate(bz_len), - allocate(cz_len), - ); + let (mut a_sparse, mut b_sparse, mut c_sparse): (Vec, Vec, Vec) = + (allocate(az_len), allocate(bz_len), allocate(cz_len)); let mut a_slices = Vec::with_capacity(triple.len() + excess_alloc); let mut b_slices = Vec::with_capacity(triple.len() + excess_alloc); @@ -107,9 +105,20 @@ pub fn par_flatten_triple Vec>( c_rest = c_new_rest; } - triple.into_par_iter().zip(a_slices.par_iter_mut().zip(b_slices.par_iter_mut().zip(c_slices.par_iter_mut()))).for_each(|(chunk, (a, (b, c)))| { - join_triple(|| a.copy_from_slice(&chunk.0), || b.copy_from_slice(&chunk.1), || c.copy_from_slice(&chunk.2)); - }); + triple + .into_par_iter() + .zip( + a_slices + .par_iter_mut() + .zip(b_slices.par_iter_mut().zip(c_slices.par_iter_mut())), + ) + .for_each(|(chunk, (a, (b, c)))| { + join_triple( + || a.copy_from_slice(&chunk.0), + || b.copy_from_slice(&chunk.1), + || c.copy_from_slice(&chunk.2), + ); + }); (a_sparse, b_sparse, c_sparse) } From 9f73a04bfbb8be512d891ebde13e8626d350a43b Mon Sep 17 00:00:00 2001 From: sragss Date: Wed, 26 Jun 2024 10:38:27 -0700 Subject: [PATCH 06/17] rm non sparse Builder::compute_spartan --- jolt-core/src/jolt/vm/mod.rs | 7 +- jolt-core/src/r1cs/builder.rs | 253 ++++++------------------- jolt-core/src/r1cs/jolt_constraints.rs | 4 +- jolt-core/src/r1cs/key.rs | 2 +- jolt-core/src/r1cs/spartan.rs | 16 +- 5 files changed, 65 insertions(+), 217 deletions(-) diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index 06f63c004..e0d0188d6 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -403,10 +403,8 @@ pub trait Jolt, const C: usize, c &mut transcript, ); - // drop_in_background_thread(jolt_polynomials); - drop(jolt_polynomials); + drop_in_background_thread(jolt_polynomials); - profiling::print_current_memory_usage("pre_spartan"); let spartan_proof = UniformSpartanProof::::prove_precommitted( &preprocessing.generators, r1cs_builder, @@ -415,7 +413,6 @@ pub trait Jolt, const C: usize, c &mut transcript, ) .expect("r1cs proof failed"); - profiling::print_current_memory_usage("post_spartan"); let r1cs_proof = R1CSProof { key: spartan_key, proof: spartan_proof, @@ -593,7 +590,7 @@ pub trait Jolt, const C: usize, c #[cfg(test)] { - let (az, bz, cz) = builder.compute_spartan_Az_Bz_Cz(&inputs_flat, &aux); + let (az, bz, cz) = builder.compute_spartan_Az_Bz_Cz_sparse(&inputs_flat, &aux); builder.assert_valid(&az, &bz, &cz); } diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 0a5bc5859..124ec573d 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -28,9 +28,9 @@ pub trait R1CSConstraintBuilder { } #[derive(Debug, Clone, Copy, PartialEq)] -pub enum EvaluationHint { +pub enum EvalHint { Zero = 0, - Other = 2, + Other = 1, } /// Constraints over a single row. Each variable points to a single item in Z and the corresponding coefficient. @@ -39,7 +39,9 @@ struct Constraint { a: LC, b: LC, c: LC, - evaluation_hint: (EvaluationHint, EvaluationHint, EvaluationHint), + + /// Shortcut for evaluation of a, b, c for an honest prover + eval_hint: (EvalHint, EvalHint, EvalHint), } impl Constraint { @@ -237,11 +239,7 @@ impl R1CSBuilder { a, b, c: LC::zero(), - evaluation_hint: ( - EvaluationHint::Zero, - EvaluationHint::Other, - EvaluationHint::Zero, - ), + eval_hint: (EvalHint::Zero, EvalHint::Other, EvalHint::Zero), }; self.constraints.push(constraint); } @@ -264,11 +262,7 @@ impl R1CSBuilder { a, b, c, - evaluation_hint: ( - EvaluationHint::Other, - EvaluationHint::Other, - EvaluationHint::Zero, - ), + eval_hint: (EvalHint::Other, EvalHint::Other, EvalHint::Zero), }; // TODO(sragss): Can do better on middle term. self.constraints.push(constraint); } @@ -282,11 +276,7 @@ impl R1CSBuilder { a, b, c: LC::zero(), - evaluation_hint: ( - EvaluationHint::Other, - EvaluationHint::Other, - EvaluationHint::Zero, - ), + eval_hint: (EvalHint::Other, EvalHint::Other, EvalHint::Zero), }; self.constraints.push(constraint); } @@ -310,11 +300,7 @@ impl R1CSBuilder { a: condition.clone(), b: (result_true - result_false.clone()), c: (alleged_result - result_false), - evaluation_hint: ( - EvaluationHint::Other, - EvaluationHint::Other, - EvaluationHint::Other, - ), // TODO(sragss): Is this the best we can do? + eval_hint: (EvalHint::Other, EvalHint::Other, EvalHint::Other), // TODO(sragss): Is this the best we can do? }; self.constraints.push(constraint); } @@ -460,11 +446,7 @@ impl R1CSBuilder { a: x.into(), b: y.into(), c: z.into(), - evaluation_hint: ( - EvaluationHint::Other, - EvaluationHint::Other, - EvaluationHint::Other, - ), + eval_hint: (EvalHint::Other, EvalHint::Other, EvalHint::Other), }; self.constraints.push(constraint); } @@ -787,108 +769,6 @@ impl CombinedUniformBuilder { NonUniformR1CS::new(eq, condition) } - /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] - /// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]] - #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan")] - pub fn compute_spartan_Az_Bz_Cz( - &self, - inputs: &[Vec], - aux: &[Vec], - ) -> (Vec, Vec, Vec) { - assert_eq!(inputs.len(), I::COUNT); - let num_aux = self.uniform_builder.num_aux(); - assert_eq!(aux.len(), num_aux); - assert!(inputs - .iter() - .chain(aux.iter()) - .all(|inner_input| inner_input.len() == self.uniform_repeat)); - - let uniform_constraint_rows = self.uniform_repeat_constraint_rows(); - // TODO(sragss): Allocation can overshoot by up to a factor of 2, Spartan could handle non-pow-2 Az,Bz,Cz - let constraint_rows = self.constraint_rows().next_power_of_two(); - let (mut Az, mut Bz, mut Cz) = ( - unsafe_allocate_zero_vec(constraint_rows), - unsafe_allocate_zero_vec(constraint_rows), - unsafe_allocate_zero_vec(constraint_rows), - ); - - let batch_inputs = |lc: &LC| batch_inputs(lc, inputs, aux); - - // uniform_constraints: Xz[0..uniform_constraint_rows] - // TODO(sragss): Attempt moving onto key and computing from materialized rows rather than linear combos - let span = tracing::span!(tracing::Level::DEBUG, "compute_constraints"); - let enter = span.enter(); - let az_chunks = Az.par_chunks_mut(self.uniform_repeat); - let bz_chunks = Bz.par_chunks_mut(self.uniform_repeat); - let cz_chunks = Cz.par_chunks_mut(self.uniform_repeat); - - self.uniform_builder - .constraints - .par_iter() - .zip(az_chunks.zip(bz_chunks.zip(cz_chunks))) - .for_each(|(constraint, (az_chunk, (bz_chunk, cz_chunk)))| { - let a_inputs = batch_inputs(&constraint.a); - let b_inputs = batch_inputs(&constraint.b); - let c_inputs = batch_inputs(&constraint.c); - - constraint.a.evaluate_batch_mut(&a_inputs, az_chunk); - constraint.b.evaluate_batch_mut(&b_inputs, bz_chunk); - constraint.c.evaluate_batch_mut(&c_inputs, cz_chunk); - }); - drop(enter); - - // offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1] - // (a - b) * condition == 0 - // For the final step we will not compute the offset terms, and will assume the condition to be set to 0 - let span = tracing::span!(tracing::Level::DEBUG, "offset_eq"); - let _enter = span.enter(); - - let constr = &self.offset_equality_constraint; - let condition_evals = constr - .cond - .1 - .evaluate_batch(&batch_inputs(&constr.cond.1), self.uniform_repeat); - let eq_a_evals = constr - .a - .1 - .evaluate_batch(&batch_inputs(&constr.a.1), self.uniform_repeat); - let eq_b_evals = constr - .b - .1 - .evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat); - - let Az_off = Az[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat] - .par_iter_mut(); - let Bz_off = Bz[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat] - .par_iter_mut(); - - (0..self.uniform_repeat) - .into_par_iter() - .zip(Az_off.zip(Bz_off)) - .for_each(|(step_index, (az, bz))| { - // Write corresponding values, if outside the step range, only include the constant. - let a_step = step_index + if constr.a.0 { 1 } else { 0 }; - let b_step = step_index + if constr.b.0 { 1 } else { 0 }; - let a = eq_a_evals - .get(a_step) - .cloned() - .unwrap_or(constr.a.1.constant_term_field()); - let b = eq_b_evals - .get(b_step) - .cloned() - .unwrap_or(constr.b.1.constant_term_field()); - *az = a - b; - - let condition_step = step_index + if constr.cond.0 { 1 } else { 0 }; - *bz = condition_evals - .get(condition_step) - .cloned() - .unwrap_or(constr.cond.1.constant_term_field()); - }); - - (Az, Bz, Cz) - } - /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] /// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]] #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan_sparse")] @@ -913,36 +793,6 @@ impl CombinedUniformBuilder { let batch_inputs = |lc: &LC| batch_inputs(lc, inputs, aux); - // Enforce correctness of hints. - // TODO(sragss): Can be moved into assert_valid. - #[cfg(test)] - self.uniform_builder - .constraints - .iter() - .for_each(|constraint| { - let assert_hint = |constraint: &Constraint| { - let a_inputs = batch_inputs(&constraint.a); - let b_inputs = batch_inputs(&constraint.b); - let c_inputs = batch_inputs(&constraint.c); - - let a = constraint.a.evaluate_batch(&a_inputs, self.uniform_repeat); - let b = constraint.b.evaluate_batch(&b_inputs, self.uniform_repeat); - let c = constraint.c.evaluate_batch(&c_inputs, self.uniform_repeat); - - if constraint.evaluation_hint.0 == EvaluationHint::Zero { - a.iter().for_each(|item| assert_eq!(*item, F::zero(),)); - } - if constraint.evaluation_hint.1 == EvaluationHint::Zero { - b.iter().for_each(|item| assert_eq!(*item, F::zero(),)); - } - if constraint.evaluation_hint.2 == EvaluationHint::Zero { - c.iter().for_each(|item| assert_eq!(*item, F::zero(),)); - } - }; - - assert_hint(constraint); - }); - // uniform_constraints: Xz[0..uniform_constraint_rows] let span = tracing::span!(tracing::Level::DEBUG, "uniform_evals"); let _enter = span.enter(); @@ -955,7 +805,7 @@ impl CombinedUniformBuilder { let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat); let mut evaluate_lc_chunk = |hint, lc: &LC| { - if hint != EvaluationHint::Zero { + if hint != EvalHint::Zero { let inputs = batch_inputs(lc); lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer); @@ -977,11 +827,11 @@ impl CombinedUniformBuilder { }; let a_chunk: Vec<(F, usize)> = - evaluate_lc_chunk(constraint.evaluation_hint.0, &constraint.a); + evaluate_lc_chunk(constraint.eval_hint.0, &constraint.a); let b_chunk: Vec<(F, usize)> = - evaluate_lc_chunk(constraint.evaluation_hint.1, &constraint.b); + evaluate_lc_chunk(constraint.eval_hint.1, &constraint.b); let c_chunk: Vec<(F, usize)> = - evaluate_lc_chunk(constraint.evaluation_hint.2, &constraint.c); + evaluate_lc_chunk(constraint.eval_hint.2, &constraint.c); (a_chunk, b_chunk, c_chunk) }) @@ -990,7 +840,7 @@ impl CombinedUniformBuilder { let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple( uni_constraint_evals, unsafe_allocate_sparse_zero_vec, - self.uniform_repeat, + self.uniform_repeat, // Capacity overhead for offset_eq constraints. ); // offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1] @@ -1055,23 +905,29 @@ impl CombinedUniformBuilder { let cz_poly = SparsePolynomial::new(num_vars, cz_sparse); #[cfg(test)] - self.assert_valid( - &az_poly.clone().to_dense().evals_ref(), - &bz_poly.clone().to_dense().evals_ref(), - &cz_poly.clone().to_dense().evals_ref(), - ); + self.assert_valid(&az_poly, &bz_poly, &cz_poly); (az_poly, bz_poly, cz_poly) } #[cfg(test)] - pub fn assert_valid(&self, az: &[F], bz: &[F], cz: &[F]) { + pub fn assert_valid( + &self, + az: &SparsePolynomial, + bz: &SparsePolynomial, + cz: &SparsePolynomial, + ) { + let az = az.clone().to_dense(); + let bz = bz.clone().to_dense(); + let cz = cz.clone().to_dense(); + let rows = az.len(); assert_eq!(bz.len(), rows); assert_eq!(cz.len(), rows); + for constraint_index in 0..rows { + let uniform_constraint_index = constraint_index / self.uniform_repeat; if az[constraint_index] * bz[constraint_index] != cz[constraint_index] { - let uniform_constraint_index = constraint_index / self.uniform_repeat; let step_index = constraint_index % self.uniform_repeat; panic!( "Mismatch at global constraint {constraint_index} => {:?}\n\ @@ -1080,6 +936,20 @@ impl CombinedUniformBuilder { self.uniform_builder.constraints[uniform_constraint_index] ); } + // Verify hints + if constraint_index < self.uniform_repeat_constraint_rows() { + let (hint_a, hint_b, hint_c) = + self.uniform_builder.constraints[uniform_constraint_index].eval_hint; + if hint_a == EvalHint::Zero { + assert_eq!(az[constraint_index], F::zero(), "Mismatch at global constraint {constraint_index} uniform constraint: {uniform_constraint_index}"); + } + if hint_b == EvalHint::Zero { + assert_eq!(bz[constraint_index], F::zero(), "Mismatch at global constraint {constraint_index} uniform constraint: {uniform_constraint_index}"); + } + if hint_c == EvalHint::Zero { + assert_eq!(cz[constraint_index], F::zero(), "Mismatch at global constraint {constraint_index} uniform constraint: {uniform_constraint_index}"); + } + } } } } @@ -1482,11 +1352,7 @@ mod tests { let aux = combined_builder.compute_aux(&inputs); assert_eq!(aux, vec![vec![Fr::from(5 * 7), Fr::from(11 * 13)]]); - let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); - assert_eq!(az.len(), 4); - assert_eq!(bz.len(), 4); - assert_eq!(cz.len(), 4); - + let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &aux); combined_builder.assert_valid(&az, &bz, &cz); } @@ -1546,11 +1412,7 @@ mod tests { ] ); - let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); - assert_eq!(az.len(), 16); - assert_eq!(bz.len(), 16); - assert_eq!(cz.len(), 16); - + let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &aux); combined_builder.assert_valid(&az, &bz, &cz); } @@ -1594,11 +1456,7 @@ mod tests { let aux = combined_builder.compute_aux(&inputs); assert_eq!(aux, vec![vec![Fr::from(5 * 7), Fr::from(5 * 13)]]); - let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); - assert_eq!(az.len(), 4); - assert_eq!(bz.len(), 4); - assert_eq!(cz.len(), 4); - + let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &aux); combined_builder.assert_valid(&az, &bz, &cz); } @@ -1669,11 +1527,16 @@ mod tests { flat_witness.resize(flat_witness.len().next_power_of_two(), Fr::zero()); flat_witness.push(Fr::one()); flat_witness.resize(flat_witness.len().next_power_of_two(), Fr::zero()); - let (mut builder_az, mut builder_bz, mut builder_cz) = - builder.compute_spartan_Az_Bz_Cz(&witness_segments, &[]); - builder_az.resize(key.num_rows_total(), Fr::zero()); - builder_bz.resize(key.num_rows_total(), Fr::zero()); - builder_cz.resize(key.num_rows_total(), Fr::zero()); + + let (builder_az, builder_bz, builder_cz) = + builder.compute_spartan_Az_Bz_Cz_sparse(&witness_segments, &[]); + let mut dense_az = builder_az.to_dense().evals(); + let mut dense_bz = builder_bz.to_dense().evals(); + let mut dense_cz = builder_cz.to_dense().evals(); + dense_az.resize(key.num_rows_total(), Fr::zero()); + dense_bz.resize(key.num_rows_total(), Fr::zero()); + dense_cz.resize(key.num_rows_total(), Fr::zero()); + for row in 0..key.num_rows_total() { let mut az_eval = Fr::zero(); let mut bz_eval = Fr::zero(); @@ -1685,9 +1548,9 @@ mod tests { } // Row 11 is the problem! Builder thinks this row should be 0. big_a thinks this row should be 17 (13 + 4) - assert_eq!(builder_az[row], az_eval, "Row {row} failed in az_eval."); - assert_eq!(builder_bz[row], bz_eval, "Row {row} failed in bz_eval."); - assert_eq!(builder_cz[row], cz_eval, "Row {row} failed in cz_eval."); + assert_eq!(dense_az[row], az_eval, "Row {row} failed in az_eval."); + assert_eq!(dense_bz[row], bz_eval, "Row {row} failed in bz_eval."); + assert_eq!(dense_cz[row], cz_eval, "Row {row} failed in cz_eval."); } } } diff --git a/jolt-core/src/r1cs/jolt_constraints.rs b/jolt-core/src/r1cs/jolt_constraints.rs index 2df3cdd28..e59207e1d 100644 --- a/jolt-core/src/r1cs/jolt_constraints.rs +++ b/jolt-core/src/r1cs/jolt_constraints.rs @@ -334,8 +334,8 @@ mod tests { inputs[JoltIn::OpFlags_IsImm as usize][0] = Fr::zero(); // second_operand = rs2 => immediate let aux = combined_builder.compute_aux(&inputs); - let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); - combined_builder.assert_valid(&az, &bz, &cz); + // Implicitly asserts validity + let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &aux); } } diff --git a/jolt-core/src/r1cs/key.rs b/jolt-core/src/r1cs/key.rs index f4dcfb166..400524af6 100644 --- a/jolt-core/src/r1cs/key.rs +++ b/jolt-core/src/r1cs/key.rs @@ -566,7 +566,7 @@ mod test { inputs[TestInputs::OpFlags1 as usize][3] = Fr::from(3); // Confirms validity of constraints - let (_az, _bz, _cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &[]); + let (_az, _bz, _cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &[]); let key = UniformSpartanKey::from_builder(&combined_builder); diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index b436769e3..f6e1d5ee7 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -6,7 +6,6 @@ use crate::poly::commitment::commitment_scheme::CommitmentScheme; use crate::r1cs::key::UniformSpartanKey; use crate::r1cs::special_polys::SegmentedPaddedWitness; use crate::utils::math::Math; -use crate::utils::profiling; use crate::utils::thread::drop_in_background_thread; use crate::utils::transcript::ProofTranscript; @@ -108,16 +107,12 @@ impl> UniformSpartanProof { let tau = (0..num_rounds_x) .map(|_i| transcript.challenge_scalar(b"t")) .collect::>(); - profiling::print_current_memory_usage("pre_poly_tau"); let mut poly_tau = DensePolynomial::new(EqPolynomial::evals(&tau)); - profiling::print_current_memory_usage("post_poly_tau"); let inputs = &segmented_padded_witness.segments[0..I::COUNT]; let aux = &segmented_padded_witness.segments[I::COUNT..]; - profiling::print_current_memory_usage("pre_az_bz_cz"); let (mut az, mut bz, mut cz) = constraint_builder.compute_spartan_Az_Bz_Cz_sparse(inputs, aux); - profiling::print_current_memory_usage("post_az_bz_cz"); let comb_func_outer = |A: &F, B: &F, C: &F, D: &F| -> F { // Below is an optimized form of: *A * (*B * *C - *D) @@ -137,8 +132,6 @@ impl> UniformSpartanProof { } }; - // profiling::start_memory_tracing_span("outersumcheck"); - profiling::print_current_memory_usage("pre_outersumcheck"); let (outer_sumcheck_proof, outer_sumcheck_r, outer_sumcheck_claims) = SumcheckInstanceProof::prove_spartan_cubic::<_>( &F::zero(), // claim is zero @@ -151,9 +144,7 @@ impl> UniformSpartanProof { transcript, ); let outer_sumcheck_r: Vec = outer_sumcheck_r.into_iter().rev().collect(); - // drop_in_background_thread((poly_Az, poly_Bz, poly_Cz, poly_tau)); - drop((az, bz, cz, poly_tau)); - profiling::print_current_memory_usage("post_outersumcheck"); + drop_in_background_thread((az, bz, cz, poly_tau)); // claims from the end of sum-check // claim_Az is the (scalar) value v_A = \sum_y A(r_x, y) * z(r_x) where r_x is the sumcheck randomness @@ -181,10 +172,8 @@ impl> UniformSpartanProof { .ilog2(); let (rx_con, rx_ts) = outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_steps_bits as usize); - profiling::print_current_memory_usage("pre_poly_ABC"); let mut poly_ABC = DensePolynomial::new(key.evaluate_r1cs_mle_rlc(rx_con, rx_ts, r_inner_sumcheck_RLC)); - profiling::print_current_memory_usage("post_poly_ABC"); let (inner_sumcheck_proof, inner_sumcheck_r, _claims_inner) = SumcheckInstanceProof::prove_spartan_quadratic::>( @@ -194,8 +183,7 @@ impl> UniformSpartanProof { &segmented_padded_witness, transcript, ); - // drop_in_background_thread(poly_ABC); - drop(poly_ABC); + drop_in_background_thread(poly_ABC); // Requires 'r_col_segment_bits' to index the (const, segment). Within that segment we index the step using 'r_col_step' let r_col_segment_bits = key.uniform_r1cs.num_vars.next_power_of_two().log_2() + 1; From fa4fefb1e2d7acbfc80ec440ca660c831ec16e76 Mon Sep 17 00:00:00 2001 From: sragss Date: Wed, 26 Jun 2024 11:39:39 -0700 Subject: [PATCH 07/17] remove duplicate code (chunk_no_ophans) --- jolt-core/src/jolt/vm/mod.rs | 3 +- jolt-core/src/r1cs/builder.rs | 10 +-- jolt-core/src/r1cs/jolt_constraints.rs | 2 +- jolt-core/src/r1cs/key.rs | 2 +- jolt-core/src/r1cs/spartan.rs | 3 +- jolt-core/src/r1cs/special_polys.rs | 120 +++++++++---------------- jolt-core/src/subprotocols/sumcheck.rs | 14 --- 7 files changed, 53 insertions(+), 101 deletions(-) diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index e0d0188d6..fb3b5e55d 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -4,7 +4,6 @@ use crate::field::JoltField; use crate::r1cs::builder::CombinedUniformBuilder; use crate::r1cs::jolt_constraints::{construct_jolt_constraints, JoltIn}; use crate::r1cs::spartan::{self, UniformSpartanProof}; -use crate::utils::profiling; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::log2; use common::constants::RAM_START_ADDRESS; @@ -590,7 +589,7 @@ pub trait Jolt, const C: usize, c #[cfg(test)] { - let (az, bz, cz) = builder.compute_spartan_Az_Bz_Cz_sparse(&inputs_flat, &aux); + let (az, bz, cz) = builder.compute_spartan_Az_Bz_Cz(&inputs_flat, &aux); builder.assert_valid(&az, &bz, &cz); } diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 124ec573d..e2cf70af9 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -772,7 +772,7 @@ impl CombinedUniformBuilder { /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] /// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]] #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan_sparse")] - pub fn compute_spartan_Az_Bz_Cz_sparse( + pub fn compute_spartan_Az_Bz_Cz( &self, inputs: &[Vec], aux: &[Vec], @@ -1352,7 +1352,7 @@ mod tests { let aux = combined_builder.compute_aux(&inputs); assert_eq!(aux, vec![vec![Fr::from(5 * 7), Fr::from(11 * 13)]]); - let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &aux); + let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); combined_builder.assert_valid(&az, &bz, &cz); } @@ -1412,7 +1412,7 @@ mod tests { ] ); - let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &aux); + let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); combined_builder.assert_valid(&az, &bz, &cz); } @@ -1456,7 +1456,7 @@ mod tests { let aux = combined_builder.compute_aux(&inputs); assert_eq!(aux, vec![vec![Fr::from(5 * 7), Fr::from(5 * 13)]]); - let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &aux); + let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); combined_builder.assert_valid(&az, &bz, &cz); } @@ -1529,7 +1529,7 @@ mod tests { flat_witness.resize(flat_witness.len().next_power_of_two(), Fr::zero()); let (builder_az, builder_bz, builder_cz) = - builder.compute_spartan_Az_Bz_Cz_sparse(&witness_segments, &[]); + builder.compute_spartan_Az_Bz_Cz(&witness_segments, &[]); let mut dense_az = builder_az.to_dense().evals(); let mut dense_bz = builder_bz.to_dense().evals(); let mut dense_cz = builder_cz.to_dense().evals(); diff --git a/jolt-core/src/r1cs/jolt_constraints.rs b/jolt-core/src/r1cs/jolt_constraints.rs index e59207e1d..6b1ef1125 100644 --- a/jolt-core/src/r1cs/jolt_constraints.rs +++ b/jolt-core/src/r1cs/jolt_constraints.rs @@ -336,6 +336,6 @@ mod tests { let aux = combined_builder.compute_aux(&inputs); // Implicitly asserts validity - let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &aux); + let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); } } diff --git a/jolt-core/src/r1cs/key.rs b/jolt-core/src/r1cs/key.rs index 400524af6..f4dcfb166 100644 --- a/jolt-core/src/r1cs/key.rs +++ b/jolt-core/src/r1cs/key.rs @@ -566,7 +566,7 @@ mod test { inputs[TestInputs::OpFlags1 as usize][3] = Fr::from(3); // Confirms validity of constraints - let (_az, _bz, _cz) = combined_builder.compute_spartan_Az_Bz_Cz_sparse(&inputs, &[]); + let (_az, _bz, _cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &[]); let key = UniformSpartanKey::from_builder(&combined_builder); diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index f6e1d5ee7..7e015c3f5 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -111,8 +111,7 @@ impl> UniformSpartanProof { let inputs = &segmented_padded_witness.segments[0..I::COUNT]; let aux = &segmented_padded_witness.segments[I::COUNT..]; - let (mut az, mut bz, mut cz) = - constraint_builder.compute_spartan_Az_Bz_Cz_sparse(inputs, aux); + let (mut az, mut bz, mut cz) = constraint_builder.compute_spartan_Az_Bz_Cz(inputs, aux); let comb_func_outer = |A: &F, B: &F, C: &F, D: &F| -> F { // Below is an optimized form of: *A * (*B * *C - *D) diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index d757a24c7..8a4361f04 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -5,9 +5,7 @@ use crate::{ compute_dotproduct_low_optimized, math::Math, mul_0_1_optimized, - thread::{ - drop_in_background_thread, unsafe_allocate_sparse_zero_vec, unsafe_allocate_zero_vec, - }, + thread::{drop_in_background_thread, unsafe_allocate_sparse_zero_vec}, }, }; use num_integer::Integer; @@ -80,9 +78,9 @@ impl SparsePolynomial { .sum() } - /// Returns n chunks of roughly even size without orphaning siblings (adjacent dense indices). Additionally returns a vector of (low, high] dense index ranges. + /// Returns `n` chunks of roughly even size without separating siblings (adjacent dense indices). Additionally returns a vector of [low, high) dense index ranges. #[tracing::instrument(skip_all)] - fn chunk_no_orphans(&self, n: usize) -> (Vec<&[(F, usize)]>, Vec<(usize, usize)>) { + fn chunk_no_split_siblings(&self, n: usize) -> (Vec<&[(F, usize)]>, Vec<(usize, usize)>) { if self.Z.len() < n * 2 { return (vec![(&self.Z)], vec![(0, self.num_vars.pow2())]); } @@ -153,38 +151,34 @@ impl SparsePolynomial { #[tracing::instrument(skip_all)] pub fn bound_poly_var_bot_par(&mut self, r: &F) { // TODO(sragss): better parallelism. + let (chunks, _range) = self.chunk_no_split_siblings(rayon::current_num_threads() * 8); + + // Calc chunk sizes post-binding for pre-allocation. let count_span = tracing::span!(tracing::Level::DEBUG, "counting"); let count_enter = count_span.enter(); - let (chunks, _range) = self.chunk_no_orphans(rayon::current_num_threads() * 8); let chunk_sizes: Vec = chunks .par_iter() .map(|chunk| { - let mut chunk_size = 0; - let mut i = 0; - while i < chunk.len() { - chunk_size += 1; - - // If they're siblings, avoid double counting - if chunk[i].1.is_even() - && i + 1 < chunk.len() - && chunk[i].1 + 1 == chunk[i + 1].1 - { - i += 1; - } - i += 1; - } - chunk_size + // Count each pair of siblings if at least one is present. + chunk + .iter() + .enumerate() + .filter(|(i, (_value, index))| { + // Always count odd, only count even indices when the paired odd index is not present. + !index.is_even() || i + 1 >= chunk.len() || index + 1 != chunk[i + 1].1 + }) + .count() }) .collect(); drop(count_enter); - let alloc_span = tracing::span!(tracing::Level::DEBUG, "alloc_new_Z"); + let alloc_span = tracing::span!(tracing::Level::DEBUG, "alloc"); let alloc_enter = alloc_span.enter(); let total_len: usize = chunk_sizes.iter().sum(); let mut new_Z: Vec<(F, usize)> = unsafe_allocate_sparse_zero_vec(total_len); drop(alloc_enter); - let mut mutable_chunks: Vec<&mut [(F, usize)]> = vec![]; + let mut mutable_chunks: Vec<&mut [(F, usize)]> = Vec::with_capacity(chunk_sizes.len()); let mut remainder = new_Z.as_mut_slice(); for chunk_size in chunk_sizes { let (first, second) = remainder.split_at_mut(chunk_size); @@ -193,6 +187,7 @@ impl SparsePolynomial { } assert_eq!(mutable_chunks.len(), chunks.len()); + // Bind each chunk in parallel chunks .into_par_iter() .zip(mutable_chunks.par_iter_mut()) @@ -203,23 +198,28 @@ impl SparsePolynomial { for (sparse_index, (value, dense_index)) in chunk.iter().enumerate() { if dense_index.is_even() { let new_dense_index = dense_index / 2; + if chunk.len() >= 2 && sparse_index <= chunk.len() - 2 && chunk[sparse_index + 1].1 == dense_index + 1 { + // (low, high) present let upper = chunk[sparse_index + 1].0; let eval = *value + mul_0_1_optimized(r, &(upper - value)); mutable[write_index] = (eval, new_dense_index); write_index += 1; } else { + // (low, _) present mutable[write_index] = (mul_0_1_optimized(&(F::one() - r), value), new_dense_index); write_index += 1; } } else { if sparse_index > 0 && chunk[sparse_index - 1].1 == dense_index - 1 { + // (low, high) present, but handeled prior continue; } else { + // (_, high) present let new_dense_index = (dense_index - 1) / 2; mutable[write_index] = (mul_0_1_optimized(r, value), new_dense_index); write_index += 1; @@ -248,7 +248,7 @@ impl SparsePolynomial { #[cfg(test)] #[tracing::instrument(skip_all)] pub fn to_dense(self) -> DensePolynomial { - use crate::utils::math::Math; + use crate::utils::{math::Math, thread::unsafe_allocate_zero_vec}; let mut evals = unsafe_allocate_zero_vec(self.num_vars.pow2()); @@ -276,9 +276,9 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { c: &'a SparsePolynomial, n: usize, ) -> Vec { - // When the instance is small enough, don't worry about parallelism + // Don't chunk for small instances let total_len = a.num_vars.pow2(); - if n * 2 > b.Z.len() { + if b.Z.len() < n * 2 { return vec![SparseTripleIterator { dense_index: 0, end_index: total_len, @@ -287,55 +287,28 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { c: &c.Z, }]; } - // Can be made more generic, but this is an optimization / simplification. - assert!( - b.Z.len() >= a.Z.len() && b.Z.len() >= c.Z.len(), - "b.Z.len() assumed to be longest of a, b, and c" - ); - // TODO(sragss): Explain the strategy + // B is assumed most dense. Parallelism depends on evenly distributing B across threads. + assert!(b.Z.len() >= a.Z.len() && b.Z.len() >= c.Z.len()); - let target_chunk_size = b.Z.len() / n; - let mut b_chunks: Vec<&[(F, usize)]> = Vec::with_capacity(n); - let mut dense_ranges: Vec<(usize, usize)> = Vec::with_capacity(n); - let mut dense_start_index = 0; - let mut sparse_start_index = 0; - let mut sparse_end_index = target_chunk_size; - for _ in 1..n { - let mut dense_end_index = b.Z[sparse_end_index].1; - if dense_end_index % 2 != 0 { - dense_end_index += 1; - sparse_end_index += 1; - } - b_chunks.push(&b.Z[sparse_start_index..sparse_end_index]); - dense_ranges.push((dense_start_index, dense_end_index)); - dense_start_index = dense_end_index; + // TODO(sragss): Explain the strategy - sparse_start_index = sparse_end_index; - sparse_end_index = std::cmp::min(sparse_end_index + target_chunk_size, b.Z.len() - 1); - } - b_chunks.push(&b.Z[sparse_start_index..]); - let highest_non_zero = { - let a_last = a.Z.last().map(|&(_, index)| index); - let b_last = b.Z.last().map(|&(_, index)| index); - let c_last = c.Z.last().map(|&(_, index)| index); - *a_last - .iter() - .chain(b_last.iter()) - .chain(c_last.iter()) - .max() - .unwrap() - }; - dense_ranges.push((dense_start_index, highest_non_zero + 1)); + let (b_chunks, mut dense_ranges) = b.chunk_no_split_siblings(n); + let highest_non_zero = [&a.Z, &b.Z, &c.Z] + .iter() + .filter_map(|z| z.last().map(|&(_, index)| index)) + .max() + .unwrap(); + dense_ranges.last_mut().unwrap().1 = highest_non_zero + 1; assert_eq!(b_chunks.len(), n); assert_eq!(dense_ranges.len(), n); - // Create chunks which overlap with b's sparse indices + // Create chunks of (a, c) which overlap with b's sparse indices let mut a_chunks: Vec<&[(F, usize)]> = vec![&[]; n]; let mut c_chunks: Vec<&[(F, usize)]> = vec![&[]; n]; let mut a_i = 0; let mut c_i = 0; - let span = tracing::span!(tracing::Level::DEBUG, "a, c scanning"); + let span = tracing::span!(tracing::Level::DEBUG, "a_c_chunking"); let _enter = span.enter(); for (chunk_index, range) in dense_ranges.iter().enumerate().skip(1) { // Find the corresponding a, c chunks @@ -378,17 +351,12 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { .zip(dense_ranges.iter()) { #[cfg(test)] - { - assert!(a_chunk - .iter() - .all(|(_, index)| *index >= range.0 && *index <= range.1)); - assert!(b_chunk - .iter() - .all(|(_, index)| *index >= range.0 && *index <= range.1)); - assert!(c_chunk - .iter() - .all(|(_, index)| *index >= range.0 && *index <= range.1)); + for chunk in &[a_chunk, b_chunk, c_chunk] { + for (_, index) in chunk.iter() { + assert!(*index >= range.0 && *index <= range.1); + } } + let iter = SparseTripleIterator { dense_index: range.0, end_index: range.1, @@ -672,7 +640,7 @@ mod tests { let mut rng = rand::thread_rng(); let prob_exists = 0.32; - let num_vars = 10; + let num_vars = 5; let total_len = 1 << num_vars; let mut a = vec![]; diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index ab5d919c9..a405fbfa3 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -298,20 +298,6 @@ impl SumcheckInstanceProof { poly_A.bound_poly_var_bot_par(&r_i); poly_B.bound_poly_var_bot_par(&r_i); poly_C.bound_poly_var_bot_par(&r_i); - // rayon::join( - // || poly_eq.bound_poly_var_bot(&r_i), - // || { - // rayon::join( - // || poly_A.bound_poly_var_bot_par(&r_i), - // || { - // rayon::join( - // || poly_B.bound_poly_var_bot_par(&r_i), - // || poly_C.bound_poly_var_bot_par(&r_i), - // ) - // }, - // ) - // }, - // ); } ( From 403b66a252c55df9a1f265d6213cf0e2c5eaadc5 Mon Sep 17 00:00:00 2001 From: sragss Date: Wed, 26 Jun 2024 11:42:59 -0700 Subject: [PATCH 08/17] warnings --- jolt-core/src/r1cs/spartan.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 7e015c3f5..0058f91e2 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -11,7 +11,6 @@ use crate::utils::thread::drop_in_background_thread; use crate::utils::transcript::ProofTranscript; use ark_serialize::CanonicalDeserialize; use ark_serialize::CanonicalSerialize; -use rayon::prelude::*; use thiserror::Error; From 5d51a234e0d72fdabfb19d5137f478a93871a39d Mon Sep 17 00:00:00 2001 From: sragss Date: Wed, 26 Jun 2024 11:47:30 -0700 Subject: [PATCH 09/17] clippy --- jolt-core/src/lib.rs | 2 +- jolt-core/src/poly/commitment/mock.rs | 4 +-- jolt-core/src/r1cs/builder.rs | 5 ++-- jolt-core/src/r1cs/jolt_constraints.rs | 2 +- jolt-core/src/r1cs/special_polys.rs | 29 +++++++++---------- .../src/subprotocols/grand_product_quarks.rs | 4 +-- 6 files changed, 21 insertions(+), 25 deletions(-) diff --git a/jolt-core/src/lib.rs b/jolt-core/src/lib.rs index f808ed946..e2dc21197 100644 --- a/jolt-core/src/lib.rs +++ b/jolt-core/src/lib.rs @@ -8,7 +8,7 @@ #![feature(generic_const_exprs)] #![feature(iter_next_chunk)] #![allow(long_running_const_eval)] - +#[allow(clippy::len_without_is_empty)] #[cfg(feature = "host")] pub mod benches; diff --git a/jolt-core/src/poly/commitment/mock.rs b/jolt-core/src/poly/commitment/mock.rs index e191ff454..ab283a53f 100644 --- a/jolt-core/src/poly/commitment/mock.rs +++ b/jolt-core/src/poly/commitment/mock.rs @@ -41,9 +41,7 @@ impl CommitmentScheme for MockCommitScheme { type Proof = MockProof; type BatchedProof = MockProof; - fn setup(_shapes: &[CommitShape]) -> Self::Setup { - () - } + fn setup(_shapes: &[CommitShape]) -> Self::Setup {} fn commit(poly: &DensePolynomial, _setup: &Self::Setup) -> Self::Commitment { MockCommitment { poly: poly.to_owned(), diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index e2cf70af9..f0297fbe9 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -46,7 +46,7 @@ struct Constraint { impl Constraint { #[cfg(test)] - fn is_sat(&self, inputs: &Vec) -> bool { + fn is_sat(&self, inputs: &[i64]) -> bool { // Find the number of variables and the number of aux. Inputs should be equal to this combined length let num_inputs = I::COUNT; @@ -772,6 +772,7 @@ impl CombinedUniformBuilder { /// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]] /// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]] #[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan_sparse")] + #[allow(clippy::type_complexity)] pub fn compute_spartan_Az_Bz_Cz( &self, inputs: &[Vec], @@ -967,7 +968,7 @@ mod tests { ) -> F { let multi_step_inputs: Vec> = single_step_inputs .iter() - .map(|input| vec![input.clone()]) + .map(|input| vec![*input]) .collect(); let multi_step_inputs_ref: Vec<&[F]> = multi_step_inputs.iter().map(|v| v.as_slice()).collect(); diff --git a/jolt-core/src/r1cs/jolt_constraints.rs b/jolt-core/src/r1cs/jolt_constraints.rs index 6b1ef1125..221d9a7c8 100644 --- a/jolt-core/src/r1cs/jolt_constraints.rs +++ b/jolt-core/src/r1cs/jolt_constraints.rs @@ -335,7 +335,7 @@ mod tests { let aux = combined_builder.compute_aux(&inputs); - // Implicitly asserts validity let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux); + combined_builder.assert_valid(&az, &bz, &cz); } } diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index 8a4361f04..47f291ce9 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -80,6 +80,7 @@ impl SparsePolynomial { /// Returns `n` chunks of roughly even size without separating siblings (adjacent dense indices). Additionally returns a vector of [low, high) dense index ranges. #[tracing::instrument(skip_all)] + #[allow(clippy::type_complexity)] fn chunk_no_split_siblings(&self, n: usize) -> (Vec<&[(F, usize)]>, Vec<(usize, usize)>) { if self.Z.len() < n * 2 { return (vec![(&self.Z)], vec![(0, self.num_vars.pow2())]); @@ -135,13 +136,11 @@ impl SparsePolynomial { } else { new_Z.push(((F::one() - r) * value, new_dense_index)); } + } else if sparse_index > 0 && self.Z[sparse_index - 1].1 == dense_index - 1 { + continue; } else { - if sparse_index > 0 && self.Z[sparse_index - 1].1 == dense_index - 1 { - continue; - } else { - let new_dense_index = (dense_index - 1) / 2; - new_Z.push((*r * value, new_dense_index)); - } + let new_dense_index = (dense_index - 1) / 2; + new_Z.push((*r * value, new_dense_index)); } } self.Z = new_Z; @@ -214,16 +213,14 @@ impl SparsePolynomial { (mul_0_1_optimized(&(F::one() - r), value), new_dense_index); write_index += 1; } + } else if sparse_index > 0 && chunk[sparse_index - 1].1 == dense_index - 1 { + // (low, high) present, but handeled prior + continue; } else { - if sparse_index > 0 && chunk[sparse_index - 1].1 == dense_index - 1 { - // (low, high) present, but handeled prior - continue; - } else { - // (_, high) present - let new_dense_index = (dense_index - 1) / 2; - mutable[write_index] = (mul_0_1_optimized(r, value), new_dense_index); - write_index += 1; - } + // (_, high) present + let new_dense_index = (dense_index - 1) / 2; + mutable[write_index] = (mul_0_1_optimized(r, value), new_dense_index); + write_index += 1; } } }); @@ -235,7 +232,7 @@ impl SparsePolynomial { pub fn final_eval(&self) -> F { assert_eq!(self.num_vars, 0); - if self.Z.len() == 0 { + if self.Z.is_empty() { F::zero() } else { assert_eq!(self.Z.len(), 1); diff --git a/jolt-core/src/subprotocols/grand_product_quarks.rs b/jolt-core/src/subprotocols/grand_product_quarks.rs index 2ceb7ac9e..e74a79f72 100644 --- a/jolt-core/src/subprotocols/grand_product_quarks.rs +++ b/jolt-core/src/subprotocols/grand_product_quarks.rs @@ -618,7 +618,7 @@ mod quark_grand_product_tests { fn quark_e2e() { const LAYER_SIZE: usize = 1 << 8; - let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(9 as u64); + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(9_u64); let leaves_1: Vec = std::iter::repeat_with(|| Fr::random(&mut rng)) .take(LAYER_SIZE) @@ -647,7 +647,7 @@ mod quark_grand_product_tests { fn quark_hybrid_e2e() { const LAYER_SIZE: usize = 1 << 8; - let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(9 as u64); + let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(9_u64); let leaves_1: Vec = std::iter::repeat_with(|| Fr::random(&mut rng)) .take(LAYER_SIZE) From fbeba2e6f7742aefb77e85eb5502e2629b90015e Mon Sep 17 00:00:00 2001 From: sragss Date: Wed, 26 Jun 2024 11:50:04 -0700 Subject: [PATCH 10/17] more clippy --- jolt-core/src/lib.rs | 3 ++- jolt-core/src/r1cs/special_polys.rs | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/jolt-core/src/lib.rs b/jolt-core/src/lib.rs index e2dc21197..965f9f090 100644 --- a/jolt-core/src/lib.rs +++ b/jolt-core/src/lib.rs @@ -8,7 +8,8 @@ #![feature(generic_const_exprs)] #![feature(iter_next_chunk)] #![allow(long_running_const_eval)] -#[allow(clippy::len_without_is_empty)] +#![allow(clippy::len_without_is_empty)] + #[cfg(feature = "host")] pub mod benches; diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index 47f291ce9..a42b01dab 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -705,12 +705,14 @@ mod tests { } } - let mut a_poly = SparsePolynomial::new(num_vars, a); - + let a_poly = SparsePolynomial::new(num_vars, a); let r = Fr::from(100); - assert_eq!( - a_poly.clone().bound_poly_var_bot(&r), - a_poly.bound_poly_var_bot_par(&r) - ); + + let mut regular = a_poly.clone(); + regular.bound_poly_var_bot(&r); + + let mut par = a_poly.clone(); + par.bound_poly_var_bot(&r); + assert_eq!(regular, par); } } From 4f0293c44e56a2748ea9210f25c8039a9ebd2ec2 Mon Sep 17 00:00:00 2001 From: sragss Date: Wed, 26 Jun 2024 12:09:46 -0700 Subject: [PATCH 11/17] rm eval_hint --- jolt-core/src/r1cs/builder.rs | 47 +++++------------------------------ 1 file changed, 6 insertions(+), 41 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index f0297fbe9..e69cdecc8 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -27,21 +27,12 @@ pub trait R1CSConstraintBuilder { fn build_constraints(&self, builder: &mut R1CSBuilder); } -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum EvalHint { - Zero = 0, - Other = 1, -} - /// Constraints over a single row. Each variable points to a single item in Z and the corresponding coefficient. #[derive(Clone, Debug)] struct Constraint { a: LC, b: LC, c: LC, - - /// Shortcut for evaluation of a, b, c for an honest prover - eval_hint: (EvalHint, EvalHint, EvalHint), } impl Constraint { @@ -239,7 +230,6 @@ impl R1CSBuilder { a, b, c: LC::zero(), - eval_hint: (EvalHint::Zero, EvalHint::Other, EvalHint::Zero), }; self.constraints.push(constraint); } @@ -258,12 +248,7 @@ impl R1CSBuilder { let a = condition; let b = left - right; let c = LC::zero(); - let constraint = Constraint { - a, - b, - c, - eval_hint: (EvalHint::Other, EvalHint::Other, EvalHint::Zero), - }; // TODO(sragss): Can do better on middle term. + let constraint = Constraint { a, b, c }; // TODO(sragss): Can do better on middle term. self.constraints.push(constraint); } @@ -276,7 +261,6 @@ impl R1CSBuilder { a, b, c: LC::zero(), - eval_hint: (EvalHint::Other, EvalHint::Other, EvalHint::Zero), }; self.constraints.push(constraint); } @@ -300,7 +284,6 @@ impl R1CSBuilder { a: condition.clone(), b: (result_true - result_false.clone()), c: (alleged_result - result_false), - eval_hint: (EvalHint::Other, EvalHint::Other, EvalHint::Other), // TODO(sragss): Is this the best we can do? }; self.constraints.push(constraint); } @@ -446,7 +429,6 @@ impl R1CSBuilder { a: x.into(), b: y.into(), c: z.into(), - eval_hint: (EvalHint::Other, EvalHint::Other, EvalHint::Other), }; self.constraints.push(constraint); } @@ -805,8 +787,8 @@ impl CombinedUniformBuilder { .map(|(constraint_index, constraint)| { let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat); - let mut evaluate_lc_chunk = |hint, lc: &LC| { - if hint != EvalHint::Zero { + let mut evaluate_lc_chunk = |lc: &LC| { + if lc.terms().len() != 0 { let inputs = batch_inputs(lc); lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer); @@ -827,12 +809,9 @@ impl CombinedUniformBuilder { } }; - let a_chunk: Vec<(F, usize)> = - evaluate_lc_chunk(constraint.eval_hint.0, &constraint.a); - let b_chunk: Vec<(F, usize)> = - evaluate_lc_chunk(constraint.eval_hint.1, &constraint.b); - let c_chunk: Vec<(F, usize)> = - evaluate_lc_chunk(constraint.eval_hint.2, &constraint.c); + let a_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.a); + let b_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.b); + let c_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.c); (a_chunk, b_chunk, c_chunk) }) @@ -937,20 +916,6 @@ impl CombinedUniformBuilder { self.uniform_builder.constraints[uniform_constraint_index] ); } - // Verify hints - if constraint_index < self.uniform_repeat_constraint_rows() { - let (hint_a, hint_b, hint_c) = - self.uniform_builder.constraints[uniform_constraint_index].eval_hint; - if hint_a == EvalHint::Zero { - assert_eq!(az[constraint_index], F::zero(), "Mismatch at global constraint {constraint_index} uniform constraint: {uniform_constraint_index}"); - } - if hint_b == EvalHint::Zero { - assert_eq!(bz[constraint_index], F::zero(), "Mismatch at global constraint {constraint_index} uniform constraint: {uniform_constraint_index}"); - } - if hint_c == EvalHint::Zero { - assert_eq!(cz[constraint_index], F::zero(), "Mismatch at global constraint {constraint_index} uniform constraint: {uniform_constraint_index}"); - } - } } } } From 55592e1eca74ef1ed06a5b40e7f867576ba9c4b2 Mon Sep 17 00:00:00 2001 From: sragss Date: Wed, 26 Jun 2024 12:15:12 -0700 Subject: [PATCH 12/17] clippy fix --- jolt-core/src/r1cs/builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index e69cdecc8..440e7ab66 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -788,7 +788,7 @@ impl CombinedUniformBuilder { let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat); let mut evaluate_lc_chunk = |lc: &LC| { - if lc.terms().len() != 0 { + if !lc.terms().is_empty() { let inputs = batch_inputs(lc); lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer); From 77fe6a184b7b781cd83a2fcc693d95a146b9e19b Mon Sep 17 00:00:00 2001 From: sragss Date: Tue, 16 Jul 2024 21:40:28 -0700 Subject: [PATCH 13/17] address comments round 1 --- jolt-core/src/poly/dense_mlpoly.rs | 10 ++++ jolt-core/src/r1cs/builder.rs | 66 ++++++++++++------------ jolt-core/src/r1cs/spartan.rs | 12 ++--- jolt-core/src/r1cs/special_polys.rs | 70 ++++++++++---------------- jolt-core/src/subprotocols/sumcheck.rs | 2 +- 5 files changed, 76 insertions(+), 84 deletions(-) diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 6bfd3aab2..34b9932c5 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -203,6 +203,16 @@ impl DensePolynomial { #[tracing::instrument(skip_all)] pub fn bound_poly_var_bot(&mut self, r: &F) { + let n = self.len() / 2; + for i in 0..n { + self.Z[i] = self.Z[2 * i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]); + } + + self.num_vars -= 1; + self.len = n; + } + + pub fn bound_poly_var_bot_01_optimized(&mut self, r: &F) { let n = self.len() / 2; let mut new_z = unsafe_allocate_zero_vec(n); new_z.par_iter_mut().enumerate().for_each(|(i, z)| { diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 440e7ab66..c6e69c62e 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -779,43 +779,41 @@ impl CombinedUniformBuilder { // uniform_constraints: Xz[0..uniform_constraint_rows] let span = tracing::span!(tracing::Level::DEBUG, "uniform_evals"); let _enter = span.enter(); - let uni_constraint_evals: Vec<(Vec<(F, usize)>, Vec<(F, usize)>, Vec<(F, usize)>)> = - self.uniform_builder - .constraints - .par_iter() - .enumerate() - .map(|(constraint_index, constraint)| { - let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat); - - let mut evaluate_lc_chunk = |lc: &LC| { - if !lc.terms().is_empty() { - let inputs = batch_inputs(lc); - lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer); - - // Take only the non-zero elements and represent them as sparse tuples (eval, dense_index) - let mut sparse = Vec::with_capacity(self.uniform_repeat); // overshoot - dense_output_buffer.iter().enumerate().for_each( - |(local_index, item)| { - if !item.is_zero() { - let global_index = - constraint_index * self.uniform_repeat + local_index; - sparse.push((*item, global_index)); - } - }, - ); - sparse - } else { - vec![] + let uni_constraint_evals: Vec<(Vec<(F, usize)>, Vec<(F, usize)>, Vec<(F, usize)>)> = self + .uniform_builder + .constraints + .par_iter() + .enumerate() + .map(|(constraint_index, constraint)| { + let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat); + + let mut evaluate_lc_chunk = |lc: &LC| { + if !lc.terms().is_empty() { + let inputs = batch_inputs(lc); + lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer); + + // Take only the non-zero elements and represent them as sparse tuples (eval, dense_index) + let mut sparse = Vec::with_capacity(self.uniform_repeat); // overshoot + for (local_index, item) in dense_output_buffer.iter().enumerate() { + if !item.is_zero() { + let global_index = + constraint_index * self.uniform_repeat + local_index; + sparse.push((*item, global_index)); + } } - }; + sparse + } else { + vec![] + } + }; - let a_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.a); - let b_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.b); - let c_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.c); + let a_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.a); + let b_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.b); + let c_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.c); - (a_chunk, b_chunk, c_chunk) - }) - .collect(); + (a_chunk, b_chunk, c_chunk) + }) + .collect(); let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple( uni_constraint_evals, diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 0058f91e2..4e603a574 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -112,20 +112,20 @@ impl> UniformSpartanProof { let aux = &segmented_padded_witness.segments[I::COUNT..]; let (mut az, mut bz, mut cz) = constraint_builder.compute_spartan_Az_Bz_Cz(inputs, aux); - let comb_func_outer = |A: &F, B: &F, C: &F, D: &F| -> F { + let comb_func_outer = |eq: &F, az: &F, bz: &F, cz: &F| -> F { // Below is an optimized form of: *A * (*B * *C - *D) - if B.is_zero() || C.is_zero() { - if D.is_zero() { + if az.is_zero() || bz.is_zero() { + if cz.is_zero() { F::zero() } else { - *A * (-(*D)) + *eq * (-(*cz)) } } else { - let inner = *B * *C - *D; + let inner = *az * *bz - *cz; if inner.is_zero() { F::zero() } else { - *A * inner + *eq * inner } } }; diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index a42b01dab..9dfb637e3 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -23,33 +23,6 @@ impl SparsePolynomial { SparsePolynomial { num_vars, Z } } - // TODO(sragss): rm - #[tracing::instrument(skip_all)] - pub fn from_dense_evals(num_vars: usize, evals: Vec) -> Self { - assert!(num_vars.pow2() >= evals.len()); - let non_zero_count: usize = evals - .par_chunks(10_000) - .map(|chunk| chunk.iter().filter(|f| !f.is_zero()).count()) - .sum(); - - let span_allocate = tracing::span!(tracing::Level::DEBUG, "allocate"); - let _enter_allocate = span_allocate.enter(); - let mut sparse: Vec<(F, usize)> = unsafe_allocate_sparse_zero_vec(non_zero_count); - drop(_enter_allocate); - - let span_copy = tracing::span!(tracing::Level::DEBUG, "copy"); - let _enter_copy = span_copy.enter(); - let mut sparse_index = 0; - for (dense_index, dense) in evals.iter().enumerate() { - if !dense.is_zero() { - sparse[sparse_index] = (*dense, dense_index); - sparse_index += 1; - } - } - drop(_enter_copy); - Self::new(num_vars, sparse) - } - /// Computes the $\tilde{eq}$ extension polynomial. /// return 1 when a == r, otherwise return 0. fn compute_chi(a: &[bool], r: &[F]) -> F { @@ -288,7 +261,14 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { // B is assumed most dense. Parallelism depends on evenly distributing B across threads. assert!(b.Z.len() >= a.Z.len() && b.Z.len() >= c.Z.len()); - // TODO(sragss): Explain the strategy + // We'd like to scan over 3 SparsePolynomials (a,b,c) in `n` chunks for parallelism. + // With dense polynomials we could split directly by index, with SparsePolynomials we don't + // know the distribution of indices in the polynomials in advance. + // Further, the dense indices do not match: a[i].dense_index != b[i].dense_index != c[i].dense_index + // We expect b.len() >> max(a.len(), c.len()), so we'll split b first and use as a guide for (a,c). + // We'll split it into `n` chunks of roughly even length, but we will not split "sibling" dense indices across + // chunks as the presence of the pair is relevant to downstream algos. + // Dense siblings: (0,1), (2,3), ... let (b_chunks, mut dense_ranges) = b.chunk_no_split_siblings(n); let highest_non_zero = [&a.Z, &b.Z, &c.Z] @@ -303,35 +283,38 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { // Create chunks of (a, c) which overlap with b's sparse indices let mut a_chunks: Vec<&[(F, usize)]> = vec![&[]; n]; let mut c_chunks: Vec<&[(F, usize)]> = vec![&[]; n]; - let mut a_i = 0; - let mut c_i = 0; + let mut a_sparse_i = 0; + let mut c_sparse_i = 0; let span = tracing::span!(tracing::Level::DEBUG, "a_c_chunking"); let _enter = span.enter(); + // Using b's dense_ranges as a guide, fill out (a_chunks, c_chunks) for (chunk_index, range) in dense_ranges.iter().enumerate().skip(1) { // Find the corresponding a, c chunks - let prev_chunk_end = range.0; + let dense_range_end = range.0; - if a_i < a.Z.len() && a.Z[a_i].1 < prev_chunk_end { - let a_start = a_i; - while a_i < a.Z.len() && a.Z[a_i].1 < prev_chunk_end { - a_i += 1; + if a_sparse_i < a.Z.len() && a.Z[a_sparse_i].1 < dense_range_end { + let a_start = a_sparse_i; + // Scan over a until the corresponding dense index is out of range + while a_sparse_i < a.Z.len() && a.Z[a_sparse_i].1 < dense_range_end { + a_sparse_i += 1; } - a_chunks[chunk_index - 1] = &a.Z[a_start..a_i]; + a_chunks[chunk_index - 1] = &a.Z[a_start..a_sparse_i]; } - if c_i < c.Z.len() && c.Z[c_i].1 < prev_chunk_end { - let c_start = c_i; - while c_i < c.Z.len() && c.Z[c_i].1 < prev_chunk_end { - c_i += 1; + if c_sparse_i < c.Z.len() && c.Z[c_sparse_i].1 < dense_range_end { + let c_start = c_sparse_i; + // Scan over c until the corresponding dense index is out of range + while c_sparse_i < c.Z.len() && c.Z[c_sparse_i].1 < dense_range_end { + c_sparse_i += 1; } - c_chunks[chunk_index - 1] = &c.Z[c_start..c_i]; + c_chunks[chunk_index - 1] = &c.Z[c_start..c_sparse_i]; } } drop(_enter); - a_chunks[n - 1] = &a.Z[a_i..]; - c_chunks[n - 1] = &c.Z[c_i..]; + a_chunks[n - 1] = &a.Z[a_sparse_i..]; + c_chunks[n - 1] = &c.Z[c_sparse_i..]; #[cfg(test)] { @@ -340,6 +323,7 @@ impl<'a, F: JoltField> SparseTripleIterator<'a, F> { assert_eq!(c_chunks.concat(), c.Z); } + // Assemble the triple iterator objects let mut iterators: Vec> = Vec::with_capacity(n); for (((a_chunk, b_chunk), c_chunk), range) in a_chunks .iter() diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index a405fbfa3..151c73334 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -294,7 +294,7 @@ impl SumcheckInstanceProof { claim_per_round = poly.evaluate(&r_i); // bound all tables to the verifier's challenege - poly_eq.bound_poly_var_bot(&r_i); + poly_eq.bound_poly_var_bot_01_optimized(&r_i); poly_A.bound_poly_var_bot_par(&r_i); poly_B.bound_poly_var_bot_par(&r_i); poly_C.bound_poly_var_bot_par(&r_i); From 193dcf04dc65fac738ec2029e5423cb16eca0d2d Mon Sep 17 00:00:00 2001 From: sragss Date: Tue, 16 Jul 2024 23:06:59 -0700 Subject: [PATCH 14/17] clean up sparse non-uniform calc --- jolt-core/src/r1cs/builder.rs | 56 ++++++++++++++++------------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index c6e69c62e..b823f0b88 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -841,41 +841,35 @@ impl CombinedUniformBuilder { .1 .evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat); - let dense_az_bz: Vec<(F, F)> = (0..self.uniform_repeat) - .into_par_iter() - .map(|step_index| { - // Write corresponding values, if outside the step range, only include the constant. - let a_step = step_index + if constr.a.0 { 1 } else { 0 }; - let b_step = step_index + if constr.b.0 { 1 } else { 0 }; - let a = eq_a_evals - .get(a_step) - .cloned() - .unwrap_or(constr.a.1.constant_term_field()); - let b = eq_b_evals - .get(b_step) - .cloned() - .unwrap_or(constr.b.1.constant_term_field()); - let az = a - b; - - let condition_step = step_index + if constr.cond.0 { 1 } else { 0 }; - let bz = condition_evals - .get(condition_step) - .cloned() - .unwrap_or(constr.cond.1.constant_term_field()); - (az, bz) - }) - .collect(); - - // Sparsify: take only the non-zero elements - for (local_index, (az, bz)) in dense_az_bz.iter().enumerate() { - let global_index = uniform_constraint_rows + local_index; + (0..self.uniform_repeat).into_iter().for_each(|step_index| { + // Write corresponding values, if outside the step range, only include the constant. + let a_step = step_index + constr.a.0 as usize; + let b_step = step_index + constr.b.0 as usize; + let a = eq_a_evals + .get(a_step) + .cloned() + .unwrap_or(constr.a.1.constant_term_field()); + let b = eq_b_evals + .get(b_step) + .cloned() + .unwrap_or(constr.b.1.constant_term_field()); + let az = a - b; + + let global_index = uniform_constraint_rows + step_index; if !az.is_zero() { - az_sparse.push((*az, global_index)); + az_sparse.push((az, global_index)); } + + let condition_step = step_index + constr.cond.0 as usize; + let bz = condition_evals + .get(condition_step) + .cloned() + .unwrap_or(constr.cond.1.constant_term_field()); if !bz.is_zero() { - bz_sparse.push((*bz, global_index)); + bz_sparse.push((bz, global_index)); } - } + }); + drop(_enter); let num_vars = self.constraint_rows().next_power_of_two().log_2(); let az_poly = SparsePolynomial::new(num_vars, az_sparse); From be5e655def748df82d7f4e84b68d44f8f64c4664 Mon Sep 17 00:00:00 2001 From: sragss Date: Tue, 16 Jul 2024 23:36:33 -0700 Subject: [PATCH 15/17] fix tests --- jolt-core/src/poly/dense_mlpoly.rs | 1 + jolt-core/src/r1cs/special_polys.rs | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 34b9932c5..f28d4ff2a 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -201,6 +201,7 @@ impl DensePolynomial { } } + /// Note: does not truncate #[tracing::instrument(skip_all)] pub fn bound_poly_var_bot(&mut self, r: &F) { let n = self.len() / 2; diff --git a/jolt-core/src/r1cs/special_polys.rs b/jolt-core/src/r1cs/special_polys.rs index 9dfb637e3..e643d728f 100644 --- a/jolt-core/src/r1cs/special_polys.rs +++ b/jolt-core/src/r1cs/special_polys.rs @@ -532,7 +532,7 @@ mod tests { let r = Fr::from(121); sparse.bound_poly_var_bot(&r); - dense.bound_poly_var_bot(&r); + dense.bound_poly_var_bot_01_optimized(&r); assert_eq!(sparse.to_dense(), dense); } @@ -548,7 +548,7 @@ mod tests { let r = Fr::from(121); sparse.bound_poly_var_bot(&r); - dense.bound_poly_var_bot(&r); + dense.bound_poly_var_bot_01_optimized(&r); assert_eq!(sparse.to_dense(), dense); } @@ -579,7 +579,7 @@ mod tests { let r = Fr::from(121); sparse.bound_poly_var_bot(&r); - dense.bound_poly_var_bot(&r); + dense.bound_poly_var_bot_01_optimized(&r); assert_eq!(sparse.to_dense(), dense); } From cc5598c4dabf0be9ee359044453861fa9f638d4a Mon Sep 17 00:00:00 2001 From: sragss Date: Tue, 16 Jul 2024 23:40:40 -0700 Subject: [PATCH 16/17] fix clippy --- jolt-core/src/r1cs/builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index b823f0b88..c09aa5dea 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -841,7 +841,7 @@ impl CombinedUniformBuilder { .1 .evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat); - (0..self.uniform_repeat).into_iter().for_each(|step_index| { + (0..self.uniform_repeat).for_each(|step_index| { // Write corresponding values, if outside the step range, only include the constant. let a_step = step_index + constr.a.0 as usize; let b_step = step_index + constr.b.0 as usize; From 829aa49fc66a7bc6322739a7529f14da416f1092 Mon Sep 17 00:00:00 2001 From: sragss Date: Thu, 18 Jul 2024 17:09:16 -0700 Subject: [PATCH 17/17] rename sumcheck A,B,C,D -> eq,A,B,C --- jolt-core/src/subprotocols/sumcheck.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index 151c73334..b3e2f011f 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -215,27 +215,27 @@ impl SumcheckInstanceProof { let m_C = c_high - c_low; // eval 2 - let poly_A_bound_point = poly_eq[dense_index + 1] + m_eq; - let poly_B_bound_point = a_high + m_A; - let poly_C_bound_point = b_high + m_B; - let poly_D_bound_point = c_high + m_C; + let poly_eq_bound_point = poly_eq[dense_index + 1] + m_eq; + let poly_A_bound_point = a_high + m_A; + let poly_B_bound_point = b_high + m_B; + let poly_C_bound_point = c_high + m_C; eval_point_2 += comb_func( + &poly_eq_bound_point, &poly_A_bound_point, &poly_B_bound_point, &poly_C_bound_point, - &poly_D_bound_point, ); // eval 3 - let poly_A_bound_point = poly_A_bound_point + m_eq; - let poly_B_bound_point = poly_B_bound_point + m_A; - let poly_C_bound_point = poly_C_bound_point + m_B; - let poly_D_bound_point = poly_D_bound_point + m_C; + let poly_eq_bound_point = poly_eq_bound_point + m_eq; + let poly_A_bound_point = poly_A_bound_point + m_A; + let poly_B_bound_point = poly_B_bound_point + m_B; + let poly_C_bound_point = poly_C_bound_point + m_C; eval_point_3 += comb_func( + &poly_eq_bound_point, &poly_A_bound_point, &poly_B_bound_point, &poly_C_bound_point, - &poly_D_bound_point, ); } (eval_point_0, eval_point_2, eval_point_3)