From 32415d42db48baf2ba323c1c0ed8944e62487119 Mon Sep 17 00:00:00 2001 From: Stav Beno Date: Mon, 23 Dec 2024 16:07:19 +0200 Subject: [PATCH] unify adapter functionality --- .../crates/adapted_prover/src/main.rs | 4 +- .../crates/prover/src/input/mod.rs | 1 + .../crates/prover/src/input/plain.rs | 50 ++++++++++--------- .../crates/prover/src/input/vm_import/mod.rs | 47 +++++++++++------ .../crates/run_and_prove/src/main.rs | 4 +- .../crates/vm_runner/src/main.rs | 4 +- 6 files changed, 64 insertions(+), 46 deletions(-) diff --git a/stwo_cairo_prover/crates/adapted_prover/src/main.rs b/stwo_cairo_prover/crates/adapted_prover/src/main.rs index 55cdb1fa..2fada1eb 100644 --- a/stwo_cairo_prover/crates/adapted_prover/src/main.rs +++ b/stwo_cairo_prover/crates/adapted_prover/src/main.rs @@ -4,7 +4,7 @@ use std::process::ExitCode; use clap::Parser; use stwo_cairo_prover::cairo_air::air::CairoProof; use stwo_cairo_prover::cairo_air::prove_cairo; -use stwo_cairo_prover::input::vm_import::{import_from_vm_output, VmImportError}; +use stwo_cairo_prover::input::vm_import::{adapt_vm_output, VmImportError}; use stwo_cairo_prover::input::CairoInput; use stwo_cairo_utils::binary_utils::run_binary; use stwo_prover::core::prover::ProvingError; @@ -56,7 +56,7 @@ fn run(args: impl Iterator) -> Result CairoInput { - let _span = tracing::info_span!("input_from_finished_runner").entered(); - let program_len = runner.get_program().iter_data().count(); - let memory = runner +pub fn adapt_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoInput { + let _span = tracing::info_span!("adapt_finished_runner").entered(); + let memory_iter = runner .relocated_memory .iter() .enumerate() @@ -72,25 +71,28 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI }) }); - let memory_segments = &runner + let trace = runner.relocated_trace.clone().unwrap(); + let trace_iter = trace.iter().map(|t| t.clone().into()); + + let public_input = runner .get_air_public_input() - .expect("Unable to get public input from the runner") - .memory_segments; - let builtins_segments = BuiltinSegments::from_memory_segments(memory_segments); + .expect("Unable to get public input from the runner"); - let trace = runner.relocated_trace.unwrap(); - let trace = trace.iter().map(|t| t.clone().into()); + let memory_segments = &public_input.memory_segments; - let memory_config = MemoryConfig::default(); - let mut memory = MemoryBuilder::from_iter(memory_config, memory); - let state_transitions = StateTransitions::from_iter(trace, &mut memory, dev_mode); + let public_memory_addresses = public_input + .public_memory + .into_iter() + .map(|s| s.address as u32) + .collect::>(); // TODO(spapini): Add output builtin to public memory. - let public_memory_addresses = (0..(program_len as u32)).collect_vec(); - CairoInput { - state_transitions, - memory: memory.build(), + let cairo_input = adapt_to_stwo_input( + trace_iter, + MemoryBuilder::from_iter(MemoryConfig::default(), memory_iter), public_memory_addresses, - builtins_segments, - } + memory_segments, + dev_mode, + ); + cairo_input.unwrap() } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index e50d3d0c..f773241b 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -1,20 +1,21 @@ mod json; +use std::collections::HashMap; use std::io::Read; use std::path::Path; use bytemuck::{bytes_of_mut, Pod, Zeroable}; -use cairo_vm::air_public_input::PublicInput; +use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput}; use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry; use json::PrivateInput; use thiserror::Error; use tracing::{span, Level}; +use super::builtin_segments::BuiltinSegments; use super::memory::MemoryConfig; use super::state_transitions::StateTransitions; use super::CairoInput; use crate::input::memory::MemoryBuilder; -use crate::input::BuiltinSegments; #[derive(Debug, Error)] pub enum VmImportError { @@ -27,12 +28,13 @@ pub enum VmImportError { } // TODO(Ohad): remove dev_mode after adding the rest of the instructions. -pub fn import_from_vm_output( +/// Adapts the VM's output files to the Cairo input of the prover. +pub fn adapt_vm_output( public_input_json: &Path, private_input_json: &Path, - dev_mode: bool, + dev_mod: bool, ) -> Result { - let _span = span!(Level::INFO, "import_from_vm_output").entered(); + let _span = span!(Level::INFO, "adapt_vm_output").entered(); let public_input_string = std::fs::read_to_string(public_input_json)?; let public_input: PublicInput<'_> = sonic_rs::from_str(&public_input_string)?; let private_input: PrivateInput = @@ -45,7 +47,6 @@ pub fn import_from_vm_output( .max() .ok_or(VmImportError::NoMemorySegments)?; assert!(end_addr < (1 << 32)); - let memory_config = MemoryConfig::default(); let memory_path = private_input_json .parent() @@ -56,25 +57,39 @@ pub fn import_from_vm_output( .unwrap() .join(&private_input.trace_path); - let mut trace_file = std::io::BufReader::new(std::fs::File::open(trace_path)?); let mut memory_file = std::io::BufReader::new(std::fs::File::open(memory_path)?); - let mut memory = MemoryBuilder::from_iter(memory_config, MemEntryIter(&mut memory_file)); - let state_transitions = - StateTransitions::from_iter(TraceIter(&mut trace_file), &mut memory, dev_mode); + let mut trace_file = std::io::BufReader::new(std::fs::File::open(trace_path)?); let public_memory_addresses = public_input .public_memory .iter() .map(|entry| entry.address as u32) .collect(); + adapt_to_stwo_input( + TraceIter(&mut trace_file), + MemoryBuilder::from_iter(MemoryConfig::default(), MemEntryIter(&mut memory_file)), + public_memory_addresses, + &public_input.memory_segments, + dev_mod, + ) +} - let builtins_segments = BuiltinSegments::from_memory_segments(&public_input.memory_segments); - +// TODO(Ohad): remove dev_mode after adding the rest of the opcodes. +/// Creates Cairo input for Stwo, utilized by: +/// - `adapt_vm_output` in the prover. +/// - `adapt_finished_runner` in the validator. +pub fn adapt_to_stwo_input( + trace_iter: impl Iterator, + mut memory: MemoryBuilder, + public_memory_addresses: Vec, + memory_segments: &HashMap<&str, MemorySegmentAddresses>, + dev_mode: bool, +) -> Result { Ok(CairoInput { - state_transitions, + state_transitions: StateTransitions::from_iter(trace_iter, &mut memory, dev_mode), memory: memory.build(), public_memory_addresses, - builtins_segments, + builtins_segments: BuiltinSegments::from_memory_segments(memory_segments), }) } @@ -144,7 +159,7 @@ pub mod tests { let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); d.push("test_data/large_cairo_input"); - import_from_vm_output( + adapt_vm_output( d.join("pub.json").as_path(), d.join("priv.json").as_path(), false, @@ -158,7 +173,7 @@ pub mod tests { pub fn small_cairo_input() -> CairoInput { let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); d.push("test_data/small_cairo_input"); - import_from_vm_output( + adapt_vm_output( d.join("pub.json").as_path(), d.join("priv.json").as_path(), false, diff --git a/stwo_cairo_prover/crates/run_and_prove/src/main.rs b/stwo_cairo_prover/crates/run_and_prove/src/main.rs index 76168b4f..2bb5c0c7 100644 --- a/stwo_cairo_prover/crates/run_and_prove/src/main.rs +++ b/stwo_cairo_prover/crates/run_and_prove/src/main.rs @@ -4,7 +4,7 @@ use std::process::ExitCode; use clap::Parser; use stwo_cairo_prover::cairo_air::air::CairoProof; use stwo_cairo_prover::cairo_air::prove_cairo; -use stwo_cairo_prover::input::plain::input_from_finished_runner; +use stwo_cairo_prover::input::plain::adapt_finished_runner; use stwo_cairo_prover::input::vm_import::VmImportError; use stwo_cairo_utils::binary_utils::run_binary; use stwo_cairo_utils::vm_utils::{run_vm, VmArgs, VmError}; @@ -57,7 +57,7 @@ fn run(args: impl Iterator) -> Result) -> Result { let _span = span!(Level::INFO, "run").entered(); let args = Args::try_parse_from(args)?; let cairo_runner = run_vm(&args.vm_args)?; - let cairo_input = input_from_finished_runner(cairo_runner, false); + let cairo_input = adapt_finished_runner(cairo_runner, false); let execution_resources = &cairo_input.state_transitions.casm_states_by_opcode.counts(); std::fs::write(