From 45022701b50002a9d0fcff2001aed9d53634a470 Mon Sep 17 00:00:00 2001 From: ohad-starkware Date: Mon, 13 Jan 2025 10:59:57 +0200 Subject: [PATCH] parallel batch inverse --- crates/prover/src/core/fields/mod.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/crates/prover/src/core/fields/mod.rs b/crates/prover/src/core/fields/mod.rs index 9a0833c7e..67cc294dd 100644 --- a/crates/prover/src/core/fields/mod.rs +++ b/crates/prover/src/core/fields/mod.rs @@ -3,6 +3,8 @@ use std::iter::{Product, Sum}; use std::ops::{Mul, MulAssign, Neg}; use num_traits::{NumAssign, NumAssignOps, NumOps, One}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; pub mod cm31; pub mod m31; @@ -101,10 +103,20 @@ pub fn batch_inverse(column: &[F]) -> Vec { dst } -// TODO(Ohad): parallelize. -pub fn batch_inverse_chunked(column: &[T], chunk_size: usize) -> Vec { +pub fn batch_inverse_chunked( + column: &[T], + chunk_size: usize, +) -> Vec { let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()]; + + #[cfg(not(feature = "parallel"))] let iter = dst.chunks_mut(chunk_size).zip(column.chunks(chunk_size)); + + #[cfg(feature = "parallel")] + let iter = dst + .par_chunks_mut(chunk_size) + .zip(column.par_chunks(chunk_size)); + iter.for_each(|(dst, column)| { batch_inverse_in_place(column, dst); });