Skip to content

Commit

Permalink
unify adapter functionality (#274)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo-cairo/274)
<!-- Reviewable:end -->
  • Loading branch information
Stavbe authored Dec 31, 2024
1 parent 74f4e37 commit 778bb22
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 46 deletions.
4 changes: 2 additions & 2 deletions stwo_cairo_prover/crates/adapted_prover/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::process::ExitCode;
use clap::Parser;
use stwo_cairo_prover::cairo_air::air::CairoProof;
use stwo_cairo_prover::cairo_air::prove_cairo;
use stwo_cairo_prover::input::vm_import::{import_from_vm_output, VmImportError};
use stwo_cairo_prover::input::vm_import::{adapt_vm_output, VmImportError};
use stwo_cairo_prover::input::CairoInput;
use stwo_cairo_utils::binary_utils::run_binary;
use stwo_prover::core::prover::ProvingError;
Expand Down Expand Up @@ -56,7 +56,7 @@ fn run(args: impl Iterator<Item = String>) -> Result<CairoProof<Blake2sMerkleHas
let args = Args::try_parse_from(args)?;

let vm_output: CairoInput =
import_from_vm_output(args.pub_json.as_path(), args.priv_json.as_path(), true)?;
adapt_vm_output(args.pub_json.as_path(), args.priv_json.as_path(), true)?;

let casm_states_by_opcode_count = &vm_output.state_transitions.casm_states_by_opcode.counts();
log::info!("Casm states by opcode count: {casm_states_by_opcode_count:?}");
Expand Down
3 changes: 2 additions & 1 deletion stwo_cairo_prover/crates/prover/src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ pub mod vm_import;

pub const N_REGISTERS: usize = 3;

// Externally provided inputs.
// TODO(Stav): rename to StwoInput.
/// Externally provided inputs for the Stwo prover.
#[derive(Debug)]
pub struct CairoInput {
pub state_transitions: StateTransitions,
Expand Down
54 changes: 30 additions & 24 deletions stwo_cairo_prover/crates/prover/src/input/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ use cairo_vm::vm::runners::cairo_runner::CairoRunner;
use itertools::Itertools;

use super::memory::{MemoryBuilder, MemoryConfig};
use super::state_transitions::StateTransitions;
use super::vm_import::MemoryEntry;
use super::{BuiltinSegments, CairoInput};
use super::vm_import::{adapt_to_stwo_input, MemoryEntry};
use super::CairoInput;

// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Translates a plain casm into a CairoInput by running the program and extracting the memory and
Expand Down Expand Up @@ -50,18 +49,20 @@ pub fn input_from_plain_casm(
)
.expect("Run failed");
runner.relocate(true).unwrap();
input_from_finished_runner(runner, dev_mode)
adapt_finished_runner(runner, dev_mode)
}

// TODO(yuval): consider returning a result instead of panicking.
// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Assumes memory and trace are already relocated. Otherwise panics.
/// Translates a CairoRunner that finished its run into a CairoInput by extracting the relevant
/// input to the adapter.
/// When dev mod is enabled, the opcodes generated from the plain casm will be mapped to the generic
/// component only.
pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoInput {
let _span = tracing::info_span!("input_from_finished_runner").entered();
let program_len = runner.get_program().iter_data().count();
let memory = runner
/// # Panics
/// - if the memory or the trace are not relocated.
pub fn adapt_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoInput {
let _span = tracing::info_span!("adapt_finished_runner").entered();
let memory_iter = runner
.relocated_memory
.iter()
.enumerate()
Expand All @@ -72,25 +73,30 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI
})
});

let memory_segments = &runner
let public_input = runner
.get_air_public_input()
.expect("Unable to get public input from the runner")
.memory_segments;
let builtins_segments = BuiltinSegments::from_memory_segments(memory_segments);
.expect("Unable to get public input from the runner");

let trace = runner.relocated_trace.unwrap();
let trace = trace.iter().map(|t| t.clone().into());
let trace_iter = match runner.relocated_trace {
Some(ref trace) => trace.iter().map(|t| t.clone().into()),
None => panic!("Trace is not relocated"),
};

let memory_config = MemoryConfig::default();
let mut memory = MemoryBuilder::from_iter(memory_config, memory);
let state_transitions = StateTransitions::from_iter(trace, &mut memory, dev_mode);
let memory_segments = &public_input.memory_segments;

let public_memory_addresses = public_input
.public_memory
.iter()
.map(|s| s.address as u32)
.collect_vec();

// TODO(spapini): Add output builtin to public memory.
let public_memory_addresses = (0..(program_len as u32)).collect_vec();
CairoInput {
state_transitions,
memory: memory.build(),
adapt_to_stwo_input(
trace_iter,
MemoryBuilder::from_iter(MemoryConfig::default(), memory_iter),
public_memory_addresses,
builtins_segments,
}
memory_segments,
dev_mode,
)
.unwrap()
}
45 changes: 30 additions & 15 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
mod json;

use std::collections::HashMap;
use std::io::Read;
use std::path::Path;

use bytemuck::{bytes_of_mut, Pod, Zeroable};
use cairo_vm::air_public_input::PublicInput;
use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput};
use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry;
use json::PrivateInput;
use thiserror::Error;
use tracing::{span, Level};

use super::builtin_segments::BuiltinSegments;
use super::memory::MemoryConfig;
use super::state_transitions::StateTransitions;
use super::CairoInput;
use crate::input::memory::MemoryBuilder;
use crate::input::BuiltinSegments;

#[derive(Debug, Error)]
pub enum VmImportError {
Expand All @@ -27,12 +28,13 @@ pub enum VmImportError {
}

// TODO(Ohad): remove dev_mode after adding the rest of the instructions.
pub fn import_from_vm_output(
/// Adapts the VM's output files to the Cairo input of the prover.
pub fn adapt_vm_output(
public_input_json: &Path,
private_input_json: &Path,
dev_mode: bool,
) -> Result<CairoInput, VmImportError> {
let _span = span!(Level::INFO, "import_from_vm_output").entered();
let _span = span!(Level::INFO, "adapt_vm_output").entered();
let public_input_string = std::fs::read_to_string(public_input_json)?;
let public_input: PublicInput<'_> = sonic_rs::from_str(&public_input_string)?;
let private_input: PrivateInput =
Expand All @@ -45,7 +47,6 @@ pub fn import_from_vm_output(
.max()
.ok_or(VmImportError::NoMemorySegments)?;
assert!(end_addr < (1 << 32));
let memory_config = MemoryConfig::default();

let memory_path = private_input_json
.parent()
Expand All @@ -56,25 +57,39 @@ pub fn import_from_vm_output(
.unwrap()
.join(&private_input.trace_path);

let mut trace_file = std::io::BufReader::new(std::fs::File::open(trace_path)?);
let mut memory_file = std::io::BufReader::new(std::fs::File::open(memory_path)?);
let mut memory = MemoryBuilder::from_iter(memory_config, MemoryEntryIter(&mut memory_file));
let state_transitions =
StateTransitions::from_iter(TraceIter(&mut trace_file), &mut memory, dev_mode);
let mut trace_file = std::io::BufReader::new(std::fs::File::open(trace_path)?);

let public_memory_addresses = public_input
.public_memory
.iter()
.map(|entry| entry.address as u32)
.collect();
adapt_to_stwo_input(
TraceIter(&mut trace_file),
MemoryBuilder::from_iter(MemoryConfig::default(), MemoryEntryIter(&mut memory_file)),
public_memory_addresses,
&public_input.memory_segments,
dev_mode,
)
}

let builtins_segments = BuiltinSegments::from_memory_segments(&public_input.memory_segments);

// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Creates Cairo input for Stwo, utilized by:
/// - `adapt_vm_output` in the prover.
/// - `adapt_finished_runner` in the validator.
pub fn adapt_to_stwo_input(
trace_iter: impl Iterator<Item = TraceEntry>,
mut memory: MemoryBuilder,
public_memory_addresses: Vec<u32>,
memory_segments: &HashMap<&str, MemorySegmentAddresses>,
dev_mode: bool,
) -> Result<CairoInput, VmImportError> {
Ok(CairoInput {
state_transitions,
state_transitions: StateTransitions::from_iter(trace_iter, &mut memory, dev_mode),
memory: memory.build(),
public_memory_addresses,
builtins_segments,
builtins_segments: BuiltinSegments::from_memory_segments(memory_segments),
})
}

Expand Down Expand Up @@ -143,7 +158,7 @@ pub mod tests {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/test_read_from_large_files");

import_from_vm_output(
adapt_vm_output(
d.join("pub.json").as_path(),
d.join("priv.json").as_path(),
false,
Expand All @@ -157,7 +172,7 @@ pub mod tests {
pub fn small_cairo_input() -> CairoInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/test_read_from_small_files");
import_from_vm_output(
adapt_vm_output(
d.join("pub.json").as_path(),
d.join("priv.json").as_path(),
false,
Expand Down
4 changes: 2 additions & 2 deletions stwo_cairo_prover/crates/run_and_prove/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::process::ExitCode;
use clap::Parser;
use stwo_cairo_prover::cairo_air::air::CairoProof;
use stwo_cairo_prover::cairo_air::prove_cairo;
use stwo_cairo_prover::input::plain::input_from_finished_runner;
use stwo_cairo_prover::input::plain::adapt_finished_runner;
use stwo_cairo_prover::input::vm_import::VmImportError;
use stwo_cairo_utils::binary_utils::run_binary;
use stwo_cairo_utils::vm_utils::{run_vm, VmArgs, VmError};
Expand Down Expand Up @@ -57,7 +57,7 @@ fn run(args: impl Iterator<Item = String>) -> Result<CairoProof<Blake2sMerkleHas
let _span = span!(Level::INFO, "run").entered();
let args = Args::try_parse_from(args)?;
let cairo_runner = run_vm(&args.vm_args)?;
let cairo_input = input_from_finished_runner(cairo_runner, false);
let cairo_input = adapt_finished_runner(cairo_runner, false);

let casm_states_by_opcode_count = &cairo_input.state_transitions.casm_states_by_opcode.counts();
log::info!("Casm states by opcode count: {casm_states_by_opcode_count:?}");
Expand Down
4 changes: 2 additions & 2 deletions stwo_cairo_prover/crates/vm_runner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::PathBuf;
use std::process::ExitCode;

use clap::Parser;
use stwo_cairo_prover::input::plain::input_from_finished_runner;
use stwo_cairo_prover::input::plain::adapt_finished_runner;
use stwo_cairo_prover::input::CairoInput;
use stwo_cairo_utils::binary_utils::run_binary;
use stwo_cairo_utils::vm_utils::{run_vm, VmArgs, VmError};
Expand Down Expand Up @@ -45,7 +45,7 @@ fn run(args: impl Iterator<Item = String>) -> Result<CairoInput, Error> {
let _span = span!(Level::INFO, "run").entered();
let args = Args::try_parse_from(args)?;
let cairo_runner = run_vm(&args.vm_args)?;
let cairo_input = input_from_finished_runner(cairo_runner, false);
let cairo_input = adapt_finished_runner(cairo_runner, false);

let execution_resources = &cairo_input.state_transitions.casm_states_by_opcode.counts();
std::fs::write(
Expand Down

0 comments on commit 778bb22

Please sign in to comment.