Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust implementation of extension field elements multiplication #167

Merged
merged 19 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions hydra/garaga/hints/ecip.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,8 @@ def zk_ecip_hint(
if ec_group_class == G1Point and use_rust:
pts = []
c_id = Bs[0].curve_id
if c_id == CurveID.BLS12_381:
nb = 48
else:
nb = 32
for pt in Bs:
pts.extend([pt.x.to_bytes(nb, "big"), pt.y.to_bytes(nb, "big")])
pts.extend([pt.x, pt.y])
field_type = get_field_type_from_ec_point(Bs[0])
field = get_base_field(c_id.value, field_type)

Expand Down
29 changes: 7 additions & 22 deletions hydra/garaga/hints/extf_mul.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import operator
from functools import reduce

from garaga import garaga_rs
from garaga.algebra import ModuloCircuitElement, Polynomial, PyFelt
from garaga.definitions import (
direct_to_tower,
Expand All @@ -19,29 +20,13 @@ def nondeterministic_extension_field_mul_divmod(
curve_id: int,
extension_degree: int,
) -> tuple[list[PyFelt], list[PyFelt]]:

Ps = [Polynomial(P) for P in Ps]
field = get_base_field(curve_id)

P_irr = get_irreducible_poly(curve_id, extension_degree)

z_poly = reduce(operator.mul, Ps) # Π(Pi)
z_polyq, z_polyr = divmod(z_poly, P_irr)

z_polyr_coeffs = z_polyr.get_coeffs()
z_polyq_coeffs = z_polyq.get_coeffs()
# assert len(z_polyq_coeffs) <= (
# extension_degree - 1
# ), f"len z_polyq_coeffs={len(z_polyq_coeffs)}, degree: {z_polyq.degree()}"
assert (
len(z_polyr_coeffs) <= extension_degree
), f"len z_polyr_coeffs={len(z_polyr_coeffs)}, degree: {z_polyr.degree()}"

# Extend polynomials with 0 coefficients to match the expected lengths.
# TODO : pass exact expected max degree when len(Ps)>2.
z_polyq_coeffs += [field(0)] * (extension_degree - 1 - len(z_polyq_coeffs))
z_polyr_coeffs += [field(0)] * (extension_degree - len(z_polyr_coeffs))

ps = [[c.value for c in P] for P in Ps]
q, r = garaga_rs.nondeterministic_extension_field_mul_divmod(
curve_id, extension_degree, ps
)
z_polyq_coeffs = [field(c) for c in q] if len(q) > 0 else [field.zero()]
z_polyr_coeffs = [field(c) for c in r] if len(r) > 0 else [field.zero()]
return (z_polyq_coeffs, z_polyr_coeffs)


Expand Down
71 changes: 21 additions & 50 deletions tools/garaga_rs/src/ecip/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,118 +3,89 @@ use lambdaworks_math::elliptic_curve::short_weierstrass::curves::bls12_381::fiel
use lambdaworks_math::elliptic_curve::short_weierstrass::curves::bn_254::field_extension::BN254PrimeField;
use lambdaworks_math::field::element::FieldElement;
use lambdaworks_math::field::traits::IsPrimeField;
use lambdaworks_math::traits::ByteConversion;

use crate::ecip::curve::{SECP256K1PrimeField, SECP256R1PrimeField, X25519PrimeField};
use crate::ecip::ff::FF;
use crate::ecip::g1point::G1Point;
use crate::ecip::rational_function::FunctionFelt;
use crate::ecip::rational_function::RationalFunction;
use crate::io::parse_field_elements_from_list;

use num_bigint::{BigInt, BigUint, ToBigInt};

use super::curve::CurveParamsProvider;

pub fn zk_ecip_hint(
list_bytes: Vec<Vec<u8>>,
list_values: Vec<BigUint>,
list_scalars: Vec<BigUint>,
curve_id: usize,
) -> Result<[Vec<String>; 5], String> {
match curve_id {
0 => {
let list_felts: Vec<FieldElement<BN254PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<BN254PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<BN254PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<BN254PrimeField>(&list_values)?;

let points: Vec<G1Point<BN254PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<BN254PrimeField>(list_scalars);
Ok(run_ecip::<BN254PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<BN254PrimeField>(list_scalars);
Ok(run_ecip::<BN254PrimeField>(points, dss))
}
1 => {
let list_felts: Vec<FieldElement<BLS12381PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<BLS12381PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<BLS12381PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<BLS12381PrimeField>(&list_values)?;

let points: Vec<G1Point<BLS12381PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<BLS12381PrimeField>(list_scalars);
Ok(run_ecip::<BLS12381PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<BLS12381PrimeField>(list_scalars);
Ok(run_ecip::<BLS12381PrimeField>(points, dss))
}
2 => {
let list_felts: Vec<FieldElement<SECP256K1PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<SECP256K1PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<SECP256K1PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<SECP256K1PrimeField>(&list_values)?;

let points: Vec<G1Point<SECP256K1PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<SECP256K1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256K1PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<SECP256K1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256K1PrimeField>(points, dss))
}
3 => {
let list_felts: Vec<FieldElement<SECP256R1PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<SECP256R1PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<SECP256R1PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<SECP256R1PrimeField>(&list_values)?;

let points: Vec<G1Point<SECP256R1PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<SECP256R1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256R1PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<SECP256R1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256R1PrimeField>(points, dss))
}
4 => {
let list_felts: Vec<FieldElement<X25519PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<X25519PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<X25519PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<X25519PrimeField>(&list_values)?;

let points: Vec<G1Point<X25519PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<X25519PrimeField>(list_scalars);
Ok(run_ecip::<X25519PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<X25519PrimeField>(list_scalars);
Ok(run_ecip::<X25519PrimeField>(points, dss))
}
_ => Err(String::from("Invalid curve ID")),
}
}

fn extract_scalars<F: IsPrimeField + CurveParamsProvider<F>>(list: Vec<BigUint>) -> Vec<Vec<i8>> {
fn construct_digits_vectors<F: IsPrimeField + CurveParamsProvider<F>>(
list: Vec<BigUint>,
) -> Vec<Vec<i8>> {
let mut dss_ = Vec::new();

for i in 0..list.len() {
let scalar_biguint = list[i].clone();
for scalar_biguint in list {
let neg_3_digits = neg_3_base_le(scalar_biguint);
dss_.push(neg_3_digits);
}
Expand Down
28 changes: 28 additions & 0 deletions tools/garaga_rs/src/ecip/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ use lambdaworks_math::field::fields::montgomery_backed_prime_fields::{
IsModulus, MontgomeryBackendPrimeField,
};

use crate::ecip::polynomial::Polynomial;
use lambdaworks_math::field::traits::IsPrimeField;
use lambdaworks_math::unsigned_integer::element::U256;
use num_bigint::BigUint;
use std::cmp::PartialEq;
use std::collections::HashMap;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CurveID {
Expand Down Expand Up @@ -72,6 +74,21 @@ pub struct CurveParams<F: IsPrimeField> {
pub g_y: FieldElement<F>,
pub n: FieldElement<F>, // Order of the curve
pub h: u32, // Cofactor
pub irreducible_polys: HashMap<usize, &'static [i8]>,
}

pub fn get_irreducible_poly<F: IsPrimeField + CurveParamsProvider<F>>(
ext_degree: usize,
) -> Polynomial<F> {
let coeffs = (F::get_curve_params().irreducible_polys)[&ext_degree];
fn lift<F: IsPrimeField>(c: i8) -> FieldElement<F> {
if c >= 0 {
FieldElement::from(c as u64)
} else {
-FieldElement::from(-c as u64)
}
}
return Polynomial::new(coeffs.iter().map(|x| lift::<F>(*x)).collect());
}

/// A trait that provides curve parameters for a specific field type.
Expand Down Expand Up @@ -99,6 +116,7 @@ impl CurveParamsProvider<SECP256K1PrimeField> for SECP256K1PrimeField {
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
),
h: 1,
irreducible_polys: HashMap::from([]), // Provide appropriate values here
}
}
}
Expand All @@ -122,6 +140,7 @@ impl CurveParamsProvider<SECP256R1PrimeField> for SECP256R1PrimeField {
"FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
),
h: 1,
irreducible_polys: HashMap::from([]), // Provide appropriate values here
}
}
}
Expand All @@ -143,6 +162,7 @@ impl CurveParamsProvider<X25519PrimeField> for X25519PrimeField {
"1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED",
),
h: 8,
irreducible_polys: HashMap::from([]), // Provide appropriate values here
}
}
}
Expand All @@ -158,6 +178,10 @@ impl CurveParamsProvider<BN254PrimeField> for BN254PrimeField {
g_y: FieldElement::from_hex_unchecked("2"), // Replace with actual 'g_y'
n: FieldElement::from_hex_unchecked("1"), // Replace with actual 'n'
h: 1, // Replace with actual 'h'
irreducible_polys: HashMap::from([
(6, [82, 0, 0, -18, 0, 0, 1].as_slice()),
(12, [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0, 1].as_slice()),
]),
}
}
}
Expand All @@ -173,6 +197,10 @@ impl CurveParamsProvider<BLS12381PrimeField> for BLS12381PrimeField {
g_y: FieldElement::from_hex_unchecked("2"), // Replace with actual 'g_y'
n: FieldElement::from_hex_unchecked("1"), // Replace with actual 'n'
h: 1, // Replace with actual 'h'
irreducible_polys: HashMap::from([
(6, [2, 0, 0, -2, 0, 0, 1].as_slice()),
(12, [2, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 1].as_slice()),
]),
}
}
}
6 changes: 5 additions & 1 deletion tools/garaga_rs/src/ecip/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ impl<F: IsPrimeField> Polynomial<F> {
Polynomial::new(vec![FieldElement::<F>::zero()])
}

pub fn one() -> Self {
Polynomial::new(vec![FieldElement::<F>::one()])
}

pub fn mul_with_ref(&self, other: &Polynomial<F>) -> Polynomial<F> {
if self.degree() == -1 || other.degree() == -1 {
return Polynomial::zero();
Expand Down Expand Up @@ -142,7 +146,7 @@ impl<F: IsPrimeField> Polynomial<F> {
for (i, coeff) in self.coefficients.iter().enumerate().skip(1) {
let u_64 = i as u64;
let degree = &FieldElement::<F>::from(u_64);
new_coeffs[i - 1] = *(&coeff) * degree;
new_coeffs[i - 1] = coeff * degree;
}
Polynomial::new(new_coeffs)
}
Expand Down
30 changes: 30 additions & 0 deletions tools/garaga_rs/src/extf_mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::ecip::{
curve::{get_irreducible_poly, CurveParamsProvider},
polynomial::{pad_with_zero_coefficients_to_length, Polynomial},
};
use lambdaworks_math::field::traits::IsPrimeField;

// Returns (Q(X), R(X)) such that Π(Pi)(X) = Q(X) * P_irr(X) + R(X), for a given curve and extension degree.
// R(X) is the result of the multiplication in the extension field.
// Q(X) is used for verification.
pub fn nondeterministic_extension_field_mul_divmod<F: IsPrimeField + CurveParamsProvider<F>>(
ext_degree: usize,
ps: Vec<Polynomial<F>>,
) -> (Polynomial<F>, Polynomial<F>) {
let mut z_poly = Polynomial::one();
for poly in ps {
z_poly = z_poly.mul_with_ref(&poly);
}

let p_irr = get_irreducible_poly(ext_degree);

let (z_polyq, mut z_polyr) = z_poly.divmod(&p_irr);
assert!(z_polyr.coefficients.len() <= ext_degree);

// Extend polynomial with 0 coefficients to match the expected length.
if z_polyr.coefficients.len() < ext_degree {
pad_with_zero_coefficients_to_length(&mut z_polyr, ext_degree);
}

(z_polyq, z_polyr)
}
24 changes: 24 additions & 0 deletions tools/garaga_rs/src/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use lambdaworks_math::field::element::FieldElement;
use lambdaworks_math::field::traits::IsPrimeField;
use lambdaworks_math::traits::ByteConversion;
use num_bigint::BigUint;

pub fn parse_field_elements_from_list<F: IsPrimeField>(
coeffs: &[BigUint],
) -> Result<Vec<FieldElement<F>>, String>
where
FieldElement<F>: ByteConversion,
{
let length = (F::field_bit_size() + 7) / 8;
coeffs
.iter()
.map(|x| {
let bytes = x.to_bytes_be();
let pad_length = length.saturating_sub(bytes.len());
let mut padded_bytes = vec![0u8; pad_length];
padded_bytes.extend(bytes);
FieldElement::from_bytes_be(&padded_bytes)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect()
}
Loading
Loading