From 2208e0ace44f75cfd524cab5b14bc0aca4dbfed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20Gonz=C3=A1lez=20Calder=C3=B3n?= Date: Fri, 4 Oct 2024 15:58:10 -0300 Subject: [PATCH] Implement trace dump --- crates/blockifier/Cargo.toml | 1 + .../execution/native/entry_point_execution.rs | 70 ++++++++++++++++++- .../blockifier/src/execution/native/utils.rs | 39 +++++++++-- 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/crates/blockifier/Cargo.toml b/crates/blockifier/Cargo.toml index 2bcd38b52a..e8339979a6 100644 --- a/crates/blockifier/Cargo.toml +++ b/crates/blockifier/Cargo.toml @@ -14,6 +14,7 @@ concurrency = [] jemalloc = ["dep:tikv-jemallocator"] testing = ["rand", "rstest"] use-sierra-emu = [] +with-trace-dump = ["cairo-native/with-trace-dump"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/crates/blockifier/src/execution/native/entry_point_execution.rs b/crates/blockifier/src/execution/native/entry_point_execution.rs index 5c7d05e28d..a54c6adb14 100644 --- a/crates/blockifier/src/execution/native/entry_point_execution.rs +++ b/crates/blockifier/src/execution/native/entry_point_execution.rs @@ -47,7 +47,75 @@ pub fn execute_entry_point_call( ); run_sierra_emu_executor(vm, function_id, call.clone()) } else { - run_native_executor(&contract_class.executor, function_id, call, syscall_handler) + #[cfg(feature = "with-trace-dump")] + let counter_value = { + use std::collections::HashMap; + use std::sync::atomic::AtomicUsize; + use std::sync::Mutex; + + use cairo_lang_sierra::program_registry::ProgramRegistry; + use cairo_native::runtime::trace_dump::TraceDump; + use cairo_native::types::TypeBuilder; + + // Since the library is statically linked, then dynamically loaded, each instance of + // `TRACE_DUMP` for each contract is separate (probably). That's why we need this + // getter and cannot use `cairo_native::runtime::TRACE_DUMP` directly. + let trace_dump = unsafe { + let fn_ptr = contract_class + .executor + .library + .get:: &'static Mutex>>( + b"get_trace_dump_ptr\0", + ) + .unwrap(); + + fn_ptr() + }; + let mut trace_dump = trace_dump.lock().unwrap(); + + static COUNTER: AtomicUsize = AtomicUsize::new(0); + let counter_value = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + trace_dump.insert( + u64::try_from(counter_value).unwrap(), + TraceDump::new( + ProgramRegistry::new(&contract_class.program).unwrap(), + |x, registry| x.layout(registry).unwrap(), + ), + ); + + // Set the active trace id. + let trace_id_ref = unsafe { + contract_class + .executor + .library + .get::(b"TRACE_DUMP__TRACE_ID\0") + .unwrap() + .try_as_raw_ptr() + .unwrap() + .cast::() + .as_mut() + .unwrap() + }; + *trace_id_ref = u64::try_from(counter_value).unwrap(); + + println!("Execution started for trace #{counter_value}."); + dbg!(trace_dump.keys().collect::>()); + counter_value + }; + + let x = run_native_executor( + &contract_class.executor, + function_id, + call, + syscall_handler, + #[cfg(feature = "with-trace-dump")] + counter_value, + ); + + #[cfg(feature = "with-trace-dump")] + println!("Execution finished for trace #{counter_value}."); + + x }; let execution_time = pre_execution_instant.elapsed().as_millis(); tracing::info!(time = execution_time, "native contract execution finished"); diff --git a/crates/blockifier/src/execution/native/utils.rs b/crates/blockifier/src/execution/native/utils.rs index fb7e35a674..a34d88b1e0 100644 --- a/crates/blockifier/src/execution/native/utils.rs +++ b/crates/blockifier/src/execution/native/utils.rs @@ -51,6 +51,7 @@ pub fn run_native_executor( function_id: &FunctionId, call: CallEntryPoint, mut syscall_handler: NativeSyscallHandler<'_>, + #[cfg(feature = "with-trace-dump")] trace_id: usize, ) -> EntryPointExecutionResult { let execution_result = native_executor.run( function_id, @@ -59,6 +60,33 @@ pub fn run_native_executor( &mut syscall_handler, ); + #[cfg(feature = "with-trace-dump")] + #[allow(warnings)] + { + use std::sync::Mutex; + + use cairo_native::runtime::trace_dump::TraceDump; + + let trace = serde_json::to_string_pretty(&{ + let trace_dump = unsafe { + let fn_ptr = native_executor + .library + .get:: &'static Mutex>>( + b"get_trace_dump_ptr\0", + ) + .unwrap(); + + fn_ptr() + }; + let mut trace_dump = trace_dump.lock().unwrap(); + + trace_dump.remove(&u64::try_from(trace_id).unwrap()).unwrap().trace + }) + .unwrap(); + std::fs::create_dir_all("traces/native/").unwrap(); + std::fs::write(&format!("traces/native/trace_{}.json", trace_id), trace).unwrap(); + } + let run_result = match execution_result { Ok(res) if res.failure_flag => Err(EntryPointExecutionError::NativeExecutionError { info: if !res.return_values.is_empty() { @@ -100,6 +128,9 @@ pub fn run_sierra_emu_executor( std::fs::create_dir_all("traces/emu/").unwrap(); std::fs::write(format!("traces/emu/trace_{}.json", counter_value), trace).unwrap(); + std::fs::write(format!("traces/program_{}.sierra", counter_value), format!("{}", vm.program)) + .unwrap(); + if execution_result.failure_flag { Err(EntryPointExecutionError::NativeExecutionError { info: if !execution_result.return_values.is_empty() { @@ -129,8 +160,8 @@ fn create_callinfo( syscall_handler: NativeSyscallHandler<'_>, ) -> Result { let gas_consumed = { - let low: u64 = run_result.remaining_gas.try_into().unwrap(); - let high: u64 = (run_result.remaining_gas >> 64).try_into().unwrap(); + let low = u64::try_from(run_result.remaining_gas & u128::from(u64::MAX)).unwrap(); + let high = u64::try_from(run_result.remaining_gas >> 64).unwrap(); if high != 0 { return Err(EntryPointExecutionError::NativeExecutionError { info: "Overflow: gas consumed bigger than 64 bit".into(), @@ -169,8 +200,8 @@ pub fn create_callinfo_emu( accessed_storage_keys: HashSet, ) -> Result { let gas_consumed = { - let low: u64 = run_result.remaining_gas.try_into().unwrap(); - let high: u64 = (run_result.remaining_gas >> 64).try_into().unwrap(); + let low = u64::try_from(run_result.remaining_gas & u128::from(u64::MAX)).unwrap(); + let high = u64::try_from(run_result.remaining_gas >> 64).unwrap(); if high != 0 { return Err(EntryPointExecutionError::NativeExecutionError { info: "Overflow: gas consumed bigger than 64 bit".into(),