Skip to content

Commit

Permalink
Implement batch_inverse on M31,CM31
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Nov 14, 2024
1 parent b70127f commit a58d998
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 65 deletions.
1 change: 1 addition & 0 deletions stwo_cairo_verifier/src/circle.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use stwo_cairo_verifier::circle_mul_table::{
M31_CIRCLE_GEN_MUL_TABLE_BITS_12_TO_17, M31_CIRCLE_GEN_MUL_TABLE_BITS_6_TO_11,
M31_CIRCLE_GEN_MUL_TABLE_BITS_0_TO_5
};
use stwo_cairo_verifier::fields::Field;
use stwo_cairo_verifier::fields::cm31::CM31;
use stwo_cairo_verifier::fields::m31::{M31, M31Impl};
use stwo_cairo_verifier::fields::qm31::{QM31Impl, QM31One, QM31, QM31Trait};
Expand Down
85 changes: 85 additions & 0 deletions stwo_cairo_verifier/src/fields.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,88 @@ pub mod qm31;

pub type BaseField = m31::M31;
pub type SecureField = qm31::QM31;

pub trait Field<T> {
fn inverse(self: T) -> T;
}

pub trait FieldBatchInverse<T, +Field<T>, +Copy<T>, +Drop<T>, +Mul<T>> {
/// Computes all `1/arr[i]` with a single call to `inverse()` using Montgomery batch inversion.
fn batch_inverse(
arr: Array<T>
) -> Array<
T
> {
if arr.is_empty() {
return array![];
}

// Collect array `z, zy, ..., zy..b`.
let mut prefix_product_rev = array![];
let mut cumulative_product = *arr[arr.len() - 1];

let mut i = arr.len() - 1;
while i != 0 {
i -= 1;
prefix_product_rev.append(cumulative_product);
cumulative_product = cumulative_product * *arr[i];
};

// Compute `1/zy..a`.
let mut cumulative_product_inv = cumulative_product.inverse();
// Collect all `1/a = zy..b/zy..a, 1/b = zy..c/zy..b, ..., 1/y = z/zy`.
let mut inverses = array![];
let mut arr = arr;

let mut i = prefix_product_rev.len();
while i != 0 {
i -= 1;
inverses.append(cumulative_product_inv * *prefix_product_rev[i]);
cumulative_product_inv = cumulative_product_inv * arr.pop_front().unwrap();
};

// Append final `1/z`.
inverses.append(cumulative_product_inv);

inverses
}
}

#[cfg(test)]
mod tests {
use super::m31::{M31, m31};
use super::{Field, FieldBatchInverse};

#[test]
fn test_batch_inverse() {
let arr = array![m31(2), m31(3), m31(5), m31(7)];
let mut arr_inv = array![];
for v in arr.span() {
arr_inv.append((*v).inverse());
};

let res = FieldBatchInverse::batch_inverse(arr);

assert_eq!(res, arr_inv);
}

#[test]
fn test_batch_inverse_with_empty_array() {
let arr: Array<M31> = array![];

let res = FieldBatchInverse::batch_inverse(arr);

assert_eq!(res, array![]);
}

#[test]
fn test_batch_inverse_with_single_value() {
let two = m31(2);
let two_inv = two.inverse();
let arr = array![two];

let res = FieldBatchInverse::batch_inverse(arr);

assert_eq!(res, array![two_inv]);
}
}
37 changes: 7 additions & 30 deletions stwo_cairo_verifier/src/fields/cm31.cairo
Original file line number Diff line number Diff line change
@@ -1,49 +1,26 @@
use core::num::traits::{One, Zero};
use core::ops::{AddAssign, MulAssign, SubAssign};
use super::m31::{M31, M31Impl, m31, M31Trait};
use super::m31::{M31, M31Impl, m31};
use super::{Field, FieldBatchInverse};

#[derive(Copy, Drop, Debug, PartialEq)]
pub struct CM31 {
pub a: M31,
pub b: M31,
}

#[generate_trait]
pub impl CM31Impl of CM31Trait {
pub impl CM31FieldImpl of Field<CM31> {
fn inverse(self: CM31) -> CM31 {
assert!(self.is_non_zero());
let denom_inverse: M31 = (self.a * self.a + self.b * self.b).inverse();
CM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}
}

/// Computes all `1/arr[i]` with a single call to `inverse()` using Montgomery batch inversion.
fn batch_inverse(arr: Array<CM31>) -> Array<CM31> {
// Construct array `1, z, zy, ..., zy..b`.
let mut prefix_product_rev = array![];
let mut cumulative_product: CM31 = One::one();

let mut i = arr.len();
while i != 0 {
i -= 1;
prefix_product_rev.append(cumulative_product);
cumulative_product *= *arr[i];
};

// Compute `1/zy..a`.
let mut cumulative_product_inv = cumulative_product.inverse();
// Compute all `1/a = zy..b/zy..a, 1/b = zy..c/zy..b, ...`.
let mut inverses = array![];

let mut i = prefix_product_rev.len();
for v in arr {
i -= 1;
inverses.append(cumulative_product_inv * *prefix_product_rev[i]);
cumulative_product_inv *= v;
};

inverses
}
pub impl CM31FieldBatchInverseImpl of FieldBatchInverse<CM31> {}

#[generate_trait]
pub impl CM31Impl of CM31Trait {
// TODO(andrew): When associated types are supported, support `Mul<CM31, M31>`.
#[inline]
fn mul_m31(self: CM31, rhs: M31) -> CM31 {
Expand Down
47 changes: 27 additions & 20 deletions stwo_cairo_verifier/src/fields/m31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use core::num::traits::{WideMul, CheckedSub};
use core::ops::{AddAssign, MulAssign, SubAssign};
use core::option::OptionTrait;
use core::traits::TryInto;
use super::{Field, FieldBatchInverse};

/// Equals `2^31 - 1`.
pub const P: u32 = 0x7fffffff;
Expand All @@ -20,6 +21,21 @@ pub struct M31 {
pub inner: u32
}

pub impl M31FieldImpl of Field<M31> {
fn inverse(self: M31) -> M31 {
assert!(self.is_non_zero());
let t0 = sqn(self, 2) * self;
let t1 = sqn(t0, 1) * t0;
let t2 = sqn(t1, 3) * t0;
let t3 = sqn(t2, 1) * t0;
let t4 = sqn(t3, 8) * t3;
let t5 = sqn(t4, 8) * t3;
sqn(t5, 7) * t2
}
}

pub impl M31FieldBatchInverseImpl of FieldBatchInverse<M31> {}

#[generate_trait]
pub impl M31Impl of M31Trait {
#[inline]
Expand All @@ -39,25 +55,6 @@ pub impl M31Impl of M31Trait {
let (_, res) = core::integer::u128_safe_divmod(val, P128NZ);
M31 { inner: res.try_into().unwrap() }
}

#[inline]
fn sqn(v: M31, n: usize) -> M31 {
if n == 0 {
return v;
}
Self::sqn(v * v, n - 1)
}

fn inverse(self: M31) -> M31 {
assert!(self.is_non_zero());
let t0 = Self::sqn(self, 2) * self;
let t1 = Self::sqn(t0, 1) * t0;
let t2 = Self::sqn(t1, 3) * t0;
let t3 = Self::sqn(t2, 1) * t0;
let t4 = Self::sqn(t3, 8) * t3;
let t5 = Self::sqn(t4, 8) * t3;
Self::sqn(t5, 7) * t2
}
}
pub impl M31Add of core::traits::Add<M31> {
#[inline]
Expand Down Expand Up @@ -192,9 +189,19 @@ impl M31IntoUnreducedM31 of Into<M31, UnreducedM31> {
}
}

/// Returns `v^(2^n)`.
fn sqn(v: M31, n: usize) -> M31 {
if n == 0 {
return v;
}
sqn(v * v, n - 1)
}

#[cfg(test)]
mod tests {
use super::{m31, P, M31Trait};
use super::super::Field;
use super::{m31, P, M31FieldImpl};

const POW2_15: u32 = 0b1000000000000000;
const POW2_16: u32 = 0b10000000000000000;

Expand Down
26 changes: 15 additions & 11 deletions stwo_cairo_verifier/src/fields/qm31.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use core::num::traits::one::One;
use core::num::traits::zero::Zero;
use core::ops::{AddAssign, MulAssign, SubAssign};
use super::cm31::{CM31, cm31, CM31Trait};
use super::Field;
use super::cm31::{CM31, cm31};
use super::m31::{M31, M31Impl, UnreducedM31};

/// Equals `(2^31 - 1)^4`.
Expand All @@ -17,6 +18,17 @@ pub struct QM31 {
pub b: CM31,
}

impl QM31Field of Field<QM31> {
fn inverse(self: QM31) -> QM31 {
assert!(self.is_non_zero());
let b2 = self.b * self.b;
let ib2 = CM31 { a: -b2.b, b: b2.a };
let denom = self.a * self.a - (b2 + b2 + ib2);
let denom_inverse = denom.inverse();
QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}
}

#[generate_trait]
pub impl QM31Impl of QM31Trait {
#[inline]
Expand All @@ -30,15 +42,6 @@ pub impl QM31Impl of QM31Trait {
[self.a.a, self.a.b, self.b.a, self.b.b]
}

fn inverse(self: QM31) -> QM31 {
assert!(self.is_non_zero());
let b2 = self.b * self.b;
let ib2 = CM31 { a: -b2.b, b: b2.a };
let denom = self.a * self.a - (b2 + b2 + ib2);
let denom_inverse = denom.inverse();
QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}

#[inline]
fn mul_m31(self: QM31, multiplier: M31) -> QM31 {
QM31 {
Expand Down Expand Up @@ -422,7 +425,8 @@ pub impl CM31IntoPackedUnreducedCM31 of Into<CM31, PackedUnreducedCM31> {

#[cfg(test)]
mod tests {
use super::super::m31::{m31, P, M31Trait};
use super::super::Field;
use super::super::m31::{m31, P};
use super::{
QM31, qm31, QM31Trait, QM31Impl, QM31IntoPackedUnreducedQM31, PackedUnreducedQM31Impl
};
Expand Down
2 changes: 1 addition & 1 deletion stwo_cairo_verifier/src/fri.cairo
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use core::dict::Felt252Dict;
use stwo_cairo_verifier::channel::{Channel, ChannelTrait};
use stwo_cairo_verifier::circle::CosetImpl;
use stwo_cairo_verifier::fields::Field;
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::fields::m31::M31Trait;
use stwo_cairo_verifier::fields::qm31::{QM31_EXTENSION_DEGREE, QM31, QM31Zero, QM31Trait};
use stwo_cairo_verifier::poly::circle::CircleDomainImpl;
use stwo_cairo_verifier::poly::circle::{
Expand Down
3 changes: 2 additions & 1 deletion stwo_cairo_verifier/src/pcs/quotients.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use core::num::traits::{One, Zero};
use stwo_cairo_verifier::circle::{
CosetImpl, CirclePointIndexImpl, CirclePoint, M31_CIRCLE_LOG_ORDER
};
use stwo_cairo_verifier::fields::FieldBatchInverse;
use stwo_cairo_verifier::fields::cm31::{CM31, CM31Impl};
use stwo_cairo_verifier::fields::m31::{M31, UnreducedM31};
use stwo_cairo_verifier::fields::qm31::{
Expand Down Expand Up @@ -391,7 +392,7 @@ fn quotient_denominator_inverses(
};
};

CM31Impl::batch_inverse(flat_denominators)
FieldBatchInverse::batch_inverse(flat_denominators)
}

/// A batch of column samplings at a point.
Expand Down
2 changes: 1 addition & 1 deletion stwo_cairo_verifier/src/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub impl OptionImpl<T> of OptionExTrait<T> {

#[generate_trait]
pub impl ArrayImpl<T, +Drop<T>> of ArrayExTrait<T> {
fn pop_n(ref self: Array<T>, mut n: usize) -> Array<T> {
fn pop_front_n(ref self: Array<T>, mut n: usize) -> Array<T> {
let mut res = array![];
while n != 0 {
if let Option::Some(value) = self.pop_front() {
Expand Down
2 changes: 1 addition & 1 deletion stwo_cairo_verifier/src/vcs/verifier.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ impl MerkleVerifierImpl<
col_query_index += 1;
res
} else {
column_witness.pop_n(n_columns_in_layer)
column_witness.pop_front_n(n_columns_in_layer)
};

if column_values.len() != n_columns_in_layer {
Expand Down

0 comments on commit a58d998

Please sign in to comment.