Skip to content

Commit

Permalink
Use PrimeField as generic bound across the codebase (#67)
Browse files Browse the repository at this point in the history
Co-authored-by: Cesar Descalzo <[email protected]>
Co-authored-by: Antonio Mejías Gil <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent fa6262a commit a573c15
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 137 deletions.
128 changes: 64 additions & 64 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions benches/groth16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};
use ark_circom::{read_zkey, CircomReduction, WitnessCalculator};
use ark_std::rand::thread_rng;

use ark_bn254::Bn254;
use ark_bn254::{Bn254, Fr};
use ark_groth16::Groth16;
use wasmer::Store;

Expand Down Expand Up @@ -39,7 +39,7 @@ fn bench_groth(c: &mut Criterion, num_validators: u32, num_constraints: u32) {
)
.unwrap();
let full_assignment = wtns
.calculate_witness_element::<Bn254, _>(&mut store, inputs, false)
.calculate_witness_element::<Fr, _>(&mut store, inputs, false)
.unwrap();

let mut rng = thread_rng();
Expand Down
25 changes: 13 additions & 12 deletions src/circom/builder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use ark_ec::pairing::Pairing;
use std::{fs::File, path::Path};
use wasmer::Store;

use ark_ff::PrimeField;

use super::{CircomCircuit, R1CS};

use num_bigint::BigInt;
Expand All @@ -14,21 +15,21 @@ use crate::{
use color_eyre::Result;

#[derive(Debug)]
pub struct CircomBuilder<E: Pairing> {
pub cfg: CircomConfig<E>,
pub struct CircomBuilder<F: PrimeField> {
pub cfg: CircomConfig<F>,
pub inputs: HashMap<String, Vec<BigInt>>,
}

// Add utils for creating this from files / directly from bytes
#[derive(Debug)]
pub struct CircomConfig<E: Pairing> {
pub r1cs: R1CS<E>,
pub struct CircomConfig<F: PrimeField> {
pub r1cs: R1CS<F>,
pub wtns: WitnessCalculator,
pub store: Store,
pub sanity_check: bool,
}

impl<E: Pairing> CircomConfig<E> {
impl<F: PrimeField> CircomConfig<F> {
pub fn new(wtns: impl AsRef<Path>, r1cs: impl AsRef<Path>) -> Result<Self> {
let mut store = Store::default();
let wtns = WitnessCalculator::new(&mut store, wtns).unwrap();
Expand Down Expand Up @@ -56,10 +57,10 @@ impl<E: Pairing> CircomConfig<E> {
}
}

impl<E: Pairing> CircomBuilder<E> {
impl<F: PrimeField> CircomBuilder<F> {
/// Instantiates a new builder using the provided WitnessGenerator and R1CS files
/// for your circuit
pub fn new(cfg: CircomConfig<E>) -> Self {
pub fn new(cfg: CircomConfig<F>) -> Self {
Self {
cfg,
inputs: HashMap::new(),
Expand All @@ -74,7 +75,7 @@ impl<E: Pairing> CircomBuilder<E> {

/// Generates an empty circom circuit with no witness set, to be used for
/// generation of the trusted setup parameters
pub fn setup(&self) -> CircomCircuit<E> {
pub fn setup(&self) -> CircomCircuit<F> {
let mut circom = CircomCircuit {
r1cs: self.cfg.r1cs.clone(),
witness: None,
Expand All @@ -88,11 +89,11 @@ impl<E: Pairing> CircomBuilder<E> {

/// Creates the circuit populated with the witness corresponding to the previously
/// provided inputs
pub fn build(mut self) -> Result<CircomCircuit<E>> {
pub fn build(mut self) -> Result<CircomCircuit<F>> {
let mut circom = self.setup();

// calculate the witness
let witness = self.cfg.wtns.calculate_witness_element::<E, _>(
let witness = self.cfg.wtns.calculate_witness_element::<F, _>(
&mut self.cfg.store,
self.inputs,
self.cfg.sanity_check,
Expand All @@ -102,7 +103,7 @@ impl<E: Pairing> CircomBuilder<E> {
// sanity check
debug_assert!({
use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystem};
let cs = ConstraintSystem::<E::ScalarField>::new_ref();
let cs = ConstraintSystem::<F>::new_ref();
circom.clone().generate_constraints(cs.clone()).unwrap();
let is_satisfied = cs.is_satisfied().unwrap();
if !is_satisfied {
Expand Down
36 changes: 16 additions & 20 deletions src/circom/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use ark_ec::pairing::Pairing;
use ark_relations::r1cs::{
ConstraintSynthesizer, ConstraintSystemRef, LinearCombination, SynthesisError, Variable,
};

use ark_ff::PrimeField;

use super::R1CS;

use color_eyre::Result;

#[derive(Clone, Debug)]
pub struct CircomCircuit<E: Pairing> {
pub r1cs: R1CS<E>,
pub witness: Option<Vec<E::ScalarField>>,
pub struct CircomCircuit<F: PrimeField> {
pub r1cs: R1CS<F>,
pub witness: Option<Vec<F>>,
}

impl<E: Pairing> CircomCircuit<E> {
pub fn get_public_inputs(&self) -> Option<Vec<E::ScalarField>> {
impl<F: PrimeField> CircomCircuit<F> {
pub fn get_public_inputs(&self) -> Option<Vec<F>> {
match &self.witness {
None => None,
Some(w) => match &self.r1cs.wire_mapping {
Expand All @@ -25,19 +26,16 @@ impl<E: Pairing> CircomCircuit<E> {
}
}

impl<E: Pairing> ConstraintSynthesizer<E::ScalarField> for CircomCircuit<E> {
fn generate_constraints(
self,
cs: ConstraintSystemRef<E::ScalarField>,
) -> Result<(), SynthesisError> {
impl<F: PrimeField> ConstraintSynthesizer<F> for CircomCircuit<F> {
fn generate_constraints(self, cs: ConstraintSystemRef<F>) -> Result<(), SynthesisError> {
let witness = &self.witness;
let wire_mapping = &self.r1cs.wire_mapping;

// Start from 1 because Arkworks implicitly allocates One for the first input
for i in 1..self.r1cs.num_inputs {
cs.new_input_variable(|| {
Ok(match witness {
None => E::ScalarField::from(1u32),
None => F::from(1u32),
Some(w) => match wire_mapping {
Some(m) => w[m[i]],
None => w[i],
Expand All @@ -49,7 +47,7 @@ impl<E: Pairing> ConstraintSynthesizer<E::ScalarField> for CircomCircuit<E> {
for i in 0..self.r1cs.num_aux {
cs.new_witness_variable(|| {
Ok(match witness {
None => E::ScalarField::from(1u32),
None => F::from(1u32),
Some(w) => match wire_mapping {
Some(m) => w[m[i + self.r1cs.num_inputs]],
None => w[i + self.r1cs.num_inputs],
Expand All @@ -65,12 +63,10 @@ impl<E: Pairing> ConstraintSynthesizer<E::ScalarField> for CircomCircuit<E> {
Variable::Witness(index - self.r1cs.num_inputs)
}
};
let make_lc = |lc_data: &[(usize, E::ScalarField)]| {
let make_lc = |lc_data: &[(usize, F)]| {
lc_data.iter().fold(
LinearCombination::<E::ScalarField>::zero(),
|lc: LinearCombination<E::ScalarField>, (index, coeff)| {
lc + (*coeff, make_index(*index))
},
LinearCombination::<F>::zero(),
|lc: LinearCombination<F>, (index, coeff)| lc + (*coeff, make_index(*index)),
)
};

Expand All @@ -90,12 +86,12 @@ impl<E: Pairing> ConstraintSynthesizer<E::ScalarField> for CircomCircuit<E> {
mod tests {
use super::*;
use crate::{CircomBuilder, CircomConfig};
use ark_bn254::{Bn254, Fr};
use ark_bn254::Fr;
use ark_relations::r1cs::ConstraintSystem;

#[tokio::test]
async fn satisfied() {
let cfg = CircomConfig::<Bn254>::new(
let cfg = CircomConfig::<Fr>::new(
"./test-vectors/mycircuit.wasm",
"./test-vectors/mycircuit.r1cs",
)
Expand Down
6 changes: 2 additions & 4 deletions src/circom/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use ark_ec::pairing::Pairing;

pub mod r1cs_reader;
pub use r1cs_reader::{R1CSFile, R1CS};

Expand All @@ -12,5 +10,5 @@ pub use builder::{CircomBuilder, CircomConfig};
mod qap;
pub use qap::CircomReduction;

pub type Constraints<E> = (ConstraintVec<E>, ConstraintVec<E>, ConstraintVec<E>);
pub type ConstraintVec<E> = Vec<(usize, <E as Pairing>::ScalarField)>;
pub type Constraints<F> = (ConstraintVec<F>, ConstraintVec<F>, ConstraintVec<F>);
pub type ConstraintVec<F> = Vec<(usize, F)>;
40 changes: 20 additions & 20 deletions src/circom/r1cs_reader.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! R1CS circom file reader
//! Copied from <https://github.com/poma/zkutil>
//! Spec: <https://github.com/iden3/r1csfile/blob/master/doc/r1cs_bin_format.md>
use ark_ff::PrimeField;
use byteorder::{LittleEndian, ReadBytesExt};
use std::io::{Error, ErrorKind};

use ark_ec::pairing::Pairing;
use ark_serialize::{CanonicalDeserialize, SerializationError, SerializationError::IoError};
use ark_serialize::{SerializationError, SerializationError::IoError};
use ark_std::io::{Read, Seek, SeekFrom};

use std::collections::HashMap;
Expand All @@ -15,16 +15,16 @@ type IoResult<T> = Result<T, SerializationError>;
use super::{ConstraintVec, Constraints};

#[derive(Clone, Debug)]
pub struct R1CS<E: Pairing> {
pub struct R1CS<F> {
pub num_inputs: usize,
pub num_aux: usize,
pub num_variables: usize,
pub constraints: Vec<Constraints<E>>,
pub constraints: Vec<Constraints<F>>,
pub wire_mapping: Option<Vec<usize>>,
}

impl<E: Pairing> From<R1CSFile<E>> for R1CS<E> {
fn from(file: R1CSFile<E>) -> Self {
impl<F: PrimeField> From<R1CSFile<F>> for R1CS<F> {
fn from(file: R1CSFile<F>) -> Self {
let num_inputs = (1 + file.header.n_pub_in + file.header.n_pub_out) as usize;
let num_variables = file.header.n_wires as usize;
let num_aux = num_variables - num_inputs;
Expand All @@ -38,20 +38,20 @@ impl<E: Pairing> From<R1CSFile<E>> for R1CS<E> {
}
}

pub struct R1CSFile<E: Pairing> {
pub struct R1CSFile<F: PrimeField> {
pub version: u32,
pub header: Header,
pub constraints: Vec<Constraints<E>>,
pub constraints: Vec<Constraints<F>>,
pub wire_mapping: Vec<u64>,
}

impl<E: Pairing> R1CSFile<E> {
impl<F: PrimeField> R1CSFile<F> {
/// reader must implement the Seek trait, for example with a Cursor
///
/// ```rust,ignore
/// let reader = BufReader::new(Cursor::new(&data[..]));
/// ```
pub fn new<R: Read + Seek>(mut reader: R) -> IoResult<R1CSFile<E>> {
pub fn new<R: Read + Seek>(mut reader: R) -> IoResult<R1CSFile<F>> {
let mut magic = [0u8; 4];
reader.read_exact(&mut magic)?;
if magic != [0x72, 0x31, 0x63, 0x73] {
Expand Down Expand Up @@ -117,7 +117,7 @@ impl<E: Pairing> R1CSFile<E> {

reader.seek(SeekFrom::Start(*constraint_offset?))?;

let constraints = read_constraints::<&mut R, E>(&mut reader, &header)?;
let constraints = read_constraints::<&mut R, F>(&mut reader, &header)?;

let wire2label_offset = sec_offsets.get(&wire2label_type).ok_or_else(|| {
Error::new(
Expand Down Expand Up @@ -200,29 +200,29 @@ impl Header {
}
}

fn read_constraint_vec<R: Read, E: Pairing>(mut reader: R) -> IoResult<ConstraintVec<E>> {
fn read_constraint_vec<R: Read, F: PrimeField>(mut reader: R) -> IoResult<ConstraintVec<F>> {
let n_vec = reader.read_u32::<LittleEndian>()? as usize;
let mut vec = Vec::with_capacity(n_vec);
for _ in 0..n_vec {
vec.push((
reader.read_u32::<LittleEndian>()? as usize,
E::ScalarField::deserialize_uncompressed(&mut reader)?,
F::deserialize_uncompressed(&mut reader)?,
));
}
Ok(vec)
}

fn read_constraints<R: Read, E: Pairing>(
fn read_constraints<R: Read, F: PrimeField>(
mut reader: R,
header: &Header,
) -> IoResult<Vec<Constraints<E>>> {
) -> IoResult<Vec<Constraints<F>>> {
// todo check section size
let mut vec = Vec::with_capacity(header.n_constraints as usize);
for _ in 0..header.n_constraints {
vec.push((
read_constraint_vec::<&mut R, E>(&mut reader)?,
read_constraint_vec::<&mut R, E>(&mut reader)?,
read_constraint_vec::<&mut R, E>(&mut reader)?,
read_constraint_vec::<&mut R, F>(&mut reader)?,
read_constraint_vec::<&mut R, F>(&mut reader)?,
read_constraint_vec::<&mut R, F>(&mut reader)?,
));
}
Ok(vec)
Expand Down Expand Up @@ -251,7 +251,7 @@ fn read_map<R: Read>(mut reader: R, size: u64, header: &Header) -> IoResult<Vec<
#[cfg(test)]
mod tests {
use super::*;
use ark_bn254::{Bn254, Fr};
use ark_bn254::Fr;
use ark_std::io::{BufReader, Cursor};

#[test]
Expand Down Expand Up @@ -309,7 +309,7 @@ mod tests {
);

let reader = BufReader::new(Cursor::new(&data[..]));
let file = R1CSFile::<Bn254>::new(reader).unwrap();
let file = R1CSFile::<Fr>::new(reader).unwrap();
assert_eq!(file.version, 1);

assert_eq!(file.header.field_size, 32);
Expand Down
10 changes: 5 additions & 5 deletions src/witness/witness_calculator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{fnv, CircomBase, SafeMemory, Wasm};
use ark_ff::PrimeField;
use color_eyre::Result;
use num_bigint::BigInt;
use num_traits::Zero;
Expand Down Expand Up @@ -284,17 +285,16 @@ impl WitnessCalculator {
}

pub fn calculate_witness_element<
E: ark_ec::pairing::Pairing,
F: PrimeField,
I: IntoIterator<Item = (String, Vec<BigInt>)>,
>(
&mut self,
store: &mut Store,
inputs: I,
sanity_check: bool,
) -> Result<Vec<E::ScalarField>> {
use ark_ff::PrimeField;
) -> Result<Vec<F>> {
let modulus = F::MODULUS;
let witness = self.calculate_witness(store, inputs, sanity_check)?;
let modulus = <E::ScalarField as PrimeField>::MODULUS;

// convert it to field elements
use num_traits::Signed;
Expand All @@ -307,7 +307,7 @@ impl WitnessCalculator {
} else {
w.to_biguint().unwrap()
};
E::ScalarField::from(w)
F::from(w)
})
.collect::<Vec<_>>();

Expand Down
4 changes: 2 additions & 2 deletions src/zkey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ mod tests {
let mut file = File::open(path).unwrap();
let (params, _matrices) = read_zkey(&mut file).unwrap(); // binfile.proving_key().unwrap();

let cfg = CircomConfig::<Bn254>::new(
let cfg = CircomConfig::<Fr>::new(
"./test-vectors/mycircuit.wasm",
"./test-vectors/mycircuit.r1cs",
)
Expand Down Expand Up @@ -896,7 +896,7 @@ mod tests {
let s = ark_bn254::Fr::rand(rng);

let full_assignment = wtns
.calculate_witness_element::<Bn254, _>(&mut store, inputs, false)
.calculate_witness_element::<Fr, _>(&mut store, inputs, false)
.unwrap();
let proof = Groth16::<Bn254, CircomReduction>::create_proof_with_reduction_and_matrices(
&params,
Expand Down
Loading

0 comments on commit a573c15

Please sign in to comment.