diff --git a/stwo_cairo_verifier/src/circle.cairo b/stwo_cairo_verifier/src/circle.cairo index 51d48663..cf7b3ccc 100644 --- a/stwo_cairo_verifier/src/circle.cairo +++ b/stwo_cairo_verifier/src/circle.cairo @@ -7,6 +7,7 @@ use stwo_cairo_verifier::circle_mul_table::{ M31_CIRCLE_GEN_MUL_TABLE_BITS_12_TO_17, M31_CIRCLE_GEN_MUL_TABLE_BITS_6_TO_11, M31_CIRCLE_GEN_MUL_TABLE_BITS_0_TO_5 }; +use stwo_cairo_verifier::fields::Invertible; use stwo_cairo_verifier::fields::cm31::CM31; use stwo_cairo_verifier::fields::m31::{M31, M31Impl}; use stwo_cairo_verifier::fields::qm31::{QM31Impl, QM31One, QM31, QM31Trait}; diff --git a/stwo_cairo_verifier/src/fields.cairo b/stwo_cairo_verifier/src/fields.cairo index 32c0ba18..a5ac4226 100644 --- a/stwo_cairo_verifier/src/fields.cairo +++ b/stwo_cairo_verifier/src/fields.cairo @@ -4,3 +4,88 @@ pub mod qm31; pub type BaseField = m31::M31; pub type SecureField = qm31::QM31; + +pub trait Invertible { + fn inverse(self: T) -> T; +} + +pub trait BatchInvertible, +Copy, +Drop, +Mul> { + /// Computes all `1/arr[i]` with a single call to `inverse()` using Montgomery batch inversion. + fn batch_inverse( + values: Array + ) -> Array< + T + > { + if values.is_empty() { + return array![]; + } + + // Collect array `z, zy, ..., zy..b`. + let mut prefix_product_rev = array![]; + let mut values_span = values.span(); + let mut cumulative_product = *values_span.pop_back().unwrap(); + + while let Option::Some(value) = values_span.pop_back() { + prefix_product_rev.append(cumulative_product); + cumulative_product = cumulative_product * *value; + }; + + // Compute `1/zy..a`. + let mut cumulative_product_inv = cumulative_product.inverse(); + + // Collect all `1/a = zy..b/zy..a, 1/b = zy..c/zy..b, ..., 1/y = z/zy`. + let mut inverses = array![]; + let mut values = values; + let mut prefix_product_rev_span = prefix_product_rev.span(); + + while let (Option::Some(prefix_product), Option::Some(value)) = + (prefix_product_rev_span.pop_back(), values.pop_front()) { + inverses.append(cumulative_product_inv * *prefix_product); + cumulative_product_inv = cumulative_product_inv * value; + }; + + // Append final `1/z`. + inverses.append(cumulative_product_inv); + + inverses + } +} + +#[cfg(test)] +mod tests { + use super::m31::{M31, m31}; + use super::{Invertible, BatchInvertible}; + + #[test] + fn test_batch_inverse() { + let arr = array![m31(2), m31(3), m31(5), m31(7)]; + let mut arr_inv = array![]; + for v in arr.span() { + arr_inv.append((*v).inverse()); + }; + + let res = BatchInvertible::batch_inverse(arr); + + assert_eq!(res, arr_inv); + } + + #[test] + fn test_batch_inverse_with_empty_array() { + let arr: Array = array![]; + + let res = BatchInvertible::batch_inverse(arr); + + assert_eq!(res, array![]); + } + + #[test] + fn test_batch_inverse_with_single_value() { + let two = m31(2); + let two_inv = two.inverse(); + let arr = array![two]; + + let res = BatchInvertible::batch_inverse(arr); + + assert_eq!(res, array![two_inv]); + } +} diff --git a/stwo_cairo_verifier/src/fields/cm31.cairo b/stwo_cairo_verifier/src/fields/cm31.cairo index 2d75ce1d..c0d38555 100644 --- a/stwo_cairo_verifier/src/fields/cm31.cairo +++ b/stwo_cairo_verifier/src/fields/cm31.cairo @@ -1,6 +1,7 @@ use core::num::traits::{One, Zero}; use core::ops::{AddAssign, MulAssign, SubAssign}; -use super::m31::{M31, M31Impl, m31, M31Trait}; +use super::m31::{M31, M31Impl, m31}; +use super::{Invertible, BatchInvertible}; #[derive(Copy, Drop, Debug, PartialEq)] pub struct CM31 { @@ -8,42 +9,18 @@ pub struct CM31 { pub b: M31, } -#[generate_trait] -pub impl CM31Impl of CM31Trait { +pub impl CM31InvertibleImpl of Invertible { fn inverse(self: CM31) -> CM31 { assert!(self.is_non_zero()); let denom_inverse: M31 = (self.a * self.a + self.b * self.b).inverse(); CM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse } } +} - /// Computes all `1/arr[i]` with a single call to `inverse()` using Montgomery batch inversion. - fn batch_inverse(arr: Array) -> Array { - // Construct array `1, z, zy, ..., zy..b`. - let mut prefix_product_rev = array![]; - let mut cumulative_product: CM31 = One::one(); - - let mut i = arr.len(); - while i != 0 { - i -= 1; - prefix_product_rev.append(cumulative_product); - cumulative_product *= *arr[i]; - }; - - // Compute `1/zy..a`. - let mut cumulative_product_inv = cumulative_product.inverse(); - // Compute all `1/a = zy..b/zy..a, 1/b = zy..c/zy..b, ...`. - let mut inverses = array![]; - - let mut i = prefix_product_rev.len(); - for v in arr { - i -= 1; - inverses.append(cumulative_product_inv * *prefix_product_rev[i]); - cumulative_product_inv *= v; - }; - - inverses - } +pub impl CM31BatchInvertibleImpl of BatchInvertible {} +#[generate_trait] +pub impl CM31Impl of CM31Trait { // TODO(andrew): When associated types are supported, support `Mul`. #[inline] fn mul_m31(self: CM31, rhs: M31) -> CM31 { diff --git a/stwo_cairo_verifier/src/fields/m31.cairo b/stwo_cairo_verifier/src/fields/m31.cairo index f56c8f27..ecb97093 100644 --- a/stwo_cairo_verifier/src/fields/m31.cairo +++ b/stwo_cairo_verifier/src/fields/m31.cairo @@ -2,6 +2,7 @@ use core::num::traits::{WideMul, CheckedSub}; use core::ops::{AddAssign, MulAssign, SubAssign}; use core::option::OptionTrait; use core::traits::TryInto; +use super::{Invertible, BatchInvertible}; /// Equals `2^31 - 1`. pub const P: u32 = 0x7fffffff; @@ -20,6 +21,21 @@ pub struct M31 { pub inner: u32 } +pub impl M31InvertibleImpl of Invertible { + fn inverse(self: M31) -> M31 { + assert!(self.is_non_zero()); + let t0 = sqn(self, 2) * self; + let t1 = sqn(t0, 1) * t0; + let t2 = sqn(t1, 3) * t0; + let t3 = sqn(t2, 1) * t0; + let t4 = sqn(t3, 8) * t3; + let t5 = sqn(t4, 8) * t3; + sqn(t5, 7) * t2 + } +} + +pub impl M31BatchInvertibleImpl of BatchInvertible {} + #[generate_trait] pub impl M31Impl of M31Trait { #[inline] @@ -39,25 +55,6 @@ pub impl M31Impl of M31Trait { let (_, res) = core::integer::u128_safe_divmod(val, P128NZ); M31 { inner: res.try_into().unwrap() } } - - #[inline] - fn sqn(v: M31, n: usize) -> M31 { - if n == 0 { - return v; - } - Self::sqn(v * v, n - 1) - } - - fn inverse(self: M31) -> M31 { - assert!(self.is_non_zero()); - let t0 = Self::sqn(self, 2) * self; - let t1 = Self::sqn(t0, 1) * t0; - let t2 = Self::sqn(t1, 3) * t0; - let t3 = Self::sqn(t2, 1) * t0; - let t4 = Self::sqn(t3, 8) * t3; - let t5 = Self::sqn(t4, 8) * t3; - Self::sqn(t5, 7) * t2 - } } pub impl M31Add of core::traits::Add { #[inline] @@ -192,9 +189,19 @@ impl M31IntoUnreducedM31 of Into { } } +/// Returns `v^(2^n)`. +fn sqn(v: M31, n: usize) -> M31 { + if n == 0 { + return v; + } + sqn(v * v, n - 1) +} + #[cfg(test)] mod tests { - use super::{m31, P, M31Trait}; + use super::super::Invertible; + use super::{m31, P}; + const POW2_15: u32 = 0b1000000000000000; const POW2_16: u32 = 0b10000000000000000; diff --git a/stwo_cairo_verifier/src/fields/qm31.cairo b/stwo_cairo_verifier/src/fields/qm31.cairo index 0aef19a9..32c8021b 100644 --- a/stwo_cairo_verifier/src/fields/qm31.cairo +++ b/stwo_cairo_verifier/src/fields/qm31.cairo @@ -1,7 +1,8 @@ use core::num::traits::one::One; use core::num::traits::zero::Zero; use core::ops::{AddAssign, MulAssign, SubAssign}; -use super::cm31::{CM31, cm31, CM31Trait}; +use super::Invertible; +use super::cm31::{CM31, cm31}; use super::m31::{M31, M31Impl, UnreducedM31}; /// Equals `(2^31 - 1)^4`. @@ -17,6 +18,17 @@ pub struct QM31 { pub b: CM31, } +impl QM31InvertibleImpl of Invertible { + fn inverse(self: QM31) -> QM31 { + assert!(self.is_non_zero()); + let b2 = self.b * self.b; + let ib2 = CM31 { a: -b2.b, b: b2.a }; + let denom = self.a * self.a - (b2 + b2 + ib2); + let denom_inverse = denom.inverse(); + QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse } + } +} + #[generate_trait] pub impl QM31Impl of QM31Trait { #[inline] @@ -30,15 +42,6 @@ pub impl QM31Impl of QM31Trait { [self.a.a, self.a.b, self.b.a, self.b.b] } - fn inverse(self: QM31) -> QM31 { - assert!(self.is_non_zero()); - let b2 = self.b * self.b; - let ib2 = CM31 { a: -b2.b, b: b2.a }; - let denom = self.a * self.a - (b2 + b2 + ib2); - let denom_inverse = denom.inverse(); - QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse } - } - #[inline] fn mul_m31(self: QM31, multiplier: M31) -> QM31 { QM31 { @@ -422,7 +425,8 @@ pub impl CM31IntoPackedUnreducedCM31 of Into { #[cfg(test)] mod tests { - use super::super::m31::{m31, P, M31Trait}; + use super::super::Invertible; + use super::super::m31::{m31, P}; use super::{ QM31, qm31, QM31Trait, QM31Impl, QM31IntoPackedUnreducedQM31, PackedUnreducedQM31Impl }; diff --git a/stwo_cairo_verifier/src/fri.cairo b/stwo_cairo_verifier/src/fri.cairo index 4669c318..4a1f6831 100644 --- a/stwo_cairo_verifier/src/fri.cairo +++ b/stwo_cairo_verifier/src/fri.cairo @@ -1,8 +1,8 @@ use core::dict::Felt252Dict; use stwo_cairo_verifier::channel::{Channel, ChannelTrait}; use stwo_cairo_verifier::circle::CosetImpl; +use stwo_cairo_verifier::fields::Invertible; use stwo_cairo_verifier::fields::m31::M31; -use stwo_cairo_verifier::fields::m31::M31Trait; use stwo_cairo_verifier::fields::qm31::{QM31_EXTENSION_DEGREE, QM31, QM31Zero, QM31Trait}; use stwo_cairo_verifier::poly::circle::CircleDomainImpl; use stwo_cairo_verifier::poly::circle::{ diff --git a/stwo_cairo_verifier/src/pcs/quotients.cairo b/stwo_cairo_verifier/src/pcs/quotients.cairo index df9007de..f8fa4a95 100644 --- a/stwo_cairo_verifier/src/pcs/quotients.cairo +++ b/stwo_cairo_verifier/src/pcs/quotients.cairo @@ -6,6 +6,7 @@ use core::num::traits::{One, Zero}; use stwo_cairo_verifier::circle::{ CosetImpl, CirclePointIndexImpl, CirclePoint, M31_CIRCLE_LOG_ORDER }; +use stwo_cairo_verifier::fields::BatchInvertible; use stwo_cairo_verifier::fields::cm31::{CM31, CM31Impl}; use stwo_cairo_verifier::fields::m31::{M31, UnreducedM31}; use stwo_cairo_verifier::fields::qm31::{ @@ -391,7 +392,7 @@ fn quotient_denominator_inverses( }; }; - CM31Impl::batch_inverse(flat_denominators) + BatchInvertible::batch_inverse(flat_denominators) } /// A batch of column samplings at a point. diff --git a/stwo_cairo_verifier/src/utils.cairo b/stwo_cairo_verifier/src/utils.cairo index 70da7e8f..04c553a5 100644 --- a/stwo_cairo_verifier/src/utils.cairo +++ b/stwo_cairo_verifier/src/utils.cairo @@ -42,7 +42,7 @@ pub impl OptionImpl of OptionExTrait { #[generate_trait] pub impl ArrayImpl> of ArrayExTrait { - fn pop_n(ref self: Array, mut n: usize) -> Array { + fn pop_front_n(ref self: Array, mut n: usize) -> Array { let mut res = array![]; while n != 0 { if let Option::Some(value) = self.pop_front() { diff --git a/stwo_cairo_verifier/src/vcs/verifier.cairo b/stwo_cairo_verifier/src/vcs/verifier.cairo index 396c9b17..8a80b678 100644 --- a/stwo_cairo_verifier/src/vcs/verifier.cairo +++ b/stwo_cairo_verifier/src/vcs/verifier.cairo @@ -169,7 +169,7 @@ impl MerkleVerifierImpl< col_query_index += 1; res } else { - column_witness.pop_n(n_columns_in_layer) + column_witness.pop_front_n(n_columns_in_layer) }; if column_values.len() != n_columns_in_layer {