Skip to content

Commit

Permalink
Merge pull request #200 from a16z/remove-circom-cleanup
Browse files Browse the repository at this point in the history
Accelerate non-Circom R1CS Constraints
  • Loading branch information
sragss authored Mar 21, 2024
2 parents 2ba7ae4 + d287690 commit 67cb6d7
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 101 deletions.
1 change: 1 addition & 0 deletions jolt-core/src/jolt/vm/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ impl<F: PrimeField, G: CurveGroup<ScalarField = F>> BytecodePolynomials<F, G> {
}
}

#[tracing::instrument(skip_all, name = "BytecodePolynomials::get_polys_r1cs")]
pub fn get_polys_r1cs(&self) -> (Vec<F>, Vec<F>) {
let a_read_write_evals = self.a_read_write.evals().clone();
let v_read_write_evals = [
Expand Down
61 changes: 36 additions & 25 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::poly::dense_mlpoly::DensePolynomial;
use crate::poly::hyrax::{HyraxCommitment, HyraxGenerators};
use crate::poly::pedersen::PedersenGenerators;
use crate::poly::structured_poly::BatchablePolynomials;
use crate::r1cs::snark::R1CSProof;
use crate::r1cs::snark::{R1CSInputs, R1CSProof};
use crate::utils::errors::ProofVerifyError;
use crate::{
jolt::{
Expand Down Expand Up @@ -400,26 +400,22 @@ where
drop(span);

// Derive chunks_x and chunks_y
let span = tracing::span!(tracing::Level::INFO, "compute chunks operands");
let span = tracing::span!(tracing::Level::INFO, "compute_chunks_operands");
let _guard = span.enter();

let mut chunks_x_vecs: Vec<Vec<F>> = vec![Vec::with_capacity(PADDED_TRACE_LEN); C];
let mut chunks_y_vecs: Vec<Vec<F>> = vec![Vec::with_capacity(PADDED_TRACE_LEN); C];
let num_chunks = PADDED_TRACE_LEN * C;
let mut chunks_x: Vec<F> = vec![F::zero(); num_chunks];
let mut chunks_y: Vec<F> = vec![F::zero(); num_chunks];

for (i, op) in instructions.iter().enumerate() {
for (instruction_index, op) in instructions.iter().enumerate() {
let [chunks_x_op, chunks_y_op] = op.operand_chunks(C, log_M);
for (j, (x, y)) in chunks_x_op.into_iter().zip(chunks_y_op.into_iter()).enumerate() {
chunks_x_vecs[j].push(F::from_u64(x as u64).unwrap());
chunks_y_vecs[j].push(F::from_u64(y as u64).unwrap());
for (chunk_index, (x, y)) in chunks_x_op.into_iter().zip(chunks_y_op.into_iter()).enumerate() {
let flat_chunk_index = instruction_index + chunk_index * PADDED_TRACE_LEN;
chunks_x[flat_chunk_index] = F::from_u64(x as u64).unwrap();
chunks_y[flat_chunk_index] = F::from_u64(y as u64).unwrap();
}
}

chunks_x_vecs.iter_mut().for_each(|vec| vec.resize(PADDED_TRACE_LEN, F::zero()));
chunks_y_vecs.iter_mut().for_each(|vec| vec.resize(PADDED_TRACE_LEN, F::zero()));

let chunks_x: Vec<F> = chunks_x_vecs.into_iter().flatten().collect();
let chunks_y: Vec<F> = chunks_y_vecs.into_iter().flatten().collect();

drop(_guard);
drop(span);

Expand All @@ -442,27 +438,42 @@ where

// Assemble the polynomials
let (bytecode_a, mut bytecode_v) = jolt_polynomials.bytecode.get_polys_r1cs();
bytecode_v.extend(packed_flags.iter());
bytecode_v.par_extend(packed_flags.par_iter());

let (memreg_a_rw, memreg_v_reads, memreg_v_writes) = jolt_polynomials.read_write_memory.get_polys_r1cs();

let chunks_query: Vec<F> = jolt_polynomials.instruction_lookups.dim.par_iter().take(C).flat_map(|poly| poly.evals()).collect();
let span = tracing::span!(tracing::Level::INFO, "chunks_query");
let _guard = span.enter();
let mut chunks_query: Vec<F> = Vec::with_capacity(C * jolt_polynomials.instruction_lookups.dim[0].len());
for i in 0..C {
chunks_query.par_extend(jolt_polynomials.instruction_lookups.dim[i].evals_ref().par_iter());
}
drop(_guard);

// Flattening this out into a Vec<F> and chunking into PADDED_TRACE_LEN-sized chunks
// will be the exact witness vector to feed into the R1CS
// after pre-pending IO and appending the AUX
let inputs: Vec<Vec<F>> = vec![
bytecode_a, // prog_a_rw,
bytecode_v, // prog_v_rw (with circuit_flags_packed)

let span = tracing::span!(tracing::Level::INFO, "input_cloning");
let _guard = span.enter();
let input_chunks_x = chunks_x.clone();
let input_chunks_y = chunks_y.clone();
let input_lookup_outputs = lookup_outputs.clone();
let input_circuit_flags_bits = circuit_flags_bits.clone();
drop(_guard);

let inputs: R1CSInputs<spartan2::provider::bn256_grumpkin::bn256::Scalar> = R1CSInputs::from_ark(
bytecode_a,
bytecode_v,
memreg_a_rw,
memreg_v_reads,
memreg_v_writes,
chunks_x.clone(),
chunks_y.clone(),
input_chunks_x,
input_chunks_y,
chunks_query,
lookup_outputs.clone(),
circuit_flags_bits.clone(),
];
input_lookup_outputs,
input_circuit_flags_bits
);

// Assemble the commitments
let span = tracing::span!(tracing::Level::INFO, "bytecode_commitment_conversions");
Expand Down Expand Up @@ -510,7 +521,7 @@ where
circuit_flags_comm
].concat();

R1CSProof::prove(
R1CSProof::prove::<F>(
32, C, PADDED_TRACE_LEN,
inputs,
preprocessing.spartan_generators,
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/jolt/vm/read_write_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ impl<F: PrimeField, G: CurveGroup<ScalarField = F>> ReadWriteMemory<F, G> {
]
}

#[tracing::instrument(skip_all, name = "ReadWriteMemory::get_polys_r1cs")]
pub fn get_polys_r1cs(&self) -> (Vec<F>, Vec<F>, Vec<F>) {
let a_polys = self.a_read_write.iter().flat_map(|poly| poly.evals()).collect::<Vec<F>>();
let v_read_polys = self.v_read.iter().flat_map(|poly| poly.evals()).collect::<Vec<F>>();
Expand Down
9 changes: 6 additions & 3 deletions jolt-core/src/r1cs/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use smallvec::SmallVec;
use smallvec::smallvec;
use ff::PrimeField;
use rayon::prelude::*;

/* Compiler Variables */
const C: usize = 4;
Expand Down Expand Up @@ -125,6 +126,7 @@ fn concat_constraint_vecs(mut x: SmallVec<[(usize, i64); SMALLVEC_SIZE]>, y: Sma
}

fn i64_to_f<F: PrimeField>(num: i64) -> F {
// TODO(sragss): Make from_u64
if num < 0 {
F::ZERO - F::from((-num) as u64)
} else {
Expand Down Expand Up @@ -721,11 +723,12 @@ impl R1CSBuilder {
|| modify_matrix(&mut self.C)));
}

#[tracing::instrument(skip_all, name = "Shape::convert_to_field")]
pub fn convert_to_field<F: PrimeField>(&self) -> (Vec<(usize, usize, F)>, Vec<(usize, usize, F)>, Vec<(usize, usize, F)>) {
(
self.A.iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::<F>(*val))).collect(),
self.B.iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::<F>(*val))).collect(),
self.C.iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::<F>(*val))).collect(),
self.A.par_iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::<F>(*val))).collect(),
self.B.par_iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::<F>(*val))).collect(),
self.C.par_iter().map(|(row, idx, val)| (*row, *idx, i64_to_f::<F>(*val))).collect(),
)
}
}
Expand Down
Loading

0 comments on commit 67cb6d7

Please sign in to comment.