Skip to content

Commit

Permalink
Logup cumsum constraint with cumsum_shift
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Jan 13, 2025
1 parent 1f7dbdd commit bbbe8a7
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 92 deletions.
15 changes: 11 additions & 4 deletions crates/prover/src/constraint_framework/expr/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ColumnExpr;
use crate::constraint_framework::preprocessed_columns::PreprocessedColumn;
use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31;
use crate::core::fields::m31::{self, M31};
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

pub struct FormalLogupAtRow {
Expand All @@ -14,6 +15,7 @@ pub struct FormalLogupAtRow {
pub fracs: Vec<Fraction<ExtExpr, ExtExpr>>,
pub is_finalized: bool,
pub is_first: BaseExpr,
pub cumsum_shift: ExtExpr,
pub log_size: u32,
}

Expand All @@ -29,12 +31,17 @@ impl FormalLogupAtRow {
Self {
interaction,
// TODO(alont): Should these be Expr::SecureField?
total_sum: ExtExpr::Param(total_sum_name),
total_sum: ExtExpr::Param(total_sum_name.clone()),
claimed_sum: has_partial_sum
.then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)),
fracs: vec![],
is_finalized: true,
is_first: BaseExpr::zero(),
cumsum_shift: ExtExpr::Param(total_sum_name)
* BaseExpr::Inv(Box::new(BaseExpr::pow(
&BaseExpr::Const(M31(2)),
log_size as u128,
))),
log_size,
}
}
Expand Down Expand Up @@ -207,8 +214,8 @@ mod tests {
\
let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0, trace_2_column_6_offset_0]) \
- (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1]) \
- ((total_sum) * (preprocessed_is_first)))) \
- (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1])) \
+ (total_sum) * (qm31(32768, 0, 0, 0))) \
* (intermediate1) \
- (qm31(1, 0, 0, 0));"
.to_string();
Expand Down
55 changes: 42 additions & 13 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use num_traits::{One, Zero};

use super::EvalAtRow;
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
Expand Down Expand Up @@ -42,18 +42,15 @@ impl LogupSumsExt for LogupSums {
pub struct LogupAtRow<E: EvalAtRow> {
/// The index of the interaction used for the cumulative sum columns.
pub interaction: usize,
/// The total sum of all the fractions.
pub total_sum: SecureField,
/// The total sum of all the fractions divided by n_rows.
pub cumsum_shift: SecureField,
/// The claimed sum of the relevant fractions.
/// This is used for padding the component with default rows. Padding should be in bit-reverse.
/// None if the claimed_sum is the total_sum.
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub fracs: Vec<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::gen_is_first()`].
pub is_first: E::F,
pub log_size: u32,
}

Expand All @@ -69,13 +66,14 @@ impl<E: EvalAtRow> LogupAtRow<E> {
claimed_sum: Option<ClaimedPrefixSum>,
log_size: u32,
) -> Self {
// TODO(ShaharS): remove once claimed sum at internal index is supported.
assert!(claimed_sum.is_none(), "Partial prefix-sum is not supported");
Self {
interaction,
total_sum,
cumsum_shift: total_sum / BaseField::from_u32_unchecked(1 << log_size),
claimed_sum,
fracs: vec![],
is_finalized: true,
is_first: E::F::zero(),
log_size,
}
}
Expand All @@ -84,11 +82,10 @@ impl<E: EvalAtRow> LogupAtRow<E> {
pub fn dummy() -> Self {
Self {
interaction: 100,
total_sum: SecureField::one(),
cumsum_shift: SecureField::one(),
claimed_sum: None,
fracs: vec![],
is_finalized: true,
is_first: E::F::zero(),
log_size: 10,
}
}
Expand Down Expand Up @@ -183,14 +180,46 @@ impl LogupTraceGenerator {
}

/// Finalize the trace. Returns the trace and the total sum of the last column.
/// The last column is shifted by the cumsum_shift.
pub fn finalize_last(
self,
mut self,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
) {
let log_size = self.log_size;
let (trace, [total_sum]) = self.finalize_at([(1 << log_size) - 1]);
let mut last_col_coords = self.trace.pop().unwrap().columns;

// Compute cumsum_shift.
let coordinate_sums = last_col_coords.each_ref().map(|c| {
c.data
.iter()
.copied()
.sum::<PackedBaseField>()
.pointwise_sum()
});
let total_sum = SecureField::from_m31_array(coordinate_sums);
let cumsum_shift = total_sum / BaseField::from_u32_unchecked(1 << self.log_size);
let packed_cumsum_shift = PackedSecureField::broadcast(cumsum_shift);

last_col_coords.iter_mut().enumerate().for_each(|(i, c)| {
c.data
.iter_mut()
.for_each(|x| *x -= packed_cumsum_shift.into_packed_m31s()[i])
});
let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum);
let secure_prefix_sum = SecureColumnByCoords {
columns: coord_prefix_sum,
};
self.trace.push(secure_prefix_sum);
let trace = self
.trace
.into_iter()
.flat_map(|eval| {
eval.columns.map(|col| {
CircleEvaluation::new(CanonicCoset::new(self.log_size).circle_domain(), col)
})
})
.collect_vec();
(trace, total_sum)
}

Expand Down
47 changes: 14 additions & 33 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,6 @@ macro_rules! logup_proxy {
() => {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst(
self.logup.log_size,
),
);
self.logup.is_finalized = false;
}
self.logup.fracs.push(fraction.clone());
Expand All @@ -187,6 +182,11 @@ macro_rules! logup_proxy {
/// `batching` should contain the batch into which every logup entry should be inserted.
fn finalize_logup_batched(&mut self, batching: &crate::constraint_framework::Batching) {
assert!(!self.logup.is_finalized, "LogupAtRow was already finalized");

assert!(
self.logup.claimed_sum.is_none(),
"Partial prefix-sum is not supported"
);
assert_eq!(
batching.len(),
self.logup.fracs.len(),
Expand Down Expand Up @@ -226,35 +226,16 @@ macro_rules! logup_proxy {
}

let frac: Fraction<_, _> = fracs_by_batch[&last_batch].clone().into_iter().sum();
let [prev_row_cumsum, cur_cumsum] =
self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]);

let diff = cur_cumsum - prev_row_cumsum - prev_col_cumsum.clone();
// Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift.
// This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint
// uniform - apply on all rows.
let fixed_diff = diff + self.logup.cumsum_shift.clone();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted
// offset from the is_first column when constant columns are supported.
let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum.clone() {
Some((claimed_sum, claimed_row_index)) => {
let [prev_row_cumsum, cur_cumsum, claimed_cumsum] = self
.next_extension_interaction_mask(
self.logup.interaction,
[-1, 0, claimed_row_index as isize],
);

// Constrain that the claimed_sum in case that it is not equal to the total_sum.
self.add_constraint(
(claimed_cumsum - claimed_sum) * self.logup.is_first.clone(),
);
(cur_cumsum, prev_row_cumsum)
}
None => {
let [prev_row_cumsum, cur_cumsum] =
self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]);
(cur_cumsum, prev_row_cumsum)
}
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum =
prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone();
let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone();

self.add_constraint(diff * frac.denominator - frac.numerator);
self.add_constraint(fixed_diff * frac.denominator - frac.numerator);

self.logup.is_finalized = true;
}
Expand Down
4 changes: 1 addition & 3 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use num_traits::One;
use tracing::{span, Level};

use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn};
use crate::constraint_framework::preprocessed_columns::PreprocessedColumn;
use crate::constraint_framework::{
assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
TraceLocationAllocator,
Expand Down Expand Up @@ -192,7 +192,6 @@ pub fn prove_fibonacci_plonk(
// Preprocessed trace.
let span = span!(Level::INFO, "Constant").entered();
let mut tree_builder = commitment_scheme.tree_builder();
let is_first = gen_is_first(log_n_rows);
let mut constant_trace = [
circuit.a_wire.clone(),
circuit.b_wire.clone(),
Expand All @@ -207,7 +206,6 @@ pub fn prove_fibonacci_plonk(
)
})
.collect_vec();
constant_trace.push(is_first);
let constants_trace_location = tree_builder.extend_evals(constant_trace);
tree_builder.commit(channel);
span.exit();
Expand Down
6 changes: 2 additions & 4 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use num_traits::One;
use tracing::{info, span, Level};

use crate::constraint_framework::logup::LogupTraceGenerator;
use crate::constraint_framework::preprocessed_columns::gen_is_first;
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkComponent, FrameworkEval, Relation, RelationEntry,
TraceLocationAllocator,
Expand Down Expand Up @@ -349,7 +348,7 @@ pub fn prove_poseidon(
// Preprocessed trace.
let span = span!(Level::INFO, "Constant").entered();
let mut tree_builder = commitment_scheme.tree_builder();
let constant_trace = vec![gen_is_first(log_n_rows)];
let constant_trace = vec![];
tree_builder.extend_evals(constant_trace);
tree_builder.commit(channel);
span.exit();
Expand Down Expand Up @@ -397,7 +396,6 @@ mod tests {
use num_traits::One;

use crate::constraint_framework::assert_constraints;
use crate::constraint_framework::preprocessed_columns::gen_is_first;
use crate::core::air::Component;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -472,7 +470,7 @@ mod tests {
let (trace1, total_sum) =
gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements);

let traces = TreeVec::new(vec![vec![gen_is_first(LOG_N_ROWS)], trace0, trace1]);
let traces = TreeVec::new(vec![vec![], trace0, trace1]);
let trace_polys =
traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec());
assert_constraints(
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl StateMachineStatement0 {
.map_cols(|_| self.m),
];
let mut log_sizes = TreeVec::concat_cols(sizes.into_iter());
log_sizes[PREPROCESSED_TRACE_IDX] = vec![self.n, self.m];
log_sizes[PREPROCESSED_TRACE_IDX] = vec![];
log_sizes
}
pub fn mix_into(&self, channel: &mut impl Channel) {
Expand Down
40 changes: 6 additions & 34 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ use components::{
use gen::{gen_interaction_trace, gen_trace};
use itertools::{chain, Itertools};

use crate::constraint_framework::preprocessed_columns::{
gen_preprocessed_columns, PreprocessedColumn,
};
use crate::constraint_framework::TraceLocationAllocator;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::SimdBackend;
Expand Down Expand Up @@ -56,14 +53,6 @@ pub fn prove_state_machine(
let mut commitment_scheme =
CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

let preprocessed_columns = [
PreprocessedColumn::IsFirst(x_axis_log_rows),
PreprocessedColumn::IsFirst(y_axis_log_rows),
];

// Preprocessed trace.
let preprocessed_trace = gen_preprocessed_columns(preprocessed_columns.iter());

// Trace.
let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0);
let trace_op1 = gen_trace(y_axis_log_rows, intermediate_state, 1);
Expand All @@ -74,7 +63,7 @@ pub fn prove_state_machine(
false => None,
true => Some(RelationSummary::summarize_relations(
&track_state_machine_relations(
&TreeVec(vec![&preprocessed_trace, &trace]),
&TreeVec(vec![&vec![], &trace]),
x_axis_log_rows,
y_axis_log_rows,
),
Expand All @@ -83,7 +72,6 @@ pub fn prove_state_machine(

// Commitments.
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(preprocessed_trace);
tree_builder.commit(channel);

let stmt0 = StateMachineStatement0 {
Expand Down Expand Up @@ -136,8 +124,6 @@ pub fn prove_state_machine(
(total_sum_op1, None),
);

tree_span_provider.validate_preprocessed_columns(&preprocessed_columns);

let components = StateMachineComponents {
component0,
component1,
Expand Down Expand Up @@ -202,7 +188,6 @@ mod tests {
use super::gen::{gen_interaction_trace, gen_trace};
use super::{prove_state_machine, verify_state_machine};
use crate::constraint_framework::expr::ExprEvaluator;
use crate::constraint_framework::preprocessed_columns::gen_is_first;
use crate::constraint_framework::{
assert_constraints, FrameworkEval, Relation, TraceLocationAllocator,
};
Expand Down Expand Up @@ -234,11 +219,7 @@ mod tests {
(total_sum, None),
);

let trace = TreeVec::new(vec![
vec![gen_is_first(log_n_rows)],
trace,
interaction_trace,
]);
let trace = TreeVec::new(vec![vec![], trace, interaction_trace]);
let trace_polys = trace.map_cols(|c| c.interpolate());
assert_constraints(
&trace_polys,
Expand Down Expand Up @@ -350,7 +331,7 @@ mod tests {
(total_sum, None),
);

let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));
let eval = component.evaluate(ExprEvaluator::new(log_n_rows, false));
let expected = "let intermediate0 = (StateMachineElements_alpha0) * (trace_1_column_0_offset_0) \
+ (StateMachineElements_alpha1) * (trace_1_column_1_offset_0) \
- (StateMachineElements_z);
Expand All @@ -361,18 +342,9 @@ mod tests {
- (StateMachineElements_z);
\
let constraint_0 = (QM31Impl::from_partial_evals([\
trace_2_column_2_offset_claimed_sum, \
trace_2_column_3_offset_claimed_sum, \
trace_2_column_4_offset_claimed_sum, \
trace_2_column_5_offset_claimed_sum\
]) - (claimed_sum)) \
* (preprocessed_is_first);
\
let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_2_offset_0, trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0]) \
- (QM31Impl::from_partial_evals([trace_2_column_2_offset_neg_1, trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1]) \
- ((total_sum) * (preprocessed_is_first)))\
let constraint_0 = (QM31Impl::from_partial_evals([trace_2_column_2_offset_0, trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0]) \
- (QM31Impl::from_partial_evals([trace_2_column_2_offset_neg_1, trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1])) \
+ (total_sum) * (qm31(8388608, 0, 0, 0))\
) \
* ((intermediate0) * (intermediate1)) \
- (intermediate1 - (intermediate0));"
Expand Down

0 comments on commit bbbe8a7

Please sign in to comment.