Skip to content

Commit

Permalink
allocating batch inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Jan 14, 2025
1 parent 29d124e commit aa3c697
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 64 deletions.
4 changes: 2 additions & 2 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::channel::Channel;
use crate::core::fields::batch_inverse_in_place;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -259,7 +259,7 @@ impl LogupColGenerator<'_> {

/// Finalizes generating the column.
pub fn finalize_col(mut self) {
FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data);
batch_inverse_in_place(&self.gen.denom.data, &mut self.gen.denom_inv.data);

for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) {
unsafe {
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::core::circle::{CirclePoint, Coset};
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::fields::{batch_inverse_in_place, ExtensionOf};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
Expand Down Expand Up @@ -172,7 +172,7 @@ impl PolyOps for CpuBackend {
.array_chunks::<CHUNK_SIZE>()
.zip(itwiddles.array_chunks_mut::<CHUNK_SIZE>())
.for_each(|(src, dst)| {
BaseField::batch_inverse(src, dst);
batch_inverse_in_place(src, dst);
});

TwiddleTree {
Expand Down
7 changes: 4 additions & 3 deletions crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ mod tests {
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::Column;
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;
use crate::core::fields::{batch_inverse_in_place, FieldExpOps};

#[test]
fn bit_reverse_works() {
Expand All @@ -106,14 +106,15 @@ mod tests {
bit_reverse(&mut data);
}

// TODO(Ohad): remove.
#[test]
fn batch_inverse_test() {
fn batch_inverse_in_place_test() {
let mut rng = SmallRng::seed_from_u64(0);
let column = rng.gen::<[QM31; 16]>().to_vec();
let expected = column.iter().map(|e| e.inverse()).collect_vec();
let mut dst = Vec::zeros(column.len());

FieldExpOps::batch_inverse(&column, &mut dst);
batch_inverse_in_place(&column, &mut dst);

assert_eq!(expected, dst);
}
Expand Down
5 changes: 1 addition & 4 deletions crates/prover/src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ fn denominator_inverses(
denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix);
}

let mut denominator_inverses = vec![CM31::zero(); denominators.len()];
CM31::batch_inverse(&denominators, &mut denominator_inverses);

denominator_inverses
CM31::batch_inverse(&denominators)
}

pub fn quotient_constants(
Expand Down
6 changes: 2 additions & 4 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ impl SimdBackend {
denominators.push(denominators[i - 1] * mappings[i]);
}

let mut denom_inverses = vec![F::zero(); denominators.len()];
F::batch_inverse(&denominators, &mut denom_inverses);
let denom_inverses = F::batch_inverse(&denominators);

let mut steps = vec![mappings[0]];

Expand Down Expand Up @@ -311,8 +310,7 @@ impl PolyOps for SimdBackend {
remaining_twiddles.try_into().unwrap(),
));

let mut itwiddles = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data;
PackedBaseField::batch_inverse(&twiddles, &mut itwiddles);
let itwiddles = PackedBaseField::batch_inverse(&twiddles);

let dbl_twiddles = twiddles
.into_iter()
Expand Down
10 changes: 2 additions & 8 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs};
use crate::core::backend::{Column, CpuBackend};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
Expand Down Expand Up @@ -251,15 +251,9 @@ fn denominator_inverses(
})
.collect();

let mut flat_denominator_inverses =
unsafe { CM31Column::uninitialized(flat_denominators.len()) };
FieldExpOps::batch_inverse(
&flat_denominators.data,
&mut flat_denominator_inverses.data[..],
);
let flat_denominator_inverses = PackedCM31::batch_inverse(&flat_denominators.data);

flat_denominator_inverses
.data
.chunks(domain.size() / N_LANES)
.map(|denominator_inverses| denominator_inverses.iter().copied().collect())
.collect()
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/simd/very_packed_m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::qm31::PackedQM31;
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;
use crate::core::fields::{batch_inverse_in_place, FieldExpOps};

pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1;
pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS;
Expand Down Expand Up @@ -247,7 +247,7 @@ impl<A: One + Copy, const N: usize> One for Vectorized<A, N> {
impl<A: FieldExpOps + Zero + Copy, const N: usize> FieldExpOps for Vectorized<A, N> {
fn inverse(&self) -> Self {
let mut dst = [A::zero(); N];
A::batch_inverse(&self.0, &mut dst);
batch_inverse_in_place(&self.0, &mut dst);
dst.into()
}
}
85 changes: 48 additions & 37 deletions crates/prover/src/core/fields/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,8 @@ pub trait FieldExpOps: Mul<Output = Self> + MulAssign + Sized + One + Clone {

fn inverse(&self) -> Self;

/// Inverts a batch of elements using Montgomery's trick.
fn batch_inverse(column: &[Self], dst: &mut [Self]) {
const WIDTH: usize = 4;
let n = column.len();
debug_assert!(dst.len() >= n);

if n <= WIDTH || n % WIDTH != 0 {
batch_inverse_classic(column, dst);
return;
}

// First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing
// instruction dependency and allowing better pipelining.
let mut cum_prod: [Self; WIDTH] = std::array::from_fn(|_| Self::one());
dst[..WIDTH].clone_from_slice(&cum_prod);
for i in 0..n {
cum_prod[i % WIDTH] *= column[i].clone();
dst[i] = cum_prod[i % WIDTH].clone();
}

// Inverse cumulative products.
// Use classic batch inversion.
let mut tail_inverses: [Self; WIDTH] = std::array::from_fn(|_| Self::one());
batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses);

// Second pass.
for i in (WIDTH..n).rev() {
dst[i] = dst[i - WIDTH].clone() * tail_inverses[i % WIDTH].clone();
tail_inverses[i % WIDTH] *= column[i].clone();
}
dst[0..WIDTH].clone_from_slice(&tail_inverses);
fn batch_inverse(column: &[Self]) -> Vec<Self> {
batch_inverse(column)
}
}

Expand Down Expand Up @@ -91,6 +62,46 @@ fn batch_inverse_classic<T: FieldExpOps>(column: &[T], dst: &mut [T]) {
dst[0] = curr_inverse;
}

/// Inverts a batch of elements using Montgomery's trick.
pub fn batch_inverse_in_place<F: FieldExpOps>(column: &[F], dst: &mut [F]) {
const WIDTH: usize = 4;
let n = column.len();
debug_assert!(dst.len() >= n);

if n <= WIDTH || n % WIDTH != 0 {
batch_inverse_classic(column, dst);
return;
}

// First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing
// instruction dependency and allowing better pipelining.
let mut cum_prod: [F; WIDTH] = std::array::from_fn(|_| F::one());
dst[..WIDTH].clone_from_slice(&cum_prod);
for i in 0..n {
cum_prod[i % WIDTH] *= column[i].clone();
dst[i] = cum_prod[i % WIDTH].clone();
}

// Inverse cumulative products.
// Use classic batch inversion.
let mut tail_inverses: [F; WIDTH] = std::array::from_fn(|_| F::one());
batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses);

// Second pass.
for i in (WIDTH..n).rev() {
dst[i] = dst[i - WIDTH].clone() * tail_inverses[i % WIDTH].clone();
tail_inverses[i % WIDTH] *= column[i].clone();
}
dst[0..WIDTH].clone_from_slice(&tail_inverses);
}

// TODO(Ohad): chunks, parallelize.
pub fn batch_inverse<F: FieldExpOps>(column: &[F]) -> Vec<F> {
let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()];
batch_inverse_in_place(column, &mut dst);
dst
}

pub trait Field:
NumAssign
+ Neg<Output = Self>
Expand Down Expand Up @@ -460,19 +471,19 @@ mod tests {
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::batch_inverse_in_place;
use crate::core::fields::batch_inverse;
use crate::core::fields::m31::M31;
use crate::core::fields::FieldExpOps;

#[test]
fn test_slice_batch_inverse() {
fn test_batch_inverse() {
let mut rng = SmallRng::seed_from_u64(0);
let elements: [M31; 16] = rng.gen();
let expected = elements.iter().map(|e| e.inverse()).collect::<Vec<_>>();
let mut dst = [M31::zero(); 16];

M31::batch_inverse(&elements, &mut dst);
let actual = batch_inverse(&elements);

assert_eq!(expected, dst);
assert_eq!(expected, actual);
}

#[test]
Expand All @@ -482,6 +493,6 @@ mod tests {
let elements: [M31; 16] = rng.gen();
let mut dst = [M31::zero(); 15];

M31::batch_inverse(&elements, &mut dst);
batch_inverse_in_place(&elements, &mut dst);
}
}
3 changes: 1 addition & 2 deletions crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,7 @@ fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint<SecureField>)
vanish_at_log_step.reverse();
// We only need the first `log_step` many values.
vanish_at_log_step.truncate(log_step as usize);
let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()];
SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv);
let vanish_at_log_step_inv = SecureField::batch_inverse(&vanish_at_log_step);

let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square();
let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::<SecureField>();
Expand Down

0 comments on commit aa3c697

Please sign in to comment.