Skip to content

Commit

Permalink
Merge pull request #403 from a16z/sragss/sparse-spartan
Browse files Browse the repository at this point in the history
Sparsify Spartan
  • Loading branch information
sragss authored Jul 19, 2024
2 parents 7a18b97 + 829aa49 commit a0f8fbb
Show file tree
Hide file tree
Showing 11 changed files with 958 additions and 287 deletions.
1 change: 1 addition & 0 deletions jolt-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#![feature(generic_const_exprs)]
#![feature(iter_next_chunk)]
#![allow(long_running_const_eval)]
#![allow(clippy::len_without_is_empty)]

#[cfg(feature = "host")]
pub mod benches;
Expand Down
4 changes: 1 addition & 3 deletions jolt-core/src/poly/commitment/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ impl<F: JoltField> CommitmentScheme for MockCommitScheme<F> {
type Proof = MockProof<F>;
type BatchedProof = MockProof<F>;

fn setup(_shapes: &[CommitShape]) -> Self::Setup {
()
}
fn setup(_shapes: &[CommitShape]) -> Self::Setup {}
fn commit(poly: &DensePolynomial<Self::Field>, _setup: &Self::Setup) -> Self::Commitment {
MockCommitment {
poly: poly.to_owned(),
Expand Down
26 changes: 25 additions & 1 deletion jolt-core/src/poly/dense_mlpoly.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::too_many_arguments)]
use crate::poly::eq_poly::EqPolynomial;
use crate::utils::thread::unsafe_allocate_zero_vec;
use crate::utils::thread::{drop_in_background_thread, unsafe_allocate_zero_vec};
use crate::utils::{self, compute_dotproduct, compute_dotproduct_low_optimized};

use crate::field::JoltField;
Expand Down Expand Up @@ -201,11 +201,35 @@ impl<F: JoltField> DensePolynomial<F> {
}
}

/// Note: does not truncate
#[tracing::instrument(skip_all)]
pub fn bound_poly_var_bot(&mut self, r: &F) {
let n = self.len() / 2;
for i in 0..n {
self.Z[i] = self.Z[2 * i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]);
}

self.num_vars -= 1;
self.len = n;
}

pub fn bound_poly_var_bot_01_optimized(&mut self, r: &F) {
let n = self.len() / 2;
let mut new_z = unsafe_allocate_zero_vec(n);
new_z.par_iter_mut().enumerate().for_each(|(i, z)| {
let m = self.Z[2 * i + 1] - self.Z[2 * i];
*z = if m.is_zero() {
self.Z[2 * i]
} else if m.is_one() {
self.Z[2 * i] + r
} else {
self.Z[2 * i] + *r * m
}
});

let old_Z = std::mem::replace(&mut self.Z, new_z);
drop_in_background_thread(old_Z);

self.num_vars -= 1;
self.len = n;
}
Expand Down
209 changes: 122 additions & 87 deletions jolt-core/src/r1cs/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ use crate::{
field::{JoltField, OptimizedMul},
r1cs::key::{SparseConstraints, UniformR1CS},
utils::{
math::Math,
mul_0_1_optimized,
thread::{drop_in_background_thread, unsafe_allocate_zero_vec},
thread::{
drop_in_background_thread, par_flatten_triple, unsafe_allocate_sparse_zero_vec,
unsafe_allocate_zero_vec,
},
},
};
#[allow(unused_imports)] // clippy thinks these aren't needed lol
Expand All @@ -14,6 +18,7 @@ use std::{collections::HashMap, fmt::Debug};
use super::{
key::{NonUniformR1CS, SparseEqualityItem},
ops::{ConstraintInput, Term, Variable, LC},
special_polys::SparsePolynomial,
};

pub trait R1CSConstraintBuilder<F: JoltField> {
Expand All @@ -32,7 +37,7 @@ struct Constraint<I: ConstraintInput> {

impl<I: ConstraintInput> Constraint<I> {
#[cfg(test)]
fn is_sat(&self, inputs: &Vec<i64>) -> bool {
fn is_sat(&self, inputs: &[i64]) -> bool {
// Find the number of variables and the number of aux. Inputs should be equal to this combined length
let num_inputs = I::COUNT;

Expand Down Expand Up @@ -243,15 +248,15 @@ impl<F: JoltField, I: ConstraintInput> R1CSBuilder<F, I> {
let a = condition;
let b = left - right;
let c = LC::zero();
let constraint = Constraint { a, b, c };
let constraint = Constraint { a, b, c }; // TODO(sragss): Can do better on middle term.
self.constraints.push(constraint);
}

pub fn constrain_binary(&mut self, value: impl Into<LC<I>>) {
let one: LC<I> = Variable::Constant.into();
let a: LC<I> = value.into();
let b = one - a.clone();
// value * (1 - value)
// value * (1 - value) == 0
let constraint = Constraint {
a,
b,
Expand Down Expand Up @@ -748,12 +753,17 @@ impl<F: JoltField, I: ConstraintInput> CombinedUniformBuilder<F, I> {

/// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]]
/// aux should be of the format [[Aux(0), Aux(0), ...], ... [Aux(self.next_aux - 1), ...]]
#[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan")]
#[tracing::instrument(skip_all, name = "CombinedUniformBuilder::compute_spartan_sparse")]
#[allow(clippy::type_complexity)]
pub fn compute_spartan_Az_Bz_Cz(
&self,
inputs: &[Vec<F>],
aux: &[Vec<F>],
) -> (Vec<F>, Vec<F>, Vec<F>) {
) -> (
SparsePolynomial<F>,
SparsePolynomial<F>,
SparsePolynomial<F>,
) {
assert_eq!(inputs.len(), I::COUNT);
let num_aux = self.uniform_builder.num_aux();
assert_eq!(aux.len(), num_aux);
Expand All @@ -763,38 +773,53 @@ impl<F: JoltField, I: ConstraintInput> CombinedUniformBuilder<F, I> {
.all(|inner_input| inner_input.len() == self.uniform_repeat));

let uniform_constraint_rows = self.uniform_repeat_constraint_rows();
// TODO(sragss): Allocation can overshoot by up to a factor of 2, Spartan could handle non-pow-2 Az,Bz,Cz
let constraint_rows = self.constraint_rows().next_power_of_two();
let (mut Az, mut Bz, mut Cz) = (
unsafe_allocate_zero_vec(constraint_rows),
unsafe_allocate_zero_vec(constraint_rows),
unsafe_allocate_zero_vec(constraint_rows),
);

let batch_inputs = |lc: &LC<I>| batch_inputs(lc, inputs, aux);

// uniform_constraints: Xz[0..uniform_constraint_rows]
// TODO(sragss): Attempt moving onto key and computing from materialized rows rather than linear combos
let span = tracing::span!(tracing::Level::DEBUG, "compute_constraints");
let enter = span.enter();
let az_chunks = Az.par_chunks_mut(self.uniform_repeat);
let bz_chunks = Bz.par_chunks_mut(self.uniform_repeat);
let cz_chunks = Cz.par_chunks_mut(self.uniform_repeat);

self.uniform_builder
let span = tracing::span!(tracing::Level::DEBUG, "uniform_evals");
let _enter = span.enter();
let uni_constraint_evals: Vec<(Vec<(F, usize)>, Vec<(F, usize)>, Vec<(F, usize)>)> = self
.uniform_builder
.constraints
.par_iter()
.zip(az_chunks.zip(bz_chunks.zip(cz_chunks)))
.for_each(|(constraint, (az_chunk, (bz_chunk, cz_chunk)))| {
let a_inputs = batch_inputs(&constraint.a);
let b_inputs = batch_inputs(&constraint.b);
let c_inputs = batch_inputs(&constraint.c);

constraint.a.evaluate_batch_mut(&a_inputs, az_chunk);
constraint.b.evaluate_batch_mut(&b_inputs, bz_chunk);
constraint.c.evaluate_batch_mut(&c_inputs, cz_chunk);
});
drop(enter);
.enumerate()
.map(|(constraint_index, constraint)| {
let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat);

let mut evaluate_lc_chunk = |lc: &LC<I>| {
if !lc.terms().is_empty() {
let inputs = batch_inputs(lc);
lc.evaluate_batch_mut(&inputs, &mut dense_output_buffer);

// Take only the non-zero elements and represent them as sparse tuples (eval, dense_index)
let mut sparse = Vec::with_capacity(self.uniform_repeat); // overshoot
for (local_index, item) in dense_output_buffer.iter().enumerate() {
if !item.is_zero() {
let global_index =
constraint_index * self.uniform_repeat + local_index;
sparse.push((*item, global_index));
}
}
sparse
} else {
vec![]
}
};

let a_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.a);
let b_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.b);
let c_chunk: Vec<(F, usize)> = evaluate_lc_chunk(&constraint.c);

(a_chunk, b_chunk, c_chunk)
})
.collect();

let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple(
uni_constraint_evals,
unsafe_allocate_sparse_zero_vec,
self.uniform_repeat, // Capacity overhead for offset_eq constraints.
);

// offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1]
// (a - b) * condition == 0
Expand All @@ -816,48 +841,65 @@ impl<F: JoltField, I: ConstraintInput> CombinedUniformBuilder<F, I> {
.1
.evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat);

let Az_off = Az[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat]
.par_iter_mut();
let Bz_off = Bz[uniform_constraint_rows..uniform_constraint_rows + self.uniform_repeat]
.par_iter_mut();

(0..self.uniform_repeat)
.into_par_iter()
.zip(Az_off.zip(Bz_off))
.for_each(|(step_index, (az, bz))| {
// Write corresponding values, if outside the step range, only include the constant.
let a_step = step_index + if constr.a.0 { 1 } else { 0 };
let b_step = step_index + if constr.b.0 { 1 } else { 0 };
let a = eq_a_evals
.get(a_step)
.cloned()
.unwrap_or(constr.a.1.constant_term_field());
let b = eq_b_evals
.get(b_step)
.cloned()
.unwrap_or(constr.b.1.constant_term_field());
*az = a - b;

let condition_step = step_index + if constr.cond.0 { 1 } else { 0 };
*bz = condition_evals
.get(condition_step)
.cloned()
.unwrap_or(constr.cond.1.constant_term_field());
});
(0..self.uniform_repeat).for_each(|step_index| {
// Write corresponding values, if outside the step range, only include the constant.
let a_step = step_index + constr.a.0 as usize;
let b_step = step_index + constr.b.0 as usize;
let a = eq_a_evals
.get(a_step)
.cloned()
.unwrap_or(constr.a.1.constant_term_field());
let b = eq_b_evals
.get(b_step)
.cloned()
.unwrap_or(constr.b.1.constant_term_field());
let az = a - b;

let global_index = uniform_constraint_rows + step_index;
if !az.is_zero() {
az_sparse.push((az, global_index));
}

(Az, Bz, Cz)
let condition_step = step_index + constr.cond.0 as usize;
let bz = condition_evals
.get(condition_step)
.cloned()
.unwrap_or(constr.cond.1.constant_term_field());
if !bz.is_zero() {
bz_sparse.push((bz, global_index));
}
});
drop(_enter);

let num_vars = self.constraint_rows().next_power_of_two().log_2();
let az_poly = SparsePolynomial::new(num_vars, az_sparse);
let bz_poly = SparsePolynomial::new(num_vars, bz_sparse);
let cz_poly = SparsePolynomial::new(num_vars, cz_sparse);

#[cfg(test)]
self.assert_valid(&az_poly, &bz_poly, &cz_poly);

(az_poly, bz_poly, cz_poly)
}

#[cfg(test)]
pub fn assert_valid(&self, az: &[F], bz: &[F], cz: &[F]) {
pub fn assert_valid(
&self,
az: &SparsePolynomial<F>,
bz: &SparsePolynomial<F>,
cz: &SparsePolynomial<F>,
) {
let az = az.clone().to_dense();
let bz = bz.clone().to_dense();
let cz = cz.clone().to_dense();

let rows = az.len();
let expected_rows = self.constraint_rows().next_power_of_two();
assert_eq!(az.len(), expected_rows);
assert_eq!(bz.len(), expected_rows);
assert_eq!(cz.len(), expected_rows);
assert_eq!(bz.len(), rows);
assert_eq!(cz.len(), rows);

for constraint_index in 0..rows {
let uniform_constraint_index = constraint_index / self.uniform_repeat;
if az[constraint_index] * bz[constraint_index] != cz[constraint_index] {
let uniform_constraint_index = constraint_index / self.uniform_repeat;
let step_index = constraint_index % self.uniform_repeat;
panic!(
"Mismatch at global constraint {constraint_index} => {:?}\n\
Expand All @@ -883,7 +925,7 @@ mod tests {
) -> F {
let multi_step_inputs: Vec<Vec<F>> = single_step_inputs
.iter()
.map(|input| vec![input.clone()])
.map(|input| vec![*input])
.collect();
let multi_step_inputs_ref: Vec<&[F]> =
multi_step_inputs.iter().map(|v| v.as_slice()).collect();
Expand Down Expand Up @@ -1269,10 +1311,6 @@ mod tests {
assert_eq!(aux, vec![vec![Fr::from(5 * 7), Fr::from(11 * 13)]]);

let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux);
assert_eq!(az.len(), 4);
assert_eq!(bz.len(), 4);
assert_eq!(cz.len(), 4);

combined_builder.assert_valid(&az, &bz, &cz);
}

Expand Down Expand Up @@ -1333,10 +1371,6 @@ mod tests {
);

let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux);
assert_eq!(az.len(), 16);
assert_eq!(bz.len(), 16);
assert_eq!(cz.len(), 16);

combined_builder.assert_valid(&az, &bz, &cz);
}

Expand Down Expand Up @@ -1381,10 +1415,6 @@ mod tests {
assert_eq!(aux, vec![vec![Fr::from(5 * 7), Fr::from(5 * 13)]]);

let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux);
assert_eq!(az.len(), 4);
assert_eq!(bz.len(), 4);
assert_eq!(cz.len(), 4);

combined_builder.assert_valid(&az, &bz, &cz);
}

Expand Down Expand Up @@ -1455,11 +1485,16 @@ mod tests {
flat_witness.resize(flat_witness.len().next_power_of_two(), Fr::zero());
flat_witness.push(Fr::one());
flat_witness.resize(flat_witness.len().next_power_of_two(), Fr::zero());
let (mut builder_az, mut builder_bz, mut builder_cz) =

let (builder_az, builder_bz, builder_cz) =
builder.compute_spartan_Az_Bz_Cz(&witness_segments, &[]);
builder_az.resize(key.num_rows_total(), Fr::zero());
builder_bz.resize(key.num_rows_total(), Fr::zero());
builder_cz.resize(key.num_rows_total(), Fr::zero());
let mut dense_az = builder_az.to_dense().evals();
let mut dense_bz = builder_bz.to_dense().evals();
let mut dense_cz = builder_cz.to_dense().evals();
dense_az.resize(key.num_rows_total(), Fr::zero());
dense_bz.resize(key.num_rows_total(), Fr::zero());
dense_cz.resize(key.num_rows_total(), Fr::zero());

for row in 0..key.num_rows_total() {
let mut az_eval = Fr::zero();
let mut bz_eval = Fr::zero();
Expand All @@ -1471,9 +1506,9 @@ mod tests {
}

// Row 11 is the problem! Builder thinks this row should be 0. big_a thinks this row should be 17 (13 + 4)
assert_eq!(builder_az[row], az_eval, "Row {row} failed in az_eval.");
assert_eq!(builder_bz[row], bz_eval, "Row {row} failed in bz_eval.");
assert_eq!(builder_cz[row], cz_eval, "Row {row} failed in cz_eval.");
assert_eq!(dense_az[row], az_eval, "Row {row} failed in az_eval.");
assert_eq!(dense_bz[row], bz_eval, "Row {row} failed in bz_eval.");
assert_eq!(dense_cz[row], cz_eval, "Row {row} failed in cz_eval.");
}
}
}
2 changes: 1 addition & 1 deletion jolt-core/src/r1cs/jolt_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ mod tests {
inputs[JoltIn::OpFlags_IsImm as usize][0] = Fr::zero(); // second_operand = rs2 => immediate

let aux = combined_builder.compute_aux(&inputs);
let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux);

let (az, bz, cz) = combined_builder.compute_spartan_Az_Bz_Cz(&inputs, &aux);
combined_builder.assert_valid(&az, &bz, &cz);
}
}
Loading

0 comments on commit a0f8fbb

Please sign in to comment.