Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parallel add inputs #353

Open
wants to merge 1 commit into
base: ohad/add_single_input
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 59 additions & 58 deletions stwo_cairo_prover/Cargo.lock

Large diffs are not rendered by default.

333 changes: 146 additions & 187 deletions stwo_cairo_prover/crates/prover/src/components/add_ap_opcode/prover.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#![allow(unused_imports)]
use std::iter::zip;

use air_structs_derive::SubComponentInputs;
use itertools::{chain, zip_eq, Itertools};
use num_traits::{One, Zero};
use prover_types::cpu::*;
Expand All @@ -28,7 +27,9 @@ use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::pcs::TreeBuilder;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order;
use stwo_prover::core::utils::{
bit_reverse_coset_to_circle_domain_order, bit_reverse_index, coset_index_to_circle_domain_index,
};

use super::component::{Claim, InteractionClaim};
use crate::components::utils::pack_values;
Expand Down Expand Up @@ -69,35 +70,14 @@ impl ClaimGenerator {
}

let packed_inputs = pack_values(&self.inputs);
let (trace, mut sub_components_inputs, lookup_data) = write_trace_simd(
let (trace, lookup_data) = write_trace_simd(
n_rows,
packed_inputs,
memory_address_to_id_state,
memory_id_to_big_state,
verify_instruction_state,
);

if need_padding {
sub_components_inputs.bit_reverse_coset_to_circle_domain_order();
}
sub_components_inputs
.memory_address_to_id_inputs
.iter()
.for_each(|inputs| {
memory_address_to_id_state.add_inputs(&inputs[..n_rows]);
});
sub_components_inputs
.memory_id_to_big_inputs
.iter()
.for_each(|inputs| {
memory_id_to_big_state.add_inputs(&inputs[..n_rows]);
});
sub_components_inputs
.verify_instruction_inputs
.iter()
.for_each(|inputs| {
verify_instruction_state.add_inputs(&inputs[..n_rows]);
});

tree_builder.extend_evals(trace.to_evals());

(
Expand All @@ -108,17 +88,6 @@ impl ClaimGenerator {
},
)
}

pub fn add_inputs(&self, _inputs: &[InputType]) {
unimplemented!("Implement manually");
}
}

#[derive(SubComponentInputs, Uninitialized, IterMut, ParIterMut)]
pub struct SubComponentInputs {
pub memory_address_to_id_inputs: [Vec<memory_address_to_id::InputType>; 1],
pub memory_id_to_big_inputs: [Vec<memory_id_to_big::InputType>; 1],
pub verify_instruction_inputs: [Vec<verify_instruction::InputType>; 1],
}

#[allow(clippy::useless_conversion)]
Expand All @@ -130,18 +99,14 @@ fn write_trace_simd(
inputs: Vec<PackedInputType>,
memory_address_to_id_state: &memory_address_to_id::ClaimGenerator,
memory_id_to_big_state: &memory_id_to_big::ClaimGenerator,
) -> (
ComponentTrace<N_TRACE_COLUMNS>,
SubComponentInputs,
LookupData,
) {
verify_instruction_state: &verify_instruction::ClaimGenerator,
) -> (ComponentTrace<N_TRACE_COLUMNS>, LookupData) {
let log_n_packed_rows = inputs.len().ilog2();
let log_size = log_n_packed_rows + LOG_N_LANES;
let (mut trace, mut lookup_data, mut sub_components_inputs) = unsafe {
let (mut trace, mut lookup_data) = unsafe {
(
ComponentTrace::<N_TRACE_COLUMNS>::uninitialized(log_size),
LookupData::uninitialized(log_n_packed_rows),
SubComponentInputs::uninitialized(log_size),
)
};

Expand All @@ -162,12 +127,8 @@ fn write_trace_simd(
.enumerate()
.zip(inputs.into_par_iter())
.zip(lookup_data.par_iter_mut())
.zip(sub_components_inputs.par_iter_mut().chunks(N_LANES))
.for_each(
|(
(((row_index, row), add_ap_opcode_imm_input), lookup_data),
mut sub_components_inputs,
)| {
|(((row_index, row), add_ap_opcode_imm_input), lookup_data)| {
let input_tmp_f4f1f_0 = add_ap_opcode_imm_input;
let input_pc_col0 = input_tmp_f4f1f_0.pc;
*row[0] = input_pc_col0;
Expand All @@ -182,20 +143,15 @@ fn write_trace_simd(
memory_address_to_id_state.deduce_output(input_pc_col0);
let memory_id_to_big_value_tmp_f4f1f_2 =
memory_id_to_big_state.deduce_output(memory_address_to_id_value_tmp_f4f1f_1);
for (i, &input) in (
let verify_instruction_inputs_0 = (
input_pc_col0,
[M31_32767, M31_32767, M31_32769],
[
M31_1, M31_1, M31_1, M31_0, M31_0, M31_0, M31_0, M31_0, M31_0, M31_0,
M31_1, M31_0, M31_0, M31_0, M31_0,
],
)
.unpack()
.iter()
.enumerate()
{
*sub_components_inputs[i].verify_instruction_inputs[0] = input;
}
.unpack();
*lookup_data.verify_instruction_0 = [
input_pc_col0,
M31_32767,
Expand Down Expand Up @@ -226,9 +182,7 @@ fn write_trace_simd(
memory_id_to_big_state.deduce_output(memory_address_to_id_value_tmp_f4f1f_3);
let op1_id_col3 = memory_address_to_id_value_tmp_f4f1f_3;
*row[3] = op1_id_col3;
for (i, &input) in ((input_pc_col0) + (M31_1)).unpack().iter().enumerate() {
*sub_components_inputs[i].memory_address_to_id_inputs[0] = input;
}
let memory_address_to_id_inputs_0 = ((input_pc_col0) + (M31_1)).unpack();
*lookup_data.memory_address_to_id_0 = [((input_pc_col0) + (M31_1)), op1_id_col3];

// Cond Decode Small Sign.
Expand All @@ -247,9 +201,7 @@ fn write_trace_simd(
*row[7] = op1_limb_1_col7;
let op1_limb_2_col8 = memory_id_to_big_value_tmp_f4f1f_4.get_m31(2);
*row[8] = op1_limb_2_col8;
for (i, &input) in op1_id_col3.unpack().iter().enumerate() {
*sub_components_inputs[i].memory_id_to_big_inputs[0] = input;
}
let memory_id_to_big_inputs_0 = op1_id_col3.unpack();
*lookup_data.memory_id_to_big_0 = [
op1_id_col3,
op1_limb_0_col6,
Expand Down Expand Up @@ -292,10 +244,24 @@ fn write_trace_simd(
- ((M31_134217728) * (mid_limbs_set_col5)))),
input_fp_col2,
];

// Add sub-components inputs.
#[allow(clippy::needless_range_loop)]
for i in 0..N_LANES {
if bit_reverse_index(
coset_index_to_circle_domain_index(row_index * N_LANES + i, log_size),
log_size,
) < n_rows
{
verify_instruction_state.add_input(&verify_instruction_inputs_0[i]);
memory_address_to_id_state.add_input(&memory_address_to_id_inputs_0[i]);
memory_id_to_big_state.add_input(&memory_id_to_big_inputs_0[i]);
}
}
},
);

(trace, sub_components_inputs, lookup_data)
(trace, lookup_data)
}

#[derive(Uninitialized, IterMut, ParIterMut)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#![allow(unused_imports)]
use std::iter::zip;

use air_structs_derive::SubComponentInputs;
use itertools::{chain, zip_eq, Itertools};
use num_traits::{One, Zero};
use prover_types::cpu::*;
Expand All @@ -28,7 +27,9 @@ use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::pcs::TreeBuilder;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order;
use stwo_prover::core::utils::{
bit_reverse_coset_to_circle_domain_order, bit_reverse_index, coset_index_to_circle_domain_index,
};

use super::component::{Claim, InteractionClaim};
use crate::components::utils::pack_values;
Expand Down Expand Up @@ -69,35 +70,14 @@ impl ClaimGenerator {
}

let packed_inputs = pack_values(&self.inputs);
let (trace, mut sub_components_inputs, lookup_data) = write_trace_simd(
let (trace, lookup_data) = write_trace_simd(
n_rows,
packed_inputs,
memory_address_to_id_state,
memory_id_to_big_state,
verify_instruction_state,
);

if need_padding {
sub_components_inputs.bit_reverse_coset_to_circle_domain_order();
}
sub_components_inputs
.memory_address_to_id_inputs
.iter()
.for_each(|inputs| {
memory_address_to_id_state.add_inputs(&inputs[..n_rows]);
});
sub_components_inputs
.memory_id_to_big_inputs
.iter()
.for_each(|inputs| {
memory_id_to_big_state.add_inputs(&inputs[..n_rows]);
});
sub_components_inputs
.verify_instruction_inputs
.iter()
.for_each(|inputs| {
verify_instruction_state.add_inputs(&inputs[..n_rows]);
});

tree_builder.extend_evals(trace.to_evals());

(
Expand All @@ -108,17 +88,6 @@ impl ClaimGenerator {
},
)
}

pub fn add_inputs(&self, _inputs: &[InputType]) {
unimplemented!("Implement manually");
}
}

#[derive(SubComponentInputs, Uninitialized, IterMut, ParIterMut)]
pub struct SubComponentInputs {
pub memory_address_to_id_inputs: [Vec<memory_address_to_id::InputType>; 1],
pub memory_id_to_big_inputs: [Vec<memory_id_to_big::InputType>; 1],
pub verify_instruction_inputs: [Vec<verify_instruction::InputType>; 1],
}

#[allow(clippy::useless_conversion)]
Expand All @@ -130,18 +99,14 @@ fn write_trace_simd(
inputs: Vec<PackedInputType>,
memory_address_to_id_state: &memory_address_to_id::ClaimGenerator,
memory_id_to_big_state: &memory_id_to_big::ClaimGenerator,
) -> (
ComponentTrace<N_TRACE_COLUMNS>,
SubComponentInputs,
LookupData,
) {
verify_instruction_state: &verify_instruction::ClaimGenerator,
) -> (ComponentTrace<N_TRACE_COLUMNS>, LookupData) {
let log_n_packed_rows = inputs.len().ilog2();
let log_size = log_n_packed_rows + LOG_N_LANES;
let (mut trace, mut lookup_data, mut sub_components_inputs) = unsafe {
let (mut trace, mut lookup_data) = unsafe {
(
ComponentTrace::<N_TRACE_COLUMNS>::uninitialized(log_size),
LookupData::uninitialized(log_n_packed_rows),
SubComponentInputs::uninitialized(log_size),
)
};

Expand All @@ -165,12 +130,8 @@ fn write_trace_simd(
.enumerate()
.zip(inputs.into_par_iter())
.zip(lookup_data.par_iter_mut())
.zip(sub_components_inputs.par_iter_mut().chunks(N_LANES))
.for_each(
|(
(((row_index, row), add_ap_opcode_op_1_base_fp_input), lookup_data),
mut sub_components_inputs,
)| {
|(((row_index, row), add_ap_opcode_op_1_base_fp_input), lookup_data)| {
let input_tmp_fc5da_0 = add_ap_opcode_op_1_base_fp_input;
let input_pc_col0 = input_tmp_fc5da_0.pc;
*row[0] = input_pc_col0;
Expand All @@ -197,20 +158,15 @@ fn write_trace_simd(
<< (UInt16_13)));
let offset2_col3 = offset2_tmp_fc5da_3.as_m31();
*row[3] = offset2_col3;
for (i, &input) in (
let verify_instruction_inputs_0 = (
input_pc_col0,
[M31_32767, M31_32767, offset2_col3],
[
M31_1, M31_1, M31_0, M31_1, M31_0, M31_0, M31_0, M31_0, M31_0, M31_0,
M31_1, M31_0, M31_0, M31_0, M31_0,
],
)
.unpack()
.iter()
.enumerate()
{
*sub_components_inputs[i].verify_instruction_inputs[0] = input;
}
.unpack();
*lookup_data.verify_instruction_0 = [
input_pc_col0,
M31_32767,
Expand Down Expand Up @@ -241,13 +197,8 @@ fn write_trace_simd(
memory_id_to_big_state.deduce_output(memory_address_to_id_value_tmp_fc5da_4);
let op1_id_col4 = memory_address_to_id_value_tmp_fc5da_4;
*row[4] = op1_id_col4;
for (i, &input) in ((input_fp_col2) + ((offset2_col3) - (M31_32768)))
.unpack()
.iter()
.enumerate()
{
*sub_components_inputs[i].memory_address_to_id_inputs[0] = input;
}
let memory_address_to_id_inputs_0 =
((input_fp_col2) + ((offset2_col3) - (M31_32768))).unpack();
*lookup_data.memory_address_to_id_0 = [
((input_fp_col2) + ((offset2_col3) - (M31_32768))),
op1_id_col4,
Expand All @@ -269,9 +220,7 @@ fn write_trace_simd(
*row[8] = op1_limb_1_col8;
let op1_limb_2_col9 = memory_id_to_big_value_tmp_fc5da_5.get_m31(2);
*row[9] = op1_limb_2_col9;
for (i, &input) in op1_id_col4.unpack().iter().enumerate() {
*sub_components_inputs[i].memory_id_to_big_inputs[0] = input;
}
let memory_id_to_big_inputs_0 = op1_id_col4.unpack();
*lookup_data.memory_id_to_big_0 = [
op1_id_col4,
op1_limb_0_col7,
Expand Down Expand Up @@ -314,10 +263,24 @@ fn write_trace_simd(
- ((M31_134217728) * (mid_limbs_set_col6)))),
input_fp_col2,
];

// Add sub-components inputs.
#[allow(clippy::needless_range_loop)]
for i in 0..N_LANES {
if bit_reverse_index(
coset_index_to_circle_domain_index(row_index * N_LANES + i, log_size),
log_size,
) < n_rows
{
verify_instruction_state.add_input(&verify_instruction_inputs_0[i]);
memory_address_to_id_state.add_input(&memory_address_to_id_inputs_0[i]);
memory_id_to_big_state.add_input(&memory_id_to_big_inputs_0[i]);
}
}
},
);

(trace, sub_components_inputs, lookup_data)
(trace, lookup_data)
}

#[derive(Uninitialized, IterMut, ParIterMut)]
Expand Down
Loading
Loading