diff --git a/src/circuit.rs b/src/circuit.rs index 175c1bf..274b9ba 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -6,6 +6,7 @@ use core::include_bytes; use once_cell::sync::Lazy; use std::io::{Cursor, Write}; use tempfile::NamedTempFile; +use std::sync::Mutex; const ZKEY_BYTES: &[u8] = include_bytes!("../semaphore/build/snark/semaphore_final.zkey"); const WASM: &[u8] = include_bytes!("../semaphore/build/snark/semaphore.wasm"); @@ -15,7 +16,7 @@ pub static ZKEY: Lazy<(ProvingKey, ConstraintMatrices)> = Lazy::new(| read_zkey(&mut reader).expect("zkey should be valid") }); -pub static WITNESS_CALCULATOR: Lazy = Lazy::new(|| { +pub static WITNESS_CALCULATOR: Lazy> = Lazy::new(|| { // HACK: ark-circom requires a file, so we make one! let mut tmpfile = NamedTempFile::new().expect("Failed to create temp file"); let written = tmpfile.write(WASM).expect("Failed to write to temp file"); @@ -23,5 +24,5 @@ pub static WITNESS_CALCULATOR: Lazy = Lazy::new(|| { let path = tmpfile.into_temp_path(); let result = WitnessCalculator::new(&path).expect("Failed to create witness calculator"); path.close().expect("Could not remove tempfile"); - result + Mutex::new(result) }); diff --git a/src/field.rs b/src/field.rs index e488879..0001ca2 100644 --- a/src/field.rs +++ b/src/field.rs @@ -1,7 +1,11 @@ -use crate::util::{bytes_from_hex, deserialize_bytes, keccak256, serialize_bytes}; +use crate::util::{bytes_from_hex, bytes_to_hex, deserialize_bytes, keccak256, serialize_bytes}; use ark_bn254::Fr as ArkField; use ark_ff::{BigInteger as _, PrimeField as _}; -use core::{str, str::FromStr}; +use core::{ + fmt::{Debug, Display}, + str, + str::FromStr, +}; use ff::{PrimeField as _, PrimeFieldRepr as _}; use num_bigint::{BigInt, Sign}; use poseidon_rs::Fr as PosField; @@ -10,7 +14,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// An element of the BN254 scalar field Fr. /// /// Represented as a big-endian byte vector without Montgomery reduction. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] // TODO: Make sure value is always reduced. pub struct Field([u8; 32]); @@ -69,6 +73,22 @@ impl From for BigInt { } } +impl Debug for Field { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let hex = bytes_to_hex::<32, 66>(&self.0); + let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8"); + write!(f, "Field({})", hex_str) + } +} + +impl Display for Field { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let hex = bytes_to_hex::<32, 66>(&self.0); + let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8"); + write!(f, "{}", hex_str) + } +} + /// Serialize a field element. /// /// For human readable formats a `0x` prefixed lower case hex string is used. diff --git a/src/hash.rs b/src/hash.rs index 790d0c1..da35dfe 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -1,11 +1,12 @@ -use crate::util::{bytes_from_hex, deserialize_bytes, serialize_bytes}; +use crate::util::{bytes_from_hex, bytes_to_hex, deserialize_bytes, serialize_bytes}; +use core::{ + fmt::{Debug, Display}, + str, + str::FromStr, +}; use ethers_core::types::U256; use num_bigint::{BigInt, Sign}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use std::{ - fmt::{Debug, Display, Formatter, Result as FmtResult}, - str::FromStr, -}; /// Container for 256-bit hash values. #[derive(Clone, Copy, PartialEq, Eq, Default)] @@ -23,20 +24,6 @@ impl Hash { } } -/// Debug print hashes using `hex!(..)` literals. -impl Debug for Hash { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!(f, "Hash(hex!(\"{}\"))", hex::encode(&self.0)) - } -} - -/// Display print hashes as `0x...`. -impl Display for Hash { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!(f, "0x{}", hex::encode(&self.0)) - } -} - /// Conversion from Ether U256 impl From<&Hash> for U256 { fn from(hash: &Hash) -> Self { @@ -75,6 +62,22 @@ impl From<&Hash> for BigInt { } } +impl Debug for Hash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let hex = bytes_to_hex::<32, 66>(&self.0); + let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8"); + write!(f, "Field({})", hex_str) + } +} + +impl Display for Hash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let hex = bytes_to_hex::<32, 66>(&self.0); + let hex_str = str::from_utf8(&hex).expect("hex is always valid utf8"); + write!(f, "{}", hex_str) + } +} + /// Parse Hash from hex string. /// Hex strings can be upper/lower/mixed case and have an optional `0x` prefix /// but they must always be exactly 32 bytes. diff --git a/src/lib.rs b/src/lib.rs index 1e387c7..2f049fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,6 +39,7 @@ mod test { protocol::{generate_nullifier_hash, generate_proof, verify_proof}, Field, }; + use std::thread::spawn; #[test] fn test_field_serde() { @@ -48,15 +49,14 @@ mod test { assert_eq!(value, deserialized); } - #[test] - fn test_end_to_end() { + fn test_end_to_end(identity: &[u8], external_nullifier: &[u8], signal: &[u8]) { // const LEAF: Hash = Hash::from_bytes_be(hex!( // "0000000000000000000000000000000000000000000000000000000000000000" // )); let leaf = Field::from(0); // generate identity - let id = Identity::from_seed(b"hello"); + let id = Identity::from_seed(identity); // generate merkle tree let mut tree = PoseidonTree::new(21, leaf); @@ -64,10 +64,7 @@ mod test { let merkle_proof = tree.proof(0).expect("proof should exist"); let root = tree.root(); - - // change signal and external_nullifier here - let signal = b"xxx"; - let external_nullifier = b"appId"; + dbg!(root); let signal_hash = hash_to_field(signal); let external_nullifier_hash = hash_to_field(external_nullifier); @@ -76,16 +73,33 @@ mod test { let proof = generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap(); - let success = verify_proof( - root, - nullifier_hash, - signal_hash, - external_nullifier_hash, - &proof, - ) - .unwrap(); + for _ in 0..5 { + let success = verify_proof( + root, + nullifier_hash, + signal_hash, + external_nullifier_hash, + &proof, + ) + .unwrap(); + assert!(success); + } + } + #[test] + fn test_single() { + // Note that rust will still run tests in parallel + test_end_to_end(b"hello", b"appId", b"xxx"); + } - assert!(success); + #[test] + fn test_parallel() { + // Note that this does not guarantee a concurrency issue will be detected. + // For that we need much more sophisticated static analysis tooling like + // loom. See + let a = spawn(|| test_end_to_end(b"hello", b"appId", b"xxx")); + let b = spawn(|| test_end_to_end(b"secret", b"test", b"signal")); + a.join().unwrap(); + b.join().unwrap(); } } diff --git a/src/protocol.rs b/src/protocol.rs index edc5402..a75eb67 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -116,8 +116,10 @@ pub fn generate_proof( let now = Instant::now(); - let full_assignment = WITNESS_CALCULATOR - .clone() + let full_assignment = + WITNESS_CALCULATOR + .lock() + .expect("witness_calculator mutex should not get poisoned") .calculate_witness_element::(inputs, false) .map_err(ProofError::WitnessError)?; @@ -178,8 +180,7 @@ mod test { use super::*; use crate::{hash_to_field, poseidon_tree::PoseidonTree}; - #[test] - fn test_proof_serialize() { + fn arb_proof() -> Proof { // generate identity let id = Identity::from_seed(b"secret"); @@ -194,9 +195,20 @@ mod test { let signal_hash = hash_to_field(b"xxx"); let external_nullifier_hash = hash_to_field(b"appId"); - let proof = - generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap(); + generate_proof(&id, &merkle_proof, external_nullifier_hash, signal_hash).unwrap() + } + + #[test] + fn test_proof_cast_roundtrip() { + let proof = arb_proof(); + let ark_proof: ArkProof> = proof.into(); + let result: Proof = ark_proof.into(); + assert_eq!(proof, result); + } + #[test] + fn test_proof_serialize() { + let proof = arb_proof(); let _json = serde_json::to_value(&proof).unwrap(); // TODO: Ideally we would check the output against an expected value, diff --git a/src/util.rs b/src/util.rs index e78dd81..dff0b28 100644 --- a/src/util.rs +++ b/src/util.rs @@ -16,6 +16,16 @@ pub(crate) fn keccak256(bytes: &[u8]) -> [u8; 32] { output } +pub(crate) fn bytes_to_hex(bytes: &[u8; N]) -> [u8; M] { + // TODO: Replace `M` with a const expression once it's stable. + debug_assert_eq!(M, 2 * N + 2); + let mut result = [0u8; M]; + result[0] = b'0'; + result[1] = b'x'; + hex::encode_to_slice(&bytes[..], &mut result[2..]).expect("the buffer is correctly sized"); + result +} + /// Helper to serialize byte arrays pub(crate) fn serialize_bytes( serializer: S, @@ -25,10 +35,7 @@ pub(crate) fn serialize_bytes( debug_assert_eq!(M, 2 * N + 2); if serializer.is_human_readable() { // Write as a 0x prefixed lower-case hex string - let mut buffer = [0u8; M]; - buffer[0] = b'0'; - buffer[1] = b'x'; - hex::encode_to_slice(&bytes[..], &mut buffer[2..]).expect("the buffer is correctly sized"); + let buffer = bytes_to_hex::(bytes); let string = str::from_utf8(&buffer).expect("the buffer is valid UTF-8"); serializer.serialize_str(string) } else {