Skip to content

Commit

Permalink
Merge pull request #10 from worldcoin/remco/concurrent-proof
Browse files Browse the repository at this point in the history
Work around concurrent witness calculator bug
  • Loading branch information
philsippl authored Mar 21, 2022
2 parents 8766cf8 + f89a4e3 commit d861a73
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 50 deletions.
5 changes: 3 additions & 2 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use core::include_bytes;
use once_cell::sync::Lazy;
use std::io::{Cursor, Write};
use tempfile::NamedTempFile;
use std::sync::Mutex;

const ZKEY_BYTES: &[u8] = include_bytes!("../semaphore/build/snark/semaphore_final.zkey");
const WASM: &[u8] = include_bytes!("../semaphore/build/snark/semaphore.wasm");
Expand All @@ -15,13 +16,13 @@ pub static ZKEY: Lazy<(ProvingKey<Bn254>, ConstraintMatrices<Fr>)> = Lazy::new(|
read_zkey(&mut reader).expect("zkey should be valid")
});

pub static WITNESS_CALCULATOR: Lazy<WitnessCalculator> = Lazy::new(|| {
pub static WITNESS_CALCULATOR: Lazy<Mutex<WitnessCalculator>> = Lazy::new(|| {
// HACK: ark-circom requires a file, so we make one!
let mut tmpfile = NamedTempFile::new().expect("Failed to create temp file");
let written = tmpfile.write(WASM).expect("Failed to write to temp file");
assert_eq!(written, WASM.len());
let path = tmpfile.into_temp_path();
let result = WitnessCalculator::new(&path).expect("Failed to create witness calculator");
path.close().expect("Could not remove tempfile");
result
Mutex::new(result)
});
26 changes: 23 additions & 3 deletions src/field.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use crate::util::{bytes_from_hex, deserialize_bytes, keccak256, serialize_bytes};
use crate::util::{bytes_from_hex, bytes_to_hex, deserialize_bytes, keccak256, serialize_bytes};
use ark_bn254::Fr as ArkField;
use ark_ff::{BigInteger as _, PrimeField as _};
use core::{str, str::FromStr};
use core::{
fmt::{Debug, Display},
str,
str::FromStr,
};
use ff::{PrimeField as _, PrimeFieldRepr as _};
use num_bigint::{BigInt, Sign};
use poseidon_rs::Fr as PosField;
Expand All @@ -10,7 +14,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
/// An element of the BN254 scalar field Fr.
///
/// Represented as a big-endian byte vector without Montgomery reduction.
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
// TODO: Make sure value is always reduced.
pub struct Field([u8; 32]);

Expand Down Expand Up @@ -69,6 +73,22 @@ impl From<Field> for BigInt {
}
}

impl Debug for Field {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let hex = bytes_to_hex::<32, 66>(&self.0);
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
write!(f, "Field({})", hex_str)
}
}

impl Display for Field {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let hex = bytes_to_hex::<32, 66>(&self.0);
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
write!(f, "{}", hex_str)
}
}

/// Serialize a field element.
///
/// For human readable formats a `0x` prefixed lower case hex string is used.
Expand Down
41 changes: 22 additions & 19 deletions src/hash.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::util::{bytes_from_hex, deserialize_bytes, serialize_bytes};
use crate::util::{bytes_from_hex, bytes_to_hex, deserialize_bytes, serialize_bytes};
use core::{
fmt::{Debug, Display},
str,
str::FromStr,
};
use ethers_core::types::U256;
use num_bigint::{BigInt, Sign};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{
fmt::{Debug, Display, Formatter, Result as FmtResult},
str::FromStr,
};

/// Container for 256-bit hash values.
#[derive(Clone, Copy, PartialEq, Eq, Default)]
Expand All @@ -23,20 +24,6 @@ impl Hash {
}
}

/// Debug print hashes using `hex!(..)` literals.
impl Debug for Hash {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "Hash(hex!(\"{}\"))", hex::encode(&self.0))
}
}

/// Display print hashes as `0x...`.
impl Display for Hash {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "0x{}", hex::encode(&self.0))
}
}

/// Conversion from Ether U256
impl From<&Hash> for U256 {
fn from(hash: &Hash) -> Self {
Expand Down Expand Up @@ -75,6 +62,22 @@ impl From<&Hash> for BigInt {
}
}

impl Debug for Hash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let hex = bytes_to_hex::<32, 66>(&self.0);
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
write!(f, "Field({})", hex_str)
}
}

impl Display for Hash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let hex = bytes_to_hex::<32, 66>(&self.0);
let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8");
write!(f, "{}", hex_str)
}
}

/// Parse Hash from hex string.
/// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix
/// but they must always be exactly 32 bytes.
Expand Down
46 changes: 30 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ mod test {
protocol::{generate_nullifier_hash, generate_proof, verify_proof},
Field,
};
use std::thread::spawn;

#[test]
fn test_field_serde() {
Expand All @@ -48,26 +49,22 @@ mod test {
assert_eq!(value, deserialized);
}

#[test]
fn test_end_to_end() {
fn test_end_to_end(identity: &[u8], external_nullifier: &[u8], signal: &[u8]) {
// const LEAF: Hash = Hash::from_bytes_be(hex!(
// "0000000000000000000000000000000000000000000000000000000000000000"
// ));
let leaf = Field::from(0);

// generate identity
let id = Identity::from_seed(b"hello");
let id = Identity::from_seed(identity);

// generate merkle tree
let mut tree = PoseidonTree::new(21, leaf);
tree.set(0, id.commitment());

let merkle_proof = tree.proof(0).expect("proof should exist");
let root = tree.root();

// change signal and external_nullifier here
let signal = b"xxx";
let external_nullifier = b"appId";
dbg!(root);

let signal_hash = hash_to_field(signal);
let external_nullifier_hash = hash_to_field(external_nullifier);
Expand All @@ -76,16 +73,33 @@ mod test {
let proof =
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap();

let success = verify_proof(
root,
nullifier_hash,
signal_hash,
external_nullifier_hash,
&proof,
)
.unwrap();
for _ in 0..5 {
let success = verify_proof(
root,
nullifier_hash,
signal_hash,
external_nullifier_hash,
&proof,
)
.unwrap();
assert!(success);
}
}
#[test]
fn test_single() {
// Note that rust will still run tests in parallel
test_end_to_end(b"hello", b"appId", b"xxx");
}

assert!(success);
#[test]
fn test_parallel() {
// Note that this does not guarantee a concurrency issue will be detected.
// For that we need much more sophisticated static analysis tooling like
// loom. See <https://github.com/tokio-rs/loom>
let a = spawn(|| test_end_to_end(b"hello", b"appId", b"xxx"));
let b = spawn(|| test_end_to_end(b"secret", b"test", b"signal"));
a.join().unwrap();
b.join().unwrap();
}
}

Expand Down
24 changes: 18 additions & 6 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ pub fn generate_proof(

let now = Instant::now();

let full_assignment = WITNESS_CALCULATOR
.clone()
let full_assignment =
WITNESS_CALCULATOR
.lock()
.expect("witness_calculator mutex should not get poisoned")
.calculate_witness_element::<Bn254, _>(inputs, false)
.map_err(ProofError::WitnessError)?;

Expand Down Expand Up @@ -178,8 +180,7 @@ mod test {
use super::*;
use crate::{hash_to_field, poseidon_tree::PoseidonTree};

#[test]
fn test_proof_serialize() {
fn arb_proof() -> Proof {
// generate identity
let id = Identity::from_seed(b"secret");

Expand All @@ -194,9 +195,20 @@ mod test {
let signal_hash = hash_to_field(b"xxx");
let external_nullifier_hash = hash_to_field(b"appId");

let proof =
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap();
generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap()
}

#[test]
fn test_proof_cast_roundtrip() {
let proof = arb_proof();
let ark_proof: ArkProof<Bn<Parameters>> = proof.into();
let result: Proof = ark_proof.into();
assert_eq!(proof, result);
}

#[test]
fn test_proof_serialize() {
let proof = arb_proof();
let _json = serde_json::to_value(&proof).unwrap();

// TODO: Ideally we would check the output against an expected value,
Expand Down
15 changes: 11 additions & 4 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ pub(crate) fn keccak256(bytes: &[u8]) -> [u8; 32] {
output
}

pub(crate) fn bytes_to_hex<const N: usize, const M: usize>(bytes: &[u8; N]) -> [u8; M] {
// TODO: Replace `M` with a const expression once it's stable.
debug_assert_eq!(M, 2 * N + 2);
let mut result = [0u8; M];
result[0] = b'0';
result[1] = b'x';
hex::encode_to_slice(&bytes[..], &mut result[2..]).expect("the buffer is correctly sized");
result
}

/// Helper to serialize byte arrays
pub(crate) fn serialize_bytes<const N: usize, const M: usize, S: Serializer>(
serializer: S,
Expand All @@ -25,10 +35,7 @@ pub(crate) fn serialize_bytes<const N: usize, const M: usize, S: Serializer>(
debug_assert_eq!(M, 2 * N + 2);
if serializer.is_human_readable() {
// Write as a 0x prefixed lower-case hex string
let mut buffer = [0u8; M];
buffer[0] = b'0';
buffer[1] = b'x';
hex::encode_to_slice(&bytes[..], &mut buffer[2..]).expect("the buffer is correctly sized");
let buffer = bytes_to_hex::<N, M>(bytes);
let string = str::from_utf8(&buffer).expect("the buffer is valid UTF-8");
serializer.serialize_str(string)
} else {
Expand Down

0 comments on commit d861a73

Please sign in to comment.