Skip to content

Commit

Permalink
Enabler struct
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Jan 16, 2025
1 parent 535753b commit 1bdfdc9
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions stwo_cairo_prover/crates/prover/src/components/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::array;
use std::mem::transmute;
use std::simd::Simd;
use std::sync::atomic::{AtomicU32, Ordering};

use num_traits::One;
use num_traits::{One, Zero};
use stwo_prover::core::backend::simd::column::BaseColumn;
use stwo_prover::core::backend::simd::conversion::Pack;
use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES};
Expand Down Expand Up @@ -83,17 +84,37 @@ impl AtomicMultiplicityColumn {
}
}

/// Generates a column for the given padding offset.
/// The enabler column is a column of length next_power_of_two(padding_offset) where
/// The enabler column is a column of length `padding_offset.next_power_of_two()` where
/// 1. The first `padding_offset` elements set to 1;
/// 2. Otherwise set to 0.
pub fn gen_enabler_column(padding_offset: usize) -> BaseColumn {
let log_size = padding_offset.next_power_of_two().ilog2();
let mut res = BaseColumn::zeros(1 << log_size);
for i in 0..padding_offset {
res.set(i, M31::one());
#[derive(Debug, Clone)]
pub struct Enabler {
pub padding_offset: usize,
}
impl Enabler {
pub const fn new(padding_offset: usize) -> Self {
Self { padding_offset }
}

pub fn packed_at(&self, vec_row: usize) -> PackedM31 {
let packed_row_offset = vec_row * N_LANES;
PackedM31::from_array(array::from_fn(|i| {
if i < self.padding_offset - packed_row_offset {
M31::one()
} else {
M31::zero()
}
}))
}

pub fn gen_column_simd(&self) -> BaseColumn {
let log_size = self.padding_offset.next_power_of_two().ilog2();
let mut col = BaseColumn::zeros(1 << log_size);
for i in 0..self.padding_offset {
col.set(i, M31::one());
}
col
}
res
}

#[cfg(test)]
Expand Down Expand Up @@ -124,10 +145,11 @@ mod tests {
}

#[test]
fn test_gen_enabler_column() {
fn test_enabler_column() {
let n_calls = 30;
let padding = super::Enabler::new(n_calls);

let enabler_column = super::gen_enabler_column(n_calls);
let enabler_column = padding.gen_column_simd();

for (i, val) in enabler_column.into_cpu_vec().into_iter().enumerate() {
if i < n_calls {
Expand All @@ -137,4 +159,16 @@ mod tests {
}
}
}

#[test]
fn test_enabler_packed_at() {
let n_calls = 30;
let padding = super::Enabler::new(n_calls);

assert_eq!(padding.packed_at(0).to_array(), [1; N_LANES].map(M31::from));
assert_eq!(
padding.packed_at(1).to_array(),
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0].map(M31::from)
);
}
}

0 comments on commit 1bdfdc9

Please sign in to comment.