From 86cdad0c13f67718844beeb1ae6e6806c3518667 Mon Sep 17 00:00:00 2001 From: sragss Date: Mon, 18 Mar 2024 14:25:30 -0600 Subject: [PATCH 1/6] move snark.inputs to a legible struct --- jolt-core/src/jolt/vm/mod.rs | 27 +++- jolt-core/src/r1cs/constraints.rs | 9 +- jolt-core/src/r1cs/snark.rs | 226 +++++++++++++++++++++++------- 3 files changed, 199 insertions(+), 63 deletions(-) diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index 4aa0f8757..78bd16e62 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -15,7 +15,7 @@ use crate::poly::dense_mlpoly::DensePolynomial; use crate::poly::hyrax::{HyraxCommitment, HyraxGenerators}; use crate::poly::pedersen::PedersenGenerators; use crate::poly::structured_poly::BatchablePolynomials; -use crate::r1cs::snark::R1CSProof; +use crate::r1cs::snark::{R1CSInputs, R1CSProof}; use crate::utils::errors::ProofVerifyError; use crate::{ jolt::{ @@ -451,9 +451,22 @@ where // Flattening this out into a Vec and chunking into PADDED_TRACE_LEN-sized chunks // will be the exact witness vector to feed into the R1CS // after pre-pending IO and appending the AUX - let inputs: Vec> = vec![ - bytecode_a, // prog_a_rw, - bytecode_v, // prog_v_rw (with circuit_flags_packed) + // let inputs: Vec> = vec![ + // bytecode_a, // prog_a_rw, + // bytecode_v, // prog_v_rw (with circuit_flags_packed) + // memreg_a_rw, + // memreg_v_reads, + // memreg_v_writes, + // chunks_x.clone(), + // chunks_y.clone(), + // chunks_query, + // lookup_outputs.clone(), + // circuit_flags_bits.clone(), + // ]; + + let inputs: R1CSInputs = R1CSInputs::from_ark( + bytecode_a, + bytecode_v, memreg_a_rw, memreg_v_reads, memreg_v_writes, @@ -461,8 +474,8 @@ where chunks_y.clone(), chunks_query, lookup_outputs.clone(), - circuit_flags_bits.clone(), - ]; + circuit_flags_bits.clone() + ); // Assemble the commitments let span = tracing::span!(tracing::Level::INFO, "bytecode_commitment_conversions"); @@ -510,7 +523,7 @@ where circuit_flags_comm ].concat(); - R1CSProof::prove( + R1CSProof::prove::( 32, C, PADDED_TRACE_LEN, inputs, preprocessing.spartan_generators, diff --git a/jolt-core/src/r1cs/constraints.rs b/jolt-core/src/r1cs/constraints.rs index 41b920f50..6438d273c 100644 --- a/jolt-core/src/r1cs/constraints.rs +++ b/jolt-core/src/r1cs/constraints.rs @@ -3,6 +3,7 @@ use smallvec::SmallVec; use smallvec::smallvec; use ff::PrimeField; +use rayon::prelude::*; /* Compiler Variables */ const C: usize = 4; @@ -126,6 +127,7 @@ fn concat_constraint_vecs(mut x: SmallVec<[(usize, i64); SMALLVEC_SIZE]>, y: Sma } fn i64_to_f(num: i64) -> F { + // TODO(sragss): Make from_u64 if num < 0 { F::ZERO - F::from((-num) as u64) } else { @@ -771,11 +773,12 @@ impl R1CSBuilder<'_, F> { || modify_matrix(&mut self.C))); } + #[tracing::instrument(skip_all, name = "Shape::convert_to_field")] pub fn convert_to_field(&self) -> (Vec<(usize, usize, F)>, Vec<(usize, usize, F)>, Vec<(usize, usize, F)>) { ( - self.A.iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::(*val))).collect(), - self.B.iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::(*val))).collect(), - self.C.iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::(*val))).collect(), + self.A.par_iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::(*val))).collect(), + self.B.par_iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::(*val))).collect(), + self.C.par_iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::(*val))).collect(), ) } } diff --git a/jolt-core/src/r1cs/snark.rs b/jolt-core/src/r1cs/snark.rs index de61f65e0..8112808a1 100644 --- a/jolt-core/src/r1cs/snark.rs +++ b/jolt-core/src/r1cs/snark.rs @@ -1,4 +1,4 @@ -use crate::jolt; +use crate::{jolt, utils::thread::drop_in_background_thread}; use super::constraints::R1CSBuilder; @@ -63,17 +63,7 @@ fn reassemble_segments_partial(jolt_witnesses: Vec>, num_f #[derive(Clone, Debug, Default)] pub struct JoltCircuit> { num_steps: usize, - inputs: Vec>, - // prog_a_rw: Vec, - // prog_v_rw: Vec, - // memreg_a_rw: Vec, - // memreg_v_reads: Vec, - // memreg_v_writes: Vec, - // chunks_x: Vec, - // chunks_y: Vec, - // chunks_query: Vec, - // lookup_outputs: Vec, - // op_flags: Vec, + inputs: R1CSInputs, } // This is a placeholder trait to satisfy Spartan's requirements. @@ -85,7 +75,7 @@ impl> Circuit for JoltCircuit { } impl> JoltCircuit { - pub fn new_from_inputs(num_steps: usize, inputs: Vec>) -> Self { + pub fn new_from_inputs(num_steps: usize, inputs: R1CSInputs) -> Self { JoltCircuit{ num_steps: num_steps, inputs: inputs, @@ -94,41 +84,16 @@ impl> JoltCircuit { #[tracing::instrument(name = "JoltCircuit::get_witnesses_by_step", skip_all)] fn synthesize_witnesses(&self) -> Result>, SynthesisError> { - let TRACE_LEN = self.inputs[0].len(); - let compute_witness_span = tracing::span!(tracing::Level::INFO, "compute_witness_loop"); - let _compute_witness_guard = compute_witness_span.enter(); + let mut step_z = self.inputs.clone_to_stepwise(); - // Allocate memory - let mut step_z: Vec> = Vec::with_capacity(TRACE_LEN); - for _ in 0..TRACE_LEN { - let step_i: Vec = Vec::with_capacity(self.inputs.len()); - step_z.push(step_i); - } - - // Allocate the inputs - step_z.par_iter_mut().enumerate().for_each(|(i, step)| { - let program_counter = if i > 0 && self.inputs[0][i] == F::from(0) { - F::from(0) - } else { - self.inputs[0][i] * F::from(4u64) + F::from(RAM_START_ADDRESS) - }; - step.extend([F::from(1), F::from(0), F::from(0), F::from(i as u64), program_counter]); - for j in 0..self.inputs.len() { - let max_k = (self.inputs[j].len() - i - 1) / TRACE_LEN; - for k in 0..=max_k { - step.push(self.inputs[j][i + k * TRACE_LEN]); - } - } - }); - - // Allocate the aux + // Compute the aux + let span = tracing::span!(tracing::Level::INFO, "calc_aux"); + let _guard = span.enter(); step_z.par_iter_mut().enumerate().for_each(|(i, step)| { R1CSBuilder::::calculate_aux(step); }); - drop(_compute_witness_guard); - Ok(step_z) } @@ -145,13 +110,167 @@ pub struct R1CSProof { vk: VerifierKey>>, } +#[derive(Clone, Debug, Default)] +pub struct R1CSInputs> { + bytecode_a: Vec, + bytecode_v: Vec, + memreg_a_rw: Vec, + memreg_v_reads: Vec, + memreg_v_writes: Vec, + chunks_x: Vec, + chunks_y: Vec, + chunks_query: Vec, + lookup_outputs: Vec, + circuit_flags_bits: Vec, +} + +impl> R1CSInputs { + #[tracing::instrument(skip_all, name = "R1CSInputs::from_ark")] + pub fn from_ark( + bytecode_a: Vec, + bytecode_v: Vec, + memreg_a_rw: Vec, + memreg_v_reads: Vec, + memreg_v_writes: Vec, + chunks_x: Vec, + chunks_y: Vec, + chunks_query: Vec, + lookup_outputs: Vec, + circuit_flag_bits: Vec + ) -> Self { + let bytecode_a: Vec = bytecode_a.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let bytecode_v: Vec = bytecode_v.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let memreg_a_rw: Vec = memreg_a_rw.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let memreg_v_reads: Vec = memreg_v_reads.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let memreg_v_writes: Vec = memreg_v_writes.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let chunks_x: Vec = chunks_x.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let chunks_y: Vec = chunks_y.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let chunks_query: Vec = chunks_query.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let lookup_outputs: Vec = lookup_outputs.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + let circuit_flags_bits: Vec = circuit_flag_bits.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect(); + + Self { + bytecode_a, + bytecode_v, + memreg_a_rw, + memreg_v_reads, + memreg_v_writes, + chunks_x, + chunks_y, + chunks_query, + lookup_outputs, + circuit_flags_bits, + } + } + + #[tracing::instrument(skip_all, name = "R1CSInputs::clone_to_stepwise")] + pub fn clone_to_stepwise(&self) -> Vec> { + const PREFIX_VARS_PER_STEP: usize = 5; + const AUX_VARS_PER_STEP: usize = 20; + let num_inputs_per_step = self.num_vars_per_step() + PREFIX_VARS_PER_STEP; + + let stepwise = (0..self.trace_len()).into_par_iter().map(|step_index| { + let mut step: Vec = Vec::with_capacity(num_inputs_per_step + AUX_VARS_PER_STEP); + let program_counter = if step_index > 0 && self.bytecode_a[step_index] == F::ZERO { + F::ZERO + } else { + self.bytecode_a[step_index] * F::from(4u64) + F::from(RAM_START_ADDRESS) + }; + // TODO(sragss): This indexing strategy is stolen from old -- but self.trace_len here is self.bytecode_a.len() -- not sure why we're using that to split inputs. + step.extend([F::from(1), F::from(0), F::from(0), F::from(step_index as u64), program_counter]); + let bytecode_a_num_vals = self.bytecode_a.len() / self.trace_len(); + for var_index in 0..bytecode_a_num_vals { + step.push(self.bytecode_a[var_index * self.trace_len() + step_index]); + } + let bytecode_v_num_vals = self.bytecode_v.len() / self.trace_len(); + for var_index in 0..bytecode_v_num_vals { + step.push(self.bytecode_v[var_index * self.trace_len() + step_index]); + } + let memreg_a_rw_num_vals = self.memreg_a_rw.len() / self.trace_len(); + for var_index in 0..memreg_a_rw_num_vals { + step.push(self.memreg_a_rw[var_index * self.trace_len() + step_index]); + } + let memreg_v_reads_num_vals = self.memreg_v_reads.len() / self.trace_len(); + for var_index in 0..memreg_v_reads_num_vals { + step.push(self.memreg_v_reads[var_index * self.trace_len() + step_index]); + } + let memreg_v_writes_num_vals = self.memreg_v_writes.len() / self.trace_len(); + for var_index in 0..memreg_v_writes_num_vals { + step.push(self.memreg_v_writes[var_index * self.trace_len() + step_index]); + } + let chunks_x_num_vals = self.chunks_x.len() / self.trace_len(); + for var_index in 0..chunks_x_num_vals { + step.push(self.chunks_x[var_index * self.trace_len() + step_index]); + } + let chunks_y_num_vals = self.chunks_y.len() / self.trace_len(); + for var_index in 0..chunks_y_num_vals { + step.push(self.chunks_y[var_index * self.trace_len() + step_index]); + } + let chunks_query_num_vals = self.chunks_query.len() / self.trace_len(); + for var_index in 0..chunks_query_num_vals { + step.push(self.chunks_query[var_index * self.trace_len() + step_index]); + } + let lookup_outputs_num_vals = self.lookup_outputs.len() / self.trace_len(); + for var_index in 0..lookup_outputs_num_vals { + step.push(self.lookup_outputs[var_index * self.trace_len() + step_index]); + } + let circuit_flags_bits_num_vals = self.circuit_flags_bits.len() / self.trace_len(); + for var_index in 0..circuit_flags_bits_num_vals { + step.push(self.circuit_flags_bits[var_index * self.trace_len() + step_index]); + } + + assert_eq!(num_inputs_per_step, step.len()); + + step + }).collect(); + + stepwise + } + + + pub fn trace_len(&self) -> usize { + self.bytecode_a.len() + } + + pub fn num_vars_per_step(&self) -> usize { + let trace_len = self.trace_len(); + self.bytecode_a.len() / trace_len + + self.bytecode_v.len() / trace_len + + self.memreg_a_rw.len() / trace_len + + self.memreg_v_reads.len() / trace_len + + self.memreg_v_writes.len() / trace_len + + self.chunks_x.len() / trace_len + + self.chunks_y.len() / trace_len + + self.chunks_query.len() / trace_len + + self.lookup_outputs.len() / trace_len + + self.circuit_flags_bits.len() / trace_len + } + + #[tracing::instrument(skip_all, name = "R1CSInputs::trace_len_chunks")] + pub fn trace_len_chunks(&self, padded_trace_len: usize) -> Vec> { + // TODO(sragss / arasuarun): Explain why non-trace-len relevant stuff gets chunked to trace_len + let mut chunks: Vec> = Vec::new(); + chunks.extend(self.bytecode_a.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.bytecode_v.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.memreg_a_rw.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.memreg_v_reads.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.memreg_v_writes.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.chunks_x.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.chunks_y.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.chunks_query.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.lookup_outputs.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.extend(self.circuit_flags_bits.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks + } +} + impl R1CSProof { #[tracing::instrument(skip_all, name = "R1CSProof::prove")] pub fn prove ( _W: usize, _C: usize, - TRACE_LEN: usize, - inputs_ark: Vec>, + padded_trace_len: usize, + inputs: R1CSInputs, generators: Vec, jolt_commitments: &Vec>, ) -> Result { @@ -160,16 +279,15 @@ impl R1CSProof { type S = spartan2::spartan::upsnark::R1CSSNARK; type F = Spartan2Fr; - let NUM_STEPS = TRACE_LEN; + let NUM_STEPS = padded_trace_len; - let span = tracing::span!(tracing::Level::TRACE, "convert_ark_to_spartan_fr"); + let span = tracing::span!(tracing::Level::TRACE, "JoltCircuit::new_from_inputs"); let _enter = span.enter(); - let inputs: Vec> = inputs_ark.into_par_iter().map(|vec| vec.into_par_iter().map(|ark_item| ark_to_spartan_unsafe::(ark_item)).collect()).collect(); - drop(_enter); - drop(span); - let jolt_circuit = JoltCircuit::::new_from_inputs(NUM_STEPS, inputs.clone()); + drop(_enter); + let span = tracing::span!(tracing::Level::TRACE, "shape_stuff"); + let _enter = span.enter(); let mut jolt_shape = R1CSBuilder::::default(); R1CSBuilder::::get_matrices(&mut jolt_shape); let constraints_F = jolt_shape.convert_to_field(); @@ -181,11 +299,15 @@ impl R1CSProof { num_vars: jolt_shape.num_aux, // shouldn't include 1 or IO num_io: jolt_shape.num_inputs, }; + drop(_enter); // Obtain public key + let span = tracing::span!(tracing::Level::TRACE, "convert_ck_to_spartan"); + let _enter = span.enter(); let hyrax_ck = HyraxCommitmentKey:: { ck: spartan2::provider::pedersen::from_gens_bn256(generators) }; + drop(_enter); // let w_segments_from_circuit = jolt_circuit.synthesize_witness_segments().unwrap(); let (io_segments, aux_segments) = jolt_circuit.synthesize_state_aux_segments(4, jolt_shape.num_internal).unwrap(); @@ -193,9 +315,7 @@ impl R1CSProof { let cloning_stuff_span = tracing::span!(tracing::Level::TRACE, "cloning_stuff"); let _enter = cloning_stuff_span.enter(); - let inputs_segments: Vec> = inputs.into_iter().flat_map(|input| { - input.chunks(TRACE_LEN).map(|chunk| chunk.to_vec()).collect::>() - }).collect(); + let inputs_segments = inputs.trace_len_chunks(padded_trace_len); let w_segments = io_segments.clone().into_iter() .chain(inputs_segments.iter().cloned()) From 97d0ccb47ace5e8befa7855adf24ce74182fd7c7 Mon Sep 17 00:00:00 2001 From: sragss Date: Mon, 18 Mar 2024 15:26:33 -0600 Subject: [PATCH 2/6] cleanup; better comments; less clones; faster extends --- jolt-core/src/r1cs/snark.rs | 66 +++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/jolt-core/src/r1cs/snark.rs b/jolt-core/src/r1cs/snark.rs index 8112808a1..0b0d439ca 100644 --- a/jolt-core/src/r1cs/snark.rs +++ b/jolt-core/src/r1cs/snark.rs @@ -35,8 +35,7 @@ fn reassemble_segments(jolt_witnesses: Vec>) -> Vec } /// Reorder and drop first element [[a1, b1, c1], [a2, b2, c2]] => [[a2], [b2], [c2]] -/// Works -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip_all, name = "reassemble_segments_partial")] fn reassemble_segments_partial(jolt_witnesses: Vec>, num_front: usize, num_back: usize) -> (Vec>, Vec>) { let trace_len = jolt_witnesses.len(); let total_length = jolt_witnesses[0].len(); @@ -57,6 +56,8 @@ fn reassemble_segments_partial(jolt_witnesses: Vec>, num_f } }); + drop_in_background_thread(jolt_witnesses); + (front_result, back_result) } @@ -82,8 +83,15 @@ impl> JoltCircuit { } } - #[tracing::instrument(name = "JoltCircuit::get_witnesses_by_step", skip_all)] - fn synthesize_witnesses(&self) -> Result>, SynthesisError> { + #[tracing::instrument(name = "synthesize_state_aux_segments", skip_all)] + pub fn synthesize_state_aux_segments(&self, num_state: usize, num_aux: usize) -> (Vec>, Vec>) { + let jolt_witnesses = self.synthesize_witnesses(); + // TODO(sragss / arasuarun): Synthsize witnesses should just return (io, aux) + reassemble_segments_partial(jolt_witnesses, num_state, num_aux) + } + + #[tracing::instrument(name = "JoltCircuit::synthesize_witnesses", skip_all)] + fn synthesize_witnesses(&self) -> Vec> { let mut step_z = self.inputs.clone_to_stepwise(); @@ -94,13 +102,7 @@ impl> JoltCircuit { R1CSBuilder::::calculate_aux(step); }); - Ok(step_z) - } - - #[tracing::instrument(name = "synthesize_witness_segments", skip_all)] - pub fn synthesize_state_aux_segments(&self, num_state: usize, num_aux: usize) -> Result<(Vec>, Vec>), SynthesisError> { - let jolt_witnesses = self.synthesize_witnesses()?; - Ok(reassemble_segments_partial(jolt_witnesses, num_state, num_aux)) + step_z } } @@ -176,7 +178,9 @@ impl> R1CSInputs { } else { self.bytecode_a[step_index] * F::from(4u64) + F::from(RAM_START_ADDRESS) }; - // TODO(sragss): This indexing strategy is stolen from old -- but self.trace_len here is self.bytecode_a.len() -- not sure why we're using that to split inputs. + // TODO(sragss / arasu arun): This indexing strategy is stolen from old -- but self.trace_len here is self.bytecode_a.len() -- not sure why we're using that to split inputs. + + // 1 is constant, 0s in slots 1, 2 are filled by aux computation step.extend([F::from(1), F::from(0), F::from(0), F::from(step_index as u64), program_counter]); let bytecode_a_num_vals = self.bytecode_a.len() / self.trace_len(); for var_index in 0..bytecode_a_num_vals { @@ -248,18 +252,18 @@ impl> R1CSInputs { #[tracing::instrument(skip_all, name = "R1CSInputs::trace_len_chunks")] pub fn trace_len_chunks(&self, padded_trace_len: usize) -> Vec> { - // TODO(sragss / arasuarun): Explain why non-trace-len relevant stuff gets chunked to trace_len + // TODO(sragss / arasuarun): Explain why non-trace-len relevant stuff (ex: bytecode) gets chunked to trace_len let mut chunks: Vec> = Vec::new(); - chunks.extend(self.bytecode_a.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.bytecode_v.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.memreg_a_rw.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.memreg_v_reads.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.memreg_v_writes.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.chunks_x.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.chunks_y.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.chunks_query.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.lookup_outputs.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); - chunks.extend(self.circuit_flags_bits.chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.bytecode_a.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.bytecode_v.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.memreg_a_rw.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.memreg_v_reads.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.memreg_v_writes.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.chunks_x.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.chunks_y.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.chunks_query.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.lookup_outputs.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); + chunks.par_extend(self.circuit_flags_bits.par_chunks(padded_trace_len).map(|chunk| chunk.to_vec())); chunks } } @@ -310,17 +314,23 @@ impl R1CSProof { drop(_enter); // let w_segments_from_circuit = jolt_circuit.synthesize_witness_segments().unwrap(); - let (io_segments, aux_segments) = jolt_circuit.synthesize_state_aux_segments(4, jolt_shape.num_internal).unwrap(); + let (io_segments, aux_segments) = jolt_circuit.synthesize_state_aux_segments(4, jolt_shape.num_internal); let cloning_stuff_span = tracing::span!(tracing::Level::TRACE, "cloning_stuff"); let _enter = cloning_stuff_span.enter(); let inputs_segments = inputs.trace_len_chunks(padded_trace_len); - let w_segments = io_segments.clone().into_iter() - .chain(inputs_segments.iter().cloned()) - .chain(aux_segments.clone().into_iter()) - .collect::>(); + let mut w_segments: Vec> = Vec::with_capacity(io_segments.len() + inputs_segments.len() + aux_segments.len()); + // TODO(sragss / arasuarun): rm clones + w_segments.par_extend(io_segments.par_iter().cloned()); + w_segments.par_extend(inputs_segments.into_par_iter()); + w_segments.par_extend(aux_segments.par_iter().cloned()); + + // let w_segments = io_segments.clone().into_iter() + // .chain(inputs_segments.iter().cloned()) + // .chain(aux_segments.clone().into_iter()) + // .collect::>(); drop(_enter); drop(cloning_stuff_span); From 2da204af4a34b7f7007550c520e4c8d42c0580fc Mon Sep 17 00:00:00 2001 From: sragss Date: Mon, 18 Mar 2024 15:33:38 -0600 Subject: [PATCH 3/6] cosmetics --- jolt-core/src/r1cs/snark.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/jolt-core/src/r1cs/snark.rs b/jolt-core/src/r1cs/snark.rs index 0b0d439ca..285814354 100644 --- a/jolt-core/src/r1cs/snark.rs +++ b/jolt-core/src/r1cs/snark.rs @@ -283,11 +283,12 @@ impl R1CSProof { type S = spartan2::spartan::upsnark::R1CSSNARK; type F = Spartan2Fr; - let NUM_STEPS = padded_trace_len; + let num_steps = padded_trace_len; let span = tracing::span!(tracing::Level::TRACE, "JoltCircuit::new_from_inputs"); let _enter = span.enter(); - let jolt_circuit = JoltCircuit::::new_from_inputs(NUM_STEPS, inputs.clone()); + // TODO(sragss / arasuarun): After Spartan is merged we don't need to clone these inputs anymore + let jolt_circuit = JoltCircuit::::new_from_inputs(num_steps, inputs.clone()); drop(_enter); let span = tracing::span!(tracing::Level::TRACE, "shape_stuff"); @@ -313,7 +314,6 @@ impl R1CSProof { }; drop(_enter); - // let w_segments_from_circuit = jolt_circuit.synthesize_witness_segments().unwrap(); let (io_segments, aux_segments) = jolt_circuit.synthesize_state_aux_segments(4, jolt_shape.num_internal); let cloning_stuff_span = tracing::span!(tracing::Level::TRACE, "cloning_stuff"); @@ -322,16 +322,11 @@ impl R1CSProof { let inputs_segments = inputs.trace_len_chunks(padded_trace_len); let mut w_segments: Vec> = Vec::with_capacity(io_segments.len() + inputs_segments.len() + aux_segments.len()); - // TODO(sragss / arasuarun): rm clones + // TODO(sragss / arasuarun): rm clones in favor of references w_segments.par_extend(io_segments.par_iter().cloned()); w_segments.par_extend(inputs_segments.into_par_iter()); w_segments.par_extend(aux_segments.par_iter().cloned()); - // let w_segments = io_segments.clone().into_iter() - // .chain(inputs_segments.iter().cloned()) - // .chain(aux_segments.clone().into_iter()) - // .collect::>(); - drop(_enter); drop(cloning_stuff_span); @@ -353,7 +348,7 @@ impl R1CSProof { .chain(aux_comms.into_iter()) .collect::>(); - let (pk, vk) = SNARK::>::setup_precommitted(shape_single, NUM_STEPS, hyrax_ck).unwrap(); + let (pk, vk) = SNARK::>::setup_precommitted(shape_single, num_steps, hyrax_ck).unwrap(); SNARK::prove_precommitted(&pk, w_segments, comm_w_vec).map(|snark| Self { proof: snark, From 2cad1f8e67a51a0bfa5a7603865fed9937af4625 Mon Sep 17 00:00:00 2001 From: sragss Date: Mon, 18 Mar 2024 15:51:15 -0600 Subject: [PATCH 4/6] 2x speedup compute_chunks_operands --- jolt-core/src/jolt/vm/mod.rs | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index 78bd16e62..378760191 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -400,26 +400,22 @@ where drop(span); // Derive chunks_x and chunks_y - let span = tracing::span!(tracing::Level::INFO, "compute chunks operands"); + let span = tracing::span!(tracing::Level::INFO, "compute_chunks_operands"); let _guard = span.enter(); - let mut chunks_x_vecs: Vec> = vec![Vec::with_capacity(PADDED_TRACE_LEN); C]; - let mut chunks_y_vecs: Vec> = vec![Vec::with_capacity(PADDED_TRACE_LEN); C]; + let num_chunks = PADDED_TRACE_LEN * C; + let mut chunks_x: Vec = vec![F::zero(); num_chunks]; + let mut chunks_y: Vec = vec![F::zero(); num_chunks]; - for (i, op) in instructions.iter().enumerate() { + for (instruction_index, op) in instructions.iter().enumerate() { let [chunks_x_op, chunks_y_op] = op.operand_chunks(C, log_M); - for (j, (x, y)) in chunks_x_op.into_iter().zip(chunks_y_op.into_iter()).enumerate() { - chunks_x_vecs[j].push(F::from_u64(x as u64).unwrap()); - chunks_y_vecs[j].push(F::from_u64(y as u64).unwrap()); + for (chunk_index, (x, y)) in chunks_x_op.into_iter().zip(chunks_y_op.into_iter()).enumerate() { + let flat_chunk_index = instruction_index + chunk_index * PADDED_TRACE_LEN; + chunks_x[flat_chunk_index] = F::from_u64(x as u64).unwrap(); + chunks_y[flat_chunk_index] = F::from_u64(y as u64).unwrap(); } } - chunks_x_vecs.iter_mut().for_each(|vec| vec.resize(PADDED_TRACE_LEN, F::zero())); - chunks_y_vecs.iter_mut().for_each(|vec| vec.resize(PADDED_TRACE_LEN, F::zero())); - - let chunks_x: Vec = chunks_x_vecs.into_iter().flatten().collect(); - let chunks_y: Vec = chunks_y_vecs.into_iter().flatten().collect(); - drop(_guard); drop(span); From af4853c471e67f9446b73fdbf49d0f7d8d429fef Mon Sep 17 00:00:00 2001 From: sragss Date: Mon, 18 Mar 2024 20:04:36 -0600 Subject: [PATCH 5/6] faster pre-r1cs transformations --- jolt-core/src/jolt/vm/bytecode.rs | 1 + jolt-core/src/jolt/vm/mod.rs | 38 ++++++++++++---------- jolt-core/src/jolt/vm/read_write_memory.rs | 1 + 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/jolt-core/src/jolt/vm/bytecode.rs b/jolt-core/src/jolt/vm/bytecode.rs index f33f609d7..f45f091ef 100644 --- a/jolt-core/src/jolt/vm/bytecode.rs +++ b/jolt-core/src/jolt/vm/bytecode.rs @@ -312,6 +312,7 @@ impl> BytecodePolynomials { } } + #[tracing::instrument(skip_all, name = "BytecodePolynomials::get_polys_r1cs")] pub fn get_polys_r1cs(&self) -> (Vec, Vec) { let a_read_write_evals = self.a_read_write.evals().clone(); let v_read_write_evals = [ diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index 378760191..feaa9c7cf 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -438,27 +438,29 @@ where // Assemble the polynomials let (bytecode_a, mut bytecode_v) = jolt_polynomials.bytecode.get_polys_r1cs(); - bytecode_v.extend(packed_flags.iter()); + bytecode_v.par_extend(packed_flags.par_iter()); let (memreg_a_rw, memreg_v_reads, memreg_v_writes) = jolt_polynomials.read_write_memory.get_polys_r1cs(); - let chunks_query: Vec = jolt_polynomials.instruction_lookups.dim.par_iter().take(C).flat_map(|poly| poly.evals()).collect(); + let span = tracing::span!(tracing::Level::INFO, "chunks_query"); + let _guard = span.enter(); + let mut chunks_query: Vec = Vec::with_capacity(C * jolt_polynomials.instruction_lookups.dim[0].len()); + for i in 0..C { + chunks_query.par_extend(jolt_polynomials.instruction_lookups.dim[i].evals_ref().par_iter()); + } + drop(_guard); // Flattening this out into a Vec and chunking into PADDED_TRACE_LEN-sized chunks // will be the exact witness vector to feed into the R1CS // after pre-pending IO and appending the AUX - // let inputs: Vec> = vec![ - // bytecode_a, // prog_a_rw, - // bytecode_v, // prog_v_rw (with circuit_flags_packed) - // memreg_a_rw, - // memreg_v_reads, - // memreg_v_writes, - // chunks_x.clone(), - // chunks_y.clone(), - // chunks_query, - // lookup_outputs.clone(), - // circuit_flags_bits.clone(), - // ]; + + let span = tracing::span!(tracing::Level::INFO, "input_cloning"); + let _guard = span.enter(); + let input_chunks_x = chunks_x.clone(); + let input_chunks_y = chunks_y.clone(); + let input_lookup_outputs = lookup_outputs.clone(); + let input_circuit_flags_bits = circuit_flags_bits.clone(); + drop(_guard); let inputs: R1CSInputs = R1CSInputs::from_ark( bytecode_a, @@ -466,11 +468,11 @@ where memreg_a_rw, memreg_v_reads, memreg_v_writes, - chunks_x.clone(), - chunks_y.clone(), + input_chunks_x, + input_chunks_y, chunks_query, - lookup_outputs.clone(), - circuit_flags_bits.clone() + input_lookup_outputs, + input_circuit_flags_bits ); // Assemble the commitments diff --git a/jolt-core/src/jolt/vm/read_write_memory.rs b/jolt-core/src/jolt/vm/read_write_memory.rs index 6ad8be671..f227de27b 100644 --- a/jolt-core/src/jolt/vm/read_write_memory.rs +++ b/jolt-core/src/jolt/vm/read_write_memory.rs @@ -504,6 +504,7 @@ impl> ReadWriteMemory { ] } + #[tracing::instrument(skip_all, name = "ReadWriteMemory::get_polys_r1cs")] pub fn get_polys_r1cs(&self) -> (Vec, Vec, Vec) { let a_polys = self.a_read_write.iter().flat_map(|poly| poly.evals()).collect::>(); let v_read_polys = self.v_read.iter().flat_map(|poly| poly.evals()).collect::>(); From ce906b6c458daff787098b50a0f8b8354d40f078 Mon Sep 17 00:00:00 2001 From: sragss Date: Mon, 18 Mar 2024 20:06:03 -0600 Subject: [PATCH 6/6] adtnl comment --- jolt-core/src/r1cs/snark.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jolt-core/src/r1cs/snark.rs b/jolt-core/src/r1cs/snark.rs index 285814354..1ee1611b9 100644 --- a/jolt-core/src/r1cs/snark.rs +++ b/jolt-core/src/r1cs/snark.rs @@ -168,7 +168,9 @@ impl> R1CSInputs { #[tracing::instrument(skip_all, name = "R1CSInputs::clone_to_stepwise")] pub fn clone_to_stepwise(&self) -> Vec> { const PREFIX_VARS_PER_STEP: usize = 5; - const AUX_VARS_PER_STEP: usize = 20; + + // AUX_VARS_PER_STEP has to be greater than the number of additional vars pushed by the constraint system + const AUX_VARS_PER_STEP: usize = 20; let num_inputs_per_step = self.num_vars_per_step() + PREFIX_VARS_PER_STEP; let stepwise = (0..self.trace_len()).into_par_iter().map(|step_index| {