Skip to content

Commit

Permalink
Handle big numbers in message signature validation (#420)
Browse files Browse the repository at this point in the history
* Fix the issue where parsing a BigInt larger than an i32 results in failure.

* Handle errors more consistently in  func. Add tests for rust components.

---------

Co-authored-by: yushihang <[email protected]>
  • Loading branch information
olomix and yushihang authored Jul 10, 2024
1 parent ee9d66d commit 3943670
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 46 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/polygonid_flutter_sdk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ jobs:
- name: Run tests
run: flutter test --coverage

- name: Run Rust library test
run: cd rust && cargo test

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
with:
Expand Down
15 changes: 7 additions & 8 deletions rust/src/eddsa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use rand6::Rng;
//extern crate blake; // compatible version with Blake used at circomlib
//#[macro_use]
//use blake_hash::Digest;
use blake::Blake;

use std::cmp::min;

Expand All @@ -33,9 +32,9 @@ pub mod utils;

lazy_static! {
static ref D: Fr = Fr::from_str("168696").unwrap();
static ref D_big: BigInt = BigInt::parse_bytes(b"168696", 10).unwrap();
static ref D_BIG: BigInt = BigInt::parse_bytes(b"168696", 10).unwrap();
static ref A: Fr = Fr::from_str("168700").unwrap();
static ref A_big: BigInt = BigInt::parse_bytes(b"168700", 10).unwrap();
static ref A_BIG: BigInt = BigInt::parse_bytes(b"168700", 10).unwrap();
pub static ref Q: BigInt = BigInt::parse_bytes(
b"21888242871839275222246405745257275088548364400416034343698204186575808495617",10
)
Expand All @@ -62,7 +61,7 @@ lazy_static! {
)
.unwrap()
>> 3;
static ref poseidon: poseidon_rs::Poseidon = Poseidon::new();
static ref POSEIDON: poseidon_rs::Poseidon = Poseidon::new();
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -213,7 +212,7 @@ pub fn decompress_point(bb: [u8; 32]) -> Result<Point, String> {
// x^2 = (1 - y^2) / (a - d * y^2) (mod p)
let den = utils::modinv(
&utils::modulus(
&(&A_big.clone() - utils::modulus(&(&D_big.clone() * (&y * &y)), &Q)),
&(&A_BIG.clone() - utils::modulus(&(&D_BIG.clone() * (&y * &y)), &Q)),
&Q,
),
&Q,
Expand Down Expand Up @@ -333,7 +332,7 @@ impl PrivateKey {
let a = &self.public();

let hm_input = vec![r8.x.clone(), r8.y.clone(), a.x.clone(), a.y.clone(), msg_fr];
let hm = poseidon.hash(hm_input)?;
let hm = POSEIDON.hash(hm_input)?;

let mut s = &self.scalar_key() << 3;
let hm_b = BigInt::parse_bytes(to_hex(&hm).as_bytes(), 16).unwrap();
Expand Down Expand Up @@ -372,7 +371,7 @@ pub fn schnorr_hash(pk: &Point, msg: BigInt, c: &Point) -> Result<BigInt, String
}
let msg_fr: Fr = Fr::from_str(&msg.to_string()).unwrap();
let hm_input = vec![pk.x.clone(), pk.y.clone(), c.x.clone(), c.y.clone(), msg_fr];
let h = poseidon.hash(hm_input)?;
let h = POSEIDON.hash(hm_input)?;
let h_b = BigInt::parse_bytes(to_hex(&h).as_bytes(), 16).unwrap();
Ok(h_b)
}
Expand Down Expand Up @@ -409,7 +408,7 @@ pub fn verify(pk: Point, sig: Signature, msg: BigInt) -> bool {
pk.y.clone(),
msg_fr,
];
let hm = match poseidon.hash(hm_input) {
let hm = match POSEIDON.hash(hm_input) {
Result::Err(_) => return false,
Result::Ok(hm) => hm,
};
Expand Down
172 changes: 134 additions & 38 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ pub mod eddsa;
use poseidon_rs::Poseidon;
pub type Fr = poseidon_rs::Fr;

#[macro_use]
extern crate ff;

#[macro_use]
Expand All @@ -22,18 +21,15 @@ extern crate blake; // compatible version with Blake used at circomlib
extern crate lazy_static;

use ff::*;
use std::str;


use crate::eddsa::{Signature, decompress_point, Point, PrivateKey, verify, decompress_signature, /*compress_point,*/ PointProjective, Q, B8, new_key};
use num_bigint::{Sign, BigInt, ToBigInt};
use crate::eddsa::{Signature, decompress_point, Point, PrivateKey, verify, decompress_signature};
use num_bigint::{Sign, BigInt};
use std::convert::TryInto;
use std::os::raw::{c_char};
use std::ffi::{CStr, CString};
use std::cmp::min;
use std::str::FromStr;
use num_traits::{Num, ToPrimitive};
use rustc_hex::{FromHex, ToHex};
use num::Zero;
use std::panic::catch_unwind;

/*lazy_static! {
Expand Down Expand Up @@ -180,15 +176,6 @@ pub extern fn unpack_signature(compressed_signature: *const c_char) -> *mut c_ch
}
}


fn vector_as_u8_64_array(vector: Vec<u8>) -> [u8; 64] {
let mut arr = [0u8;64];
for (place, element) in arr.iter_mut().zip(vector.iter()) {
*place = *element;
}
arr
}

#[no_mangle]
pub /*extern*/ fn pack_point_internal(point_x: *const c_char, point_y: *const c_char) -> *mut c_char {
let point_x_cstr = unsafe { CStr::from_ptr(point_x) };
Expand Down Expand Up @@ -537,30 +524,48 @@ pub extern fn sign_poseidon(private_key: *const c_char, msg: *const c_char) -> *
}
}

#[no_mangle]
pub /*extern*/ fn verify_poseidon_internal(private_key: *const c_char, compressed_signature: *const c_char, message: *const c_char) -> *mut c_char {
let private_key_str = unsafe { CStr::from_ptr(private_key) }.to_str().unwrap();
// let pk_bigint = BigInt::from_str(private_key_str).unwrap();
let pk_bytes_raw = private_key_str.from_hex().unwrap();
let mut pk_bytes: [u8; 32] = [0; 32];
pk_bytes.copy_from_slice(&pk_bytes_raw);
let pk = PrivateKey { key: pk_bytes };
let compressed_signature_str = unsafe { CStr::from_ptr(compressed_signature) }.to_str().unwrap();
let signature_bytes_raw = compressed_signature_str.from_hex().unwrap();
let mut signature_bytes: [u8; 64] = [0; 64];
signature_bytes.copy_from_slice(&signature_bytes_raw);
let sig = decompress_signature(&signature_bytes).unwrap();
let message_c_str = unsafe { CStr::from_ptr(message) };
let message_str = match message_c_str.to_str() {
Err(_) => "there",
Ok(string) => string,
fn bytes_from_str(s: *const c_char) -> Result<Vec<u8>, String> {
if s.is_null() {
return Err("str pointer is null".to_owned());
};
let message_bigint = match message_str.parse::<i32>() {
Ok(n) => BigInt::from(n),
Err(e) => BigInt::zero(),
let s = unsafe { CStr::from_ptr(s) }.to_str()
.map_err(|e| format!("utf8 string error: {}", e.to_string()))?;
s.from_hex().map_err(|e| format!("hex decode error: {}", e.to_string()))
}

fn bigint_from_str(s: *const c_char) -> Result<BigInt, String> {
if s.is_null() {
return Err("str pointer is null".to_owned());
};
let s = unsafe { CStr::from_ptr(s) }.to_str()
.map_err(|e| format!("utf8 string error: {}", e.to_string()))?;
BigInt::from_str(s).map_err(|e| format!("bigint parse error: {}", e.to_string()))
}

fn priv_key(private_key: *const c_char) -> Result<PrivateKey, String> {
let pk_bytes = bytes_from_str(private_key).map_err(|e| format!("private key error: {}", e))?;
Ok(PrivateKey { key: pk_bytes.try_into()
.map_err(|_| "private key should be exactly 32 bytes long".to_owned())? })
}

if verify(pk.public(), sig.clone(), message_bigint.clone()) {
fn unpack_sig(compressed_signature: *const c_char) -> Result<Signature, String> {
let signature_bytes = bytes_from_str(compressed_signature)
.map_err(|e| format!("signature error: {}", e.to_string()))?;
let signature_bytes: [u8; 64] = signature_bytes.try_into()
.map_err(|_| "signature should be exactly 64 bytes long".to_owned())?;
decompress_signature(&signature_bytes)
}

#[no_mangle]
pub fn verify_poseidon_internal(private_key: *const c_char, compressed_signature: *const c_char, message: *const c_char) -> *mut c_char {
let pk = priv_key(private_key)
.unwrap_or_else(|err_msg| panic!("{}", err_msg));
let sig = unpack_sig(compressed_signature)
.unwrap_or_else(|err_msg| panic!("{}", err_msg));
let message_bigint = bigint_from_str(message)
.unwrap_or_else(|err_msg| panic!("message parse error: {}", err_msg));

if verify(pk.public(), sig, message_bigint) {
CString::new("1".to_owned()).unwrap().into_raw()
} else {
CString::new("0".to_owned()).unwrap().into_raw()
Expand All @@ -584,6 +589,97 @@ pub extern fn verify_poseidon(private_key: *const c_char, compressed_signature:
pub extern fn cstring_free(str: *mut c_char) {
unsafe {
if str.is_null() { return }
CString::from_raw(str)
drop(CString::from_raw(str));
};
}

#[cfg(test)]
mod tests {
use std::ptr::null;
use super::*;

#[test]
#[should_panic(expected = "private key error: str pointer is null")]
fn test_verify_poseidon_internal_with_null_private_key_should_panic() {
verify_poseidon_internal(null(), null(), null());
}

#[test]
fn test_verify_poseidon_with_null_private_key_should_panic() {
let x = verify_poseidon(null(), null(), null());
assert!(x.is_null());
}

#[test]
#[should_panic(expected = "private key error: hex decode error: Invalid character 'p' at position 0")]
fn test_verify_poseidon_internal_with_incorrect_hex_private_key_should_panic() {
let pk = CString::new("pk").unwrap();
verify_poseidon_internal(pk.into_raw(), null(), null());
}

#[test]
#[should_panic(expected = "signature error: str pointer is null")]
fn test_verify_poseidon_null_sig() {
let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap();
let msg = CString::new("184467440737095516150").unwrap();
verify_poseidon_internal(pk.into_raw(), null(), msg.into_raw());
}

#[test]
#[should_panic(expected = "message parse error: str pointer is null")]
fn test_verify_poseidon_null_msg() {
let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap();
let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb305").unwrap();
verify_poseidon_internal(pk.into_raw(), sig.into_raw(), null());
}

#[test]
fn test_verify_poseidon_ok() {
let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap();
let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb305").unwrap();
let msg = CString::new("184467440737095516150").unwrap();
let r = verify_poseidon(pk.into_raw(), sig.into_raw(), msg.into_raw());

let r = unsafe { CStr::from_ptr(r) }.to_str().unwrap();
assert_eq!(r, "1");
}

#[test]
fn test_verify_poseidon_invalid_sig() {
let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap();
let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb307").unwrap();
let msg = CString::new("184467440737095516150").unwrap();
let r = verify_poseidon(pk.into_raw(), sig.into_raw(), msg.into_raw());

let r = unsafe { CStr::from_ptr(r) }.to_str().unwrap();
assert_eq!(r, "0");
}

#[test]
#[should_panic(expected = "message parse error: bigint parse error: invalid digit found in string")]
fn test_verify_poseidon_internal_invalid_msg() {
let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap();
let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb307").unwrap();
let msg = CString::new("abc").unwrap();
verify_poseidon_internal(pk.into_raw(), sig.into_raw(), msg.into_raw());
}

#[test]
#[should_panic(expected = "signature should be exactly 64 bytes long")]
fn test_verify_poseidon_internal_sig_len_error() {
let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap();
let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb3").unwrap();
let msg = CString::new("abc").unwrap();
verify_poseidon_internal(pk.into_raw(), sig.into_raw(), msg.into_raw());
}

#[test]
// #[should_panic(expected = "signature should be exactly 64 bytes long")]
fn test_verify_poseidon_sig_len_error() {
let pk = CString::new("459a964f864b613e0fae29bd5395cb7e5cb16d9501d898a5630d25dc56ab87aa").unwrap();
let sig = CString::new("aac24e561679c387a075ea22a153d8d060ee751555da44484f96ef3721537c9cf436f9668439cc183382a0ec1445ca594c8b626041bba1c28870c318e41cb3").unwrap();
let msg = CString::new("abc").unwrap();
let r = verify_poseidon(pk.into_raw(), sig.into_raw(), msg.into_raw());
assert_eq!(std::ptr::null_mut(), r);
}
}

0 comments on commit 3943670

Please sign in to comment.