From 8aee306793c72dea854acf34c8abff21a57ff49d Mon Sep 17 00:00:00 2001 From: ohad-starkware Date: Sun, 12 Jan 2025 15:40:32 +0200 Subject: [PATCH] allocating batch inverse --- .../prover/src/constraint_framework/logup.rs | 4 +- crates/prover/src/core/backend/cpu/circle.rs | 4 +- crates/prover/src/core/backend/cpu/mod.rs | 7 +- .../prover/src/core/backend/cpu/quotients.rs | 5 +- crates/prover/src/core/backend/simd/circle.rs | 6 +- .../prover/src/core/backend/simd/quotients.rs | 10 +-- .../src/core/backend/simd/very_packed_m31.rs | 4 +- crates/prover/src/core/fields/mod.rs | 81 +++++++++++-------- .../src/examples/xor/gkr_lookups/mle_eval.rs | 3 +- 9 files changed, 62 insertions(+), 62 deletions(-) diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 370987e4c9..ba97fc4635 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -11,10 +11,10 @@ use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; use crate::core::backend::Column; use crate::core::channel::Channel; +use crate::core::fields::batch_inverse_in_place; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; @@ -259,7 +259,7 @@ impl LogupColGenerator<'_> { /// Finalizes generating the column. pub fn finalize_col(mut self) { - FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data); + batch_inverse_in_place(&self.gen.denom.data, &mut self.gen.denom_inv.data); for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) { unsafe { diff --git a/crates/prover/src/core/backend/cpu/circle.rs b/crates/prover/src/core/backend/cpu/circle.rs index 21351d1647..c040138bd1 100644 --- a/crates/prover/src/core/backend/cpu/circle.rs +++ b/crates/prover/src/core/backend/cpu/circle.rs @@ -7,7 +7,7 @@ use crate::core::circle::{CirclePoint, Coset}; use crate::core::fft::{butterfly, ibutterfly}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::{ExtensionOf, FieldExpOps}; +use crate::core::fields::{batch_inverse_in_place, ExtensionOf}; use crate::core::poly::circle::{ CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, }; @@ -172,7 +172,7 @@ impl PolyOps for CpuBackend { .array_chunks::() .zip(itwiddles.array_chunks_mut::()) .for_each(|(src, dst)| { - BaseField::batch_inverse(src, dst); + batch_inverse_in_place(src, dst); }); TwiddleTree { diff --git a/crates/prover/src/core/backend/cpu/mod.rs b/crates/prover/src/core/backend/cpu/mod.rs index 4ae7c3e5c8..5f2aefdab0 100644 --- a/crates/prover/src/core/backend/cpu/mod.rs +++ b/crates/prover/src/core/backend/cpu/mod.rs @@ -90,7 +90,7 @@ mod tests { use crate::core::backend::cpu::bit_reverse; use crate::core::backend::Column; use crate::core::fields::qm31::QM31; - use crate::core::fields::FieldExpOps; + use crate::core::fields::{batch_inverse_in_place, FieldExpOps}; #[test] fn bit_reverse_works() { @@ -106,14 +106,15 @@ mod tests { bit_reverse(&mut data); } + // TODO(Ohad): remove. #[test] - fn batch_inverse_test() { + fn batch_inverse_in_place_test() { let mut rng = SmallRng::seed_from_u64(0); let column = rng.gen::<[QM31; 16]>().to_vec(); let expected = column.iter().map(|e| e.inverse()).collect_vec(); let mut dst = Vec::zeros(column.len()); - FieldExpOps::batch_inverse(&column, &mut dst); + batch_inverse_in_place(&column, &mut dst); assert_eq!(expected, dst); } diff --git a/crates/prover/src/core/backend/cpu/quotients.rs b/crates/prover/src/core/backend/cpu/quotients.rs index 16f0647b64..b11a830a48 100644 --- a/crates/prover/src/core/backend/cpu/quotients.rs +++ b/crates/prover/src/core/backend/cpu/quotients.rs @@ -132,10 +132,7 @@ fn denominator_inverses( denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix); } - let mut denominator_inverses = vec![CM31::zero(); denominators.len()]; - CM31::batch_inverse(&denominators, &mut denominator_inverses); - - denominator_inverses + CM31::invert_many(&denominators) } pub fn quotient_constants( diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index 61588ffe3f..4d51ec3626 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -96,8 +96,7 @@ impl SimdBackend { denominators.push(denominators[i - 1] * mappings[i]); } - let mut denom_inverses = vec![F::zero(); denominators.len()]; - F::batch_inverse(&denominators, &mut denom_inverses); + let denom_inverses = F::invert_many(&denominators); let mut steps = vec![mappings[0]]; @@ -311,8 +310,7 @@ impl PolyOps for SimdBackend { remaining_twiddles.try_into().unwrap(), )); - let mut itwiddles = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data; - PackedBaseField::batch_inverse(&twiddles, &mut itwiddles); + let itwiddles = PackedBaseField::invert_many(&twiddles); let dbl_twiddles = twiddles .into_iter() diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index f0155ebb12..890facfc77 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -12,7 +12,7 @@ use super::qm31::PackedSecureField; use super::SimdBackend; use crate::core::backend::cpu::bit_reverse; use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs}; -use crate::core::backend::{Column, CpuBackend}; +use crate::core::backend::CpuBackend; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; @@ -243,15 +243,9 @@ fn denominator_inverses( }) .collect(); - let mut flat_denominator_inverses = - unsafe { CM31Column::uninitialized(flat_denominators.len()) }; - FieldExpOps::batch_inverse( - &flat_denominators.data, - &mut flat_denominator_inverses.data[..], - ); + let flat_denominator_inverses = PackedCM31::invert_many(&flat_denominators.data); flat_denominator_inverses - .data .chunks(domain.size() / N_LANES) .map(|denominator_inverses| denominator_inverses.iter().copied().collect()) .collect() diff --git a/crates/prover/src/core/backend/simd/very_packed_m31.rs b/crates/prover/src/core/backend/simd/very_packed_m31.rs index 781212d6fe..fbafcb93b0 100644 --- a/crates/prover/src/core/backend/simd/very_packed_m31.rs +++ b/crates/prover/src/core/backend/simd/very_packed_m31.rs @@ -9,7 +9,7 @@ use super::qm31::PackedQM31; use crate::core::fields::cm31::CM31; use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; -use crate::core::fields::FieldExpOps; +use crate::core::fields::{batch_inverse_in_place, FieldExpOps}; pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1; pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS; @@ -247,7 +247,7 @@ impl One for Vectorized { impl FieldExpOps for Vectorized { fn inverse(&self) -> Self { let mut dst = [A::zero(); N]; - A::batch_inverse(&self.0, &mut dst); + batch_inverse_in_place(&self.0, &mut dst); dst.into() } } diff --git a/crates/prover/src/core/fields/mod.rs b/crates/prover/src/core/fields/mod.rs index b19ea9bc9a..deb66269ad 100644 --- a/crates/prover/src/core/fields/mod.rs +++ b/crates/prover/src/core/fields/mod.rs @@ -30,37 +30,8 @@ pub trait FieldExpOps: Mul + MulAssign + Sized + One + Clone { fn inverse(&self) -> Self; - /// Inverts a batch of elements using Montgomery's trick. - fn batch_inverse(column: &[Self], dst: &mut [Self]) { - const WIDTH: usize = 4; - let n = column.len(); - debug_assert!(dst.len() >= n); - - if n <= WIDTH || n % WIDTH != 0 { - batch_inverse_classic(column, dst); - return; - } - - // First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing - // instruction dependency and allowing better pipelining. - let mut cum_prod: [Self; WIDTH] = std::array::from_fn(|_| Self::one()); - dst[..WIDTH].clone_from_slice(&cum_prod); - for i in 0..n { - cum_prod[i % WIDTH] *= column[i].clone(); - dst[i] = cum_prod[i % WIDTH].clone(); - } - - // Inverse cumulative products. - // Use classic batch inversion. - let mut tail_inverses: [Self; WIDTH] = std::array::from_fn(|_| Self::one()); - batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses); - - // Second pass. - for i in (WIDTH..n).rev() { - dst[i] = dst[i - WIDTH].clone() * tail_inverses[i % WIDTH].clone(); - tail_inverses[i % WIDTH] *= column[i].clone(); - } - dst[0..WIDTH].clone_from_slice(&tail_inverses); + fn invert_many(column: &[Self]) -> Vec { + batch_inverse(column) } } @@ -91,6 +62,46 @@ fn batch_inverse_classic(column: &[T], dst: &mut [T]) { dst[0] = curr_inverse; } +/// Inverts a batch of elements using Montgomery's trick. +pub fn batch_inverse_in_place(column: &[F], dst: &mut [F]) { + const WIDTH: usize = 4; + let n = column.len(); + debug_assert!(dst.len() >= n); + + if n <= WIDTH || n % WIDTH != 0 { + batch_inverse_classic(column, dst); + return; + } + + // First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing + // instruction dependency and allowing better pipelining. + let mut cum_prod: [F; WIDTH] = std::array::from_fn(|_| F::one()); + dst[..WIDTH].clone_from_slice(&cum_prod); + for i in 0..n { + cum_prod[i % WIDTH] *= column[i].clone(); + dst[i] = cum_prod[i % WIDTH].clone(); + } + + // Inverse cumulative products. + // Use classic batch inversion. + let mut tail_inverses: [F; WIDTH] = std::array::from_fn(|_| F::one()); + batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses); + + // Second pass. + for i in (WIDTH..n).rev() { + dst[i] = dst[i - WIDTH].clone() * tail_inverses[i % WIDTH].clone(); + tail_inverses[i % WIDTH] *= column[i].clone(); + } + dst[0..WIDTH].clone_from_slice(&tail_inverses); +} + +// TODO(Ohad): chunks, parallelize. +pub fn batch_inverse(column: &[F]) -> Vec { + let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()]; + batch_inverse_in_place(column, &mut dst); + dst +} + pub trait Field: NumAssign + Neg @@ -460,17 +471,17 @@ mod tests { use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; + use super::batch_inverse_in_place; use crate::core::fields::m31::M31; - use crate::core::fields::FieldExpOps; #[test] - fn test_slice_batch_inverse() { + fn test_slice_batch_inverse_in_place() { let mut rng = SmallRng::seed_from_u64(0); let elements: [M31; 16] = rng.gen(); let expected = elements.iter().map(|e| e.inverse()).collect::>(); let mut dst = [M31::zero(); 16]; - M31::batch_inverse(&elements, &mut dst); + batch_inverse_in_place(&elements, &mut dst); assert_eq!(expected, dst); } @@ -482,6 +493,6 @@ mod tests { let elements: [M31; 16] = rng.gen(); let mut dst = [M31::zero(); 15]; - M31::batch_inverse(&elements, &mut dst); + batch_inverse_in_place(&elements, &mut dst); } } diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 3c0c17d0cd..7b4b22feed 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -687,8 +687,7 @@ fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint) vanish_at_log_step.reverse(); // We only need the first `log_step` many values. vanish_at_log_step.truncate(log_step as usize); - let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()]; - SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv); + let vanish_at_log_step_inv = SecureField::invert_many(&vanish_at_log_step); let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square(); let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::();