Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made builder inject output-ptr for provable mode. #6643

Open
wants to merge 1 commit into
base: spr/main/63508b7b
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions crates/cairo-lang-runnable-utils/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,25 @@ impl RunnableBuilder {
pub struct EntryCodeConfig {
/// Whether to finalize the segment arena after calling the function.
pub finalize_segment_arena: bool,
/// Whether the wrapped function is expecting the output builtin.
///
/// If true, will expect the function signature to be `(Span<felt252>, Array<felt252>) ->
/// Array<felt252>`. And will inject the output builtin to be the supplied array input, and
/// to be the result of the output.
pub outputting_function: bool,
}
impl EntryCodeConfig {
/// Returns a configuration for testing purposes.
///
/// This configuration will not finalize the segment arena after calling the function, to
/// prevent failure in case of functions returning values.
pub fn testing() -> Self {
Self { finalize_segment_arena: false }
Self { finalize_segment_arena: false, outputting_function: false }
}

/// Returns a configuration for proving purposes.
pub fn provable() -> Self {
Self { finalize_segment_arena: true }
Self { finalize_segment_arena: true, outputting_function: true }
}
}

Expand All @@ -272,6 +278,7 @@ pub fn create_entry_code_from_params(
let mut ctx = CasmBuilder::default();
let mut builtin_offset = 3;
let mut builtin_vars = UnorderedHashMap::<_, _>::default();
let mut builtin_ty_to_vm_name = UnorderedHashMap::<_, _>::default();
let mut builtins = vec![];
for (builtin_name, builtin_ty) in [
(BuiltinName::mul_mod, MulModType::ID),
Expand All @@ -286,13 +293,19 @@ pub fn create_entry_code_from_params(
if param_types.iter().any(|(ty, _)| ty == &builtin_ty) {
// The offset [fp - i] for each of this builtins in this configuration.
builtin_vars.insert(
builtin_ty,
builtin_name,
ctx.add_var(CellExpression::Deref(deref!([fp - builtin_offset]))),
);
builtin_ty_to_vm_name.insert(builtin_ty.clone(), builtin_name);
builtin_offset += 1;
builtins.push(builtin_name);
}
}
if config.outputting_function {
let output_builtin_var = ctx.add_var(CellExpression::Deref(deref!([fp - builtin_offset])));
builtin_vars.insert(BuiltinName::output, output_builtin_var);
builtins.push(BuiltinName::output);
}
builtins.reverse();

let emulated_builtins = UnorderedHashSet::<_>::from_iter([SystemType::ID]);
Expand Down Expand Up @@ -328,20 +341,22 @@ pub fn create_entry_code_from_params(
assert zero = *(segment_arena++);
}
// Adding the segment arena to the builtins var map.
builtin_vars.insert(SegmentArenaType::ID, segment_arena);
builtin_vars.insert(BuiltinName::segment_arena, segment_arena);
builtin_ty_to_vm_name.insert(SegmentArenaType::ID, BuiltinName::segment_arena);
}
let mut unallocated_count = 0;
let mut param_index = 0;
for (generic_ty, ty_size) in param_types {
if let Some(var) = builtin_vars.get(generic_ty).cloned() {
if let Some(name) = builtin_ty_to_vm_name.get(generic_ty).cloned() {
let var = builtin_vars[&name];
casm_build_extend!(ctx, tempvar _builtin = var;);
} else if emulated_builtins.contains(generic_ty) {
casm_build_extend! {ctx,
tempvar system;
hint AllocSegment into {dst: system};
};
unallocated_count += ty_size;
} else {
} else if !config.outputting_function {
if *ty_size > 0 {
casm_build_extend! { ctx,
tempvar first;
Expand All @@ -354,7 +369,20 @@ pub fn create_entry_code_from_params(
}
param_index += 1;
unallocated_count += ty_size;
}
}
if config.outputting_function {
let output_ptr = builtin_vars[&BuiltinName::output];
casm_build_extend! { ctx,
tempvar input_start;
tempvar _input_end;
const param_index = 0;
hint ExternalHint::WriteRunParam { index: param_index } into { dst: input_start };
const user_data_offset = 1;
tempvar output_start = output_ptr + user_data_offset;
tempvar output_end = output_start;
};
unallocated_count += 4;
}
if unallocated_count > 0 {
casm_build_extend!(ctx, ap += unallocated_count.into_or_panic::<usize>(););
Expand All @@ -363,9 +391,51 @@ pub fn create_entry_code_from_params(
let mut unprocessed_return_size = return_types.iter().map(|(_, size)| size).sum::<i16>();
let mut return_data = vec![];
for (ret_ty, size) in return_types {
if let Some(var) = builtin_vars.get_mut(ret_ty) {
*var = ctx.add_var(CellExpression::Deref(deref!([ap - unprocessed_return_size])));
if let Some(name) = builtin_ty_to_vm_name.get(ret_ty) {
*builtin_vars.get_mut(name).unwrap() =
ctx.add_var(CellExpression::Deref(deref!([ap - unprocessed_return_size])));
unprocessed_return_size -= 1;
} else if config.outputting_function {
let output_ptr_var = builtin_vars[&BuiltinName::output];
// The output builtin values.
let new_output_ptr = if *size == 3 {
let panic_indicator =
ctx.add_var(CellExpression::Deref(deref!([ap - unprocessed_return_size])));
unprocessed_return_size -= 2;
// The output ptr in the case of successful run.
let output_ptr_end =
ctx.add_var(CellExpression::Deref(deref!([ap - unprocessed_return_size])));
unprocessed_return_size -= 1;
casm_build_extend! {ctx,
tempvar new_output_ptr;
jump PANIC if panic_indicator != 0;
// SUCCESS:
assert new_output_ptr = output_ptr_end;
jump AFTER_PANIC_HANDLING;
PANIC:
const one = 1;
// In the case of an error, we assume no values are written to the output_ptr.
assert new_output_ptr = output_ptr_var + one;
AFTER_PANIC_HANDLING:
assert panic_indicator = output_ptr_var[0];
};
new_output_ptr
} else if *size == 2 {
// No panic possible.
unprocessed_return_size -= 1;
let output_ptr_end =
ctx.add_var(CellExpression::Deref(deref!([ap - unprocessed_return_size])));
unprocessed_return_size -= 1;
casm_build_extend! {ctx,
const czero = 0;
tempvar zero = czero;
assert zero = *(output_ptr_var++);
};
output_ptr_end
} else {
panic!("Unexpected output size: {size}",);
};
*builtin_vars.get_mut(&BuiltinName::output).unwrap() = new_output_ptr;
} else {
for _ in 0..*size {
return_data.push(
Expand All @@ -384,7 +454,7 @@ pub fn create_entry_code_from_params(
}
}
if got_segment_arena && config.finalize_segment_arena {
let segment_arena = builtin_vars[&SegmentArenaType::ID];
let segment_arena = builtin_vars[&BuiltinName::segment_arena];
// Validating the segment arena's segments are one after the other.
casm_build_extend! {ctx,
tempvar n_segments = segment_arena[-2];
Expand Down Expand Up @@ -426,6 +496,23 @@ pub fn create_entry_code_from_params(
};
}
}
if config.outputting_function {
for name in [
BuiltinName::output,
BuiltinName::mul_mod,
BuiltinName::add_mod,
BuiltinName::range_check96,
BuiltinName::poseidon,
BuiltinName::ec_op,
BuiltinName::bitwise,
BuiltinName::range_check,
BuiltinName::pedersen,
] {
if let Some(var) = builtin_vars.get(&name).copied() {
casm_build_extend!(ctx, tempvar cell = var;);
}
}
}
casm_build_extend! (ctx, ret;);
ctx.future_label("FUNCTION".into(), code_offset);
Ok((ctx.build([]).instructions, builtins))
Expand Down
50 changes: 38 additions & 12 deletions crates/cairo-lang-runnable/src/compile_test_data/basic
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@ CompileRunnableTestRunner(expect_diagnostics: false)
fn main() {}

//! > generated_casm_code
# builtins:
# builtins: output
# header #
%{ raise NotImplementedError("memory[ap + 0].. = params[0])") %}
%{ raise NotImplementedError("memory[ap + 2].. = params[1])") %}
[ap + 2] = [fp + -3] + 1, ap++;
[ap + 2] = [ap + 1], ap++;
ap += 4;
call rel 3;
call rel 12;
jmp rel 5 if [ap + -3] != 0, ap++;
[ap + -1] = [ap + -2];
jmp rel 4;
[ap + -1] = [fp + -3] + 1;
[ap + -4] = [[fp + -3] + 0];
[ap + 0] = [ap + -1], ap++;
ret;
# sierra based code #
[fp + -5] = [ap + 0] + [fp + -6], ap++;
Expand Down Expand Up @@ -51,12 +58,19 @@ fn main(a: felt252, b: felt252) -> felt252 {
}

//! > generated_casm_code
# builtins:
# builtins: output
# header #
%{ raise NotImplementedError("memory[ap + 0].. = params[0])") %}
%{ raise NotImplementedError("memory[ap + 2].. = params[1])") %}
[ap + 2] = [fp + -3] + 1, ap++;
[ap + 2] = [ap + 1], ap++;
ap += 4;
call rel 3;
call rel 12;
jmp rel 5 if [ap + -3] != 0, ap++;
[ap + -1] = [ap + -2];
jmp rel 4;
[ap + -1] = [fp + -3] + 1;
[ap + -4] = [[fp + -3] + 0];
[ap + 0] = [ap + -1], ap++;
ret;
# sierra based code #
[fp + -5] = [ap + 0] + [fp + -6], ap++;
Expand Down Expand Up @@ -130,13 +144,21 @@ fn fib(a: u128, b: u128, n: u128) -> u128 {
}

//! > generated_casm_code
# builtins: range_check
# builtins: output, range_check
# header #
[ap + 0] = [fp + -3], ap++;
%{ raise NotImplementedError("memory[ap + 0].. = params[0])") %}
%{ raise NotImplementedError("memory[ap + 2].. = params[1])") %}
[ap + 2] = [fp + -4] + 1, ap++;
[ap + 2] = [ap + 1], ap++;
ap += 4;
call rel 3;
call rel 13;
jmp rel 5 if [ap + -3] != 0, ap++;
[ap + -1] = [ap + -2];
jmp rel 4;
[ap + -1] = [fp + -4] + 1;
[ap + -4] = [[fp + -4] + 0];
[ap + 0] = [ap + -1], ap++;
[ap + 0] = [ap + -6], ap++;
ret;
# sierra based code #
[fp + -5] = [ap + 0] + [fp + -6], ap++;
Expand Down Expand Up @@ -359,12 +381,16 @@ CompileRunnableTestRunner(expect_diagnostics: false)
fn main(mut _input: Span<felt252>, ref _output: Array<felt252>) {}

//! > generated_casm_code
# builtins:
# builtins: output
# header #
%{ raise NotImplementedError("memory[ap + 0].. = params[0])") %}
%{ raise NotImplementedError("memory[ap + 2].. = params[1])") %}
[ap + 2] = [fp + -3] + 1, ap++;
[ap + 2] = [ap + 1], ap++;
ap += 4;
call rel 3;
call rel 7;
[ap + 0] = 0, ap++;
[ap + -1] = [[fp + -3] + 0];
[ap + 0] = [ap + -2], ap++;
ret;
# sierra based code #
[ap + 0] = [fp + -4], ap++;
Expand Down