Skip to content

Commit

Permalink
Merge pull request #6 from atgrosso/10-01-add_line_folding
Browse files Browse the repository at this point in the history
Add line folding
  • Loading branch information
atgrosso authored Oct 14, 2024
2 parents f43f285 + 33a865f commit b547253
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 11 deletions.
23 changes: 22 additions & 1 deletion stwo_cairo_verifier/src/circle.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use stwo_cairo_verifier::fields::m31::{M31, m31};
use stwo_cairo_verifier::fields::m31::{M31, m31, M31One};
use super::utils::pow;

pub const M31_CIRCLE_GEN: CirclePointM31 =
Expand Down Expand Up @@ -28,6 +28,27 @@ pub impl CirclePointM31Impl of CirclePointM31Trait {
CirclePointM31 { x: m31(1), y: m31(0) }
}

/// Applies the circle's x-coordinate doubling map.
fn double_x(x: M31) -> M31 {
let sqx = x * x;
sqx + sqx - M31One::one()
}

/// Returns the log order of a point.
///
/// All points have an order of the form `2^k`.
fn log_order(self: @CirclePointM31) -> u32 {
// we only need the x-coordinate to check order since the only point
// with x=1 is the circle's identity
let mut res = 0;
let mut cur = self.x.clone();
while cur != M31One::one() {
cur = Self::double_x(cur);
res += 1;
};
res
}

fn mul(self: @CirclePointM31, mut scalar: u32) -> CirclePointM31 {
let mut result = Self::zero();
let mut cur = *self;
Expand Down
7 changes: 7 additions & 0 deletions stwo_cairo_verifier/src/fields/qm31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ pub impl QM31Impl of QM31Trait {
let denom_inverse = denom.inverse();
QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}
fn mul_m31(self: QM31, multiplier: M31) -> QM31 {
QM31 {
a: CM31 { a: self.a.a * multiplier, b: self.a.b * multiplier },
b: CM31 { a: self.b.a * multiplier, b: self.b.b * multiplier }
}
}
}

pub impl QM31Add of core::traits::Add<QM31> {
Expand Down Expand Up @@ -113,5 +119,6 @@ mod tests {
assert_eq!(qm1 - m.into(), qm1 - qm);
assert_eq!(qm0_x_qm1 * qm1.inverse(), qm31(1, 2, 3, 4));
assert_eq!(qm1 * m.inverse().into(), qm1 * qm.inverse());
assert_eq!(qm1.mul_m31(m), qm1 * m.into());
}
}
82 changes: 82 additions & 0 deletions stwo_cairo_verifier/src/fri.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use stwo_cairo_verifier::fields::m31::M31Trait;
use stwo_cairo_verifier::circle::{Coset, CosetImpl};
use stwo_cairo_verifier::poly::line::{LineDomain, LineDomainImpl};
use stwo_cairo_verifier::fields::qm31::{QM31, qm31, QM31Trait};
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::utils::{bit_reverse_index, pow};
use stwo_cairo_verifier::poly::line::{LineEvaluation, LineEvaluationImpl};
pub const CIRCLE_TO_LINE_FOLD_STEP: u32 = 1;
pub const FOLD_STEP: u32 = 1;

/// Folds a degree `d` polynomial into a degree `d/2` polynomial.
///
/// Let `eval` be a polynomial evaluated on line domain `E`, `alpha` be a random field
/// element and `pi(x) = 2x^2 - 1` be the circle's x-coordinate doubling map. This function
/// returns `f' = f0 + alpha * f1` evaluated on `pi(E)` such that `2f(x) = f0(pi(x)) + x *
/// f1(pi(x))`.
pub fn fold_line(eval: @LineEvaluation, alpha: QM31) -> LineEvaluation {
let domain = eval.domain;
let mut values = array![];
let mut i = 0;
while i < eval.values.len() / 2 {
let x = domain.at(bit_reverse_index(i * pow(2, FOLD_STEP), domain.log_size()));
let f_x = eval.values[2 * i];
let f_neg_x = eval.values[2 * i + 1];
let (f0, f1) = ibutterfly(*f_x, *f_neg_x, x.inverse());
values.append(f0 + alpha * f1);
i += 1;
};
LineEvaluationImpl::new(domain.double(), values)
}

pub fn ibutterfly(v0: QM31, v1: QM31, itwid: M31) -> (QM31, QM31) {
(v0 + v1, (v0 - v1).mul_m31(itwid))
}

#[cfg(test)]
mod test {
use stwo_cairo_verifier::poly::line::{
LineEvaluation, SparseLineEvaluation, SparseLineEvaluationImpl
};
use stwo_cairo_verifier::fields::m31::M31Trait;
use stwo_cairo_verifier::circle::{Coset, CosetImpl};
use stwo_cairo_verifier::poly::line::{LineDomain, LineDomainImpl};
use stwo_cairo_verifier::fields::qm31::{QM31, qm31};
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::utils::{bit_reverse_index, pow};
use stwo_cairo_verifier::fri::{FOLD_STEP, CIRCLE_TO_LINE_FOLD_STEP};

#[test]
fn test_fold_line_1() {
let domain = LineDomainImpl::new(CosetImpl::new(67108864, 1));
let values = array![
qm31(440443058, 1252794525, 1129773609, 1309365757),
qm31(974589897, 1592795796, 649052897, 863407657)
];
let sparse_line_evaluation = SparseLineEvaluation {
subline_evals: array![LineEvaluation { values, domain }]
};
let alpha = qm31(1047716961, 506143067, 1065078409, 990356067);

let result = sparse_line_evaluation.fold(alpha);

assert_eq!(result, array![qm31(515899232, 1030391528, 1006544539, 11142505)]);
}

#[test]
fn test_fold_line_2() {
let domain = LineDomainImpl::new(CosetImpl::new(553648128, 1));
let values = array![
qm31(730692421, 1363821003, 2146256633, 106012305),
qm31(1387266930, 149259209, 1148988082, 1930518101)
];
let sparse_line_evaluation = SparseLineEvaluation {
subline_evals: array![LineEvaluation { values, domain }]
};
let alpha = qm31(2084521793, 1326105747, 548635876, 1532708504);

let result = sparse_line_evaluation.fold(alpha);

assert_eq!(result, array![qm31(1379727866, 1083096056, 1409020369, 1977903500)]);
}
}
1 change: 1 addition & 0 deletions stwo_cairo_verifier/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod fields;
mod poly;
mod utils;
mod vcs;
mod fri;

pub use fields::{BaseField, SecureField};

Expand Down
5 changes: 3 additions & 2 deletions stwo_cairo_verifier/src/poly.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod circle;
mod line;
mod utils;
pub mod line;
pub mod utils;

145 changes: 139 additions & 6 deletions stwo_cairo_verifier/src/poly/line.cairo
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
use stwo_cairo_verifier::fields::SecureField;
use stwo_cairo_verifier::fields::m31::m31;
use core::option::OptionTrait;
use core::clone::Clone;
use core::result::ResultTrait;
use stwo_cairo_verifier::poly::utils::fold;
use stwo_cairo_verifier::fields::SecureField;
use stwo_cairo_verifier::fields::m31::{M31, m31, M31Trait};
use stwo_cairo_verifier::fields::qm31::{QM31, qm31, QM31Zero};
use stwo_cairo_verifier::utils::pow;
use stwo_cairo_verifier::circle::{Coset, CosetImpl, CirclePointM31Trait, M31_CIRCLE_GEN};
use stwo_cairo_verifier::fri::fold_line;

/// A univariate polynomial represented by its coefficients in the line part of the FFT-basis
/// in bit reversed order.
#[derive(Drop, Clone)]
/// A univariate polynomial defined on a [LineDomain].
#[derive(Debug, Drop, Clone)]
pub struct LinePoly {
/// The coefficients of the polynomial stored in bit-reversed order.
///
/// The coefficients are in a basis relating to the circle's x-coordinate doubling
/// map `pi(x) = 2x^2 - 1` i.e.
///
/// ```text
/// B = { 1 } * { x } * { pi(x) } * { pi(pi(x)) } * ...
/// = { 1, x, pi(x), pi(x) * x, pi(pi(x)), pi(pi(x)) * x, pi(pi(x)) * pi(x), ... }
/// ```
pub coeffs: Array<SecureField>,
/// The number of coefficients stored as `log2(len(coeffs))`.
pub log_size: u32,
}

#[generate_trait]
pub impl LinePolyImpl of LinePolyTrait {
/// Returns the number of coefficients.
fn len(self: @LinePoly) -> usize {
self.coeffs.len()
}

/// Evaluates the polynomial at a single point.
fn eval_at_point(self: @LinePoly, mut x: SecureField) -> SecureField {
let mut doublings = array![];
let mut i = 0;
Expand All @@ -30,12 +48,127 @@ pub impl LinePolyImpl of LinePolyTrait {
}
}

/// Domain comprising of the x-coordinates of points in a [Coset].
///
/// For use with univariate polynomials.
#[derive(Copy, Drop, Debug)]
pub struct LineDomain {
pub coset: Coset,
}

#[generate_trait]
pub impl LineDomainImpl of LineDomainTrait {
/// Returns a domain comprising of the x-coordinates of points in a coset.
fn new(coset: Coset) -> LineDomain {
let coset_size = coset.size();
if (coset_size == 2) {
// If the coset with two points contains (0, y) then the coset is {(0, y), (0, -y)}.
assert!(!coset.at(0).x.is_zero(), "coset x-coordinates not unique");
} else if (coset_size > 2) {
// Let our coset be `E = c + <G>` with `|E| > 2` then:
// 1. if `ord(c) <= ord(G)` the coset contains two points at x=0
// 2. if `ord(c) = 2 * ord(G)` then `c` and `-c` are in our coset
let coset_step = M31_CIRCLE_GEN.mul(coset.step_size);
assert!(
coset.at(0).log_order() >= coset_step.log_order() + 2,
"coset x-coordinates not unique"
);
}
LineDomain { coset: coset }
}

/// Returns the `i`th domain element.
fn at(self: @LineDomain, index: usize) -> M31 {
self.coset.at(index).x
}

/// Returns the size of the domain.
fn size(self: @LineDomain) -> usize {
self.coset.size()
}

/// Returns the log size of the domain.
fn log_size(self: @LineDomain) -> usize {
*self.coset.log_size
}

/// Returns a new domain comprising of all points in current domain doubled.
fn double(self: @LineDomain) -> LineDomain {
LineDomain { coset: self.coset.double() }
}
}

/// Evaluations of a univariate polynomial on a [LineDomain].
#[derive(Drop)]
pub struct LineEvaluation {
pub values: Array<QM31>,
pub domain: LineDomain
}

#[generate_trait]
pub impl LineEvaluationImpl of LineEvaluationTrait {
/// Creates new [LineEvaluation] from a set of polynomial evaluations over a [LineDomain].
fn new(domain: LineDomain, values: Array<QM31>) -> LineEvaluation {
assert_eq!(values.len(), domain.size());
LineEvaluation { values: values, domain: domain }
}
}

/// Holds a small foldable subset of univariate SecureField polynomial evaluations.
#[derive(Drop)]
pub struct SparseLineEvaluation {
pub subline_evals: Array<LineEvaluation>,
}

#[generate_trait]
pub impl SparseLineEvaluationImpl of SparseLineEvaluationTrait {
fn fold(self: @SparseLineEvaluation, alpha: QM31) -> Array<QM31> {
let mut i = 0;
let mut res: Array<QM31> = array![];
while i < self.subline_evals.len() {
let line_evaluation = fold_line(self.subline_evals[i], alpha);
res.append(*line_evaluation.values.at(0));
i += 1;
};
res
}
}


#[cfg(test)]
mod tests {
use super::{LinePoly, LinePolyTrait};
use super::{LinePoly, LinePolyTrait, LineDomain, LineDomainImpl};
use stwo_cairo_verifier::fields::qm31::qm31;
use stwo_cairo_verifier::fields::m31::m31;
use stwo_cairo_verifier::circle::{Coset, CosetImpl, CIRCLE_LOG_ORDER};
use stwo_cairo_verifier::utils::pow;

#[test]
#[should_panic]
fn bad_line_domain() {
// This coset doesn't have points with unique x-coordinates.
let LOG_SIZE = 2;
let initial_index = pow(2, CIRCLE_LOG_ORDER - (LOG_SIZE + 1));
let coset = CosetImpl::new(initial_index, LOG_SIZE);

LineDomainImpl::new(coset);
}

#[test]
fn line_domain_of_size_two_works() {
let LOG_SIZE: u32 = 1;
let coset = CosetImpl::new(0, LOG_SIZE);

LineDomainImpl::new(coset);
}

#[test]
fn line_domain_of_size_one_works() {
let LOG_SIZE: u32 = 0;
let coset = CosetImpl::new(0, LOG_SIZE);

LineDomainImpl::new(coset);
}

#[test]
fn test_eval_at_point_1() {
Expand Down
45 changes: 43 additions & 2 deletions stwo_cairo_verifier/src/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use core::box::BoxTrait;
use core::dict::Felt252DictEntryTrait;
use core::dict::Felt252DictTrait;
use core::iter::Iterator;

use core::num::traits::BitSize;
use stwo_cairo_verifier::BaseField;
use core::traits::DivRem;

#[generate_trait]
pub impl DictImpl<T, +Felt252DictValue<T>, +PanicDestruct<T>> of DictTrait<T> {
Expand Down Expand Up @@ -97,9 +98,24 @@ pub fn pow(base: u32, mut exponent: u32) -> u32 {
result
}

pub fn bit_reverse_index(mut index: usize, mut bits: u32) -> usize {
assert!(bits <= BitSize::<usize>::bits());

let NZ2: NonZero<usize> = 2;

let mut result = 0;
while bits > 0 {
let (next_index, bit) = DivRem::div_rem(index, NZ2);
result = (result * 2) | bit;
index = next_index;
bits -= 1;
};
result
}

#[cfg(test)]
mod tests {
use super::pow;
use super::{pow, bit_reverse_index};

#[test]
fn test_pow() {
Expand All @@ -109,5 +125,30 @@ mod tests {
assert_eq!(4096, pow(2, 12));
assert_eq!(1048576, pow(2, 20));
}

#[test]
fn test_bit_reverse() {
// 1 bit
assert_eq!(0, bit_reverse_index(0, 1));
assert_eq!(1, bit_reverse_index(1, 1));

// 2 bits
assert_eq!(0, bit_reverse_index(0, 2));
assert_eq!(2, bit_reverse_index(1, 2));
assert_eq!(1, bit_reverse_index(2, 2));
assert_eq!(3, bit_reverse_index(3, 2));

// 3 bits
assert_eq!(0, bit_reverse_index(0, 3));
assert_eq!(4, bit_reverse_index(1, 3));
assert_eq!(2, bit_reverse_index(2, 3));
assert_eq!(6, bit_reverse_index(3, 3));

// 16 bits
assert_eq!(24415, bit_reverse_index(64250, 16));

// 31 bits
assert_eq!(16448250, bit_reverse_index(800042880, 31));
}
}

0 comments on commit b547253

Please sign in to comment.