Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logup cumsum constraint with cumsum_shift #978

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions crates/prover/src/constraint_framework/expr/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ColumnExpr;
use crate::constraint_framework::preprocessed_columns::PreProcessedColumnId;
use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31;
use crate::core::fields::m31::{self, M31};
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

pub struct FormalLogupAtRow {
Expand All @@ -14,6 +15,7 @@ pub struct FormalLogupAtRow {
pub fracs: Vec<Fraction<ExtExpr, ExtExpr>>,
pub is_finalized: bool,
pub is_first: BaseExpr,
pub cumsum_shift: ExtExpr,
pub log_size: u32,
}

Expand All @@ -29,12 +31,17 @@ impl FormalLogupAtRow {
Self {
interaction,
// TODO(alont): Should these be Expr::SecureField?
total_sum: ExtExpr::Param(total_sum_name),
total_sum: ExtExpr::Param(total_sum_name.clone()),
claimed_sum: has_partial_sum
.then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)),
fracs: vec![],
is_finalized: true,
is_first: BaseExpr::zero(),
cumsum_shift: ExtExpr::Param(total_sum_name)
* BaseExpr::Inv(Box::new(BaseExpr::pow(
&BaseExpr::Const(M31(2)),
log_size as u128,
))),
log_size,
}
}
Expand Down Expand Up @@ -207,8 +214,8 @@ mod tests {

\
let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0, trace_2_column_6_offset_0]) \
- (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1]) \
- ((total_sum) * (preprocessed_is_first_16)))) \
- (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1])) \
+ (total_sum) * (qm31(32768, 0, 0, 0))) \
* (intermediate1) \
- (qm31(1, 0, 0, 0));"
.to_string();
Expand Down
55 changes: 42 additions & 13 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use num_traits::{One, Zero};

use super::EvalAtRow;
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
Expand Down Expand Up @@ -42,18 +42,15 @@ impl LogupSumsExt for LogupSums {
pub struct LogupAtRow<E: EvalAtRow> {
/// The index of the interaction used for the cumulative sum columns.
pub interaction: usize,
/// The total sum of all the fractions.
pub total_sum: SecureField,
/// The total sum of all the fractions divided by n_rows.
pub cumsum_shift: SecureField,
/// The claimed sum of the relevant fractions.
/// This is used for padding the component with default rows. Padding should be in bit-reverse.
/// None if the claimed_sum is the total_sum.
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub fracs: Vec<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::IsFirst`].
pub is_first: E::F,
pub log_size: u32,
}

Expand All @@ -69,13 +66,14 @@ impl<E: EvalAtRow> LogupAtRow<E> {
claimed_sum: Option<ClaimedPrefixSum>,
log_size: u32,
) -> Self {
// TODO(ShaharS): remove once claimed sum at internal index is supported.
assert!(claimed_sum.is_none(), "Partial prefix-sum is not supported");
Self {
interaction,
total_sum,
cumsum_shift: total_sum / BaseField::from_u32_unchecked(1 << log_size),
claimed_sum,
fracs: vec![],
is_finalized: true,
is_first: E::F::zero(),
log_size,
}
}
Expand All @@ -84,11 +82,10 @@ impl<E: EvalAtRow> LogupAtRow<E> {
pub fn dummy() -> Self {
Self {
interaction: 100,
total_sum: SecureField::one(),
cumsum_shift: SecureField::one(),
claimed_sum: None,
fracs: vec![],
is_finalized: true,
is_first: E::F::zero(),
log_size: 10,
}
}
Expand Down Expand Up @@ -183,14 +180,46 @@ impl LogupTraceGenerator {
}

/// Finalize the trace. Returns the trace and the total sum of the last column.
/// The last column is shifted by the cumsum_shift.
pub fn finalize_last(
self,
mut self,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
) {
let log_size = self.log_size;
let (trace, [total_sum]) = self.finalize_at([(1 << log_size) - 1]);
let mut last_col_coords = self.trace.pop().unwrap().columns;

// Compute cumsum_shift.
let coordinate_sums = last_col_coords.each_ref().map(|c| {
c.data
.iter()
.copied()
.sum::<PackedBaseField>()
.pointwise_sum()
});
let total_sum = SecureField::from_m31_array(coordinate_sums);
let cumsum_shift = total_sum / BaseField::from_u32_unchecked(1 << self.log_size);
let packed_cumsum_shift = PackedSecureField::broadcast(cumsum_shift);

last_col_coords.iter_mut().enumerate().for_each(|(i, c)| {
c.data
.iter_mut()
.for_each(|x| *x -= packed_cumsum_shift.into_packed_m31s()[i])
});
let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum);
let secure_prefix_sum = SecureColumnByCoords {
columns: coord_prefix_sum,
};
self.trace.push(secure_prefix_sum);
let trace = self
.trace
.into_iter()
.flat_map(|eval| {
eval.columns.map(|col| {
CircleEvaluation::new(CanonicCoset::new(self.log_size).circle_domain(), col)
})
})
.collect_vec();
(trace, total_sum)
}

Expand Down
48 changes: 14 additions & 34 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,6 @@ macro_rules! logup_proxy {
() => {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::IsFirst::new(
self.logup.log_size,
)
.id(),
);
self.logup.is_finalized = false;
}
self.logup.fracs.push(fraction.clone());
Expand All @@ -188,6 +182,11 @@ macro_rules! logup_proxy {
/// `batching` should contain the batch into which every logup entry should be inserted.
fn finalize_logup_batched(&mut self, batching: &crate::constraint_framework::Batching) {
assert!(!self.logup.is_finalized, "LogupAtRow was already finalized");

assert!(
self.logup.claimed_sum.is_none(),
"Partial prefix-sum is not supported"
);
assert_eq!(
batching.len(),
self.logup.fracs.len(),
Expand Down Expand Up @@ -227,35 +226,16 @@ macro_rules! logup_proxy {
}

let frac: Fraction<_, _> = fracs_by_batch[&last_batch].clone().into_iter().sum();
let [prev_row_cumsum, cur_cumsum] =
self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]);

let diff = cur_cumsum - prev_row_cumsum - prev_col_cumsum.clone();
// Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift.
// This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint
// uniform - apply on all rows.
let fixed_diff = diff + self.logup.cumsum_shift.clone();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted
// offset from the is_first column when constant columns are supported.
let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum.clone() {
Some((claimed_sum, claimed_row_index)) => {
let [prev_row_cumsum, cur_cumsum, claimed_cumsum] = self
.next_extension_interaction_mask(
self.logup.interaction,
[-1, 0, claimed_row_index as isize],
);

// Constrain that the claimed_sum in case that it is not equal to the total_sum.
self.add_constraint(
(claimed_cumsum - claimed_sum) * self.logup.is_first.clone(),
);
(cur_cumsum, prev_row_cumsum)
}
None => {
let [prev_row_cumsum, cur_cumsum] =
self.next_extension_interaction_mask(self.logup.interaction, [-1, 0]);
(cur_cumsum, prev_row_cumsum)
}
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum =
prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone();
let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone();

self.add_constraint(diff * frac.denominator - frac.numerator);
self.add_constraint(fixed_diff * frac.denominator - frac.numerator);

self.logup.is_finalized = true;
}
Expand Down
4 changes: 1 addition & 3 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use num_traits::One;
use tracing::{span, Level};

use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId};
use crate::constraint_framework::preprocessed_columns::PreProcessedColumnId;
use crate::constraint_framework::{
assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
TraceLocationAllocator,
Expand Down Expand Up @@ -192,7 +192,6 @@ pub fn prove_fibonacci_plonk(
// Preprocessed trace.
let span = span!(Level::INFO, "Constant").entered();
let mut tree_builder = commitment_scheme.tree_builder();
let is_first = IsFirst::new(log_n_rows).gen_column_simd();
let mut constant_trace = [
circuit.a_wire.clone(),
circuit.b_wire.clone(),
Expand All @@ -207,7 +206,6 @@ pub fn prove_fibonacci_plonk(
)
})
.collect_vec();
constant_trace.push(is_first);
let constants_trace_location = tree_builder.extend_evals(constant_trace);
tree_builder.commit(channel);
span.exit();
Expand Down
10 changes: 2 additions & 8 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use num_traits::One;
use tracing::{info, span, Level};

use crate::constraint_framework::logup::LogupTraceGenerator;
use crate::constraint_framework::preprocessed_columns::IsFirst;
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkComponent, FrameworkEval, Relation, RelationEntry,
TraceLocationAllocator,
Expand Down Expand Up @@ -349,7 +348,7 @@ pub fn prove_poseidon(
// Preprocessed trace.
let span = span!(Level::INFO, "Constant").entered();
let mut tree_builder = commitment_scheme.tree_builder();
let constant_trace = vec![IsFirst::new(log_n_rows).gen_column_simd()];
let constant_trace = vec![];
tree_builder.extend_evals(constant_trace);
tree_builder.commit(channel);
span.exit();
Expand Down Expand Up @@ -397,7 +396,6 @@ mod tests {
use num_traits::One;

use crate::constraint_framework::assert_constraints;
use crate::constraint_framework::preprocessed_columns::IsFirst;
use crate::core::air::Component;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -472,11 +470,7 @@ mod tests {
let (trace1, total_sum) =
gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements);

let traces = TreeVec::new(vec![
vec![IsFirst::new(LOG_N_ROWS).gen_column_simd()],
trace0,
trace1,
]);
let traces = TreeVec::new(vec![vec![], trace0, trace1]);
let trace_polys =
traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec());
assert_constraints(
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl StateMachineStatement0 {
.map_cols(|_| self.m),
];
let mut log_sizes = TreeVec::concat_cols(sizes.into_iter());
log_sizes[PREPROCESSED_TRACE_IDX] = vec![self.n, self.m];
log_sizes[PREPROCESSED_TRACE_IDX] = vec![];
log_sizes
}
pub fn mix_into(&self, channel: &mut impl Channel) {
Expand Down
Loading
Loading