diff --git a/air/src/air/context.rs b/air/src/air/context.rs index c7054a565..a90c6637e 100644 --- a/air/src/air/context.rs +++ b/air/src/air/context.rs @@ -254,7 +254,7 @@ impl AirContext { /// Returns the index of the auxiliary column which implements the Lagrange kernel, if any pub fn lagrange_kernel_aux_column_idx(&self) -> Option { if self.logup_gkr_enabled() { - Some(self.trace_info().aux_segment_width() - 1) + Some(self.trace_info().aux_segment_width() - LAGRANGE_KERNEL_OFFSET) } else { None } diff --git a/air/src/air/logup_gkr.rs b/air/src/air/logup_gkr.rs index 63338e6ad..5a9db4ff6 100644 --- a/air/src/air/logup_gkr.rs +++ b/air/src/air/logup_gkr.rs @@ -10,6 +10,11 @@ use math::{ExtensionOf, FieldElement, StarkField, ToElements}; use super::EvaluationFrame; +// CONSTANTS +// =============================================================================================== +pub const LAGRANGE_KERNEL_OFFSET: usize = 1; +pub const S_COLUMN_OFFSET: usize = 2; + /// A trait containing the necessary information in order to run the LogUp-GKR protocol of [1]. /// /// The trait contains useful information for running the GKR protocol as well as for implementing diff --git a/air/src/air/mod.rs b/air/src/air/mod.rs index 1ae4aa771..9d5b01cba 100644 --- a/air/src/air/mod.rs +++ b/air/src/air/mod.rs @@ -35,7 +35,9 @@ pub use lagrange::{ }; mod logup_gkr; -pub use logup_gkr::{LogUpGkrEvaluator, LogUpGkrOracle, PhantomLogUpGkrEval, LAGRANGE_KERNEL_OFFSET, S_COLUMN_OFFSET}; +pub use logup_gkr::{ + LogUpGkrEvaluator, LogUpGkrOracle, PhantomLogUpGkrEval, LAGRANGE_KERNEL_OFFSET, S_COLUMN_OFFSET, +}; mod coefficients; pub use coefficients::{ @@ -599,7 +601,7 @@ pub trait Air: Send + Sync { None }; - let s_col = if self.context().logup_gkr_enabled() { + let s_col_cc = if self.context().logup_gkr_enabled() { Some(public_coin.draw()?) } else { None diff --git a/air/src/lib.rs b/air/src/lib.rs index 4b6c5b914..6c82362a9 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -48,5 +48,6 @@ pub use air::{ LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, PhantomLogUpGkrEval, - TraceInfo, TransitionConstraintDegree, TransitionConstraints, + TraceInfo, TransitionConstraintDegree, TransitionConstraints, LAGRANGE_KERNEL_OFFSET, + S_COLUMN_OFFSET, }; diff --git a/prover/src/constraints/evaluator/default.rs b/prover/src/constraints/evaluator/default.rs index fd4bb1c42..4373494f9 100644 --- a/prover/src/constraints/evaluator/default.rs +++ b/prover/src/constraints/evaluator/default.rs @@ -158,7 +158,7 @@ where &composition_coefficients.boundary, ); - let lagrange_constraints_evaluator = if air.context().logup_gkr_enabled() { + let logup_gkr_constraints_evaluator = if air.context().logup_gkr_enabled() { let aux_rand_elements = aux_rand_elements.as_ref().expect("expected aux rand elements to be present"); diff --git a/prover/src/constraints/evaluator/logup_gkr.rs b/prover/src/constraints/evaluator/logup_gkr.rs index ef45a6816..0ed5b2c75 100644 --- a/prover/src/constraints/evaluator/logup_gkr.rs +++ b/prover/src/constraints/evaluator/logup_gkr.rs @@ -71,12 +71,12 @@ where let mut lagrange_frame = LagrangeKernelEvaluationFrame::new_empty(); - let evaluator = self.air.get_logup_gkr_evaluator::(); + let evaluator = self.air.get_logup_gkr_evaluator::(); let s_col_constraint_divisor = compute_s_col_divisor::(domain.ce_domain_size(), domain, self.air.trace_length()); let s_col_idx = trace.trace_info().aux_segment_width() - S_COLUMN_OFFSET; let l_col_idx = trace.trace_info().aux_segment_width() - LAGRANGE_KERNEL_OFFSET; - let mut main_frame = EvaluationFrame::new(trace.trace_info().main_trace_width()); + let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); let mut aux_frame = EvaluationFrame::new(trace.trace_info().aux_segment_width()); let c = self.gkr_data.compute_batched_claim(); diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index a7d87d6b9..691195925 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -17,9 +17,6 @@ use crate::{ MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, }; -#[cfg(feature = "concurrent")] -pub use rayon::prelude::*; - /// A sum-check prover for the input layer which can accommodate non-linear expressions in /// the numerators of the LogUp relation. /// diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index e69de29bb..e0092cf10 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -0,0 +1,216 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use smallvec::smallvec; + +use super::SumCheckProverError; +use crate::{ + comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, + SumCheckProof, +}; + +/// Sum-check prover for non-linear multivariate polynomial of the simple LogUp-GKR. +/// +/// More specifically, the following function implements the logic of the sum-check prover as +/// described in Section 3.2 in [1], that is, given verifier challenges , the following implements +/// the sum-check prover for the following two statements +/// $$ +/// p_{\nu - \kappa}\left(v_{\kappa+1}, \cdots, v_{\nu}\right) = \sum_{w_i} +/// EQ\left(\left(v_{\kappa+1}, \cdots, v_{\nu}\right), \left(w_{\kappa+1}, \cdots, +/// w_{\nu}\right)\right) \cdot +/// \left( p_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right) + +/// p_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// $$ +/// q_{\nu -k}\left(v_{\kappa+1}, \cdots, v_{\nu}\right) = \sum_{w_i}EQ\left(\left(v_{\kappa+1}, +/// \cdots, v_{\nu}\right), \left(w_{\kappa+1}, \cdots, w_{\nu }\right)\right) \cdot +/// \left( q_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right)\right) +/// $$ +/// +/// for $k = 1, \cdots, \nu - 1$ +/// +/// Instead of executing two runs of the sum-check protocol, a batching randomness `r_batch` is +/// sent by the verifier at the outset in order to batch the two statments. +/// +/// Note that the degree of the non-linear composition polynomial is 3. +/// +/// [1]: https://eprint.iacr.org/2023/1284 +#[allow(clippy::too_many_arguments)] +pub fn sumcheck_prove_plain>( + mut claim: E, + r_batch: E, + p: MultiLinearPoly, + q: MultiLinearPoly, + eq: &mut MultiLinearPoly, + transcript: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let mut round_proofs = vec![]; + + let mut challenges = vec![]; + + // construct the vector of multi-linear polynomials + let (mut p0, mut p1) = p.project_least_significant_variable(); + let (mut q0, mut q1) = q.project_least_significant_variable(); + + for _ in 0..p0.num_variables() { + let len = p0.num_evaluations() / 2; + + #[cfg(not(feature = "concurrent"))] + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( + (E::ZERO, E::ZERO, E::ZERO), + |(acc_point_1, acc_point_2, acc_point_3), i| { + let round_poly_eval_at_1 = comb_func( + p0[2 * i + 1], + p1[2 * i + 1], + q0[2 * i + 1], + q1[2 * i + 1], + eq[2 * i + 1], + r_batch, + ); + + let p0_delta = p0[2 * i + 1] - p0[2 * i]; + let p1_delta = p1[2 * i + 1] - p1[2 * i]; + let q0_delta = q0[2 * i + 1] - q0[2 * i]; + let q1_delta = q1[2 * i + 1] - q1[2 * i]; + let eq_delta = eq[2 * i + 1] - eq[2 * i]; + + let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; + let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; + let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; + let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; + let mut eq_evx = eq[2 * i + 1] + eq_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + p0_eval_at_x += p0_delta; + p1_eval_at_x += p1_delta; + q0_eval_at_x += q0_delta; + q1_eval_at_x += q1_delta; + eq_evx += eq_delta; + let round_poly_eval_at_3 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + ( + round_poly_eval_at_1 + acc_point_1, + round_poly_eval_at_2 + acc_point_2, + round_poly_eval_at_3 + acc_point_3, + ) + }, + ); + + #[cfg(feature = "concurrent")] + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len) + .into_par_iter() + .fold( + || (E::ZERO, E::ZERO, E::ZERO), + |(a, b, c), i| { + let round_poly_eval_at_1 = comb_func( + p0[2 * i + 1], + p1[2 * i + 1], + q0[2 * i + 1], + q1[2 * i + 1], + eq[2 * i + 1], + r_batch, + ); + + let p0_delta = p0[2 * i + 1] - p0[2 * i]; + let p1_delta = p1[2 * i + 1] - p1[2 * i]; + let q0_delta = q0[2 * i + 1] - q0[2 * i]; + let q1_delta = q1[2 * i + 1] - q1[2 * i]; + let eq_delta = eq[2 * i + 1] - eq[2 * i]; + + let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; + let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; + let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; + let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; + let mut eq_evx = eq[2 * i + 1] + eq_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + p0_eval_at_x += p0_delta; + p1_eval_at_x += p1_delta; + q0_eval_at_x += q0_delta; + q1_eval_at_x += q1_delta; + eq_evx += eq_delta; + let round_poly_eval_at_3 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + (round_poly_eval_at_1 + a, round_poly_eval_at_2 + b, round_poly_eval_at_3 + c) + }, + ) + .reduce( + || (E::ZERO, E::ZERO, E::ZERO), + |(a0, b0, c0), (a1, b1, c1)| (a0 + a1, b0 + b1, c0 + c1), + ); + + let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; + let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); + let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); + + // reseed with the s_i polynomial + transcript.reseed(H::hash_elements(&compressed_round_poly.0)); + let round_proof = RoundProof { + round_poly_coefs: compressed_round_poly.clone(), + }; + + let round_challenge = + transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // fold each multi-linear using the round challenge + p0.bind_least_significant_variable(round_challenge); + p1.bind_least_significant_variable(round_challenge); + q0.bind_least_significant_variable(round_challenge); + q1.bind_least_significant_variable(round_challenge); + eq.bind_least_significant_variable(round_challenge); + + // compute the new reduced round claim + claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); + + round_proofs.push(round_proof); + challenges.push(round_challenge); + } + + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { + eval_point: challenges, + openings: vec![p0[0], p1[0], q0[0], q1[0]], + }, + round_proofs, + }) +} diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 7173c31bc..5cc338bb7 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -136,130 +136,6 @@ where }) } -/// Verifies sum-check proofs, as part of the GKR proof, for all GKR layers except for the last one -/// i.e., the circuit input layer. -pub fn verify_sum_check_intermediate_layers< - E: FieldElement, - H: ElementHasher, ->( - proof: &SumCheckProof, - gkr_eval_point: &[E], - claim: (E, E), - transcript: &mut impl RandomCoin, -) -> Result, SumCheckVerifierError> { - // generate challenge to batch sum-checks - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); - let r_batch: E = transcript - .draw() - .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; - - // compute the claim for the batched sum-check - let reduced_claim = claim.0 + claim.1 * r_batch; - - let SumCheckProof { openings_claim, round_proofs } = proof; - - let final_round_claim = verify_rounds(reduced_claim, round_proofs, transcript)?; - assert_eq!(openings_claim.eval_point, final_round_claim.eval_point); - - let p0 = openings_claim.openings[0]; - let p1 = openings_claim.openings[1]; - let q0 = openings_claim.openings[2]; - let q1 = openings_claim.openings[3]; - - let eq = EqFunction::new(gkr_eval_point.into()).evaluate(&openings_claim.eval_point); - - if comb_func(p0, p1, q0, q1, eq, r_batch) != final_round_claim.claim { - return Err(SumCheckVerifierError::FinalEvaluationCheckFailed); - } - - Ok(openings_claim.clone()) -} - -/// Verifies the final sum-check proof i.e., the one for the input layer, including the final check, -/// and returns a [`FinalOpeningClaim`] to the STARK verifier in order to verify the correctness of -/// the openings. -pub fn verify_sum_check_input_layer>( - evaluator: &impl LogUpGkrEvaluator, - proof: &FinalLayerProof, - log_up_randomness: Vec, - gkr_eval_point: &[E], - claim: (E, E), - transcript: &mut impl RandomCoin, -) -> Result, SumCheckVerifierError> { - let FinalLayerProof { proof } = proof; - - // generate challenge to batch sum-checks - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); - let r_batch: E = transcript - .draw() - .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; - - // compute the claim for the batched sum-check - let reduced_claim = claim.0 + claim.1 * r_batch; - - // verify the sum-check proof - let SumCheckRoundClaim { eval_point, claim } = - verify_rounds(reduced_claim, &proof.round_proofs, transcript)?; - - // execute the final evaluation check - if proof.openings_claim.eval_point != eval_point { - return Err(SumCheckVerifierError::WrongOpeningPoint); - } - - let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; - let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; - evaluator.evaluate_query( - &proof.openings_claim.openings, - &log_up_randomness, - &mut numerators, - &mut denominators, - ); - - let mu = evaluator.get_num_fractions().trailing_zeros() - 1; - let (evaluation_point_mu, evaluation_point_nu) = gkr_eval_point.split_at(mu as usize); - - let eq_mu = EqFunction::new(evaluation_point_mu.into()).evaluations(); - let eq_nu = EqFunction::new(evaluation_point_nu.into()); - - let eq_nu_eval = eq_nu.evaluate(&proof.openings_claim.eval_point); - let expected_evaluation = - evaluate_composition_poly(&eq_mu, &numerators, &denominators, eq_nu_eval, r_batch); - - if expected_evaluation != claim { - Err(SumCheckVerifierError::FinalEvaluationCheckFailed) - } else { - Ok(proof.openings_claim.clone()) - } -} - -/// Verifies a round of the sum-check protocol without executing the final check. -fn verify_rounds( - claim: E, - round_proofs: &[RoundProof], - coin: &mut impl RandomCoin, -) -> Result, SumCheckVerifierError> -where - E: FieldElement, - H: ElementHasher, -{ - let mut round_claim = claim; - let mut evaluation_point = vec![]; - for round_proof in round_proofs { - let round_poly_coefs = round_proof.round_poly_coefs.clone(); - coin.reseed(H::hash_elements(&round_poly_coefs.0)); - - let r = coin.draw().map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; - - round_claim = round_proof.round_poly_coefs.evaluate_using_claim(&round_claim, &r); - evaluation_point.push(r); - } - - Ok(SumCheckRoundClaim { - eval_point: evaluation_point, - claim: round_claim, - }) -} - #[derive(Debug, thiserror::Error)] pub enum SumCheckVerifierError { #[error("the final evaluation check of sum-check failed")] diff --git a/verifier/src/evaluator.rs b/verifier/src/evaluator.rs index 636124bd6..b66eb931f 100644 --- a/verifier/src/evaluator.rs +++ b/verifier/src/evaluator.rs @@ -131,9 +131,13 @@ pub fn evaluate_constraints>( let mean = batched_claim .mul_base(E::BaseField::ONE / E::BaseField::from(air.trace_length() as u32)); - let mut query = vec![E::ZERO; air.get_logup_gkr_evaluator::().get_oracles().len()]; - air.get_logup_gkr_evaluator::() - .build_query(main_trace_frame, &[], &mut query); + let mut query = + vec![E::ZERO; air.get_logup_gkr_evaluator::().get_oracles().len()]; + air.get_logup_gkr_evaluator::().build_query( + main_trace_frame, + &[], + &mut query, + ); let batched_claim_at_query = gkr_data.compute_batched_query::(&query); let rhs = s_cur - mean + batched_claim_at_query * l_cur; let lhs = s_nxt;