Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for periodic columns in LogUp-GKR #307

Merged
merged 27 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8ac25c7
feat: math utilities needed for sum-check protocol
Al-Kindi-0 Aug 6, 2024
5e06378
feat: add sum-check prover and verifier
Al-Kindi-0 Aug 6, 2024
16389d6
tests: add sanity tests for utils
Al-Kindi-0 Aug 6, 2024
380aa1a
doc: document sumcheck_round
Al-Kindi-0 Aug 6, 2024
7a1a99e
feat: use SmallVec
Al-Kindi-0 Aug 7, 2024
1901066
docs: improve documentation of sum-check
Al-Kindi-0 Aug 8, 2024
8a57216
feat: add remaining functions for sum-check verifier
Al-Kindi-0 Aug 9, 2024
ff9e6fa
chore: move prover into sub-mod
Al-Kindi-0 Aug 9, 2024
7e24f8f
chore: remove utils mod
Al-Kindi-0 Aug 9, 2024
23044e8
chore: remove utils mod
Al-Kindi-0 Aug 9, 2024
ad0497d
chore: move logup evaluator trait to separate file
Al-Kindi-0 Aug 9, 2024
a0272ea
feat: add GKR backend for LogUp-GKR
Al-Kindi-0 Aug 9, 2024
7b8caff
chore: remove old way of handling Lagrange kernel
Al-Kindi-0 Aug 12, 2024
e2b8c12
wip: add s-column constraints
Al-Kindi-0 Aug 12, 2024
b813916
chore: correct header
Al-Kindi-0 Aug 12, 2024
492f247
wip
Al-Kindi-0 Aug 12, 2024
0d664e0
wip: add support for periodic columns in gkr backend
Al-Kindi-0 Aug 14, 2024
8617308
Merge branch 'logup-gkr' into al-gkr-periodic
Al-Kindi-0 Sep 3, 2024
ed781d8
chore: fix post merge issues
Al-Kindi-0 Sep 3, 2024
98c0e71
chore: fix issues
Al-Kindi-0 Sep 3, 2024
807aba1
doc: add comment about periodic values table
Al-Kindi-0 Sep 3, 2024
4e6d3ab
chore: address feedback
Al-Kindi-0 Sep 4, 2024
c93ec35
chore: fix concurrent portion
Al-Kindi-0 Sep 4, 2024
873345a
chore: address feedback
Al-Kindi-0 Sep 5, 2024
d94794a
chore: address feedback
Al-Kindi-0 Sep 9, 2024
b484e04
chore: remove unnecessary mut
Al-Kindi-0 Sep 10, 2024
dec6589
chore: remove unnecessary mut
Al-Kindi-0 Sep 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions air/src/air/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ pub trait LogUpGkrEvaluator: Clone + Sync {
) -> SColumnConstraint<E> {
SColumnConstraint::new(gkr_data, composition_coefficient)
}

/// Returns the periodic values used in the LogUp-GKR statement, either as base field element
/// during circuit evaluation or as extension field element during the run of sum-check for
/// the input layer.
fn build_periodic_values<F, E>(&self) -> PeriodicTable<F>
where
F: FieldElement<BaseField = Self::BaseField>,
E: FieldElement<BaseField = Self::BaseField> + ExtensionOf<F>,
{
let mut table = Vec::new();

let oracles = self.get_oracles();

for oracle in oracles {
if let LogUpGkrOracle::PeriodicValue(values) = oracle {
let values = embed_in_extension(values.to_vec());
table.push(values)
}
}
PeriodicTable { table }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I would rewrite this as
let table: Vec<Vec<F>> = self
    .get_oracles()
    .iter()
    .filter_map(|oracle| {
        if let LogUpGkrOracle::PeriodicValue(values) = oracle {
            Some(values.into_iter().copied().map(F::from).collect())
        } else {
            None
        }
    })
    .collect();

PeriodicTable{ table }
  1. E is not used, and it is redundant, right? It seems like we could always make F::BaseField the basefield, and F the extension field (or F::BaseField and F both be the base field).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, switched to it
Indeed, the generic is redundant now, it is a leftover from a previous iteration. Removed now

}
}

#[derive(Clone, Default)]
Expand Down Expand Up @@ -229,3 +250,72 @@ pub enum LogUpGkrOracle<B: StarkField> {
/// must be a power of 2.
PeriodicValue(Vec<B>),
}

// PERIODIC COLUMNS FOR LOGUP
// =================================================================================================

/// Stores the periodic columns used in a LogUp-GKR statement.
///
/// Each stored periodic column is interpreted as a multi-linear extension polynomial of the column
/// with the given periodic values. Due to the periodic nature of the values, storing, binding of
/// an argument and evaluating the said multi-linear extension can be all done linearly in the size
/// of the smallest cycle defining the periodic values. Hence we only store the values of this
/// smallest cycle. The cycle is assumed throughout to be a power of 2.
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)]
pub struct PeriodicTable<E: FieldElement> {
pub table: Vec<Vec<E>>,
}

impl<E> PeriodicTable<E>
where
E: FieldElement,
{
pub fn new(table: Vec<Vec<E::BaseField>>) -> Self {
let mut result = vec![];
for col in table.iter() {
let res = embed_in_extension(col.to_vec());
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
result.push(res)
}

Self { table: result }
}

pub fn num_columns(&self) -> usize {
self.table.len()
}

pub fn table(&self) -> &[Vec<E>] {
&self.table
}

pub fn get_periodic_values_at(&self, row: usize, values: &mut [E]) {
self.table
.iter()
.zip(values.iter_mut())
.for_each(|(col, value)| *value = col[row % col.len()])
}
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

pub fn bind_least_significant_variable(&mut self, round_challenge: E) {
for col in self.table.iter_mut() {
if col.len() > 1 {
let num_evals = col.len() >> 1;
for i in 0..num_evals {
col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]);
}
col.truncate(num_evals)
}
}
}
}

// HELPER
// =================================================================================================

fn embed_in_extension<E: FieldElement>(values: Vec<E::BaseField>) -> Vec<E> {
let mut res = vec![];
for v in values {
res.push(E::from(v))
}

res
}
2 changes: 1 addition & 1 deletion air/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use logup_gkr::PhantomLogUpGkrEval;
pub use logup_gkr::{
LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame,
LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator,
LogUpGkrOracle,
LogUpGkrOracle, PeriodicTable,
};

mod coefficients;
Expand Down
4 changes: 2 additions & 2 deletions air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ pub use air::{
DeepCompositionCoefficients, EvaluationFrame, GkrData,
LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint,
LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements,
LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, TraceInfo,
TransitionConstraintDegree, TransitionConstraints,
LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable,
TraceInfo, TransitionConstraintDegree, TransitionConstraints,
};
5 changes: 5 additions & 0 deletions prover/src/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,21 @@ impl<E: FieldElement> EvaluatedCircuit<E> {
log_up_randomness: &[E],
) -> CircuitLayer<E> {
let num_fractions = evaluator.get_num_fractions();
let periodic_values = evaluator.build_periodic_values::<E::BaseField, E>();

let mut input_layer_wires =
Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions);
let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols());

let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()];
let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()];
let mut numerators = vec![E::ZERO; num_fractions];
let mut denominators = vec![E::ZERO; num_fractions];
for i in 0..main_trace.main_segment().num_rows() {
let wires_from_trace_row = {
main_trace.read_main_frame(i, &mut main_frame);
periodic_values.get_periodic_values_at(i, &mut periodic_values_row);
evaluator.build_query(&main_frame, &periodic_values_row, &mut query);

evaluator.build_query(&main_frame, &[], &mut query);
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
24 changes: 17 additions & 7 deletions prover/src/logup_gkr/prover.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::vec::Vec;

use air::{LogUpGkrEvaluator, LogUpGkrOracle};
use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable};
use crypto::{ElementHasher, RandomCoin};
use math::FieldElement;
use sumcheck::{
Expand Down Expand Up @@ -75,11 +75,18 @@ pub fn prove_gkr<E: FieldElement>(
let (before_final_layer_proofs, gkr_claim) = prove_intermediate_layers(circuit, public_coin)?;

// build the MLEs of the relevant main trace columns
let main_trace_mls =
let (main_trace_mls, mut periodic_table) =
build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?;

let final_layer_proof =
prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?;
// run the GKR prover for the input layer
let final_layer_proof = prove_input_layer(
evaluator,
logup_randomness,
main_trace_mls,
&mut periodic_table,
gkr_claim,
public_coin,
)?;

Ok(GkrCircuitProof {
circuit_outputs: CircuitOutput { numerators, denominators },
Expand All @@ -97,6 +104,7 @@ fn prove_input_layer<
evaluator: &impl LogUpGkrEvaluator<BaseField = E::BaseField>,
log_up_randomness: Vec<E>,
multi_linear_ext_polys: Vec<MultiLinearPoly<E>>,
periodic_table: &mut PeriodicTable<E>,
claim: GkrClaim<E>,
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
transcript: &mut C,
) -> Result<FinalLayerProof<E>, GkrProverError> {
Expand All @@ -114,6 +122,7 @@ fn prove_input_layer<
r_batch,
log_up_randomness,
multi_linear_ext_polys,
periodic_table,
transcript,
)?;

Expand All @@ -125,8 +134,9 @@ fn prove_input_layer<
fn build_mls_from_main_trace_segment<E: FieldElement>(
oracles: &[LogUpGkrOracle<E::BaseField>],
main_trace: &ColMatrix<<E as FieldElement>::BaseField>,
) -> Result<Vec<MultiLinearPoly<E>>, GkrProverError> {
) -> Result<(Vec<MultiLinearPoly<E>>, PeriodicTable<E>), GkrProverError> {
let mut mls = vec![];
let mut periodic_values = vec![];

for oracle in oracles {
match oracle {
Expand All @@ -146,10 +156,10 @@ fn build_mls_from_main_trace_segment<E: FieldElement>(
let ml = MultiLinearPoly::from_evaluations(values);
mls.push(ml)
},
LogUpGkrOracle::PeriodicValue(_) => unimplemented!(),
LogUpGkrOracle::PeriodicValue(values) => periodic_values.push(values.to_vec()),
};
}
Ok(mls)
Ok((mls, PeriodicTable::new(periodic_values)))
}

/// Proves all GKR layers except for input layer.
Expand Down
1 change: 0 additions & 1 deletion sumcheck/Cargo.toml
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ utils = { version = "0.9", path = "../utils/core", package = "winter-utils", def
rayon = { version = "1.8", optional = true }
smallvec = { version = "1.13", default-features = false }
thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false }

[dev-dependencies]
criterion = "0.5"
rand-utils = { version = "0.9", path = "../utils/rand", package = "winter-rand-utils" }
16 changes: 13 additions & 3 deletions sumcheck/benches/sum_check_high_degree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use std::{marker::PhantomData, time::Duration};

use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle};
use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin};
use math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField};
Expand Down Expand Up @@ -37,13 +37,14 @@ fn sum_check_high_degree(c: &mut Criterion) {
)
},
|(
(claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4)),
(claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4), periodic_table),
evaluator,
logup_randomness,
transcript,
)| {
let mls = vec![ml0, ml1, ml2, ml3, ml4];
let mut transcript = transcript;
let mut periodic_table = periodic_table;

sum_check_prove_higher_degree(
&evaluator,
Expand All @@ -52,6 +53,7 @@ fn sum_check_high_degree(c: &mut Criterion) {
r_batch,
logup_randomness,
mls,
&mut periodic_table,
&mut transcript,
)
},
Expand All @@ -76,21 +78,29 @@ fn setup_sum_check<E: FieldElement>(
MultiLinearPoly<E>,
MultiLinearPoly<E>,
),
PeriodicTable<E>,
) {
let n = 1 << log_size;
let table = MultiLinearPoly::from_evaluations(rand_vector(n));
let multiplicity = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_0 = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_1 = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_2 = MultiLinearPoly::from_evaluations(rand_vector(n));
let periodic_table = PeriodicTable::default();

// this will not generate the correct claim with overwhelming probability but should be fine
// for benchmarking
let rand_pt: Vec<E> = rand_vector(log_size + 2);
let r_batch: E = rand_value();
let claim: E = rand_value();

(claim, r_batch, rand_pt, (table, multiplicity, values_0, values_1, values_2))
(
claim,
r_batch,
rand_pt,
(table, multiplicity, values_0, values_1, values_2),
periodic_table,
)
}

#[derive(Clone, Default)]
Expand Down
Loading