Skip to content

Commit

Permalink
fill memory segments of builtins to the next power of 2 instances
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-nir-starkware committed Jan 10, 2025
1 parent 07d3a9f commit e3ddac8
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 3 deletions.
71 changes: 71 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use cairo_vm::stdlib::collections::HashMap;
use cairo_vm::types::builtin_name::BuiltinName;
use serde::{Deserialize, Serialize};

use super::memory::*;

/// This struct holds the builtins used in a Cairo program.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct BuiltinSegments {
Expand Down Expand Up @@ -40,6 +42,75 @@ impl BuiltinSegments {
}
}

fn get_segment(&mut self, builtin_name: BuiltinName) -> &Option<MemorySegmentAddresses> {
match builtin_name {
BuiltinName::range_check => &self.range_check_bits_128,
BuiltinName::pedersen => &self.pedersen,
BuiltinName::ecdsa => &self.ecdsa,
BuiltinName::keccak => &self.keccak,
BuiltinName::bitwise => &self.bitwise,
BuiltinName::ec_op => &self.ec_op,
BuiltinName::poseidon => &self.poseidon,
BuiltinName::range_check96 => &self.range_check_bits_96,
BuiltinName::add_mod => &self.add_mod,
BuiltinName::mul_mod => &self.mul_mod,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => &None,
}
}

pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize {
match builtin_name {
BuiltinName::range_check => 1,
BuiltinName::pedersen => 3,
BuiltinName::ecdsa => 2,
BuiltinName::keccak => 16,
BuiltinName::bitwise => 5,
BuiltinName::ec_op => 7,
BuiltinName::poseidon => 6,
BuiltinName::range_check96 => 1,
BuiltinName::add_mod => 7,
BuiltinName::mul_mod => 7,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => 0,
}
}

pub fn fill_builtin_segment(
&mut self,
mut memory: MemoryBuilder,
builtin_name: BuiltinName,
) -> MemoryBuilder {
if let &Some(MemorySegmentAddresses {
begin_addr,
stop_ptr,
}) = self.get_segment(builtin_name)
{
let initial_length = stop_ptr - begin_addr;
let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name);
assert!(initial_length % cells_per_instance == 0);
let num_instances = initial_length / cells_per_instance;
let nearest_power_of_two = num_instances.next_power_of_two();
for i in num_instances..nearest_power_of_two {
for j in 0..cells_per_instance {
let address_to_fill = (begin_addr + i * cells_per_instance + j) as u64;
let value_to_fill = memory.get((begin_addr + j) as u32);
memory.set(address_to_fill, value_to_fill);
}
}
self.add_segment(
builtin_name,
Some(MemorySegmentAddresses {
begin_addr,
stop_ptr: begin_addr + cells_per_instance * nearest_power_of_two,
}),
);
memory
} else {
memory
}
}

/// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses.
pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self {
let mut res = BuiltinSegments::default();
Expand Down
5 changes: 5 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ impl MemoryBuilder {
});
self.address_to_id[addr as usize] = res;
}

pub fn get(&self, addr: u32) -> MemoryValue {
self.memory.get(addr)
}

pub fn build(self) -> Memory {
self.memory
}
Expand Down
41 changes: 38 additions & 3 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::path::Path;
use bytemuck::{bytes_of_mut, Pod, Zeroable};
use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput};
use cairo_vm::stdlib::collections::HashMap;
use cairo_vm::types::builtin_name::BuiltinName;
use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry;
use json::PrivateInput;
use thiserror::Error;
Expand Down Expand Up @@ -88,12 +89,14 @@ pub fn adapt_to_stwo_input(
) -> Result<ProverInput, VmImportError> {
let (state_transitions, instruction_by_pc) =
StateTransitions::from_iter(trace_iter, &mut memory, dev_mode);
let mut builtins_segments = BuiltinSegments::from_memory_segments(memory_segments);
memory = builtins_segments.fill_builtin_segment(memory, BuiltinName::range_check);
Ok(ProverInput {
state_transitions,
instruction_by_pc,
memory: memory.build(),
public_memory_addresses,
builtins_segments: BuiltinSegments::from_memory_segments(memory_segments),
builtins_segments,
})
}

Expand Down Expand Up @@ -156,7 +159,10 @@ impl<R: Read> Iterator for MemoryEntryIter<'_, R> {
pub mod tests {
use std::path::PathBuf;

use cairo_vm::types::builtin_name::BuiltinName;

use super::*;
use crate::input::memory::{EncodedMemoryValueId, Memory};

pub fn large_cairo_input() -> ProverInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
Expand Down Expand Up @@ -187,6 +193,25 @@ pub mod tests {
)
}

pub fn verify_segment_is_full(
segment: Option<MemorySegmentAddresses>,
builtin_name: BuiltinName,
memory: &Memory,
) {
if let Some(segment) = segment {
let cells_per_instance =
BuiltinSegments::builtin_memory_cells_per_instance(builtin_name);
let segment_length = segment.stop_ptr - segment.begin_addr;
assert!(segment_length % cells_per_instance == 0);
let num_instances = segment_length / cells_per_instance;
assert!((num_instances & (num_instances - 1)) == 0); // num_instances is a power of 2.
assert!(segment.stop_ptr - 1 <= memory.address_to_id.len());
for address in segment.begin_addr..segment.stop_ptr {
assert!(memory.address_to_id[address] != EncodedMemoryValueId::default());
}
}
}

#[test]
#[cfg(feature = "slow-tests")]
fn test_read_from_large_files() {
Expand Down Expand Up @@ -270,7 +295,12 @@ pub mod tests {
assert_eq!(builtins_segments.range_check_bits_96, None);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((1715768, 1757348).into())
Some((1715768, 1781304).into())
);
verify_segment_is_full(
builtins_segments.range_check_bits_128,
BuiltinName::range_check,
&input.memory,
);
}

Expand Down Expand Up @@ -354,7 +384,12 @@ pub mod tests {
);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((6000, 6050).into())
Some((6000, 6064).into())
);
verify_segment_is_full(
builtins_segments.range_check_bits_128,
BuiltinName::range_check,
&input.memory,
);
}
}

0 comments on commit e3ddac8

Please sign in to comment.