diff --git a/crates/prover/src/constraint_framework/expr/evaluator.rs b/crates/prover/src/constraint_framework/expr/evaluator.rs index 404ef600e..e3230f131 100644 --- a/crates/prover/src/constraint_framework/expr/evaluator.rs +++ b/crates/prover/src/constraint_framework/expr/evaluator.rs @@ -4,7 +4,8 @@ use super::{BaseExpr, ExtExpr}; use crate::constraint_framework::expr::ColumnExpr; use crate::constraint_framework::preprocessed_columns::PreProcessedColumnId; use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; -use crate::core::fields::m31; +use crate::core::fields::m31::{self, M31}; +use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; pub struct FormalLogupAtRow { @@ -14,6 +15,7 @@ pub struct FormalLogupAtRow { pub fracs: Vec>, pub is_finalized: bool, pub is_first: BaseExpr, + pub cumsum_shift: ExtExpr, pub log_size: u32, } @@ -29,12 +31,17 @@ impl FormalLogupAtRow { Self { interaction, // TODO(alont): Should these be Expr::SecureField? - total_sum: ExtExpr::Param(total_sum_name), + total_sum: ExtExpr::Param(total_sum_name.clone()), claimed_sum: has_partial_sum .then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), fracs: vec![], is_finalized: true, is_first: BaseExpr::zero(), + cumsum_shift: ExtExpr::Param(total_sum_name) + * BaseExpr::Inv(Box::new(BaseExpr::pow( + &BaseExpr::Const(M31(2)), + log_size as u128, + ))), log_size, } } @@ -207,8 +214,8 @@ mod tests { \ let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0, trace_2_column_6_offset_0]) \ - - (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1]) \ - - ((total_sum) * (preprocessed_is_first_16)))) \ + - (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1])) \ + + (total_sum) * (qm31(32768, 0, 0, 0))) \ * (intermediate1) \ - (qm31(1, 0, 0, 0));" .to_string(); diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index e2c2dff3c..8432e2470 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -5,7 +5,7 @@ use num_traits::{One, Zero}; use super::EvalAtRow; use crate::core::backend::simd::column::SecureColumn; -use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; @@ -42,8 +42,8 @@ impl LogupSumsExt for LogupSums { pub struct LogupAtRow { /// The index of the interaction used for the cumulative sum columns. pub interaction: usize, - /// The total sum of all the fractions. - pub total_sum: SecureField, + /// The total sum of all the fractions divided by n_rows. + pub cumsum_shift: SecureField, /// The claimed sum of the relevant fractions. /// This is used for padding the component with default rows. Padding should be in bit-reverse. /// None if the claimed_sum is the total_sum. @@ -51,9 +51,6 @@ pub struct LogupAtRow { /// The evaluation of the last cumulative sum column. pub fracs: Vec>, pub is_finalized: bool, - /// The value of the `is_first` constant column at current row. - /// See [`super::preprocessed_columns::IsFirst`]. - pub is_first: E::F, pub log_size: u32, } @@ -69,13 +66,14 @@ impl LogupAtRow { claimed_sum: Option, log_size: u32, ) -> Self { + // TODO(ShaharS): remove once claimed sum at internal index is supported. + assert!(claimed_sum.is_none(), "Partial prefix-sum is not supported"); Self { interaction, - total_sum, + cumsum_shift: total_sum / BaseField::from_u32_unchecked(1 << log_size), claimed_sum, fracs: vec![], is_finalized: true, - is_first: E::F::zero(), log_size, } } @@ -84,11 +82,10 @@ impl LogupAtRow { pub fn dummy() -> Self { Self { interaction: 100, - total_sum: SecureField::one(), + cumsum_shift: SecureField::one(), claimed_sum: None, fracs: vec![], is_finalized: true, - is_first: E::F::zero(), log_size: 10, } } @@ -183,14 +180,46 @@ impl LogupTraceGenerator { } /// Finalize the trace. Returns the trace and the total sum of the last column. + /// The last column is shifted by the cumsum_shift. pub fn finalize_last( - self, + mut self, ) -> ( ColumnVec>, SecureField, ) { - let log_size = self.log_size; - let (trace, [total_sum]) = self.finalize_at([(1 << log_size) - 1]); + let mut last_col_coords = self.trace.pop().unwrap().columns; + + // Compute cumsum_shift. + let coordinate_sums = last_col_coords.each_ref().map(|c| { + c.data + .iter() + .copied() + .sum::() + .pointwise_sum() + }); + let total_sum = SecureField::from_m31_array(coordinate_sums); + let cumsum_shift = total_sum / BaseField::from_u32_unchecked(1 << self.log_size); + let packed_cumsum_shift = PackedSecureField::broadcast(cumsum_shift); + + last_col_coords.iter_mut().enumerate().for_each(|(i, c)| { + c.data + .iter_mut() + .for_each(|x| *x -= packed_cumsum_shift.into_packed_m31s()[i]) + }); + let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum); + let secure_prefix_sum = SecureColumnByCoords { + columns: coord_prefix_sum, + }; + self.trace.push(secure_prefix_sum); + let trace = self + .trace + .into_iter() + .flat_map(|eval| { + eval.columns.map(|col| { + CircleEvaluation::new(CanonicCoset::new(self.log_size).circle_domain(), col) + }) + }) + .collect_vec(); (trace, total_sum) } diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index ca037ec18..f423e459a 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -172,12 +172,6 @@ macro_rules! logup_proxy { () => { fn write_logup_frac(&mut self, fraction: Fraction) { if self.logup.fracs.is_empty() { - self.logup.is_first = self.get_preprocessed_column( - crate::constraint_framework::preprocessed_columns::IsFirst::new( - self.logup.log_size, - ) - .id(), - ); self.logup.is_finalized = false; } self.logup.fracs.push(fraction.clone()); @@ -188,6 +182,11 @@ macro_rules! logup_proxy { /// `batching` should contain the batch into which every logup entry should be inserted. fn finalize_logup_batched(&mut self, batching: &crate::constraint_framework::Batching) { assert!(!self.logup.is_finalized, "LogupAtRow was already finalized"); + + assert!( + self.logup.claimed_sum.is_none(), + "Partial prefix-sum is not supported" + ); assert_eq!( batching.len(), self.logup.fracs.len(), @@ -227,35 +226,16 @@ macro_rules! logup_proxy { } let frac: Fraction<_, _> = fracs_by_batch[&last_batch].clone().into_iter().sum(); + let [prev_row_cumsum, cur_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]); + + let diff = cur_cumsum - prev_row_cumsum - prev_col_cumsum.clone(); + // Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift. + // This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint + // uniform - apply on all rows. + let fixed_diff = diff + self.logup.cumsum_shift.clone(); - // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted - // offset from the is_first column when constant columns are supported. - let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum.clone() { - Some((claimed_sum, claimed_row_index)) => { - let [prev_row_cumsum, cur_cumsum, claimed_cumsum] = self - .next_extension_interaction_mask( - self.logup.interaction, - [-1, 0, claimed_row_index as isize], - ); - - // Constrain that the claimed_sum in case that it is not equal to the total_sum. - self.add_constraint( - (claimed_cumsum - claimed_sum) * self.logup.is_first.clone(), - ); - (cur_cumsum, prev_row_cumsum) - } - None => { - let [prev_row_cumsum, cur_cumsum] = - self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]); - (cur_cumsum, prev_row_cumsum) - } - }; - // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. - let fixed_prev_row_cumsum = - prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone(); - let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone(); - - self.add_constraint(diff * frac.denominator - frac.numerator); + self.add_constraint(fixed_diff * frac.denominator - frac.numerator); self.logup.is_finalized = true; } diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 841079c8d..9926e7c81 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -3,7 +3,7 @@ use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId}; +use crate::constraint_framework::preprocessed_columns::PreProcessedColumnId; use crate::constraint_framework::{ assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, TraceLocationAllocator, @@ -192,7 +192,6 @@ pub fn prove_fibonacci_plonk( // Preprocessed trace. let span = span!(Level::INFO, "Constant").entered(); let mut tree_builder = commitment_scheme.tree_builder(); - let is_first = IsFirst::new(log_n_rows).gen_column_simd(); let mut constant_trace = [ circuit.a_wire.clone(), circuit.b_wire.clone(), @@ -207,7 +206,6 @@ pub fn prove_fibonacci_plonk( ) }) .collect_vec(); - constant_trace.push(is_first); let constants_trace_location = tree_builder.extend_evals(constant_trace); tree_builder.commit(channel); span.exit(); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 3f45e329c..69fc0aa91 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -7,7 +7,6 @@ use num_traits::One; use tracing::{info, span, Level}; use crate::constraint_framework::logup::LogupTraceGenerator; -use crate::constraint_framework::preprocessed_columns::IsFirst; use crate::constraint_framework::{ relation, EvalAtRow, FrameworkComponent, FrameworkEval, Relation, RelationEntry, TraceLocationAllocator, @@ -349,7 +348,7 @@ pub fn prove_poseidon( // Preprocessed trace. let span = span!(Level::INFO, "Constant").entered(); let mut tree_builder = commitment_scheme.tree_builder(); - let constant_trace = vec![IsFirst::new(log_n_rows).gen_column_simd()]; + let constant_trace = vec![]; tree_builder.extend_evals(constant_trace); tree_builder.commit(channel); span.exit(); @@ -397,7 +396,6 @@ mod tests { use num_traits::One; use crate::constraint_framework::assert_constraints; - use crate::constraint_framework::preprocessed_columns::IsFirst; use crate::core::air::Component; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; @@ -472,11 +470,7 @@ mod tests { let (trace1, total_sum) = gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements); - let traces = TreeVec::new(vec![ - vec![IsFirst::new(LOG_N_ROWS).gen_column_simd()], - trace0, - trace1, - ]); + let traces = TreeVec::new(vec![vec![], trace0, trace1]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); assert_constraints( diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index f7afc675d..3b0cc57fa 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -83,7 +83,7 @@ impl StateMachineStatement0 { .map_cols(|_| self.m), ]; let mut log_sizes = TreeVec::concat_cols(sizes.into_iter()); - log_sizes[PREPROCESSED_TRACE_IDX] = vec![self.n, self.m]; + log_sizes[PREPROCESSED_TRACE_IDX] = vec![]; log_sizes } pub fn mix_into(&self, channel: &mut impl Channel) { diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index d892621af..5b81762bf 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -11,7 +11,6 @@ use components::{ use gen::{gen_interaction_trace, gen_trace}; use itertools::{chain, Itertools}; -use crate::constraint_framework::preprocessed_columns::IsFirst; use crate::constraint_framework::TraceLocationAllocator; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::SimdBackend; @@ -54,18 +53,6 @@ pub fn prove_state_machine( let mut commitment_scheme = CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); - let preprocessed_columns = [IsFirst::new(x_axis_log_rows), IsFirst::new(y_axis_log_rows)]; - - // Preprocessed trace. - let preprocessed_trace = preprocessed_columns - .into_iter() - .map(|col| col.gen_column_simd()) - .collect(); - let preprocessed_columns = [ - IsFirst::new(x_axis_log_rows).id(), - IsFirst::new(y_axis_log_rows).id(), - ]; - // Trace. let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0); let trace_op1 = gen_trace(y_axis_log_rows, intermediate_state, 1); @@ -76,7 +63,7 @@ pub fn prove_state_machine( false => None, true => Some(RelationSummary::summarize_relations( &track_state_machine_relations( - &TreeVec(vec![&preprocessed_trace, &trace]), + &TreeVec(vec![&vec![], &trace]), x_axis_log_rows, y_axis_log_rows, ), @@ -85,7 +72,6 @@ pub fn prove_state_machine( // Commitments. let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(preprocessed_trace); tree_builder.commit(channel); let stmt0 = StateMachineStatement0 { @@ -138,8 +124,6 @@ pub fn prove_state_machine( (total_sum_op1, None), ); - tree_span_provider.validate_preprocessed_columns(&preprocessed_columns); - let components = StateMachineComponents { component0, component1, @@ -204,7 +188,6 @@ mod tests { use super::gen::{gen_interaction_trace, gen_trace}; use super::{prove_state_machine, verify_state_machine}; use crate::constraint_framework::expr::ExprEvaluator; - use crate::constraint_framework::preprocessed_columns::IsFirst; use crate::constraint_framework::{ assert_constraints, FrameworkEval, Relation, TraceLocationAllocator, }; @@ -236,11 +219,7 @@ mod tests { (total_sum, None), ); - let trace = TreeVec::new(vec![ - vec![IsFirst::new(log_n_rows).gen_column_simd()], - trace, - interaction_trace, - ]); + let trace = TreeVec::new(vec![vec![], trace, interaction_trace]); let trace_polys = trace.map_cols(|c| c.interpolate()); assert_constraints( &trace_polys, @@ -352,7 +331,7 @@ mod tests { (total_sum, None), ); - let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); + let eval = component.evaluate(ExprEvaluator::new(log_n_rows, false)); let expected = "let intermediate0 = (StateMachineElements_alpha0) * (trace_1_column_0_offset_0) \ + (StateMachineElements_alpha1) * (trace_1_column_1_offset_0) \ - (StateMachineElements_z); @@ -363,18 +342,9 @@ mod tests { - (StateMachineElements_z); \ - let constraint_0 = (QM31Impl::from_partial_evals([\ - trace_2_column_2_offset_claimed_sum, \ - trace_2_column_3_offset_claimed_sum, \ - trace_2_column_4_offset_claimed_sum, \ - trace_2_column_5_offset_claimed_sum\ - ]) - (claimed_sum)) \ - * (preprocessed_is_first_8); - -\ - let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_2_offset_0, trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0]) \ - - (QM31Impl::from_partial_evals([trace_2_column_2_offset_neg_1, trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1]) \ - - ((total_sum) * (preprocessed_is_first_8)))\ + let constraint_0 = (QM31Impl::from_partial_evals([trace_2_column_2_offset_0, trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0]) \ + - (QM31Impl::from_partial_evals([trace_2_column_2_offset_neg_1, trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1])) \ + + (total_sum) * (qm31(8388608, 0, 0, 0))\ ) \ * ((intermediate0) * (intermediate1)) \ - (intermediate1 - (intermediate0));"