From e88854cf484eb8d3ad2ed7a2c7494a0f821dd503 Mon Sep 17 00:00:00 2001 From: ohad nir Date: Thu, 9 Jan 2025 11:24:46 +0200 Subject: [PATCH] fill memory segments of builtins to the next power of 2 instances --- .../prover/src/input/builtin_segments.rs | 76 +++++++++++++++++++ .../crates/prover/src/input/memory.rs | 5 ++ .../crates/prover/src/input/vm_import/mod.rs | 61 ++++++++++++++- 3 files changed, 139 insertions(+), 3 deletions(-) diff --git a/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs b/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs index 488cd72e..b93b692d 100644 --- a/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs +++ b/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs @@ -1,8 +1,11 @@ use cairo_vm::air_public_input::MemorySegmentAddresses; use cairo_vm::stdlib::collections::HashMap; use cairo_vm::types::builtin_name::BuiltinName; +use num_traits::Euclid; use serde::{Deserialize, Serialize}; +use super::memory::MemoryBuilder; + /// This struct holds the builtins used in a Cairo program. #[derive(Debug, Default, Serialize, Deserialize)] pub struct BuiltinSegments { @@ -19,6 +22,7 @@ pub struct BuiltinSegments { } impl BuiltinSegments { + /// Adds a segment to the builtin segments. pub fn add_segment( &mut self, builtin_name: BuiltinName, @@ -40,6 +44,78 @@ impl BuiltinSegments { } } + /// Returns the segment for a given builtin name. + fn get_segment(&mut self, builtin_name: BuiltinName) -> &Option { + match builtin_name { + BuiltinName::range_check => &self.range_check_bits_128, + BuiltinName::pedersen => &self.pedersen, + BuiltinName::ecdsa => &self.ecdsa, + BuiltinName::keccak => &self.keccak, + BuiltinName::bitwise => &self.bitwise, + BuiltinName::ec_op => &self.ec_op, + BuiltinName::poseidon => &self.poseidon, + BuiltinName::range_check96 => &self.range_check_bits_96, + BuiltinName::add_mod => &self.add_mod, + BuiltinName::mul_mod => &self.mul_mod, + // Not builtins. + BuiltinName::output | BuiltinName::segment_arena => &None, + } + } + + /// Returns the number of memory cells per instance for a given builtin name. + pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize { + match builtin_name { + BuiltinName::range_check => 1, + BuiltinName::pedersen => 3, + BuiltinName::ecdsa => 2, + BuiltinName::keccak => 16, + BuiltinName::bitwise => 5, + BuiltinName::ec_op => 7, + BuiltinName::poseidon => 6, + BuiltinName::range_check96 => 1, + BuiltinName::add_mod => 7, + BuiltinName::mul_mod => 7, + // Not builtins. + BuiltinName::output | BuiltinName::segment_arena => 0, + } + } + + /// Pads a builtin segment with copies of it's first instance. + /// The segment is padded to the nearest power of two number of instances. + pub fn fill_builtin_segment( + &mut self, + mut memory: MemoryBuilder, + builtin_name: BuiltinName, + ) -> MemoryBuilder { + let &Some(MemorySegmentAddresses { + begin_addr, + stop_ptr, + }) = self.get_segment(builtin_name) + else { + return memory; + }; + let initial_length = stop_ptr - begin_addr; + let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name); + let (num_instances, remainder) = initial_length.div_rem_euclid(&cells_per_instance); + assert!(remainder == 0); + let nearest_power_of_two = num_instances.next_power_of_two(); + let mut address_to_fill = (begin_addr + num_instances * cells_per_instance) as u64; + for _ in num_instances..nearest_power_of_two { + for j in 0..cells_per_instance { + memory.copy_value((begin_addr + j) as u32, address_to_fill as u32); + address_to_fill += 1; + } + } + self.add_segment( + builtin_name, + Some(MemorySegmentAddresses { + begin_addr, + stop_ptr: begin_addr + cells_per_instance * nearest_power_of_two, + }), + ); + memory + } + /// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses. pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self { let mut res = BuiltinSegments::default(); diff --git a/stwo_cairo_prover/crates/prover/src/input/memory.rs b/stwo_cairo_prover/crates/prover/src/input/memory.rs index 5d814b23..2b6a72d2 100644 --- a/stwo_cairo_prover/crates/prover/src/input/memory.rs +++ b/stwo_cairo_prover/crates/prover/src/input/memory.rs @@ -168,6 +168,11 @@ impl MemoryBuilder { }); self.address_to_id[addr as usize] = res; } + + pub fn copy_value(&mut self, src_addr: u32, dst_addr: u32) { + self.set(dst_addr as u64, self.memory.get(src_addr)); + } + pub fn build(self) -> Memory { self.memory } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index f98842f3..7374bff9 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -6,6 +6,7 @@ use std::path::Path; use bytemuck::{bytes_of_mut, Pod, Zeroable}; use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput}; use cairo_vm::stdlib::collections::HashMap; +use cairo_vm::types::builtin_name::BuiltinName; use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry; use json::PrivateInput; use thiserror::Error; @@ -88,12 +89,14 @@ pub fn adapt_to_stwo_input( ) -> Result { let (state_transitions, instruction_by_pc) = StateTransitions::from_iter(trace_iter, &mut memory, dev_mode); + let mut builtins_segments = BuiltinSegments::from_memory_segments(memory_segments); + memory = builtins_segments.fill_builtin_segment(memory, BuiltinName::range_check); Ok(ProverInput { state_transitions, instruction_by_pc, memory: memory.build(), public_memory_addresses, - builtins_segments: BuiltinSegments::from_memory_segments(memory_segments), + builtins_segments, }) } @@ -156,7 +159,11 @@ impl Iterator for MemoryEntryIter<'_, R> { pub mod tests { use std::path::PathBuf; + use cairo_vm::types::builtin_name::BuiltinName; + use num_traits::Euclid; + use super::*; + use crate::input::memory::Memory; pub fn large_cairo_input() -> ProverInput { let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); @@ -187,6 +194,42 @@ pub mod tests { ) } + /// Verifies that the builtin segment is padded with copies of the first instance to the next + /// power of 2 instances. + pub fn verify_segment_is_padded( + segment: &Option, + builtin_name: BuiltinName, + memory: &Memory, + original_stop_ptr: usize, + ) { + if let Some(segment) = segment { + let cells_per_instance = + BuiltinSegments::builtin_memory_cells_per_instance(builtin_name); + let segment_length = segment.stop_ptr - segment.begin_addr; + let (num_instances, remainder) = segment_length.div_rem_euclid(&cells_per_instance); + assert!(remainder == 0); + // assert that num_instances is a power of 2. + assert!((num_instances & (num_instances - 1)) == 0); + + let original_segment_length = original_stop_ptr - segment.begin_addr; + let (original_num_instances, remainder) = + original_segment_length.div_rem_euclid(&cells_per_instance); + assert!(remainder == 0); + assert!(original_num_instances * 2 > num_instances); // the next power of 2. + + assert!(segment.stop_ptr - 1 <= memory.address_to_id.len()); + for instance in original_num_instances..num_instances { + for j in 0..cells_per_instance { + let address = segment.begin_addr + instance * cells_per_instance + j; + assert!( + memory.address_to_id[address] + == memory.address_to_id[segment.begin_addr + j] + ); + } + } + } + } + #[test] #[cfg(feature = "slow-tests")] fn test_read_from_large_files() { @@ -270,7 +313,13 @@ pub mod tests { assert_eq!(builtins_segments.range_check_bits_96, None); assert_eq!( builtins_segments.range_check_bits_128, - Some((1715768, 1757348).into()) + Some((1715768, 1781304).into()) + ); + verify_segment_is_padded( + &builtins_segments.range_check_bits_128, + BuiltinName::range_check, + &input.memory, + 1757348, ); } @@ -354,7 +403,13 @@ pub mod tests { ); assert_eq!( builtins_segments.range_check_bits_128, - Some((6000, 6050).into()) + Some((6000, 6064).into()) + ); + verify_segment_is_padded( + &builtins_segments.range_check_bits_128, + BuiltinName::range_check, + &input.memory, + 6050, ); } }