diff --git a/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs b/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs index 488cd72e..42874c6d 100644 --- a/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs +++ b/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs @@ -3,6 +3,8 @@ use cairo_vm::stdlib::collections::HashMap; use cairo_vm::types::builtin_name::BuiltinName; use serde::{Deserialize, Serialize}; +use super::memory::*; + /// This struct holds the builtins used in a Cairo program. #[derive(Debug, Default, Serialize, Deserialize)] pub struct BuiltinSegments { @@ -40,6 +42,75 @@ impl BuiltinSegments { } } + fn get_segment(&mut self, builtin_name: BuiltinName) -> &Option { + match builtin_name { + BuiltinName::range_check => &self.range_check_bits_128, + BuiltinName::pedersen => &self.pedersen, + BuiltinName::ecdsa => &self.ecdsa, + BuiltinName::keccak => &self.keccak, + BuiltinName::bitwise => &self.bitwise, + BuiltinName::ec_op => &self.ec_op, + BuiltinName::poseidon => &self.poseidon, + BuiltinName::range_check96 => &self.range_check_bits_96, + BuiltinName::add_mod => &self.add_mod, + BuiltinName::mul_mod => &self.mul_mod, + // Not builtins. + BuiltinName::output | BuiltinName::segment_arena => &None, + } + } + + pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize { + match builtin_name { + BuiltinName::range_check => 1, + BuiltinName::pedersen => 3, + BuiltinName::ecdsa => 2, + BuiltinName::keccak => 16, + BuiltinName::bitwise => 5, + BuiltinName::ec_op => 7, + BuiltinName::poseidon => 6, + BuiltinName::range_check96 => 1, + BuiltinName::add_mod => 7, + BuiltinName::mul_mod => 7, + // Not builtins. + BuiltinName::output | BuiltinName::segment_arena => 0, + } + } + + pub fn fill_builtin_segment( + &mut self, + mut memory: MemoryBuilder, + builtin_name: BuiltinName, + ) -> MemoryBuilder { + if let &Some(MemorySegmentAddresses { + begin_addr, + stop_ptr, + }) = self.get_segment(builtin_name) + { + let initial_length = stop_ptr - begin_addr; + let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name); + assert!(initial_length % cells_per_instance == 0); + let num_instances = initial_length / cells_per_instance; + let nearest_power_of_two = num_instances.next_power_of_two(); + for i in num_instances..nearest_power_of_two { + for j in 0..cells_per_instance { + let address_to_fill = (begin_addr + i * cells_per_instance + j) as u64; + let value_to_fill = memory.get((begin_addr + j) as u32); + memory.set(address_to_fill, value_to_fill); + } + } + self.add_segment( + builtin_name, + Some(MemorySegmentAddresses { + begin_addr, + stop_ptr: begin_addr + cells_per_instance * nearest_power_of_two, + }), + ); + memory + } else { + memory + } + } + /// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses. pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self { let mut res = BuiltinSegments::default(); diff --git a/stwo_cairo_prover/crates/prover/src/input/memory.rs b/stwo_cairo_prover/crates/prover/src/input/memory.rs index 5d814b23..385ecc62 100644 --- a/stwo_cairo_prover/crates/prover/src/input/memory.rs +++ b/stwo_cairo_prover/crates/prover/src/input/memory.rs @@ -168,6 +168,11 @@ impl MemoryBuilder { }); self.address_to_id[addr as usize] = res; } + + pub fn get(&self, addr: u32) -> MemoryValue { + self.memory.get(addr) + } + pub fn build(self) -> Memory { self.memory } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index f98842f3..7e39c3e1 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -6,6 +6,7 @@ use std::path::Path; use bytemuck::{bytes_of_mut, Pod, Zeroable}; use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput}; use cairo_vm::stdlib::collections::HashMap; +use cairo_vm::types::builtin_name::BuiltinName; use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry; use json::PrivateInput; use thiserror::Error; @@ -88,12 +89,14 @@ pub fn adapt_to_stwo_input( ) -> Result { let (state_transitions, instruction_by_pc) = StateTransitions::from_iter(trace_iter, &mut memory, dev_mode); + let mut builtins_segments = BuiltinSegments::from_memory_segments(memory_segments); + memory = builtins_segments.fill_builtin_segment(memory, BuiltinName::range_check); Ok(ProverInput { state_transitions, instruction_by_pc, memory: memory.build(), public_memory_addresses, - builtins_segments: BuiltinSegments::from_memory_segments(memory_segments), + builtins_segments, }) } @@ -156,7 +159,10 @@ impl Iterator for MemoryEntryIter<'_, R> { pub mod tests { use std::path::PathBuf; + use cairo_vm::types::builtin_name::BuiltinName; + use super::*; + use crate::input::memory::{EncodedMemoryValueId, Memory}; pub fn large_cairo_input() -> ProverInput { let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); @@ -187,6 +193,25 @@ pub mod tests { ) } + pub fn verify_segment_is_full( + segment: Option, + builtin_name: BuiltinName, + memory: &Memory, + ) { + if let Some(segment) = segment { + let cells_per_instance = + BuiltinSegments::builtin_memory_cells_per_instance(builtin_name); + let segment_length = segment.stop_ptr - segment.begin_addr; + assert!(segment_length % cells_per_instance == 0); + let num_instances = segment_length / cells_per_instance; + assert!((num_instances & (num_instances - 1)) == 0); // num_instances is a power of 2. + assert!(segment.stop_ptr - 1 <= memory.address_to_id.len()); + for address in segment.begin_addr..segment.stop_ptr { + assert!(memory.address_to_id[address] != EncodedMemoryValueId::default()); + } + } + } + #[test] #[cfg(feature = "slow-tests")] fn test_read_from_large_files() { @@ -270,7 +295,12 @@ pub mod tests { assert_eq!(builtins_segments.range_check_bits_96, None); assert_eq!( builtins_segments.range_check_bits_128, - Some((1715768, 1757348).into()) + Some((1715768, 1781304).into()) + ); + verify_segment_is_full( + builtins_segments.range_check_bits_128, + BuiltinName::range_check, + &input.memory, ); } @@ -354,7 +384,12 @@ pub mod tests { ); assert_eq!( builtins_segments.range_check_bits_128, - Some((6000, 6050).into()) + Some((6000, 6064).into()) + ); + verify_segment_is_full( + builtins_segments.range_check_bits_128, + BuiltinName::range_check, + &input.memory, ); } }