Skip to content

Commit

Permalink
fill memory segments of builtins to the next power of 2 instances
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-nir-starkware committed Jan 12, 2025
1 parent 07d3a9f commit e88854c
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 3 deletions.
76 changes: 76 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use cairo_vm::air_public_input::MemorySegmentAddresses;
use cairo_vm::stdlib::collections::HashMap;
use cairo_vm::types::builtin_name::BuiltinName;
use num_traits::Euclid;
use serde::{Deserialize, Serialize};

use super::memory::MemoryBuilder;

/// This struct holds the builtins used in a Cairo program.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct BuiltinSegments {
Expand All @@ -19,6 +22,7 @@ pub struct BuiltinSegments {
}

impl BuiltinSegments {
/// Adds a segment to the builtin segments.
pub fn add_segment(
&mut self,
builtin_name: BuiltinName,
Expand All @@ -40,6 +44,78 @@ impl BuiltinSegments {
}
}

/// Returns the segment for a given builtin name.
fn get_segment(&mut self, builtin_name: BuiltinName) -> &Option<MemorySegmentAddresses> {
match builtin_name {
BuiltinName::range_check => &self.range_check_bits_128,
BuiltinName::pedersen => &self.pedersen,
BuiltinName::ecdsa => &self.ecdsa,
BuiltinName::keccak => &self.keccak,
BuiltinName::bitwise => &self.bitwise,
BuiltinName::ec_op => &self.ec_op,
BuiltinName::poseidon => &self.poseidon,
BuiltinName::range_check96 => &self.range_check_bits_96,
BuiltinName::add_mod => &self.add_mod,
BuiltinName::mul_mod => &self.mul_mod,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => &None,
}
}

/// Returns the number of memory cells per instance for a given builtin name.
pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize {
match builtin_name {
BuiltinName::range_check => 1,
BuiltinName::pedersen => 3,
BuiltinName::ecdsa => 2,
BuiltinName::keccak => 16,
BuiltinName::bitwise => 5,
BuiltinName::ec_op => 7,
BuiltinName::poseidon => 6,
BuiltinName::range_check96 => 1,
BuiltinName::add_mod => 7,
BuiltinName::mul_mod => 7,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => 0,
}
}

/// Pads a builtin segment with copies of it's first instance.
/// The segment is padded to the nearest power of two number of instances.
pub fn fill_builtin_segment(
&mut self,
mut memory: MemoryBuilder,
builtin_name: BuiltinName,
) -> MemoryBuilder {
let &Some(MemorySegmentAddresses {
begin_addr,
stop_ptr,
}) = self.get_segment(builtin_name)
else {
return memory;
};
let initial_length = stop_ptr - begin_addr;
let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name);
let (num_instances, remainder) = initial_length.div_rem_euclid(&cells_per_instance);
assert!(remainder == 0);
let nearest_power_of_two = num_instances.next_power_of_two();
let mut address_to_fill = (begin_addr + num_instances * cells_per_instance) as u64;
for _ in num_instances..nearest_power_of_two {
for j in 0..cells_per_instance {
memory.copy_value((begin_addr + j) as u32, address_to_fill as u32);
address_to_fill += 1;
}
}
self.add_segment(
builtin_name,
Some(MemorySegmentAddresses {
begin_addr,
stop_ptr: begin_addr + cells_per_instance * nearest_power_of_two,
}),
);
memory
}

/// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses.
pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self {
let mut res = BuiltinSegments::default();
Expand Down
5 changes: 5 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ impl MemoryBuilder {
});
self.address_to_id[addr as usize] = res;
}

pub fn copy_value(&mut self, src_addr: u32, dst_addr: u32) {
self.set(dst_addr as u64, self.memory.get(src_addr));
}

pub fn build(self) -> Memory {
self.memory
}
Expand Down
61 changes: 58 additions & 3 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::path::Path;
use bytemuck::{bytes_of_mut, Pod, Zeroable};
use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput};
use cairo_vm::stdlib::collections::HashMap;
use cairo_vm::types::builtin_name::BuiltinName;
use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry;
use json::PrivateInput;
use thiserror::Error;
Expand Down Expand Up @@ -88,12 +89,14 @@ pub fn adapt_to_stwo_input(
) -> Result<ProverInput, VmImportError> {
let (state_transitions, instruction_by_pc) =
StateTransitions::from_iter(trace_iter, &mut memory, dev_mode);
let mut builtins_segments = BuiltinSegments::from_memory_segments(memory_segments);
memory = builtins_segments.fill_builtin_segment(memory, BuiltinName::range_check);
Ok(ProverInput {
state_transitions,
instruction_by_pc,
memory: memory.build(),
public_memory_addresses,
builtins_segments: BuiltinSegments::from_memory_segments(memory_segments),
builtins_segments,
})
}

Expand Down Expand Up @@ -156,7 +159,11 @@ impl<R: Read> Iterator for MemoryEntryIter<'_, R> {
pub mod tests {
use std::path::PathBuf;

use cairo_vm::types::builtin_name::BuiltinName;
use num_traits::Euclid;

use super::*;
use crate::input::memory::Memory;

pub fn large_cairo_input() -> ProverInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
Expand Down Expand Up @@ -187,6 +194,42 @@ pub mod tests {
)
}

/// Verifies that the builtin segment is padded with copies of the first instance to the next
/// power of 2 instances.
pub fn verify_segment_is_padded(
segment: &Option<MemorySegmentAddresses>,
builtin_name: BuiltinName,
memory: &Memory,
original_stop_ptr: usize,
) {
if let Some(segment) = segment {
let cells_per_instance =
BuiltinSegments::builtin_memory_cells_per_instance(builtin_name);
let segment_length = segment.stop_ptr - segment.begin_addr;
let (num_instances, remainder) = segment_length.div_rem_euclid(&cells_per_instance);
assert!(remainder == 0);
// assert that num_instances is a power of 2.
assert!((num_instances & (num_instances - 1)) == 0);

let original_segment_length = original_stop_ptr - segment.begin_addr;
let (original_num_instances, remainder) =
original_segment_length.div_rem_euclid(&cells_per_instance);
assert!(remainder == 0);
assert!(original_num_instances * 2 > num_instances); // the next power of 2.

assert!(segment.stop_ptr - 1 <= memory.address_to_id.len());
for instance in original_num_instances..num_instances {
for j in 0..cells_per_instance {
let address = segment.begin_addr + instance * cells_per_instance + j;
assert!(
memory.address_to_id[address]
== memory.address_to_id[segment.begin_addr + j]
);
}
}
}
}

#[test]
#[cfg(feature = "slow-tests")]
fn test_read_from_large_files() {
Expand Down Expand Up @@ -270,7 +313,13 @@ pub mod tests {
assert_eq!(builtins_segments.range_check_bits_96, None);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((1715768, 1757348).into())
Some((1715768, 1781304).into())
);
verify_segment_is_padded(
&builtins_segments.range_check_bits_128,
BuiltinName::range_check,
&input.memory,
1757348,
);
}

Expand Down Expand Up @@ -354,7 +403,13 @@ pub mod tests {
);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((6000, 6050).into())
Some((6000, 6064).into())
);
verify_segment_is_padded(
&builtins_segments.range_check_bits_128,
BuiltinName::range_check,
&input.memory,
6050,
);
}
}

0 comments on commit e88854c

Please sign in to comment.