Skip to content

Commit

Permalink
Merge pull request #1 from lita-xyz/dorebell-public
Browse files Browse the repository at this point in the history
Add public instance variables
  • Loading branch information
tess-eract authored Jul 31, 2024
2 parents bdd338d + dbb62c6 commit 623741c
Show file tree
Hide file tree
Showing 20 changed files with 595 additions and 72 deletions.
17 changes: 16 additions & 1 deletion air/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,19 @@ use p3_matrix::MatrixRowSlices;

/// An AIR (algebraic intermediate representation).
pub trait BaseAir<F>: Sync {
/// The number of columns (a.k.a. registers) in this AIR.
/// The number of private columns (a.k.a. registers) in this AIR.
fn width(&self) -> usize;

/// The number of preprocessed columns in this AIR.
fn preprocessed_width(&self) -> usize {
0
}

/// The number of public columns in this AIR.
fn public_width(&self) -> usize {
0
}

fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
None
}
Expand Down Expand Up @@ -108,6 +118,10 @@ pub trait AirBuilder: Sized {
}
}

pub trait AirBuilderWithPublicValues: AirBuilder {
fn public_values(&self) -> Self::M;
}

pub trait PairBuilder: AirBuilder {
fn preprocessed(&self) -> Self::M;
}
Expand Down Expand Up @@ -213,6 +227,7 @@ mod tests {

use crate::{Air, AirBuilder, BaseAir};

#[allow(dead_code)]
struct FibonacciAir;

impl<F> BaseAir<F> for FibonacciAir {
Expand Down
37 changes: 34 additions & 3 deletions air/src/virtual_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@ use core::ops::Mul;
use p3_field::{AbstractField, Field};

/// An affine function over columns in a PAIR.
#[derive(Debug, Clone)]
pub struct VirtualPairCol<F: Field> {
column_weights: Vec<(PairCol, F)>,
constant: F,
}

/// A column in a PAIR, i.e. either a preprocessed column or a main trace column.
#[derive(Debug, Clone)]
pub enum PairCol {
Preprocessed(usize),
Public(usize),
Main(usize),
}

impl PairCol {
fn get<T: Copy>(&self, preprocessed: &[T], main: &[T]) -> T {
fn get<T: Copy>(&self, preprocessed: &[T], public: &[T], main: &[T]) -> T {
match self {
PairCol::Preprocessed(i) => preprocessed[*i],
PairCol::Main(i) => main[*i],
PairCol::Public(i) => public[*i],
}
}
}
Expand Down Expand Up @@ -53,6 +57,16 @@ impl<F: Field> VirtualPairCol<F> {
)
}

pub fn new_public(column_weights: Vec<(usize, F)>, constant: F) -> Self {
Self::new(
column_weights
.into_iter()
.map(|(i, w)| (PairCol::Public(i), w))
.collect(),
constant,
)
}

#[must_use]
pub fn one() -> Self {
Self::constant(F::one())
Expand Down Expand Up @@ -84,6 +98,11 @@ impl<F: Field> VirtualPairCol<F> {
Self::single(PairCol::Main(column))
}

#[must_use]
pub fn single_public(column: usize) -> Self {
Self::single(PairCol::Public(column))
}

#[must_use]
pub fn sum_main(columns: Vec<usize>) -> Self {
let column_weights = columns.into_iter().map(|col| (col, F::one())).collect();
Expand All @@ -96,6 +115,12 @@ impl<F: Field> VirtualPairCol<F> {
Self::new_preprocessed(column_weights, F::zero())
}

#[must_use]
pub fn sum_public(columns: Vec<usize>) -> Self {
let column_weights = columns.into_iter().map(|col| (col, F::one())).collect();
Self::new_public(column_weights, F::zero())
}

/// `a - b`, where `a` and `b` are columns in the preprocessed trace.
#[must_use]
pub fn diff_preprocessed(a_col: usize, b_col: usize) -> Self {
Expand All @@ -108,15 +133,21 @@ impl<F: Field> VirtualPairCol<F> {
Self::new_main(vec![(a_col, F::one()), (b_col, F::neg_one())], F::zero())
}

pub fn apply<Expr, Var>(&self, preprocessed: &[Var], main: &[Var]) -> Expr
/// `a - b`, where `a` and `b` are columns in the main trace.
#[must_use]
pub fn diff_public(a_col: usize, b_col: usize) -> Self {
Self::new_public(vec![(a_col, F::one()), (b_col, F::neg_one())], F::zero())
}

pub fn apply<Expr, Var>(&self, preprocessed: &[Var], public: &[Var], main: &[Var]) -> Expr
where
F: Into<Expr>,
Expr: AbstractField + Mul<F, Output = Expr>,
Var: Into<Expr> + Copy,
{
let mut result = self.constant.into();
for (column, weight) in &self.column_weights {
result += column.get(preprocessed, main).into() * *weight;
result += column.get(preprocessed, public, main).into() * *weight;
}
result
}
Expand Down
1 change: 1 addition & 0 deletions commit/src/adapters/multi_from_uni_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use p3_matrix::MatrixRows;

use crate::pcs::UnivariatePcs;

#[allow(dead_code)]
pub struct MultiFromUniPcs<Val, EF, In, U, Challenger>
where
Val: Field,
Expand Down
1 change: 1 addition & 0 deletions commit/src/adapters/uni_from_multi_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use p3_matrix::MatrixRows;

use crate::pcs::MultivariatePcs;

#[allow(dead_code)]
pub struct UniFromMultiPcs<Val, EF, In, M, Challenger>
where
Val: Field,
Expand Down
13 changes: 13 additions & 0 deletions commit/src/pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use core::fmt::Debug;

use p3_challenger::FieldChallenger;
use p3_field::{ExtensionField, Field};
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::{Dimensions, MatrixGet, MatrixRows};
use serde::de::DeserializeOwned;
use serde::Serialize;
Expand Down Expand Up @@ -85,6 +86,18 @@ where
where
'a: 'b;

// Compute the (shifted) low-degree extensions only without computing the commitment.
fn compute_coset_ldes_batches(
&self,
polynomials: Vec<In>,
coset_shifts: Vec<Val>,
) -> Vec<RowMajorMatrix<Val>>;

fn compute_lde_batch(&self, polynomials: In) -> RowMajorMatrix<Val> {
self.compute_coset_ldes_batches(vec![polynomials], vec![Val::one()])
.pop()
.expect("length of output of compute_coset_ldes_batches should be the same as length of the input")
}
// Commit to polys that are already defined over a coset.
fn commit_shifted_batches(
&self,
Expand Down
64 changes: 44 additions & 20 deletions fri/src/two_adic_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use p3_field::{
};
use p3_interpolation::interpolate_coset;
use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView};
use p3_matrix::dense::RowMajorMatrixView;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_matrix::{Dimensions, Matrix, MatrixRows};
use p3_maybe_rayon::prelude::*;
use p3_util::linear_map::LinearMap;
Expand Down Expand Up @@ -162,26 +162,36 @@ where
.collect()
}

fn commit_shifted_batches(
fn compute_coset_ldes_batches(
&self,
polynomials: Vec<In>,
coset_shifts: &[C::Val],
) -> (Self::Commitment, Self::ProverData) {
let ldes = info_span!("compute all coset LDEs").in_scope(|| {
coset_shifts: Vec<C::Val>,
) -> Vec<RowMajorMatrix<C::Val>> {
info_span!("compute all coset LDEs").in_scope(|| {
polynomials
.par_iter()
.zip_eq(coset_shifts)
.map(|(poly, coset_shift)| {
let shift = C::Val::generator() / *coset_shift;
let shift = C::Val::generator() / coset_shift;
let input = ((*poly).clone()).to_row_major_matrix();
// Commit to the bit-reversed LDE.
self.dft
.coset_lde_batch(input, self.fri.log_blowup, shift)
.bit_reverse_rows()
.to_row_major_matrix()
})
.collect()
});
})
}

fn commit_shifted_batches(
&self,
polynomials: Vec<In>,
coset_shifts: &[C::Val],
) -> (Self::Commitment, Self::ProverData) {
let ldes = self
.compute_coset_ldes_batches(polynomials, coset_shifts.to_vec())
.into_iter()
.map(|x| x.bit_reverse_rows().to_row_major_matrix())
.collect();

self.mmcs.commit(ldes)
}
Expand Down Expand Up @@ -246,15 +256,16 @@ where
.iter()
.map(|(data, points)| (self.mmcs.get_matrices(data), *points))
.collect_vec();

let max_width = mats_and_points
let mats = mats_and_points
.iter()
.flat_map(|(mats, _)| mats)
.map(|mat| mat.width())
.max()
.unwrap();
.collect_vec();

let alpha_reducer = PowersReducer::<C::Val, C::Challenge>::new(alpha, max_width);
let global_max_width = mats.iter().map(|m| m.width()).max().unwrap();
let global_max_height = mats.iter().map(|m| m.height()).max().unwrap();
let log_global_max_height = log2_strict_usize(global_max_height);

let alpha_reducer = PowersReducer::<C::Val, C::Challenge>::new(alpha, global_max_width);

// For each unique opening point z, we will find the largest degree bound
// for that point, and precompute 1/(X - z) for the largest subgroup (in bitrev order).
Expand Down Expand Up @@ -348,7 +359,11 @@ where
prover_data_and_points
.iter()
.map(|(data, _)| {
let (opened_values, opening_proof) = self.mmcs.open_batch(index, data);
let log_max_height = log2_strict_usize(self.mmcs.get_max_height(data));
let bits_reduced = log_global_max_height - log_max_height;
let reduced_index = index >> bits_reduced;
let (opened_values, opening_proof) =
self.mmcs.open_batch(reduced_index, data);
BatchOpening {
opened_values,
opening_proof,
Expand Down Expand Up @@ -382,8 +397,8 @@ where
verifier::verify_shape_and_sample_challenges(&self.fri, &proof.fri_proof, challenger)
.map_err(VerificationError::FriError)?;

let log_max_height = proof.fri_proof.commit_phase_commits.len() + self.fri.log_blowup;

let log_global_max_height =
proof.fri_proof.commit_phase_commits.len() + self.fri.log_blowup;
let reduced_openings: Vec<[C::Challenge; 32]> = proof
.query_openings
.iter()
Expand All @@ -394,10 +409,19 @@ where
for (batch_opening, batch_dims, (batch_commit, batch_points), batch_at_z) in
izip!(query_opening, dims, commits_and_points, &values)
{
let batch_max_height = batch_dims
.iter()
.map(|dims| dims.height << self.fri.log_blowup)
.max()
.expect("Empty batch?");
let log_batch_max_height = log2_strict_usize(batch_max_height);
let bits_reduced = log_global_max_height - log_batch_max_height;
let reduced_index = index >> bits_reduced;

self.mmcs.verify_batch(
batch_commit,
batch_dims,
index,
reduced_index,
&batch_opening.opened_values,
&batch_opening.opening_proof,
)?;
Expand All @@ -409,7 +433,7 @@ where
) {
let log_height = log2_strict_usize(mat_dims.height) + self.fri.log_blowup;

let bits_reduced = log_max_height - log_height;
let bits_reduced = log_global_max_height - log_height;
let rev_reduced_index = reverse_bits_len(index >> bits_reduced, log_height);

let x = C::Val::generator()
Expand Down
1 change: 0 additions & 1 deletion fri/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ where
ro,
log_max_height,
)?;

if folded_eval != proof.final_poly {
return Err(FriError::FinalPolyMismatch);
}
Expand Down
19 changes: 16 additions & 3 deletions keccak-air/examples/prove_baby_bear_keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use p3_field::extension::BinomialExtensionField;
use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig};
use p3_keccak::Keccak256Hash;
use p3_keccak_air::{generate_trace_rows, KeccakAir};
use p3_matrix::dense::RowMajorMatrix;
use p3_merkle_tree::FieldMerkleTreeMmcs;
use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2};
use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32};
use p3_uni_stark::{prove, verify, StarkConfig, VerificationError};
use p3_uni_stark::{prove, verify, PublicRow, StarkConfig, VerificationError};
use rand::{random, thread_rng};
use tracing_forest::util::LevelFilter;
use tracing_forest::ForestLayer;
Expand Down Expand Up @@ -69,8 +70,20 @@ fn main() -> Result<(), VerificationError> {

let inputs = (0..NUM_HASHES).map(|_| random()).collect::<Vec<_>>();
let trace = generate_trace_rows::<Val>(inputs);
let proof = prove::<MyConfig, _>(&config, &KeccakAir {}, &mut challenger, trace);
let proof = prove::<MyConfig, _, PublicRow<Val>>(
&config,
&KeccakAir {},
&mut challenger,
trace,
&PublicRow::default(),
);

let mut challenger = Challenger::new(perm);
verify(&config, &KeccakAir {}, &mut challenger, &proof)
verify(
&config,
&KeccakAir {},
&mut challenger,
&proof,
&RowMajorMatrix::new(vec![], 0),
)
}
19 changes: 16 additions & 3 deletions keccak-air/examples/prove_baby_bear_poseidon2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use p3_field::extension::BinomialExtensionField;
use p3_field::Field;
use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig};
use p3_keccak_air::{generate_trace_rows, KeccakAir};
use p3_matrix::dense::RowMajorMatrix;
use p3_merkle_tree::FieldMerkleTreeMmcs;
use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2};
use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
use p3_uni_stark::{prove, verify, StarkConfig, VerificationError};
use p3_uni_stark::{prove, verify, PublicRow, StarkConfig, VerificationError};
use rand::{random, thread_rng};
use tracing_forest::util::LevelFilter;
use tracing_forest::ForestLayer;
Expand Down Expand Up @@ -69,8 +70,20 @@ fn main() -> Result<(), VerificationError> {

let inputs = (0..NUM_HASHES).map(|_| random()).collect::<Vec<_>>();
let trace = generate_trace_rows::<Val>(inputs);
let proof = prove::<MyConfig, _>(&config, &KeccakAir {}, &mut challenger, trace);
let proof = prove::<MyConfig, _, PublicRow<Val>>(
&config,
&KeccakAir {},
&mut challenger,
trace,
&PublicRow::default(),
);

let mut challenger = Challenger::new(perm);
verify(&config, &KeccakAir {}, &mut challenger, &proof)
verify(
&config,
&KeccakAir {},
&mut challenger,
&proof,
&RowMajorMatrix::new(vec![], 0),
)
}
Loading

0 comments on commit 623741c

Please sign in to comment.