From 16f372f31d2ff534c6f21eb30b68432c7aaba628 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Mon, 23 Sep 2024 18:33:52 -1000 Subject: [PATCH] Make CommitmentSchemeProver::prove_values take ownership --- crates/prover/src/core/fri.rs | 2 +- crates/prover/src/core/pcs/prover.rs | 8 +- crates/prover/src/core/prover/mod.rs | 103 +++++++++--------- crates/prover/src/core/vcs/prover.rs | 11 +- crates/prover/src/core/vcs/test_utils.rs | 2 +- crates/prover/src/examples/blake/air.rs | 2 +- crates/prover/src/examples/plonk/mod.rs | 4 +- crates/prover/src/examples/poseidon/mod.rs | 4 +- .../prover/src/examples/wide_fibonacci/mod.rs | 12 +- 9 files changed, 68 insertions(+), 80 deletions(-) diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index ce5d4bddb..528d457dc 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -821,7 +821,7 @@ impl, H: MerkleHasher> FriLayerProver { let commitment = self.merkle_tree.root(); // TODO(andrew): Use _evals. let (_evals, decommitment) = self.merkle_tree.decommit( - [(self.evaluation.len().ilog2(), decommit_positions)] + &[(self.evaluation.len().ilog2(), decommit_positions)] .into_iter() .collect(), self.evaluation.values.columns.iter().collect_vec(), diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index e5ae1b266..aed9ffbfe 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -82,7 +82,7 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, } pub fn prove_values( - &self, + self, sampled_points: TreeVec>>>, channel: &mut MC::C, ) -> CommitmentSchemeProof { @@ -134,13 +134,14 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, .iter() .map(|(&log_size, domain)| (log_size, domain.flatten())) .collect(); - tree.decommit(queries) + tree.decommit(&queries) }); let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone()); let decommitments = decommitment_results.map(|(_, d)| d); CommitmentSchemeProof { + commitments: self.roots(), sampled_values, decommitments, queried_values, @@ -152,6 +153,7 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, #[derive(Debug, Serialize, Deserialize)] pub struct CommitmentSchemeProof { + pub commitments: TreeVec, pub sampled_values: TreeVec>>, pub decommitments: TreeVec>, pub queried_values: TreeVec>>, @@ -243,7 +245,7 @@ impl, MC: MerkleChannel> CommitmentTreeProver { /// positions on each column of that size. fn decommit( &self, - queries: BTreeMap>, + queries: &BTreeMap>, ) -> (ColumnVec>, MerkleDecommitment) { let eval_vec = self .evaluations diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index acc7f6c01..53b2265f3 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -1,3 +1,4 @@ +use std::ops::Deref; use std::{array, mem}; use serde::{Deserialize, Serialize}; @@ -9,7 +10,7 @@ use super::backend::BackendForChannel; use super::channel::MerkleChannel; use super::fields::secure_column::SECURE_EXTENSION_DEGREE; use super::fri::FriVerificationError; -use super::pcs::{CommitmentSchemeProof, TreeVec}; +use super::pcs::CommitmentSchemeProof; use super::vcs::ops::MerkleHasher; use crate::core::channel::Channel; use crate::core::circle::CirclePoint; @@ -21,17 +22,11 @@ use crate::core::vcs::hash::Hash; use crate::core::vcs::prover::MerkleDecommitment; use crate::core::vcs::verifier::MerkleVerificationError; -#[derive(Debug, Serialize, Deserialize)] -pub struct StarkProof { - pub commitments: TreeVec, - pub commitment_scheme_proof: CommitmentSchemeProof, -} - #[instrument(skip_all)] pub fn prove, MC: MerkleChannel>( components: &[&dyn ComponentProver], channel: &mut MC::C, - commitment_scheme: &mut CommitmentSchemeProver<'_, B, MC>, + mut commitment_scheme: CommitmentSchemeProver<'_, B, MC>, ) -> Result, ProvingError> { let component_provers = ComponentProvers(components.to_vec()); let trace = commitment_scheme.trace(); @@ -59,25 +54,19 @@ pub fn prove, MC: MerkleChannel>( // Prove the trace and composition OODS values, and retrieve them. let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel); - - let sampled_oods_values = &commitment_scheme_proof.sampled_values; - let composition_oods_eval = extract_composition_eval(sampled_oods_values).unwrap(); + let proof = StarkProof(commitment_scheme_proof); + info!(proof_size_estimate = proof.size_estimate()); // Evaluate composition polynomial at OODS point and check that it matches the trace OODS // values. This is a sanity check. - if composition_oods_eval + if proof.extract_composition_oods_eval().unwrap() != component_provers .components() - .eval_composition_polynomial_at_point(oods_point, sampled_oods_values, random_coeff) + .eval_composition_polynomial_at_point(oods_point, &proof.sampled_values, random_coeff) { return Err(ProvingError::ConstraintsNotSatisfied); } - let proof = StarkProof { - commitments: commitment_scheme.roots(), - commitment_scheme_proof, - }; - info!(proof_size_estimate = proof.size_estimate()); Ok(proof) } @@ -105,42 +94,21 @@ pub fn verify( // Add the composition polynomial mask points. sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); - let sampled_oods_values = &proof.commitment_scheme_proof.sampled_values; - let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| { + let composition_oods_eval = proof.extract_composition_oods_eval().map_err(|_| { VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) })?; if composition_oods_eval != components.eval_composition_polynomial_at_point( oods_point, - sampled_oods_values, + &proof.sampled_values, random_coeff, ) { return Err(VerificationError::OodsNotMatching); } - commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel) -} - -/// Extracts the composition trace evaluation from the mask. -fn extract_composition_eval( - mask: &TreeVec>>, -) -> Result { - let mut composition_cols = mask.last().into_iter().flatten(); - - let coordinate_evals = array::try_from_fn(|_| { - let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; - let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?; - Ok(eval) - })?; - - // Too many columns. - if composition_cols.next().is_some() { - return Err(InvalidOodsSampleStructure); - } - - Ok(SecureField::from_partial_evals(coordinate_evals)) + commitment_scheme.verify_values(sample_points, proof.0, channel) } /// Error when the sampled values have an invalid structure. @@ -172,7 +140,33 @@ pub enum VerificationError { ProofOfWork, } +#[derive(Debug, Serialize, Deserialize)] +pub struct StarkProof(pub CommitmentSchemeProof); + impl StarkProof { + /// Extracts the composition trace Out-Of-Domain-Sample evaluation from the mask. + fn extract_composition_oods_eval(&self) -> Result { + // TODO(andrew): `[.., composition_mask, _quotients_mask]` when add quotients commitment. + let [.., composition_mask] = &**self.sampled_values else { + return Err(InvalidOodsSampleStructure); + }; + + let mut composition_cols = composition_mask.iter(); + + let coordinate_evals = array::try_from_fn(|_| { + let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; + let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?; + Ok(eval) + })?; + + // Too many columns. + if composition_cols.next().is_some() { + return Err(InvalidOodsSampleStructure); + } + + Ok(SecureField::from_partial_evals(coordinate_evals)) + } + /// Returns the estimate size (in bytes) of the proof. pub fn size_estimate(&self) -> usize { SizeEstimate::size_estimate(self) @@ -180,12 +174,10 @@ impl StarkProof { /// Returns size estimates (in bytes) for different parts of the proof. pub fn size_breakdown_estimate(&self) -> StarkProofSizeBreakdown { - let Self { - commitments, - commitment_scheme_proof, - } = self; + let Self(commitment_scheme_proof) = self; let CommitmentSchemeProof { + commitments, sampled_values, decommitments, queried_values, @@ -221,6 +213,14 @@ impl StarkProof { } } +impl Deref for StarkProof { + type Target = CommitmentSchemeProof; + + fn deref(&self) -> &CommitmentSchemeProof { + &self.0 + } +} + /// Size estimate (in bytes) for different parts of the proof. pub struct StarkProofSizeBreakdown { pub oods_samples: usize, @@ -298,13 +298,15 @@ impl SizeEstimate for FriProof { impl SizeEstimate for CommitmentSchemeProof { fn size_estimate(&self) -> usize { let Self { + commitments, sampled_values, decommitments, queried_values, proof_of_work, fri_proof, } = self; - sampled_values.size_estimate() + commitments.size_estimate() + + sampled_values.size_estimate() + decommitments.size_estimate() + queried_values.size_estimate() + mem::size_of_val(proof_of_work) @@ -314,11 +316,8 @@ impl SizeEstimate for CommitmentSchemeProof { impl SizeEstimate for StarkProof { fn size_estimate(&self) -> usize { - let Self { - commitments, - commitment_scheme_proof, - } = self; - commitments.size_estimate() + commitment_scheme_proof.size_estimate() + let Self(commitment_scheme_proof) = self; + commitment_scheme_proof.size_estimate() } } diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index 6312de114..136821079 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -75,18 +75,9 @@ impl, H: MerkleHasher> MerkleProver { /// * A `MerkleDecommitment` containing the hash and column witnesses. pub fn decommit( &self, - queries_per_log_size: BTreeMap>, + queries_per_log_size: &BTreeMap>, columns: Vec<&Col>, ) -> (ColumnVec>, MerkleDecommitment) { - // Check that queries are sorted and deduped. - // TODO(andrew): Consider using a Queries struct to prevent this. - for queries in queries_per_log_size.values() { - assert!( - queries.windows(2).all(|w| w[0] < w[1]), - "Queries are not sorted." - ); - } - // Prepare output buffers. let mut queried_values_by_layer = vec![]; let mut decommitment = MerkleDecommitment::empty(); diff --git a/crates/prover/src/core/vcs/test_utils.rs b/crates/prover/src/core/vcs/test_utils.rs index 8fc535b05..b92f9e971 100644 --- a/crates/prover/src/core/vcs/test_utils.rs +++ b/crates/prover/src/core/vcs/test_utils.rs @@ -50,7 +50,7 @@ where queries.insert(log_size, layer_queries); } - let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec()); + let (values, decommitment) = merkle.decommit(&queries, cols.iter().collect_vec()); let verifier = MerkleVerifier { root: merkle.root(), diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index ec4de8c0b..fa918b8b4 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -250,7 +250,7 @@ where // Setup protocol. let channel = &mut MC::C::default(); - let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); + let mut commitment_scheme = CommitmentSchemeProver::new(config, &twiddles); let span = span!(Level::INFO, "Trace").entered(); diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 2832d0cca..803396899 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -189,8 +189,8 @@ pub fn prove_fibonacci_plonk( // Setup protocol. let channel = &mut Blake2sChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + let mut commitment_scheme = + CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); // Trace. let span = span!(Level::INFO, "Trace").entered(); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 269214ca1..66635841a 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -347,8 +347,8 @@ pub fn prove_poseidon( // Setup protocol. let channel = &mut Blake2sChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + let mut commitment_scheme = + CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); // Trace. let span = span!(Level::INFO, "Trace").entered(); diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 2d9672b5e..23d9c7e20 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -180,10 +180,8 @@ mod tests { // Setup protocol. let prover_channel = &mut Blake2sChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::::new( - config, &twiddles, - ); + let mut commitment_scheme = + CommitmentSchemeProver::::new(config, &twiddles); // Trace. let trace = generate_test_trace(log_n_instances); @@ -232,10 +230,8 @@ mod tests { // Setup protocol. let prover_channel = &mut Poseidon252Channel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::::new( - config, &twiddles, - ); + let mut commitment_scheme = + CommitmentSchemeProver::::new(config, &twiddles); // Trace. let trace = generate_test_trace(LOG_N_INSTANCES);