diff --git a/.github/workflows/polygonid_flutter_sdk.yml b/.github/workflows/polygonid_flutter_sdk.yml index 27efe91ca..7af2c61ed 100644 --- a/.github/workflows/polygonid_flutter_sdk.yml +++ b/.github/workflows/polygonid_flutter_sdk.yml @@ -50,6 +50,9 @@ jobs: - name: Run tests run: flutter test --coverage + - name: Run Rust library test + run: cd rust && cargo test + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 with: diff --git a/rust/src/eddsa/mod.rs b/rust/src/eddsa/mod.rs index 8841498e9..67f0902e3 100644 --- a/rust/src/eddsa/mod.rs +++ b/rust/src/eddsa/mod.rs @@ -20,7 +20,6 @@ use rand6::Rng; //extern crate blake; // compatible version with Blake used at circomlib //#[macro_use] //use blake_hash::Digest; -use blake::Blake; use std::cmp::min; @@ -33,9 +32,9 @@ pub mod utils; lazy_static! { static ref D: Fr = Fr::from_str("168696").unwrap(); - static ref D_big: BigInt = BigInt::parse_bytes(b"168696", 10).unwrap(); + static ref D_BIG: BigInt = BigInt::parse_bytes(b"168696", 10).unwrap(); static ref A: Fr = Fr::from_str("168700").unwrap(); - static ref A_big: BigInt = BigInt::parse_bytes(b"168700", 10).unwrap(); + static ref A_BIG: BigInt = BigInt::parse_bytes(b"168700", 10).unwrap(); pub static ref Q: BigInt = BigInt::parse_bytes( b"21888242871839275222246405745257275088548364400416034343698204186575808495617",10 ) @@ -62,7 +61,7 @@ lazy_static! { ) .unwrap() >> 3; - static ref poseidon: poseidon_rs::Poseidon = Poseidon::new(); + static ref POSEIDON: poseidon_rs::Poseidon = Poseidon::new(); } #[derive(Clone, Debug)] @@ -213,7 +212,7 @@ pub fn decompress_point(bb: [u8; 32]) -> Result { // x^2 = (1 - y^2) / (a - d * y^2) (mod p) let den = utils::modinv( &utils::modulus( - &(&A_big.clone() - utils::modulus(&(&D_big.clone() * (&y * &y)), &Q)), + &(&A_BIG.clone() - utils::modulus(&(&D_BIG.clone() * (&y * &y)), &Q)), &Q, ), &Q, @@ -333,7 +332,7 @@ impl PrivateKey { let a = &self.public(); let hm_input = vec![r8.x.clone(), r8.y.clone(), a.x.clone(), a.y.clone(), msg_fr]; - let hm = poseidon.hash(hm_input)?; + let hm = POSEIDON.hash(hm_input)?; let mut s = &self.scalar_key() << 3; let hm_b = BigInt::parse_bytes(to_hex(&hm).as_bytes(), 16).unwrap(); @@ -372,7 +371,7 @@ pub fn schnorr_hash(pk: &Point, msg: BigInt, c: &Point) -> Result bool { pk.y.clone(), msg_fr, ]; - let hm = match poseidon.hash(hm_input) { + let hm = match POSEIDON.hash(hm_input) { Result::Err(_) => return false, Result::Ok(hm) => hm, }; diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 6868db534..c9d9aafa9 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -5,7 +5,6 @@ pub mod eddsa; use poseidon_rs::Poseidon; pub type Fr = poseidon_rs::Fr; -#[macro_use] extern crate ff; #[macro_use] @@ -22,18 +21,15 @@ extern crate blake; // compatible version with Blake used at circomlib extern crate lazy_static; use ff::*; -use std::str; - -use crate::eddsa::{Signature, decompress_point, Point, PrivateKey, verify, decompress_signature, /*compress_point,*/ PointProjective, Q, B8, new_key}; -use num_bigint::{Sign, BigInt, ToBigInt}; +use crate::eddsa::{Signature, decompress_point, Point, PrivateKey, verify, decompress_signature}; +use num_bigint::{Sign, BigInt}; +use std::convert::TryInto; use std::os::raw::{c_char}; use std::ffi::{CStr, CString}; use std::cmp::min; use std::str::FromStr; -use num_traits::{Num, ToPrimitive}; use rustc_hex::{FromHex, ToHex}; -use num::Zero; use std::panic::catch_unwind; /*lazy_static! { @@ -180,15 +176,6 @@ pub extern fn unpack_signature(compressed_signature: *const c_char) -> *mut c_ch } } - -fn vector_as_u8_64_array(vector: Vec) -> [u8; 64] { - let mut arr = [0u8;64]; - for (place, element) in arr.iter_mut().zip(vector.iter()) { - *place = *element; - } - arr -} - #[no_mangle] pub /*extern*/ fn pack_point_internal(point_x: *const c_char, point_y: *const c_char) -> *mut c_char { let point_x_cstr = unsafe { CStr::from_ptr(point_x) }; @@ -537,30 +524,48 @@ pub extern fn sign_poseidon(private_key: *const c_char, msg: *const c_char) -> * } } -#[no_mangle] -pub /*extern*/ fn verify_poseidon_internal(private_key: *const c_char, compressed_signature: *const c_char, message: *const c_char) -> *mut c_char { - let private_key_str = unsafe { CStr::from_ptr(private_key) }.to_str().unwrap(); - // let pk_bigint = BigInt::from_str(private_key_str).unwrap(); - let pk_bytes_raw = private_key_str.from_hex().unwrap(); - let mut pk_bytes: [u8; 32] = [0; 32]; - pk_bytes.copy_from_slice(&pk_bytes_raw); - let pk = PrivateKey { key: pk_bytes }; - let compressed_signature_str = unsafe { CStr::from_ptr(compressed_signature) }.to_str().unwrap(); - let signature_bytes_raw = compressed_signature_str.from_hex().unwrap(); - let mut signature_bytes: [u8; 64] = [0; 64]; - signature_bytes.copy_from_slice(&signature_bytes_raw); - let sig = decompress_signature(&signature_bytes).unwrap(); - let message_c_str = unsafe { CStr::from_ptr(message) }; - let message_str = match message_c_str.to_str() { - Err(_) => "there", - Ok(string) => string, +fn bytes_from_str(s: *const c_char) -> Result, String> { + if s.is_null() { + return Err("str pointer is null".to_owned()); }; - let message_bigint = match message_str.parse::() { - Ok(n) => BigInt::from(n), - Err(e) => BigInt::zero(), + let s = unsafe { CStr::from_ptr(s) }.to_str() + .map_err(|e| format!("utf8 string error: {}", e.to_string()))?; + s.from_hex().map_err(|e| format!("hex decode error: {}", e.to_string())) +} + +fn bigint_from_str(s: *const c_char) -> Result { + if s.is_null() { + return Err("str pointer is null".to_owned()); }; + let s = unsafe { CStr::from_ptr(s) }.to_str() + .map_err(|e| format!("utf8 string error: {}", e.to_string()))?; + BigInt::from_str(s).map_err(|e| format!("bigint parse error: {}", e.to_string())) +} + +fn priv_key(private_key: *const c_char) -> Result { + let pk_bytes = bytes_from_str(private_key).map_err(|e| format!("private key error: {}", e))?; + Ok(PrivateKey { key: pk_bytes.try_into() + .map_err(|_| "private key should be exactly 32 bytes long".to_owned())? }) +} - if verify(pk.public(), sig.clone(), message_bigint.clone()) { +fn unpack_sig(compressed_signature: *const c_char) -> Result { + let signature_bytes = bytes_from_str(compressed_signature) + .map_err(|e| format!("signature error: {}", e.to_string()))?; + let signature_bytes: [u8; 64] = signature_bytes.try_into() + .map_err(|_| "signature should be exactly 64 bytes long".to_owned())?; + decompress_signature(&signature_bytes) +} + +#[no_mangle] +pub fn verify_poseidon_internal(private_key: *const c_char, compressed_signature: *const c_char, message: *const c_char) -> *mut c_char { + let pk = priv_key(private_key) + .unwrap_or_else(|err_msg| panic!("{}", err_msg)); + let sig = unpack_sig(compressed_signature) + .unwrap_or_else(|err_msg| panic!("{}", err_msg)); + let message_bigint = bigint_from_str(message) + .unwrap_or_else(|err_msg| panic!("message parse error: {}", err_msg)); + + if verify(pk.public(), sig, message_bigint) { CString::new("1".to_owned()).unwrap().into_raw() } else { CString::new("0".to_owned()).unwrap().into_raw() @@ -584,6 +589,97 @@ pub extern fn verify_poseidon(private_key: *const c_char, compressed_signature: pub extern fn cstring_free(str: *mut c_char) { unsafe { if str.is_null() { return } - CString::from_raw(str) + drop(CString::from_raw(str)); }; } + +#[cfg(test)] +mod tests { + use std::ptr::null; + use super::*; + + #[test] + #[should_panic(expected = "private key error: str pointer is null")] + fn test_verify_poseidon_internal_with_null_private_key_should_panic() { + verify_poseidon_internal(null(), null(), null()); + } + + #[test] + fn test_verify_poseidon_with_null_private_key_should_panic() { + let x = verify_poseidon(null(), null(), null()); + assert!(x.is_null()); + } + + #[test] + #[should_panic(expected = "private key error: hex decode error: Invalid character 'p' at position 0")] + fn test_verify_poseidon_internal_with_incorrect_hex_private_key_should_panic() { + let pk = CString::new("pk").unwrap(); + verify_poseidon_internal(pk.into_raw(), null(), null()); + } + + #[test] + #[should_panic(expected = "signature error: str pointer is null")] + fn test_verify_poseidon_null_sig() { + let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap(); + let msg = CString::new("184467440737095516150").unwrap(); + verify_poseidon_internal(pk.into_raw(), null(), msg.into_raw()); + } + + #[test] + #[should_panic(expected = "message parse error: str pointer is null")] + fn test_verify_poseidon_null_msg() { + let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap(); + let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb305").unwrap(); + verify_poseidon_internal(pk.into_raw(), sig.into_raw(), null()); + } + + #[test] + fn test_verify_poseidon_ok() { + let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap(); + let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb305").unwrap(); + let msg = CString::new("184467440737095516150").unwrap(); + let r = verify_poseidon(pk.into_raw(), sig.into_raw(), msg.into_raw()); + + let r = unsafe { CStr::from_ptr(r) }.to_str().unwrap(); + assert_eq!(r, "1"); + } + + #[test] + fn test_verify_poseidon_invalid_sig() { + let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap(); + let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb307").unwrap(); + let msg = CString::new("184467440737095516150").unwrap(); + let r = verify_poseidon(pk.into_raw(), sig.into_raw(), msg.into_raw()); + + let r = unsafe { CStr::from_ptr(r) }.to_str().unwrap(); + assert_eq!(r, "0"); + } + + #[test] + #[should_panic(expected = "message parse error: bigint parse error: invalid digit found in string")] + fn test_verify_poseidon_internal_invalid_msg() { + let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap(); + let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb307").unwrap(); + let msg = CString::new("abc").unwrap(); + verify_poseidon_internal(pk.into_raw(), sig.into_raw(), msg.into_raw()); + } + + #[test] + #[should_panic(expected = "signature should be exactly 64 bytes long")] + fn test_verify_poseidon_internal_sig_len_error() { + let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap(); + let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb3").unwrap(); + let msg = CString::new("abc").unwrap(); + verify_poseidon_internal(pk.into_raw(), sig.into_raw(), msg.into_raw()); + } + + #[test] + // #[should_panic(expected = "signature should be exactly 64 bytes long")] + fn test_verify_poseidon_sig_len_error() { + let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap(); + let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb3").unwrap(); + let msg = CString::new("abc").unwrap(); + let r = verify_poseidon(pk.into_raw(), sig.into_raw(), msg.into_raw()); + assert_eq!(std::ptr::null_mut(), r); + } +} \ No newline at end of file