From 41afb4066bea7cb0256b8b160f7399df20e1fd2f Mon Sep 17 00:00:00 2001 From: Tom Grosso Date: Tue, 1 Oct 2024 16:26:17 -0300 Subject: [PATCH] Generalize field in CirclePoint --- stwo_cairo_verifier/src/circle.cairo | 251 ++++++++++++++++------ stwo_cairo_verifier/src/fields/qm31.cairo | 8 +- stwo_cairo_verifier/src/poly/circle.cairo | 14 +- stwo_cairo_verifier/src/poly/line.cairo | 5 +- 4 files changed, 204 insertions(+), 74 deletions(-) diff --git a/stwo_cairo_verifier/src/circle.cairo b/stwo_cairo_verifier/src/circle.cairo index b464b11b6..8a77a5041 100644 --- a/stwo_cairo_verifier/src/circle.cairo +++ b/stwo_cairo_verifier/src/circle.cairo @@ -1,8 +1,23 @@ -use stwo_cairo_verifier::fields::m31::{M31, m31, M31One}; +use stwo_cairo_verifier::fields::m31::{M31, M31Impl}; +use stwo_cairo_verifier::fields::cm31::CM31; +use stwo_cairo_verifier::fields::qm31::{QM31Impl, QM31, QM31Trait}; use super::utils::pow; - -pub const M31_CIRCLE_GEN: CirclePointM31 = - CirclePointM31 { x: M31 { inner: 2 }, y: M31 { inner: 1268011823 }, }; +use core::num::traits::zero::Zero; +use core::num::traits::one::One; + +pub const M31_CIRCLE_GEN: CirclePoint = + CirclePoint { x: M31 { inner: 2 }, y: M31 { inner: 1268011823 }, }; + +pub const QM31_CIRCLE_GEN: CirclePoint = CirclePoint { + x: QM31 { + a: CM31 { a: M31 { inner: 1 }, b: M31 { inner: 0 }, }, + b: CM31 { a: M31 { inner: 478637715 }, b: M31 { inner: 513582971 } } + }, + y: QM31 { + a: CM31 { a: M31 { inner: 992285211 }, b: M31 { inner: 649143431 } }, + b: CM31 { a: M31 { inner: 740191619 }, b: M31 { inner: 1186584352 } } + }, +}; pub const CIRCLE_LOG_ORDER: u32 = 31; @@ -15,41 +30,49 @@ pub const CIRCLE_ORDER_BIT_MASK: u32 = 0x7fffffff; // `U32_BIT_MASK` equals 2^32 - 1 pub const U32_BIT_MASK: u64 = 0xffffffff; +/// A point on the complex circle. Treated as an additive group. #[derive(Drop, Copy, Debug, PartialEq, Eq)] -pub struct CirclePointM31 { - pub x: M31, - pub y: M31, +pub struct CirclePoint { + pub x: F, + pub y: F } -#[generate_trait] -pub impl CirclePointM31Impl of CirclePointM31Trait { +pub trait CirclePointTrait< + F, +Add, +Sub, +Mul, +Drop, +Copy, +Zero, +One, +PartialEq +> { // Returns the neutral element of the circle. - fn zero() -> CirclePointM31 { - CirclePointM31 { x: m31(1), y: m31(0) } + fn zero() -> CirclePoint { + CirclePoint { x: One::one(), y: Zero::zero() } } /// Applies the circle's x-coordinate doubling map. - fn double_x(x: M31) -> M31 { + fn double_x(x: F) -> F { let sqx = x * x; - sqx + sqx - M31One::one() + sqx + sqx - One::one() } /// Returns the log order of a point. /// /// All points have an order of the form `2^k`. - fn log_order(self: @CirclePointM31) -> u32 { + fn log_order( + self: @CirclePoint + ) -> u32 { // we only need the x-coordinate to check order since the only point // with x=1 is the circle's identity let mut res = 0; let mut cur = self.x.clone(); - while cur != M31One::one() { + while cur != One::one() { cur = Self::double_x(cur); res += 1; }; res } - fn mul(self: @CirclePointM31, mut scalar: u32) -> CirclePointM31 { + fn mul( + self: @CirclePoint, ref scalar: u128 + ) -> CirclePoint< + F + > { let mut result = Self::zero(); let mut cur = *self; while scalar > 0 { @@ -63,13 +86,25 @@ pub impl CirclePointM31Impl of CirclePointM31Trait { } } -impl CirclePointM31Add of Add { - // The operation of the circle as a group with additive notation. - fn add(lhs: CirclePointM31, rhs: CirclePointM31) -> CirclePointM31 { - CirclePointM31 { x: lhs.x * rhs.x - lhs.y * rhs.y, y: lhs.x * rhs.y + lhs.y * rhs.x } +impl CirclePointAdd, +Sub, +Mul, +Drop, +Copy> of Add> { + /// Performs the operation of the circle as a group with additive notation. + fn add(lhs: CirclePoint, rhs: CirclePoint) -> CirclePoint { + CirclePoint { x: lhs.x * rhs.x - lhs.y * rhs.y, y: lhs.x * rhs.y + lhs.y * rhs.x } + } +} + +pub impl CirclePointM31Impl of CirclePointTrait {} + +pub impl CirclePointQM31Impl of CirclePointTrait {} + +#[generate_trait] +pub impl ComplexConjugateImpl of ComplexConjugateTrait { + fn complex_conjugate(self: CirclePoint) -> CirclePoint { + CirclePoint { x: self.x.complex_conjugate(), y: self.y.complex_conjugate() } } } +/// Represents the coset `initial + `. #[derive(Copy, Clone, Debug, PartialEq, Eq, Drop)] pub struct Coset { // This is an index in the range [0, 2^31) @@ -107,109 +142,141 @@ pub impl CosetImpl of CosetTrait { } } - fn at(self: @Coset, index: usize) -> CirclePointM31 { - M31_CIRCLE_GEN.mul(self.index_at(index)) + fn at(self: @Coset, index: usize) -> CirclePoint { + let mut scalar = self.index_at(index).into(); + M31_CIRCLE_GEN.mul(ref scalar) } + /// Returns the size of the coset. fn size(self: @Coset) -> usize { pow(2, *self.log_size) } + + /// Creates a coset of the form `G_2n + `. + /// + /// For example, for `n=8`, we get the point indices `[1,3,5,7,9,11,13,15]`. + fn odds(log_size: u32) -> Coset { + let subgroup_generator_index = Self::subgroup_generator_index(log_size); + Self::new(subgroup_generator_index, log_size) + } + + /// Creates a coset of the form `G_4n + `. + /// + /// For example, for `n=8`, we get the point indices `[1,5,9,13,17,21,25,29]`. + /// Its conjugate will be `[3,7,11,15,19,23,27,31]`. + fn half_odds(log_size: u32) -> Coset { + Self::new(Self::subgroup_generator_index(log_size + 2), log_size) + } + + fn subgroup_generator_index(log_size: u32) -> u32 { + assert!(log_size <= CIRCLE_LOG_ORDER); + pow(2, CIRCLE_LOG_ORDER - log_size) + } } #[cfg(test)] mod tests { - use super::{M31_CIRCLE_GEN, CIRCLE_ORDER, CirclePointM31, CirclePointM31Impl, Coset, CosetImpl}; - use stwo_cairo_verifier::fields::m31::m31; + use super::{M31_CIRCLE_GEN, CIRCLE_ORDER, CirclePoint, CirclePointM31Impl, Coset, CosetImpl}; + use core::option::OptionTrait; + use core::array::ArrayTrait; + use core::traits::TryInto; + use super::{CirclePointQM31Impl, QM31_CIRCLE_GEN}; + use stwo_cairo_verifier::fields::m31::{m31, M31}; + use stwo_cairo_verifier::fields::qm31::{qm31, QM31, QM31One}; + use stwo_cairo_verifier::utils::pow; #[test] fn test_add_1() { - let i = CirclePointM31 { x: m31(0), y: m31(1) }; + let i = CirclePoint { x: m31(0), y: m31(1) }; let result = i + i; - let expected_result = CirclePointM31 { x: -m31(1), y: m31(0) }; - assert_eq!(result, expected_result); + assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) }); } #[test] fn test_add_2() { - let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) }; - let point_2 = CirclePointM31 { x: m31(1737427771), y: m31(309481134) }; + let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; + let point_2 = CirclePoint { x: m31(1737427771), y: m31(309481134) }; let result = point_1 + point_2; - let expected_result = CirclePointM31 { x: m31(1476625263), y: m31(1040927458) }; - assert_eq!(result, expected_result); + assert_eq!(result, CirclePoint { x: m31(1476625263), y: m31(1040927458) }); } #[test] fn test_zero_1() { let result = CirclePointM31Impl::zero(); - let expected_result = CirclePointM31 { x: m31(1), y: m31(0) }; - assert_eq!(result, expected_result); + + assert_eq!(result, CirclePoint { x: m31(1), y: m31(0) }); } #[test] fn test_zero_2() { - let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) }; + let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; let point_2 = CirclePointM31Impl::zero(); - let expected_result = point_1.clone(); let result = point_1 + point_2; - assert_eq!(result, expected_result); + assert_eq!(result, point_1.clone()); } #[test] fn test_mul_1() { - let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) }; - let result = point_1.mul(5); - let expected_result = point_1 + point_1 + point_1 + point_1 + point_1; + let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; + let mut scalar = 5; + let result = point_1.mul(ref scalar); - assert_eq!(result, expected_result); + assert_eq!(result, point_1 + point_1 + point_1 + point_1 + point_1); } #[test] fn test_mul_2() { - let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) }; - let result = point_1.mul(8); - let mut expected_result = point_1 + point_1; - expected_result = expected_result + expected_result; - expected_result = expected_result + expected_result; + let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; + let mut scalar = 8; + let result = point_1.mul(ref scalar); - assert_eq!(result, expected_result); + assert_eq!(result, point_1 + point_1 + point_1 + point_1 + point_1 + point_1 + point_1 + point_1); } #[test] fn test_mul_3() { - let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) }; - let result = point_1.mul(418776494); - let expected_result = CirclePointM31 { x: m31(1987283985), y: m31(1500510905) }; + let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; + let mut scalar = 418776494; + let result = point_1.mul(ref scalar); - assert_eq!(result, expected_result); + assert_eq!(result, CirclePoint { x: m31(1987283985), y: m31(1500510905) }); } #[test] fn test_generator_order() { let half_order = CIRCLE_ORDER / 2; - let mut result = M31_CIRCLE_GEN.mul(half_order); - let expected_result = CirclePointM31 { x: -m31(1), y: m31(0) }; + let mut scalar = half_order.into(); + let mut result = M31_CIRCLE_GEN.mul(ref scalar); // Assert `M31_CIRCLE_GEN^{2^30}` equals `-1`. - assert_eq!(expected_result, result); + assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) }); + } + + #[test] + fn test_generator() { + let mut scalar = pow(2, 30).try_into().unwrap(); + let mut result = M31_CIRCLE_GEN.mul(ref scalar); + + assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) }); } #[test] fn test_coset_index_at() { let coset = Coset { initial_index: 16777216, log_size: 5, step_size: 67108864 }; let result = coset.index_at(8); - let expected_result = 553648128; - assert_eq!(expected_result, result); + + assert_eq!(result, 553648128); } #[test] fn test_coset_constructor() { let result = CosetImpl::new(16777216, 5); - let expected_result = Coset { initial_index: 16777216, log_size: 5, step_size: 67108864 }; - assert_eq!(expected_result, result); + + assert_eq!(result, Coset { initial_index: 16777216, log_size: 5, step_size: 67108864 }); } #[test] @@ -217,24 +284,82 @@ mod tests { let coset = Coset { initial_index: 16777216, step_size: 67108864, log_size: 5 }; let result = coset.double(); - let expected_result = Coset { initial_index: 33554432, step_size: 134217728, log_size: 4 }; - assert_eq!(expected_result, result); + assert_eq!(result, Coset { initial_index: 33554432, step_size: 134217728, log_size: 4 }); } #[test] fn test_coset_at() { let coset = Coset { initial_index: 16777216, step_size: 67108864, log_size: 5 }; let result = coset.at(17); - let expected_result = CirclePointM31 { x: m31(7144319), y: m31(1742797653) }; - assert_eq!(expected_result, result); + + assert_eq!(result, CirclePoint:: { x: m31(7144319), y: m31(1742797653) }); } #[test] fn test_coset_size() { let coset = Coset { initial_index: 16777216, step_size: 67108864, log_size: 5 }; let result = coset.size(); - let expected_result = 32; - assert_eq!(result, expected_result); + + assert_eq!(result, 32); + } + + #[test] + fn test_qm31_circle_gen() { + let P4: u128 = 21267647892944572736998860269687930881; + + let first_prime = 2; + let last_prime = 368140581013; + let prime_factors: Array<(u128, u32)> = array![ + (first_prime, 33), + (3, 2), + (5, 1), + (7, 1), + (11, 1), + (31, 1), + (151, 1), + (331, 1), + (733, 1), + (1709, 1), + (last_prime, 1), + ]; + + let product = iter_product(first_prime, @prime_factors, last_prime); + + assert_eq!(product, P4 - 1); + + assert_eq!( + QM31_CIRCLE_GEN.x * QM31_CIRCLE_GEN.x + QM31_CIRCLE_GEN.y * QM31_CIRCLE_GEN.y, + QM31One::one() + ); + + let mut scalar = P4 - 1; + assert_eq!(QM31_CIRCLE_GEN.mul(ref scalar), CirclePointQM31Impl::zero()); + + let mut i = 0; + while i < prime_factors.len() { + let (p, _) = *prime_factors.at(i); + let mut scalar = (P4 - 1) / p.into(); + assert_ne!(QM31_CIRCLE_GEN.mul(ref scalar), CirclePointQM31Impl::zero()); + + i = i + 1; + } + } + + fn iter_product( + first_prime: u128, prime_factors: @Array<(u128, u32)>, last_prime: u128 + ) -> u128 { + let mut accum_product: u128 = 1; + accum_product = accum_product + * pow(first_prime.try_into().unwrap(), 31).into() + * 4; // * 2^33 + let mut i = 1; + while i < prime_factors.len() - 1 { + let (prime, exponent): (u128, u32) = *prime_factors.at(i); + accum_product = accum_product * pow(prime.try_into().unwrap(), exponent).into(); + i = i + 1; + }; + accum_product = accum_product * last_prime; + accum_product } } diff --git a/stwo_cairo_verifier/src/fields/qm31.cairo b/stwo_cairo_verifier/src/fields/qm31.cairo index 03f3ba915..3418eb631 100644 --- a/stwo_cairo_verifier/src/fields/qm31.cairo +++ b/stwo_cairo_verifier/src/fields/qm31.cairo @@ -7,8 +7,8 @@ pub const R: CM31 = CM31 { a: M31 { inner: 2 }, b: M31 { inner: 1 } }; #[derive(Copy, Drop, Debug, PartialEq, Eq)] pub struct QM31 { - a: CM31, - b: CM31, + pub a: CM31, + pub b: CM31, } #[generate_trait] @@ -36,6 +36,10 @@ pub impl QM31Impl of QM31Trait { b: CM31 { a: self.b.a * multiplier, b: self.b.b * multiplier } } } + + fn complex_conjugate(self: QM31) -> QM31 { + QM31 { a: self.a, b: -self.b } + } } pub impl QM31Add of core::traits::Add { diff --git a/stwo_cairo_verifier/src/poly/circle.cairo b/stwo_cairo_verifier/src/poly/circle.cairo index bd8dd34ff..39898389c 100644 --- a/stwo_cairo_verifier/src/poly/circle.cairo +++ b/stwo_cairo_verifier/src/poly/circle.cairo @@ -1,11 +1,10 @@ -use stwo_cairo_verifier::circle::CirclePointM31Trait; use core::option::OptionTrait; use core::clone::Clone; use core::result::ResultTrait; use stwo_cairo_verifier::fields::m31::{M31, m31}; use stwo_cairo_verifier::utils::pow; use stwo_cairo_verifier::circle::{ - Coset, CosetImpl, CirclePointM31, CirclePointM31Impl, M31_CIRCLE_GEN, CIRCLE_ORDER + Coset, CosetImpl, CirclePoint, CirclePointM31Impl, M31_CIRCLE_GEN, CIRCLE_ORDER }; /// A valid domain for circle polynomial interpolation and evaluation. @@ -31,8 +30,9 @@ pub impl CircleDomainImpl of CircleDomainTrait { } } - fn at(self: @CircleDomain, index: usize) -> CirclePointM31 { - M31_CIRCLE_GEN.mul(self.index_at(index)) + fn at(self: @CircleDomain, index: usize) -> CirclePoint:: { + let mut scalar = self.index_at(index).into(); + M31_CIRCLE_GEN.mul(ref scalar) } } @@ -41,7 +41,7 @@ pub impl CircleDomainImpl of CircleDomainTrait { mod tests { use super::{CircleDomain, CircleDomainTrait}; use stwo_cairo_verifier::circle::{ - Coset, CosetImpl, CirclePointM31, CirclePointM31Impl, M31_CIRCLE_GEN, CIRCLE_ORDER + Coset, CosetImpl, CirclePoint, CirclePointM31Impl, M31_CIRCLE_GEN, CIRCLE_ORDER }; use stwo_cairo_verifier::fields::m31::{M31, m31}; @@ -51,7 +51,7 @@ mod tests { let domain = CircleDomain { half_coset }; let index = 17; let result = domain.at(index); - let expected_result = CirclePointM31 { x: m31(7144319), y: m31(1742797653) }; + let expected_result = CirclePoint:: { x: m31(7144319), y: m31(1742797653) }; assert_eq!(expected_result, result); } @@ -61,7 +61,7 @@ mod tests { let domain = CircleDomain { half_coset }; let index = 37; let result = domain.at(index); - let expected_result = CirclePointM31 { x: m31(9803698), y: m31(2079025011) }; + let expected_result = CirclePoint:: { x: m31(9803698), y: m31(2079025011) }; assert_eq!(expected_result, result); } } diff --git a/stwo_cairo_verifier/src/poly/line.cairo b/stwo_cairo_verifier/src/poly/line.cairo index b0fec5b9d..83af45b6d 100644 --- a/stwo_cairo_verifier/src/poly/line.cairo +++ b/stwo_cairo_verifier/src/poly/line.cairo @@ -6,7 +6,7 @@ use stwo_cairo_verifier::fields::SecureField; use stwo_cairo_verifier::fields::m31::{M31, m31, M31Trait}; use stwo_cairo_verifier::fields::qm31::{QM31, qm31, QM31Zero}; use stwo_cairo_verifier::utils::pow; -use stwo_cairo_verifier::circle::{Coset, CosetImpl, CirclePointM31Trait, M31_CIRCLE_GEN}; +use stwo_cairo_verifier::circle::{Coset, CosetImpl, CirclePointTrait, M31_CIRCLE_GEN}; use stwo_cairo_verifier::fri::fold_line; /// A univariate polynomial defined on a [LineDomain]. @@ -68,7 +68,8 @@ pub impl LineDomainImpl of LineDomainTrait { // Let our coset be `E = c + ` with `|E| > 2` then: // 1. if `ord(c) <= ord(G)` the coset contains two points at x=0 // 2. if `ord(c) = 2 * ord(G)` then `c` and `-c` are in our coset - let coset_step = M31_CIRCLE_GEN.mul(coset.step_size); + let mut scalar = coset.step_size.into(); + let coset_step = M31_CIRCLE_GEN.mul(ref scalar); assert!( coset.at(0).log_order() >= coset_step.log_order() + 2, "coset x-coordinates not unique"