Skip to content

Commit

Permalink
change CairoInput to StwoInput
Browse files Browse the repository at this point in the history
  • Loading branch information
Stavbe committed Dec 31, 2024
1 parent 778bb22 commit 8d8521f
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 23 deletions.
4 changes: 2 additions & 2 deletions stwo_cairo_prover/crates/adapted_prover/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use clap::Parser;
use stwo_cairo_prover::cairo_air::air::CairoProof;
use stwo_cairo_prover::cairo_air::prove_cairo;
use stwo_cairo_prover::input::vm_import::{adapt_vm_output, VmImportError};
use stwo_cairo_prover::input::CairoInput;
use stwo_cairo_prover::input::ProverInput;
use stwo_cairo_utils::binary_utils::run_binary;
use stwo_prover::core::prover::ProvingError;
use stwo_prover::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher};
Expand Down Expand Up @@ -55,7 +55,7 @@ fn run(args: impl Iterator<Item = String>) -> Result<CairoProof<Blake2sMerkleHas
let _span = span!(Level::INFO, "run").entered();
let args = Args::try_parse_from(args)?;

let vm_output: CairoInput =
let vm_output: ProverInput =
adapt_vm_output(args.pub_json.as_path(), args.priv_json.as_path(), true)?;

let casm_states_by_opcode_count = &vm_output.state_transitions.casm_states_by_opcode.counts();
Expand Down
4 changes: 2 additions & 2 deletions stwo_cairo_prover/crates/prover/src/cairo_air/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::components::range_check_vector::{
};
use crate::components::verify_instruction;
use crate::felt::split_f252;
use crate::input::CairoInput;
use crate::input::ProverInput;
use crate::relations;

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -166,7 +166,7 @@ pub struct CairoClaimGenerator {
// ...
}
impl CairoClaimGenerator {
pub fn new(input: CairoInput) -> Self {
pub fn new(input: ProverInput) -> Self {
let initial_state = input.state_transitions.initial_state;
let final_state = input.state_transitions.final_state;
let opcodes = OpcodesClaimGenerator::new(input.state_transitions);
Expand Down
8 changes: 4 additions & 4 deletions stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use stwo_prover::core::prover::{prove, verify, ProvingError, VerificationError};
use thiserror::Error;
use tracing::{span, Level};

use crate::input::CairoInput;
use crate::input::ProverInput;

const LOG_MAX_ROWS: u32 = 22;

Expand All @@ -27,7 +27,7 @@ const IS_FIRST_LOG_SIZES: [u32; 19] = [
];

pub fn prove_cairo<MC: MerkleChannel>(
input: CairoInput,
input: ProverInput,
// TODO(Ohad): wrap these flags in a struct.
track_relations: bool,
display_components: bool,
Expand Down Expand Up @@ -185,11 +185,11 @@ mod tests {
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
use stwo_prover::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;

use crate::cairo_air::{prove_cairo, verify_cairo, CairoInput};
use crate::cairo_air::{prove_cairo, verify_cairo, ProverInput};
use crate::input::plain::input_from_plain_casm;
use crate::input::vm_import::tests::small_cairo_input;

fn test_input() -> CairoInput {
fn test_input() -> ProverInput {
let u128_max = u128::MAX;
let instructions = casm! {
// TODO(AlonH): Add actual range check segment.
Expand Down
3 changes: 1 addition & 2 deletions stwo_cairo_prover/crates/prover/src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ pub mod vm_import;

pub const N_REGISTERS: usize = 3;

// TODO(Stav): rename to StwoInput.
/// Externally provided inputs for the Stwo prover.
#[derive(Debug)]
pub struct CairoInput {
pub struct ProverInput {
pub state_transitions: StateTransitions,
pub memory: Memory,
pub public_memory_addresses: Vec<u32>,
Expand Down
10 changes: 5 additions & 5 deletions stwo_cairo_prover/crates/prover/src/input/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ use itertools::Itertools;

use super::memory::{MemoryBuilder, MemoryConfig};
use super::vm_import::{adapt_to_stwo_input, MemoryEntry};
use super::CairoInput;
use super::ProverInput;

// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Translates a plain casm into a CairoInput by running the program and extracting the memory and
/// Translates a plain casm into a ProverInput by running the program and extracting the memory and
/// the state transitions.
/// When dev mod is enabled, the opcodes generated from the plain casm will
/// be mapped to the generic component only.
pub fn input_from_plain_casm(
casm: Vec<cairo_lang_casm::instructions::Instruction>,
dev_mode: bool,
) -> CairoInput {
) -> ProverInput {
let felt_code = casm
.into_iter()
.flat_map(|instruction| instruction.assemble().encode())
Expand Down Expand Up @@ -54,13 +54,13 @@ pub fn input_from_plain_casm(

// TODO(yuval): consider returning a result instead of panicking.
// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Translates a CairoRunner that finished its run into a CairoInput by extracting the relevant
/// Translates a CairoRunner that finished its run into a ProverInput by extracting the relevant
/// input to the adapter.
/// When dev mod is enabled, the opcodes generated from the plain casm will be mapped to the generic
/// component only.
/// # Panics
/// - if the memory or the trace are not relocated.
pub fn adapt_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoInput {
pub fn adapt_finished_runner(runner: CairoRunner, dev_mode: bool) -> ProverInput {
let _span = tracing::info_span!("adapt_finished_runner").entered();
let memory_iter = runner
.relocated_memory
Expand Down
12 changes: 6 additions & 6 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use tracing::{span, Level};
use super::builtin_segments::BuiltinSegments;
use super::memory::MemoryConfig;
use super::state_transitions::StateTransitions;
use super::CairoInput;
use super::ProverInput;
use crate::input::memory::MemoryBuilder;

#[derive(Debug, Error)]
Expand All @@ -33,7 +33,7 @@ pub fn adapt_vm_output(
public_input_json: &Path,
private_input_json: &Path,
dev_mode: bool,
) -> Result<CairoInput, VmImportError> {
) -> Result<ProverInput, VmImportError> {
let _span = span!(Level::INFO, "adapt_vm_output").entered();
let public_input_string = std::fs::read_to_string(public_input_json)?;
let public_input: PublicInput<'_> = sonic_rs::from_str(&public_input_string)?;
Expand Down Expand Up @@ -84,8 +84,8 @@ pub fn adapt_to_stwo_input(
public_memory_addresses: Vec<u32>,
memory_segments: &HashMap<&str, MemorySegmentAddresses>,
dev_mode: bool,
) -> Result<CairoInput, VmImportError> {
Ok(CairoInput {
) -> Result<ProverInput, VmImportError> {
Ok(ProverInput {
state_transitions: StateTransitions::from_iter(trace_iter, &mut memory, dev_mode),
memory: memory.build(),
public_memory_addresses,
Expand Down Expand Up @@ -154,7 +154,7 @@ pub mod tests {

use super::*;

pub fn large_cairo_input() -> CairoInput {
pub fn large_cairo_input() -> ProverInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/test_read_from_large_files");

Expand All @@ -169,7 +169,7 @@ pub mod tests {
)
}

pub fn small_cairo_input() -> CairoInput {
pub fn small_cairo_input() -> ProverInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/test_read_from_small_files");
adapt_vm_output(
Expand Down
4 changes: 2 additions & 2 deletions stwo_cairo_prover/crates/vm_runner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::process::ExitCode;

use clap::Parser;
use stwo_cairo_prover::input::plain::adapt_finished_runner;
use stwo_cairo_prover::input::CairoInput;
use stwo_cairo_prover::input::ProverInput;
use stwo_cairo_utils::binary_utils::run_binary;
use stwo_cairo_utils::vm_utils::{run_vm, VmArgs, VmError};
use thiserror::Error;
Expand Down Expand Up @@ -41,7 +41,7 @@ fn main() -> ExitCode {
run_binary(run)
}

fn run(args: impl Iterator<Item = String>) -> Result<CairoInput, Error> {
fn run(args: impl Iterator<Item = String>) -> Result<ProverInput, Error> {
let _span = span!(Level::INFO, "run").entered();
let args = Args::try_parse_from(args)?;
let cairo_runner = run_vm(&args.vm_args)?;
Expand Down

0 comments on commit 8d8521f

Please sign in to comment.