From b3385e8b33ac4d4d36d08b03f2fa8c062ff759c4 Mon Sep 17 00:00:00 2001 From: Arasu Arun Date: Sat, 2 Nov 2024 21:14:04 -0400 Subject: [PATCH 1/8] add sanity check to sumcheck --- jolt-core/src/subprotocols/sumcheck.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index 008f892be..95553caa4 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -105,6 +105,17 @@ impl SumcheckInstanceProof = Vec::new(); let mut compressed_polys: Vec> = Vec::new(); + #[cfg(test)] + { + let total_evals = 1 << num_rounds; + let mut sum = F::zero(); + for i in 0..total_evals { + let params: Vec = polys.iter().map(|poly| poly[i]).collect(); + sum += comb_func(¶ms); + } + assert_eq!(&sum, _claim, "Sumcheck claim is wrong"); + } + for _round in 0..num_rounds { // Vector storing evaluations of combined polynomials g(x) = P_0(x) * ... P_{num_polys} (x) // for points {0, ..., |g(x)|} From c4d0bbfccce0af3b87fa7944dca29b0304511089 Mon Sep 17 00:00:00 2001 From: Arasu Arun Date: Sun, 3 Nov 2024 21:19:35 -0500 Subject: [PATCH 2/8] working crush optimization with only uniform constraints --- jolt-core/src/poly/dense_mlpoly.rs | 2 +- jolt-core/src/r1cs/key.rs | 45 +++++++++++----- jolt-core/src/r1cs/spartan.rs | 83 ++++++++++++++++++++++-------- 3 files changed, 94 insertions(+), 36 deletions(-) diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 550840005..aec80cd73 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -240,7 +240,7 @@ impl DensePolynomial { } pub fn evals(&self) -> Vec { - self.Z.clone() + self.Z[..self.len].to_owned() } pub fn evals_ref(&self) -> &[F] { diff --git a/jolt-core/src/r1cs/key.rs b/jolt-core/src/r1cs/key.rs index 4c777d3d9..afd5bb584 100644 --- a/jolt-core/src/r1cs/key.rs +++ b/jolt-core/src/r1cs/key.rs @@ -171,6 +171,10 @@ impl UniformSpartanKey usize { + self.uniform_r1cs.num_vars + } + /// Evaluates A(r_x, y) + r_rlc * B(r_x, y) + r_rlc^2 * C(r_x, y) where r_x = r_constr || r_step for all y. #[tracing::instrument(skip_all, name = "UniformSpartanKey::evaluate_r1cs_mle_rlc")] pub fn evaluate_r1cs_mle_rlc(&self, r_constr: &[F], r_step: &[F], r_rlc: F) -> Vec { @@ -183,7 +187,7 @@ impl UniformSpartanKey UniformSpartanKey, non_uni_constants: Option>| -> Vec { - // +1 for constant - let mut evals = unsafe_allocate_zero_vec(self.uniform_r1cs.num_vars + 1); + // evals: [inputs, aux ... 1, ...] where ... indicates padding to next power of 2 + let mut evals = + unsafe_allocate_zero_vec(self.uniform_r1cs.num_vars.next_power_of_two() * 2); for (row, col, val) in constraints.vars.iter() { evals[*col] += mul_0_1_optimized(val, &eq_rx_constr[*row]); } @@ -218,13 +223,14 @@ impl UniformSpartanKey>(); + /* Crush: not needed let mut rlc = unsafe_allocate_zero_vec(self.num_cols_total()); { @@ -243,6 +249,7 @@ impl UniformSpartanKey, @@ -271,12 +278,12 @@ impl UniformSpartanKey UniformSpartanKey UniformSpartanKey| -> F { let mut full_mle_evaluation: F = constraints @@ -334,7 +351,9 @@ impl UniformSpartanKey() - * eq_rx_ry_step; + ; + // Crush: + // * eq_rx_ry_step; full_mle_evaluation += constraints .consts diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 1be9ee2ed..38eebeabb 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -155,6 +155,30 @@ where + r_inner_sumcheck_RLC * claim_Bz + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * claim_Cz; + // Crush: + let num_steps_bits_ = constraint_builder + .uniform_repeat() + .next_power_of_two() + .ilog2() as usize; + let num_constraints_bits = key.num_cons_total.log_2() - num_steps_bits_; + let r_x_step = &outer_sumcheck_r[num_constraints_bits..]; + + let mut z: Vec = flattened_polys.clone().into_iter().map(|poly| { + let mut resized = poly.Z.clone(); + resized.resize(poly.len().next_power_of_two(), F::zero()); + resized + }).flatten().collect(); + z.resize(z.len().next_power_of_two(), F::zero()); + + let mut poly_z = DensePolynomial::new(z.clone()); + for r_s in r_x_step.iter().rev() { + poly_z.bound_poly_var_bot(r_s); + } + let mut evals = poly_z.evals(); + evals.push(F::one()); + evals.resize(evals.len().next_power_of_two(), F::zero()); + poly_z = DensePolynomial::new(evals); + // this is the polynomial extended from the vector r_A * A(r_x, y) + r_B * B(r_x, y) + r_C * C(r_x, y) for all y let num_steps_bits = constraint_builder .uniform_repeat() @@ -162,24 +186,32 @@ where .ilog2(); let (rx_con, rx_ts) = outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_steps_bits as usize); - let mut poly_ABC = + let poly_ABC = DensePolynomial::new(key.evaluate_r1cs_mle_rlc(rx_con, rx_ts, r_inner_sumcheck_RLC)); - + assert_eq!(poly_z.len(), poly_ABC.len()); + + // Crush: second sumcheck call + let num_rounds = (key.num_vars_uniform() * 2).next_power_of_two().log_2(); + let mut polys = vec![poly_ABC, poly_z]; + let comb_func = |poly_evals: &[F]| -> F { + assert_eq!(poly_evals.len(), 2); + poly_evals[0] * poly_evals[1] + }; let (inner_sumcheck_proof, inner_sumcheck_r, _claims_inner) = - SumcheckInstanceProof::prove_spartan_quadratic( - &claim_inner_joint, // r_A * v_A + r_B * v_B + r_C * v_C - num_rounds_y, - &mut poly_ABC, // r_A * A(r_x, y) + r_B * B(r_x, y) + r_C * C(r_x, y) for all y - &flattened_polys, - transcript, - ); - drop_in_background_thread(poly_ABC); - - // Requires 'r_col_segment_bits' to index the (const, segment). Within that segment we index the step using 'r_col_step' - let r_col_segment_bits = key.uniform_r1cs.num_vars.next_power_of_two().log_2() + 1; - let r_col_step = &inner_sumcheck_r[r_col_segment_bits..]; - - let chi = EqPolynomial::evals(r_col_step); + SumcheckInstanceProof::prove_arbitrary( + &claim_inner_joint, + num_rounds, + &mut polys, + comb_func, + 2, + transcript); + + drop_in_background_thread(polys); + + // Crush: + let r_z = r_x_step; + + let chi = EqPolynomial::evals(r_z); let claimed_witness_evals: Vec<_> = flattened_polys .par_iter() .map(|poly| poly.evaluate_at_chi_low_optimized(&chi)) @@ -188,7 +220,7 @@ where opening_accumulator.append( &flattened_polys, DensePolynomial::new(chi), - r_col_step.to_vec(), + r_z.to_vec(), &claimed_witness_evals.iter().collect::>(), transcript, ); @@ -260,18 +292,23 @@ where + r_inner_sumcheck_RLC * self.outer_sumcheck_claims.1 + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * self.outer_sumcheck_claims.2; + let num_rounds = (key.num_vars_uniform() * 2).next_power_of_two().log_2(); let (claim_inner_final, inner_sumcheck_r) = self .inner_sumcheck_proof - .verify(claim_inner_joint, num_rounds_y, 2, transcript) + .verify(claim_inner_joint, num_rounds, 2, transcript) .map_err(|_| SpartanError::InvalidInnerSumcheckProof)?; // n_prefix = n_segments + 1 let n_prefix = key.uniform_r1cs.num_vars.next_power_of_two().log_2() + 1; - let eval_Z = key.evaluate_z_mle(&self.claimed_witness_evals, &inner_sumcheck_r); + // Crush: + let n_constraint_bits_uniform = key.uniform_r1cs.num_rows.next_power_of_two().log_2(); + let outer_sumcheck_r_step = &r_x[n_constraint_bits_uniform..]; + let y_prime = [inner_sumcheck_r.to_owned(), outer_sumcheck_r_step.to_owned()].concat(); + let eval_Z = key.evaluate_z_mle(&self.claimed_witness_evals, &y_prime); - let r_y = inner_sumcheck_r.clone(); - let r = [r_x, r_y].concat(); + // Crush: + let r = [r_x.clone(), y_prime].concat(); let (eval_a, eval_b, eval_c) = key.evaluate_r1cs_matrix_mles(&r); let left_expected = eval_a @@ -279,6 +316,7 @@ where + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * eval_c; let right_expected = eval_Z; let claim_inner_final_expected = left_expected * right_expected; + if claim_inner_final != claim_inner_final_expected { return Err(SpartanError::InvalidInnerSumcheckClaim); } @@ -287,7 +325,8 @@ where .iter() .map(|var| var.get_ref(commitments)) .collect(); - let r_y_point = &inner_sumcheck_r[n_prefix..]; + // Crush: + let r_y_point = &r_x[n_constraint_bits_uniform..]; opening_accumulator.append( &flattened_commitments, r_y_point.to_vec(), From 1a32c3b8553f6c35e5ea68b62a977308f0e586c4 Mon Sep 17 00:00:00 2001 From: Arasu Arun Date: Tue, 19 Nov 2024 22:28:07 -0500 Subject: [PATCH 3/8] working prover second sumcheck --- jolt-core/src/r1cs/builder.rs | 14 ++-- jolt-core/src/r1cs/key.rs | 89 ++++++++++---------- jolt-core/src/r1cs/spartan.rs | 152 +++++++++++++++++++++++++++++++--- 3 files changed, 191 insertions(+), 64 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 85be551c5..08a934f3a 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -671,7 +671,9 @@ impl CombinedUniformBuilder CombinedUniformBuilder CombinedUniformBuilder UniformSpartanKey UniformSpartanKey, non_uni_constants: Option>| -> Vec { // evals: [inputs, aux ... 1, ...] where ... indicates padding to next power of 2 let mut evals = - unsafe_allocate_zero_vec(self.uniform_r1cs.num_vars.next_power_of_two() * 2); + unsafe_allocate_zero_vec(self.uniform_r1cs.num_vars.next_power_of_two() * 4); // *4 to accommodate cross-step constraints for (row, col, val) in constraints.vars.iter() { evals[*col] += mul_0_1_optimized(val, &eq_rx_constr[*row]); } @@ -209,6 +209,8 @@ impl UniformSpartanKey UniformSpartanKey UniformSpartanKey = [r_row_constr, r_row_step].concat(); -// for i in 0..key.num_cols_total() { -// let col_coordinate = index_to_field_bitvector(i, col_coordinate_len); - -// let coordinate: Vec = [row_coordinate.clone(), col_coordinate].concat(); -// let expected_rlc = a.evaluate(&coordinate) -// + r_rlc * b.evaluate(&coordinate) -// + r_rlc * r_rlc * c.evaluate(&coordinate); - -// assert_eq!(expected_rlc, rlc[i], "Failed at {i}"); -// } -// } + // #[test] + // fn evaluate_r1cs_mle_rlc() { + // let (_builder, key) = simp_test_builder_key(); + // let (a, b, c) = simp_test_big_matrices(); + // let a = DensePolynomial::new(a); + // let b = DensePolynomial::new(b); + // let c = DensePolynomial::new(c); + + // let r_row_constr_len = (key.uniform_r1cs.num_rows + 1).next_power_of_two().log_2(); + // let r_col_step_len = key.num_steps.log_2(); + + // let r_row_constr = vec![Fr::from(100), Fr::from(200)]; + // let r_row_step = vec![Fr::from(100), Fr::from(200)]; + // assert_eq!(r_row_constr.len(), r_row_constr_len); + // assert_eq!(r_row_step.len(), r_col_step_len); + // let r_rlc = Fr::from(1000); + + // let rlc = key.evaluate_r1cs_mle_rlc(&r_row_constr, &r_row_step, r_rlc); + + // // let row_coordinate_len = key.num_rows_total().log_2(); + // let col_coordinate_len = key.num_cols_total().log_2(); + // let row_coordinate: Vec = [r_row_constr, r_row_step].concat(); + // for i in 0..key.num_cols_total() { + // let col_coordinate = index_to_field_bitvector(i, col_coordinate_len); + + // let coordinate: Vec = [row_coordinate.clone(), col_coordinate].concat(); + // let expected_rlc = a.evaluate(&coordinate) + // + r_rlc * b.evaluate(&coordinate) + // + r_rlc * r_rlc * c.evaluate(&coordinate); + + // assert_eq!(expected_rlc, rlc[i], "Failed at {i}"); + // } + // } // #[test] // fn r1cs_matrix_mles_offset_constraints() { diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 38eebeabb..63d83f797 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -1,5 +1,6 @@ #![allow(clippy::len_without_is_empty)] +// use core::range; use std::marker::PhantomData; use crate::field::JoltField; @@ -23,6 +24,7 @@ use thiserror::Error; use crate::{ poly::{dense_mlpoly::DensePolynomial, eq_poly::EqPolynomial}, subprotocols::sumcheck::SumcheckInstanceProof, + r1cs::special_polys::eq_plus_one, }; use super::builder::CombinedUniformBuilder; @@ -156,6 +158,7 @@ where + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * claim_Cz; // Crush: + let num_steps_padded = constraint_builder.uniform_repeat().next_power_of_two(); let num_steps_bits_ = constraint_builder .uniform_repeat() .next_power_of_two() @@ -170,14 +173,46 @@ where }).flatten().collect(); z.resize(z.len().next_power_of_two(), F::zero()); - let mut poly_z = DensePolynomial::new(z.clone()); - for r_s in r_x_step.iter().rev() { - poly_z.bound_poly_var_bot(r_s); - } - let mut evals = poly_z.evals(); - evals.push(F::one()); + let is_last_step = EqPolynomial::new(r_x_step.to_vec()).evaluate(&vec![F::one(); r_x_step.len()]); + let eq_rx_step = EqPolynomial::evals(r_x_step); + + // // Evals with binding. Doesn't ignore when r_x_step is the last step. + // for r_s in r_x_step.iter().rev() { + // poly_z.bound_poly_var_bot(r_s); + // } + // let mut evals = poly_z.evals(); + // evals.push(F::one() - is_last_step); // ARASU: IGNORE LAST STEP? + // evals.resize(evals.len().next_power_of_two() * 1, F::zero()); + + // Evals straightfoward + let mut evals: Vec = (0..key.num_vars_uniform()) // until the constant (which is not included) + .map(|y_var| { + (0..(num_steps_padded-1)) // Ignore the last step + .map(|t| z[y_var * num_steps_padded + t] * eq_rx_step[t]) + .sum() + }) + .collect(); + evals.resize(evals.len().next_power_of_two(), F::zero()); + evals.push(F::one() - is_last_step); // Constant, ignores the last step. evals.resize(evals.len().next_power_of_two(), F::zero()); - poly_z = DensePolynomial::new(evals); + + let n_bits_ts = r_x_step.len(); + let eq_plus_one_rx_step: Vec = (0..num_steps_padded) + .map(|t| eq_plus_one(r_x_step, &crate::utils::index_to_field_bitvector(t, n_bits_ts), n_bits_ts)) + .collect(); + + let mut evals_shifted = (0..key.num_vars_uniform()) + .map(|y_var: usize| { + (0..num_steps_padded-1) + .map(|t| + z[y_var * num_steps_padded + t] * eq_plus_one_rx_step[t] + ) + .sum::() + }) + .collect::>(); + evals_shifted.resize(evals.len(), F::zero()); + + let poly_z = DensePolynomial::new(evals.into_iter().chain(evals_shifted.into_iter()).collect()); // this is the polynomial extended from the vector r_A * A(r_x, y) + r_B * B(r_x, y) + r_C * C(r_x, y) for all y let num_steps_bits = constraint_builder @@ -189,9 +224,9 @@ where let poly_ABC = DensePolynomial::new(key.evaluate_r1cs_mle_rlc(rx_con, rx_ts, r_inner_sumcheck_RLC)); assert_eq!(poly_z.len(), poly_ABC.len()); + assert_eq!(poly_ABC.len(), key.num_vars_uniform().next_power_of_two() * 4); // *4 to support cross_step constraints - // Crush: second sumcheck call - let num_rounds = (key.num_vars_uniform() * 2).next_power_of_two().log_2(); + let num_rounds = poly_ABC.len().log_2(); let mut polys = vec![poly_ABC, poly_z]; let comb_func = |poly_evals: &[F]| -> F { assert_eq!(poly_evals.len(), 2); @@ -211,7 +246,7 @@ where // Crush: let r_z = r_x_step; - let chi = EqPolynomial::evals(r_z); + let chi = EqPolynomial::evals(&r_z); let claimed_witness_evals: Vec<_> = flattened_polys .par_iter() .map(|poly| poly.evaluate_at_chi_low_optimized(&chi)) @@ -292,10 +327,11 @@ where + r_inner_sumcheck_RLC * self.outer_sumcheck_claims.1 + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * self.outer_sumcheck_claims.2; + let num_rounds = (key.num_vars_uniform() * 2).next_power_of_two().log_2(); let (claim_inner_final, inner_sumcheck_r) = self .inner_sumcheck_proof - .verify(claim_inner_joint, num_rounds, 2, transcript) + .verify(claim_inner_joint, num_rounds, 2, transcript) .map_err(|_| SpartanError::InvalidInnerSumcheckProof)?; // n_prefix = n_segments + 1 @@ -317,6 +353,10 @@ where let right_expected = eval_Z; let claim_inner_final_expected = left_expected * right_expected; + assert_eq!(claim_inner_final, claim_inner_final_expected); + println!("claim_inner_final: {:?}", claim_inner_final); + println!("claim_inner_final_expected: {:?}", claim_inner_final_expected); + assert!(false); if claim_inner_final != claim_inner_final_expected { return Err(SpartanError::InvalidInnerSumcheckClaim); } @@ -392,3 +432,93 @@ where // .expect("Spartan verifier failed"); // } // } + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + use rand_core::{RngCore, CryptoRng}; + use ark_bn254::Fr as F; // Add this line to import the field type + use ark_ff::{Zero, One}; // Import the Zero trait + + #[test] + fn test_shifted_polynomial_evaluations() { + // Generate a vector z of random field elements of length 128 + let mut rng = rand::thread_rng(); + let z: Vec = (0..128).map(|_| F::from(rng.gen::())).collect(); + + // Resize z to the next power of two + let mut z_resized = z.clone(); + z_resized.resize(z.len().next_power_of_two() * 2, F::zero()); + + println!("z_resized.len(): {:?}", z_resized.len()); + + let r_x_step: Vec = vec![F::zero(), F::zero(), F::one(), F::zero()]; + + // Create the polynomial from z + let mut poly_z = DensePolynomial::new(z_resized.clone()); + for r_s in r_x_step.iter().rev() { + poly_z.bound_poly_var_bot(r_s); + } + let evals = poly_z.evals(); + + // Create the shifted polynomial from z + let mut z_shifted: Vec = z[1..].to_vec(); + z_shifted.resize(z.len().next_power_of_two(), F::zero()); + + let mut poly_z_shifted = DensePolynomial::new(z_shifted.clone()); + for r_s in r_x_step.iter().rev() { + poly_z_shifted.bound_poly_var_bot(r_s); + } + let evals_shifted = poly_z_shifted.evals(); + + // // print the first 10 lines of evals and evals_shifted + // for i in 0..4 { + // println!("z: {:?}", z); + // println!("evals_shifted: {:?}", evals_shifted); + // // println!("evals[{}]: {:?}, evals_shifted[{}]: {:?}", i, evals[i], i, evals_shifted[i]); + // println!("z[{}]: {:?}, evals_shifted[{}]: {:?}", i+1, z[i+1], i, evals_shifted[i]); + // // println!("z[{}]: {:?}", i+1, z[i+1]); + + // } + + // print each element of z preceded by index: + for i in 0..z.len() { + println!("z[{}]: {:?}", i, z[i]); + } + println!("evals_shifted: {:?}", evals_shifted); + + + + // // Evaluate the polynomials at a random point k + // let k: F = F::random(&mut rng); + // let eval_at_k = poly_z.evaluate(&k); + // let eval_shifted_at_k_minus_1 = poly_z_shifted.evaluate(&(k - F::one())); + + // // Check if the evaluations are correct + // assert_eq!(eval_at_k, eval_shifted_at_k_minus_1); + } + #[test] + fn test_eq_polynomial_evals() { + // Generate a random vector of length 8 + let mut rng = rand::thread_rng(); + let random_vector: Vec = (0..8).map(|_| F::from(rng.gen::())).collect(); + + // generate all 1s vector of lenght 8 + let all_ones_vector: Vec = (0..8).map(|_| F::one()).collect(); + + // Run EqPolynomial::evals on the random vector + let eq_evals = EqPolynomial::evals(&random_vector); + let all_ones_evals = EqPolynomial::evals(&all_ones_vector); + + // // Print the random vector and its evaluations + // for i in 0..random_vector.len() { + // println!("random_vector[{}]: {:?}", i, random_vector[i]); + // } + // println!("eq_evals: {:?}", eq_evals); + println!("all_ones_evals.last(): {:?}", all_ones_evals[2]); + + // // Check if the evaluations are correct (this is a placeholder, you should replace it with actual checks) + // assert_eq!(eq_evals.len(), random_vector.len().next_power_of_two()); + } +} \ No newline at end of file From 424688bf42460120962c2fa770647352bd24db09 Mon Sep 17 00:00:00 2001 From: Arasu Arun Date: Sat, 14 Dec 2024 17:18:12 -0500 Subject: [PATCH 4/8] partial crush optimization with shift sumcheck --- jolt-core/src/r1cs/builder.rs | 10 ++++ jolt-core/src/r1cs/key.rs | 95 ++++++++++++++++------------------- jolt-core/src/r1cs/spartan.rs | 86 ++++++++++++++++++++----------- 3 files changed, 112 insertions(+), 79 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 08a934f3a..747e3fe7b 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -165,6 +165,16 @@ impl AuxComputation { }); }); + /* Hack(arasuarun): Set all variables in the last step to 0. + Needed for the crush-second-sumcheck optimization. + There should be a better way to do this instead of iterating over all witness segments. + */ + let mut last_index = batch_size - 1; + while last_index < aux_poly.len() { + aux_poly[last_index] = F::zero(); + last_index += batch_size; + } + DensePolynomial::new(aux_poly) } diff --git a/jolt-core/src/r1cs/key.rs b/jolt-core/src/r1cs/key.rs index 2bf747693..99252229e 100644 --- a/jolt-core/src/r1cs/key.rs +++ b/jolt-core/src/r1cs/key.rs @@ -232,27 +232,6 @@ impl UniformSpartanKey>(); - /* Crush: not needed - let mut rlc = unsafe_allocate_zero_vec(self.num_cols_total()); - - { - let span = tracing::span!(tracing::Level::INFO, "big_rlc_computation"); - let _guard = span.enter(); - rlc.par_chunks_mut(self.num_steps) - .take(self.uniform_r1cs.num_vars) - .enumerate() - .for_each(|(var_index, var_chunk)| { - if !sm_rlc[var_index].is_zero() { - for (step_index, item) in var_chunk.iter_mut().enumerate() { - *item = mul_0_1_optimized(&eq_rx_step[step_index], &sm_rlc[var_index]); - } - } - }); - } - - rlc[self.num_vars_total()] = sm_rlc[self.uniform_r1cs.num_vars]; // constant - */ - // Handle non-uniform constraints let update_non_uni = |rlc: &mut Vec, offset: &SparseEqualityItem, @@ -285,23 +264,25 @@ impl UniformSpartanKey F { assert_eq!(self.uniform_r1cs.num_vars, segment_evals.len()); - assert_eq!(r.len(), self.full_z_len().log_2()); + assert_eq!(r.len(), self.full_z_len().log_2()); // Z can be computed in two halves, [Variables, (constant) 1, 0 , ...] indexed by the first bit. let r_const = r[0]; let r_rest = &r[1..]; - assert_eq!(r_rest.len(), self.num_vars_total().log_2()); // Don't need the last log2(num_steps) bits, they've been evaluated already. let var_bits = self.uniform_r1cs.num_vars.next_power_of_two().log_2(); let r_var = &r_rest[..var_bits]; + let r_x_step = &r_rest[var_bits..]; + + let eq_last_step = EqPolynomial::new(r_x_step.to_vec()).evaluate(&vec![F::one(); r_x_step.len()]); let r_var_eq = EqPolynomial::evals(r_var); let eval_variables: F = (0..self.uniform_r1cs.num_vars) .map(|var_index| r_var_eq[var_index] * segment_evals[var_index]) .sum(); - // Crush: + // If r_const = 1, only the constant position (with all other index bits are 0) has a non-zero value let var_and_const_bits: usize = var_bits + 1; let eq_consts = EqPolynomial::new(r[..var_and_const_bits].to_vec()); let eq_const = eq_consts.evaluate(&index_to_field_bitvector( @@ -309,12 +290,12 @@ impl UniformSpartanKey (F, F, F) { + pub fn evaluate_r1cs_matrix_mles(&self, r: &[F], r_choice: &F) -> (F, F, F) { let total_rows_bits = self.num_rows_total().log_2(); let total_cols_bits = self.num_cols_total().log_2(); let steps_bits: usize = self.num_steps.log_2(); @@ -329,7 +310,6 @@ impl UniformSpartanKey UniformSpartanKey() ; - // Crush: - // * eq_rx_ry_step; full_mle_evaluation += constraints .consts @@ -364,31 +342,23 @@ impl UniformSpartanKey| -> F { - let mut non_uni_mle = non_uni - .offset_vars - .iter() - .map(|(col, offset, coeff)| { - if !offset { - *coeff * eq_ry_var[*col] * eq_rx_ry_step - } else { - *coeff * eq_ry_var[*col] * eq_step_offset_1 - } - }) - .sum::(); - - non_uni_mle += non_uni.constant * col_eq_constant; - - non_uni_mle + let mut non_uni_a_mle = F::zero(); + let mut non_uni_b_mle = F::zero(); + + let compute_non_uniform = |uni_mle: &mut F, non_uni_mle: &mut F, non_uni: &SparseEqualityItem, eq_rx: F| { + for (col, offset, coeff) in &non_uni.offset_vars { + if !offset { + *uni_mle += *coeff * eq_ry_var[*col] * eq_rx; + } else { + *non_uni_mle += *coeff * eq_ry_var[*col] * eq_rx; + } + } }; for (i, constraint) in self.offset_eq_r1cs.constraints.iter().enumerate() { - let non_uni_a = compute_non_uniform(&constraint.eq); - let non_uni_b = compute_non_uniform(&constraint.condition); let non_uni_constraint_index = index_to_field_bitvector(self.uniform_r1cs.num_rows + i, constraint_rows_bits); @@ -400,10 +370,33 @@ impl UniformSpartanKey>| { + if let Some(non_uni_constants) = non_uni_constants { + for (i, non_uni_constant) in non_uni_constants.iter().enumerate() { + // The matrix values are present even in the last step. + // It's the role of the evaluation of the z mle to ignore the last step. + let first_non_uniform_row = self.uniform_r1cs.num_rows; + *uni_mle += eq_rx_constr[first_non_uniform_row + i] * non_uni_constant * col_eq_constant; + } + } + }; + + let (eq_constants, condition_constants) = self.offset_eq_r1cs.constants(); + compute_non_uni_constants(&mut a_mle, Some(eq_constants)); + compute_non_uni_constants(&mut b_mle, Some(condition_constants)); + + a_mle = (F::one() - r_choice) * a_mle + + *r_choice * non_uni_a_mle; + b_mle = (F::one() - r_choice) * b_mle + + *r_choice * non_uni_b_mle; + c_mle = (F::one() - r_choice) * c_mle; + (a_mle, b_mle, c_mle) } diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 63d83f797..10d471572 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -79,6 +79,8 @@ pub struct UniformSpartanProof< pub(crate) outer_sumcheck_proof: SumcheckInstanceProof, pub(crate) outer_sumcheck_claims: (F, F, F), pub(crate) inner_sumcheck_proof: SumcheckInstanceProof, + pub(crate) shift_sumcheck_proof: SumcheckInstanceProof, + pub(crate) shift_sumcheck_claim: F, pub(crate) claimed_witness_evals: Vec, _marker: PhantomData, } @@ -157,13 +159,10 @@ where + r_inner_sumcheck_RLC * claim_Bz + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * claim_Cz; - // Crush: let num_steps_padded = constraint_builder.uniform_repeat().next_power_of_two(); - let num_steps_bits_ = constraint_builder - .uniform_repeat() - .next_power_of_two() - .ilog2() as usize; - let num_constraints_bits = key.num_cons_total.log_2() - num_steps_bits_; + let num_steps_bits = num_steps_padded.ilog2() as usize; + let num_constraints_bits = key.num_cons_total.log_2() - num_steps_bits; + let r_x_step = &outer_sumcheck_r[num_constraints_bits..]; let mut z: Vec = flattened_polys.clone().into_iter().map(|poly| { @@ -196,14 +195,13 @@ where evals.push(F::one() - is_last_step); // Constant, ignores the last step. evals.resize(evals.len().next_power_of_two(), F::zero()); - let n_bits_ts = r_x_step.len(); let eq_plus_one_rx_step: Vec = (0..num_steps_padded) - .map(|t| eq_plus_one(r_x_step, &crate::utils::index_to_field_bitvector(t, n_bits_ts), n_bits_ts)) + .map(|t| eq_plus_one(r_x_step, &crate::utils::index_to_field_bitvector(t, num_steps_bits), num_steps_bits)) .collect(); let mut evals_shifted = (0..key.num_vars_uniform()) .map(|y_var: usize| { - (0..num_steps_padded-1) + (0..num_steps_padded-1) // Ignore the last step .map(|t| z[y_var * num_steps_padded + t] * eq_plus_one_rx_step[t] ) @@ -215,10 +213,6 @@ where let poly_z = DensePolynomial::new(evals.into_iter().chain(evals_shifted.into_iter()).collect()); // this is the polynomial extended from the vector r_A * A(r_x, y) + r_B * B(r_x, y) + r_C * C(r_x, y) for all y - let num_steps_bits = constraint_builder - .uniform_repeat() - .next_power_of_two() - .ilog2(); let (rx_con, rx_ts) = outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_steps_bits as usize); let poly_ABC = @@ -243,9 +237,40 @@ where drop_in_background_thread(polys); - // Crush: let r_z = r_x_step; + let r_y_var = inner_sumcheck_r[1..].to_vec(); + assert_eq!(r_y_var.len(), key.num_vars_uniform().next_power_of_two().log_2() + 1); + + // Third sumcheck: the shift sumcheck + let mut poly_z2 = DensePolynomial::new( + z.clone().into_iter().chain(vec![F::zero(); z.len()].into_iter()).collect() + ); + for r_s in r_y_var.iter() { + poly_z2.bound_poly_var_top(r_s); + } + let evals_z_r_y_var= poly_z2.evals(); + + let num_rounds_shift_sumcheck = num_steps_bits; + let mut shift_sumcheck_polys = vec![DensePolynomial::new(evals_z_r_y_var), DensePolynomial::new(eq_plus_one_rx_step.clone())]; + + let shift_sumcheck_claim = (0..((1 << num_rounds_shift_sumcheck) - 1)) + .map(|i| { + let params: Vec = shift_sumcheck_polys.iter().map(|poly| poly[i]).collect(); + comb_func(¶ms) + }) + .fold(F::zero(), |acc, x| acc + x); + + let (shift_sumcheck_proof, _, _) = + SumcheckInstanceProof::prove_arbitrary( + &shift_sumcheck_claim, + num_rounds_shift_sumcheck, + &mut shift_sumcheck_polys, + comb_func, + 2, + transcript); + drop_in_background_thread(shift_sumcheck_polys); + let chi = EqPolynomial::evals(&r_z); let claimed_witness_evals: Vec<_> = flattened_polys .par_iter() @@ -271,6 +296,8 @@ where outer_sumcheck_proof, outer_sumcheck_claims, inner_sumcheck_proof, + shift_sumcheck_proof, + shift_sumcheck_claim, claimed_witness_evals, _marker: PhantomData, }) @@ -327,45 +354,48 @@ where + r_inner_sumcheck_RLC * self.outer_sumcheck_claims.1 + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * self.outer_sumcheck_claims.2; - - let num_rounds = (key.num_vars_uniform() * 2).next_power_of_two().log_2(); + let num_rounds = (key.num_vars_uniform() * 2).next_power_of_two().log_2() + 1; // +1 for cross-step let (claim_inner_final, inner_sumcheck_r) = self .inner_sumcheck_proof .verify(claim_inner_joint, num_rounds, 2, transcript) .map_err(|_| SpartanError::InvalidInnerSumcheckProof)?; - // n_prefix = n_segments + 1 - let n_prefix = key.uniform_r1cs.num_vars.next_power_of_two().log_2() + 1; - - // Crush: let n_constraint_bits_uniform = key.uniform_r1cs.num_rows.next_power_of_two().log_2(); let outer_sumcheck_r_step = &r_x[n_constraint_bits_uniform..]; - let y_prime = [inner_sumcheck_r.to_owned(), outer_sumcheck_r_step.to_owned()].concat(); + + let r_choice = inner_sumcheck_r[0]; + let r_y_var = inner_sumcheck_r[1..].to_vec(); + let y_prime = [r_y_var, outer_sumcheck_r_step.to_owned()].concat(); let eval_Z = key.evaluate_z_mle(&self.claimed_witness_evals, &y_prime); - // Crush: let r = [r_x.clone(), y_prime].concat(); - let (eval_a, eval_b, eval_c) = key.evaluate_r1cs_matrix_mles(&r); + let (eval_a, eval_b, eval_c) = key.evaluate_r1cs_matrix_mles(&r, &r_choice); let left_expected = eval_a + r_inner_sumcheck_RLC * eval_b + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * eval_c; - let right_expected = eval_Z; + let right_expected = + (F::one() - r_choice) * eval_Z + + r_choice * self.shift_sumcheck_claim; let claim_inner_final_expected = left_expected * right_expected; assert_eq!(claim_inner_final, claim_inner_final_expected); - println!("claim_inner_final: {:?}", claim_inner_final); - println!("claim_inner_final_expected: {:?}", claim_inner_final_expected); - assert!(false); if claim_inner_final != claim_inner_final_expected { return Err(SpartanError::InvalidInnerSumcheckClaim); } + let num_steps_bits = outer_sumcheck_r_step.len(); + let num_rounds_shift_sumcheck = num_steps_bits; + let (claim_shift_final, shift_sumcheck_r) = self + .shift_sumcheck_proof + .verify(self.shift_sumcheck_claim, num_rounds_shift_sumcheck, 2, transcript) + .map_err(|_| SpartanError::InvalidInnerSumcheckProof)?; + let flattened_commitments: Vec<_> = I::flatten::() .iter() .map(|var| var.get_ref(commitments)) .collect(); - // Crush: + let r_y_point = &r_x[n_constraint_bits_uniform..]; opening_accumulator.append( &flattened_commitments, From b8e371954754e903f240b664095d66d6c412dd79 Mon Sep 17 00:00:00 2001 From: Arasu Arun Date: Sat, 14 Dec 2024 18:33:19 -0500 Subject: [PATCH 5/8] final verifier pc checks done; crush finished --- jolt-core/src/r1cs/key.rs | 7 +++--- jolt-core/src/r1cs/spartan.rs | 45 +++++++++++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/jolt-core/src/r1cs/key.rs b/jolt-core/src/r1cs/key.rs index 99252229e..f453e93b4 100644 --- a/jolt-core/src/r1cs/key.rs +++ b/jolt-core/src/r1cs/key.rs @@ -262,7 +262,7 @@ impl UniformSpartanKey F { + pub fn evaluate_z_mle(&self, segment_evals: &[F], r: &[F], with_const: bool) -> F { assert_eq!(self.uniform_r1cs.num_vars, segment_evals.len()); assert_eq!(r.len(), self.full_z_len().log_2()); @@ -285,12 +285,13 @@ impl UniformSpartanKey, pub(crate) shift_sumcheck_claim: F, pub(crate) claimed_witness_evals: Vec, + pub(crate) claimed_witness_evals_shift_sumcheck: Vec, _marker: PhantomData, } @@ -261,7 +262,7 @@ where }) .fold(F::zero(), |acc, x| acc + x); - let (shift_sumcheck_proof, _, _) = + let (shift_sumcheck_proof, shift_sumcheck_r, shift_sumcheck_claims) = SumcheckInstanceProof::prove_arbitrary( &shift_sumcheck_claim, num_rounds_shift_sumcheck, @@ -285,7 +286,26 @@ where transcript, ); +<<<<<<< HEAD // Outer sumcheck claims: [A(r_x), B(r_x), C(r_x)] +======= + // Polynomial evals for shift sumcheck + let chi2 = EqPolynomial::evals(&shift_sumcheck_r); + let claimed_witness_evals_shift_sumcheck: Vec<_> = flattened_polys + .par_iter() + .map(|poly| poly.evaluate_at_chi_low_optimized(&chi2)) + .collect(); + + opening_accumulator.append( + &flattened_polys, + DensePolynomial::new(chi2), + shift_sumcheck_r.to_vec(), + &claimed_witness_evals_shift_sumcheck.iter().collect::>(), + transcript, + ); + + // Outer sumcheck claims: [eq(r_x), A(r_x), B(r_x), C(r_x)] +>>>>>>> 0cd27f42 (final verifier pc checks done; crush finished) let outer_sumcheck_claims = ( outer_sumcheck_claims[0], outer_sumcheck_claims[1], @@ -299,6 +319,7 @@ where shift_sumcheck_proof, shift_sumcheck_claim, claimed_witness_evals, + claimed_witness_evals_shift_sumcheck, _marker: PhantomData, }) } @@ -365,8 +386,8 @@ where let r_choice = inner_sumcheck_r[0]; let r_y_var = inner_sumcheck_r[1..].to_vec(); - let y_prime = [r_y_var, outer_sumcheck_r_step.to_owned()].concat(); - let eval_Z = key.evaluate_z_mle(&self.claimed_witness_evals, &y_prime); + let y_prime = [r_y_var.clone(), outer_sumcheck_r_step.to_owned()].concat(); + let eval_z = key.evaluate_z_mle(&self.claimed_witness_evals, &y_prime, true); let r = [r_x.clone(), y_prime].concat(); let (eval_a, eval_b, eval_c) = key.evaluate_r1cs_matrix_mles(&r, &r_choice); @@ -375,7 +396,7 @@ where + r_inner_sumcheck_RLC * eval_b + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * eval_c; let right_expected = - (F::one() - r_choice) * eval_Z + + (F::one() - r_choice) * eval_z + r_choice * self.shift_sumcheck_claim; let claim_inner_final_expected = left_expected * right_expected; @@ -391,6 +412,15 @@ where .verify(self.shift_sumcheck_claim, num_rounds_shift_sumcheck, 2, transcript) .map_err(|_| SpartanError::InvalidInnerSumcheckProof)?; + let y_prime_shift_sumcheck = [r_y_var, shift_sumcheck_r.to_owned()].concat(); + let eval_z_shift_sumcheck = key.evaluate_z_mle(&self.claimed_witness_evals_shift_sumcheck, &y_prime_shift_sumcheck, false); + let eq_plus_one_shift_sumcheck = eq_plus_one(&outer_sumcheck_r_step, &shift_sumcheck_r, num_steps_bits); + let claim_shift_sumcheck_expected = eval_z_shift_sumcheck * eq_plus_one_shift_sumcheck; + assert_eq!(claim_shift_final, claim_shift_sumcheck_expected); + if claim_shift_final != claim_shift_sumcheck_expected { + return Err(SpartanError::InvalidInnerSumcheckClaim); + } + let flattened_commitments: Vec<_> = I::flatten::() .iter() .map(|var| var.get_ref(commitments)) @@ -404,6 +434,13 @@ where transcript, ); + opening_accumulator.append( + &flattened_commitments, + shift_sumcheck_r.to_vec(), + &self.claimed_witness_evals_shift_sumcheck.iter().collect::>(), + transcript, + ); + Ok(()) } } From d35b2a600e78e62d71fb0d70c9832a7b311394f6 Mon Sep 17 00:00:00 2001 From: Arasu Arun Date: Mon, 16 Dec 2024 17:09:42 -0500 Subject: [PATCH 6/8] optimize away unnecessary z clones --- jolt-core/src/r1cs/spartan.rs | 95 ++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 40 deletions(-) diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index e83dc15b7..9dc11842f 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -166,49 +166,50 @@ where let r_x_step = &outer_sumcheck_r[num_constraints_bits..]; - let mut z: Vec = flattened_polys.clone().into_iter().map(|poly| { - let mut resized = poly.Z.clone(); - resized.resize(poly.len().next_power_of_two(), F::zero()); - resized - }).flatten().collect(); - z.resize(z.len().next_power_of_two(), F::zero()); - + // Binding 1: evaluating z on r_x_step let is_last_step = EqPolynomial::new(r_x_step.to_vec()).evaluate(&vec![F::one(); r_x_step.len()]); let eq_rx_step = EqPolynomial::evals(r_x_step); - // // Evals with binding. Doesn't ignore when r_x_step is the last step. - // for r_s in r_x_step.iter().rev() { - // poly_z.bound_poly_var_bot(r_s); - // } - // let mut evals = poly_z.evals(); - // evals.push(F::one() - is_last_step); // ARASU: IGNORE LAST STEP? - // evals.resize(evals.len().next_power_of_two() * 1, F::zero()); - - // Evals straightfoward - let mut evals: Vec = (0..key.num_vars_uniform()) // until the constant (which is not included) - .map(|y_var| { - (0..(num_steps_padded-1)) // Ignore the last step - .map(|t| z[y_var * num_steps_padded + t] * eq_rx_step[t]) - .sum() - }) - .collect(); + let mut evals: Vec = flattened_polys + .par_iter() + .map(|poly| { + poly.Z + .par_iter() + .enumerate() + .map(|(t, &val)| { + if t == num_steps_padded - 1 { // ignore last step + F::zero() + } else { + val * eq_rx_step[t] + } + }) + .sum() + }) + .collect(); evals.resize(evals.len().next_power_of_two(), F::zero()); - evals.push(F::one() - is_last_step); // Constant, ignores the last step. + evals.push(F::one() - is_last_step); // Constant, ignores the last step. evals.resize(evals.len().next_power_of_two(), F::zero()); let eq_plus_one_rx_step: Vec = (0..num_steps_padded) .map(|t| eq_plus_one(r_x_step, &crate::utils::index_to_field_bitvector(t, num_steps_bits), num_steps_bits)) .collect(); - let mut evals_shifted = (0..key.num_vars_uniform()) - .map(|y_var: usize| { - (0..num_steps_padded-1) // Ignore the last step - .map(|t| - z[y_var * num_steps_padded + t] * eq_plus_one_rx_step[t] - ) - .sum::() + let mut evals_shifted: Vec = flattened_polys + .par_iter() + .map(|poly| { + poly.Z + .par_iter() + .enumerate() + .map(|(t, &val)| { + if t == num_steps_padded - 1 { // ignore last step + F::zero() + } else { + val * eq_plus_one_rx_step[t] + } + }) + .sum() }) - .collect::>(); + .collect(); evals_shifted.resize(evals.len(), F::zero()); let poly_z = DensePolynomial::new(evals.into_iter().chain(evals_shifted.into_iter()).collect()); @@ -243,14 +244,28 @@ where let r_y_var = inner_sumcheck_r[1..].to_vec(); assert_eq!(r_y_var.len(), key.num_vars_uniform().next_power_of_two().log_2() + 1); - // Third sumcheck: the shift sumcheck - let mut poly_z2 = DensePolynomial::new( - z.clone().into_iter().chain(vec![F::zero(); z.len()].into_iter()).collect() - ); - for r_s in r_y_var.iter() { - poly_z2.bound_poly_var_top(r_s); - } - let evals_z_r_y_var= poly_z2.evals(); + let eq_ry_var = EqPolynomial::evals(&r_y_var); + + // Binding 2: evaluating z on r_y_var + /* TODO(arasuarun): this might lead to inefficient memory paging + as we access each poly in flattened_poly num_steps_padded-many times. + */ + let mut evals_z_r_y_var: Vec = (0..constraint_builder.uniform_repeat()) + .map(|t| { + flattened_polys + .par_iter() + .enumerate() + .map(|(i, poly)| { + if t < poly.Z.len() { + poly.Z[t] * eq_ry_var[i] + } else { + F::zero() + } + }) + .sum() + }) + .collect(); + evals_z_r_y_var.resize(num_steps_padded, F::zero()); let num_rounds_shift_sumcheck = num_steps_bits; let mut shift_sumcheck_polys = vec![DensePolynomial::new(evals_z_r_y_var), DensePolynomial::new(eq_plus_one_rx_step.clone())]; From 81a011d6369c4d0280cf61f75ddefff0608534c5 Mon Sep 17 00:00:00 2001 From: Arasu Arun Date: Mon, 16 Dec 2024 17:22:52 -0500 Subject: [PATCH 7/8] parallelize calculation in shift sumcheck --- jolt-core/src/r1cs/spartan.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 9dc11842f..ee150b95f 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -239,15 +239,16 @@ where drop_in_background_thread(polys); - let r_z = r_x_step; - let r_y_var = inner_sumcheck_r[1..].to_vec(); assert_eq!(r_y_var.len(), key.num_vars_uniform().next_power_of_two().log_2() + 1); let eq_ry_var = EqPolynomial::evals(&r_y_var); - // Binding 2: evaluating z on r_y_var - /* TODO(arasuarun): this might lead to inefficient memory paging + /* Sumcheck 3: the shift sumcheck */ + + /* Binding 2: evaluating z on r_y_var + + TODO(arasuarun): this might lead to inefficient memory paging as we access each poly in flattened_poly num_steps_padded-many times. */ let mut evals_z_r_y_var: Vec = (0..constraint_builder.uniform_repeat()) @@ -271,11 +272,12 @@ where let mut shift_sumcheck_polys = vec![DensePolynomial::new(evals_z_r_y_var), DensePolynomial::new(eq_plus_one_rx_step.clone())]; let shift_sumcheck_claim = (0..((1 << num_rounds_shift_sumcheck) - 1)) + .into_par_iter() .map(|i| { let params: Vec = shift_sumcheck_polys.iter().map(|poly| poly[i]).collect(); comb_func(¶ms) }) - .fold(F::zero(), |acc, x| acc + x); + .reduce(|| F::zero(), |acc, x| acc + x); let (shift_sumcheck_proof, shift_sumcheck_r, shift_sumcheck_claims) = SumcheckInstanceProof::prove_arbitrary( @@ -287,7 +289,8 @@ where transcript); drop_in_background_thread(shift_sumcheck_polys); - let chi = EqPolynomial::evals(&r_z); + // Polynomial evals for inner sumcheck + let chi = EqPolynomial::evals(&r_x_step); let claimed_witness_evals: Vec<_> = flattened_polys .par_iter() .map(|poly| poly.evaluate_at_chi_low_optimized(&chi)) @@ -296,7 +299,7 @@ where opening_accumulator.append( &flattened_polys, DensePolynomial::new(chi), - r_z.to_vec(), + r_x_step.to_vec(), &claimed_witness_evals.iter().collect::>(), transcript, ); From 80faf15d87c5331c63e250943ec9844dcc761acf Mon Sep 17 00:00:00 2001 From: Arasu Arun Date: Mon, 16 Dec 2024 17:36:45 -0500 Subject: [PATCH 8/8] finished rebase --- jolt-core/src/r1cs/spartan.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index ee150b95f..1e5abc594 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -304,9 +304,6 @@ where transcript, ); -<<<<<<< HEAD - // Outer sumcheck claims: [A(r_x), B(r_x), C(r_x)] -======= // Polynomial evals for shift sumcheck let chi2 = EqPolynomial::evals(&shift_sumcheck_r); let claimed_witness_evals_shift_sumcheck: Vec<_> = flattened_polys @@ -323,7 +320,6 @@ where ); // Outer sumcheck claims: [eq(r_x), A(r_x), B(r_x), C(r_x)] ->>>>>>> 0cd27f42 (final verifier pc checks done; crush finished) let outer_sumcheck_claims = ( outer_sumcheck_claims[0], outer_sumcheck_claims[1],