Skip to content

Commit

Permalink
add adapter tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Stavbe committed Dec 22, 2024
1 parent 7b89d37 commit c7d080d
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 11 deletions.
2 changes: 1 addition & 1 deletion stwo_cairo_prover/crates/adapted_prover/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fn run(args: impl Iterator<Item = String>) -> Result<CairoProof<Blake2sMerkleHas
let args = Args::try_parse_from(args)?;

let vm_output: CairoInput =
import_from_vm_output(args.pub_json.as_path(), args.priv_json.as_path())?;
import_from_vm_output(args.pub_json.as_path(), args.priv_json.as_path(), true)?;

let casm_states_by_opcode_count = &vm_output.state_transitions.casm_states_by_opcode.counts();
log::info!("Casm states by opcode count: {casm_states_by_opcode_count:?}");
Expand Down
138 changes: 134 additions & 4 deletions stwo_cairo_prover/crates/prover/src/input/state_transitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,13 @@ fn is_small_mul(op0: MemoryValue, op_1: MemoryValue) -> bool {
})
}

/// Tests instructions mapping.
/// Every opcode is tested except:
/// - `jmp rel [ap/fp + offset]`
/// - `jmp abs [[ap/fp + offset1] + offset2]`
// TODO(Stav): Find a way to check those without the casm macro.
#[cfg(test)]
mod tests {
mod mappings_tests {
use cairo_lang_casm::casm;

use crate::input::plain::input_from_plain_casm;
Expand Down Expand Up @@ -679,13 +684,32 @@ mod tests {
);
}

#[test]
fn test_jmp_rel() {
let instructions = casm! {
jmp rel 2;
[ap] = [ap-1] + 3, ap++;
}
.instructions;

let input = input_from_plain_casm(instructions, false);
let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode;
assert_eq!(
casm_states_by_opcode
.jump_opcode_is_rel_t_is_imm_t_is_double_deref_f
.len(),
1
);
}

#[test]
fn test_add_ap() {
let instructions = casm! {
[ap] = 38, ap++;
[ap] = 12, ap++;
ap += [ap -2];
ap += [fp + 1];
ap += 1;
[ap] = 1, ap++;
}
.instructions;
Expand All @@ -700,7 +724,13 @@ mod tests {
);
assert_eq!(
casm_states_by_opcode
.add_ap_opcode_is_imm_f_op_1_base_fp_f
.add_ap_opcode_is_imm_f_op_1_base_fp_t
.len(),
1
);
assert_eq!(
casm_states_by_opcode
.add_ap_opcode_is_imm_t_op_1_base_fp_f
.len(),
1
);
Expand Down Expand Up @@ -751,7 +781,7 @@ mod tests {
}

#[test]
fn test_jnz_taken() {
fn test_jnz_not_taken_ap() {
let instructions = casm! {
[ap] = 0, ap++;
jmp rel 2 if [ap-1] != 0;
Expand All @@ -770,7 +800,27 @@ mod tests {
}

#[test]
fn test_jnz_not_taken() {
fn test_jnz_not_taken_fp() {
let instructions = casm! {
call rel 2;
[ap] = 0, ap++;
jmp rel 2 if [fp] != 0;
[ap] = 1, ap++;
}
.instructions;

let input = input_from_plain_casm(instructions, false);
let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode;
assert_eq!(
casm_states_by_opcode
.jnz_opcode_is_taken_f_dst_base_fp_t
.len(),
1
);
}

#[test]
fn test_jnz_taken_fp() {
let instructions = casm! {
call rel 2;
jmp rel 2 if [fp-1] != 0;
Expand All @@ -788,6 +838,25 @@ mod tests {
);
}

#[test]
fn test_jnz_taken_ap() {
let instructions = casm! {
[ap] = 5, ap++;
jmp rel 2 if [ap-1] != 0;
[ap] = 1, ap++;
}
.instructions;

let input = input_from_plain_casm(instructions, false);
let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode;
assert_eq!(
casm_states_by_opcode
.jnz_opcode_is_taken_t_dst_base_fp_f
.len(),
1
);
}

#[test]
fn test_assert_equal() {
let instructions = casm! {
Expand Down Expand Up @@ -828,6 +897,12 @@ mod tests {
casm_states_by_opcode.add_opcode_is_small_t_is_imm_f.len(),
1
);
assert_eq!(
casm_states_by_opcode
.assert_eq_opcode_is_double_deref_f_is_imm_t
.len(),
2
);
assert_eq!(
casm_states_by_opcode.add_opcode_is_small_t_is_imm_t.len(),
1
Expand Down Expand Up @@ -909,4 +984,59 @@ mod tests {
1
);
}

#[test]
fn test_generic() {
let instructions = casm! {
[ap]=1, ap++;
[ap]=2, ap++;
jmp rel [ap-2] if [ap-1] != 0;
[ap]=1, ap++;
}
.instructions;

let input = input_from_plain_casm(instructions, false);
let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode;
assert_eq!(casm_states_by_opcode.generic_opcode.len(), 1);
}

#[test]
fn test_ret() {
let instructions = casm! {
[ap] = 10, ap++;
call rel 4;
jmp rel 11;

jmp rel 4 if [fp-3] != 0;
jmp rel 6;
[ap] = [fp-3] + (-1), ap++;
call rel (-6);
ret;
}
.instructions;

let input = input_from_plain_casm(instructions, false);
let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode;
assert_eq!(casm_states_by_opcode.ret_opcode.len(), 11);
}

#[test]
fn test_assert_eq_double_deref() {
let instructions = casm! {
call rel 2;
[ap] = 100, ap++;
[ap] = [[fp - 2] + 2], ap++; // [fp - 2] is the old fp.
[ap] = 5;
}
.instructions;

let input = input_from_plain_casm(instructions, false);
let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode;
assert_eq!(
casm_states_by_opcode
.assert_eq_opcode_is_double_deref_t_is_imm_f
.len(),
1
);
}
}
22 changes: 16 additions & 6 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ pub enum VmImportError {
NoMemorySegments,
}

// TODO(Ohad): remove dev_mode after adding the rest of the instructions.
pub fn import_from_vm_output(
pub_json: &Path,
priv_json: &Path,
dev_mod: bool,
) -> Result<CairoInput, VmImportError> {
let _span = span!(Level::INFO, "import_from_vm_output").entered();
let pub_data: PublicInput = sonic_rs::from_str(&std::fs::read_to_string(pub_json)?)?;
Expand All @@ -48,7 +50,8 @@ pub fn import_from_vm_output(
let mut trace_file = std::io::BufReader::new(std::fs::File::open(trace_path)?);
let mut mem_file = std::io::BufReader::new(std::fs::File::open(mem_path)?);
let mut mem = MemoryBuilder::from_iter(mem_config, MemEntryIter(&mut mem_file));
let state_transitions = StateTransitions::from_iter(TraceIter(&mut trace_file), &mut mem, true);
let state_transitions =
StateTransitions::from_iter(TraceIter(&mut trace_file), &mut mem, dev_mod);

let public_mem_addresses = pub_data
.public_memory
Expand Down Expand Up @@ -133,7 +136,12 @@ pub mod tests {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/large_cairo_input");

import_from_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect(
import_from_vm_output(
d.join("pub.json").as_path(),
d.join("priv.json").as_path(),
false,
)
.expect(
"
Failed to read test files. Maybe git-lfs is not installed? Checkout README.md.",
)
Expand All @@ -142,15 +150,18 @@ pub mod tests {
pub fn small_cairo_input() -> CairoInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/small_cairo_input");
import_from_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect(
import_from_vm_output(
d.join("pub.json").as_path(),
d.join("priv.json").as_path(),
false,
)
.expect(
"
Failed to read test files. Maybe git-lfs is not installed? Checkout README.md.",
)
}

// TODO (Stav): Once all the components are in, verify the proof to ensure the sort was correct.
// TODO (Ohad): remove the following doc after deleting dev_mod.
/// When not ignored, the test passes only with dev_mod = false.
#[ignore]
#[test]
fn test_read_from_large_files() {
Expand Down Expand Up @@ -217,7 +228,6 @@ pub mod tests {
assert_eq!(components.ret_opcode.len(), 49472);
}

// When not ignored, the test passes only with dev_mod = false.
#[ignore]
#[test]
fn test_read_from_small_files() {
Expand Down

0 comments on commit c7d080d

Please sign in to comment.