Skip to content

Commit

Permalink
Make garaga-rs an importable Rust lib and add get_groth16_calldata
Browse files Browse the repository at this point in the history
…rust function (#229)
  • Loading branch information
feltroidprime authored Oct 25, 2024
1 parent 7aaaec0 commit 0ea5289
Show file tree
Hide file tree
Showing 28 changed files with 498 additions and 59 deletions.
24 changes: 23 additions & 1 deletion hydra/garaga/starknet/groth16_contract_generator/calldata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from garaga import garaga_rs
from garaga.definitions import G1G2Pair, G1Point
from garaga.starknet.groth16_contract_generator.parsing_utils import (
Groth16Proof,
Expand All @@ -6,10 +7,15 @@
from garaga.starknet.tests_and_calldata_generators.mpcheck import MPCheckCalldataBuilder
from garaga.starknet.tests_and_calldata_generators.msm import MSMCalldataBuilder

garaga_rs.get_groth16_calldata


def groth16_calldata_from_vk_and_proof(
vk: Groth16VerifyingKey, proof: Groth16Proof
vk: Groth16VerifyingKey, proof: Groth16Proof, use_rust: bool = True
) -> list[int]:
if use_rust:
return _groth16_calldata_from_vk_and_proof_rust(vk, proof)

assert (
vk.curve_id == proof.curve_id
), f"Curve ID mismatch: {vk.curve_id} != {proof.curve_id}"
Expand Down Expand Up @@ -68,6 +74,22 @@ def groth16_calldata_from_vk_and_proof(
return [len(calldata)] + calldata


def _groth16_calldata_from_vk_and_proof_rust(
vk: Groth16VerifyingKey, proof: Groth16Proof
) -> list[int]:
assert (
vk.curve_id == proof.curve_id
), f"Curve ID mismatch: {vk.curve_id} != {proof.curve_id}"

return garaga_rs.get_groth16_calldata(
proof.flatten(),
vk.flatten(),
proof.curve_id.value,
proof.image_id,
proof.journal,
)


if __name__ == "__main__":
VK_PATH = "hydra/garaga/starknet/groth16_contract_generator/examples/snarkjs_vk_bn254.json"
PROOF_PATH = "hydra/garaga/starknet/groth16_contract_generator/examples/snarkjs_proof_bn254.json"
Expand Down
20 changes: 20 additions & 0 deletions hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ def serialize_to_cairo(self) -> str:
"""
return code

def flatten(self) -> list[int]:
lst = []
lst.extend([self.alpha.x, self.alpha.y])
lst.extend([self.beta.x[0], self.beta.x[1], self.beta.y[0], self.beta.y[1]])
lst.extend([self.gamma.x[0], self.gamma.x[1], self.gamma.y[0], self.gamma.y[1]])
lst.extend([self.delta.x[0], self.delta.x[1], self.delta.y[0], self.delta.y[1]])
for point in self.ic:
lst.extend([point.x, point.y])
return lst


def reverse_byte_order_uint256(value: int | bytes) -> int:
if isinstance(value, int):
Expand Down Expand Up @@ -469,6 +479,16 @@ def serialize_to_calldata(self) -> list[int]:
cd.extend(io.bigint_split(pub, 2, 2**128))
return cd

def flatten(self, include_public_inputs: bool = True) -> list[int]:
lst = []
lst.extend([self.a.x, self.a.y])
lst.extend([self.b.x[0], self.b.x[1]])
lst.extend([self.b.y[0], self.b.y[1]])
lst.extend([self.c.x, self.c.y])
if include_public_inputs:
lst.extend(self.public_inputs)
return lst


class ExitCode:
def __init__(self, system, user):
Expand Down
41 changes: 41 additions & 0 deletions tests/hydra/starknet/test_groth16_vk_proof_parsing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

from garaga.starknet.groth16_contract_generator.calldata import (
groth16_calldata_from_vk_and_proof,
)
from garaga.starknet.groth16_contract_generator.parsing_utils import (
Groth16Proof,
Groth16VerifyingKey,
Expand Down Expand Up @@ -45,3 +48,41 @@ def test_proof_parsing_with_public_input(proof_path: str, pub_inputs_path: str):
proof = Groth16Proof.from_json(proof_path, pub_inputs_path)

print(proof)


@pytest.mark.parametrize(
"proof_path, vk_path, pub_inputs_path",
[
(f"{PATH}/proof_bn254.json", f"{PATH}/vk_bn254.json", None),
(f"{PATH}/proof_bls.json", f"{PATH}/vk_bls.json", None),
(
f"{PATH}/gnark_proof_bn254.json",
f"{PATH}/gnark_vk_bn254.json",
f"{PATH}/gnark_public_bn254.json",
),
(
f"{PATH}/snarkjs_proof_bn254.json",
f"{PATH}/snarkjs_vk_bn254.json",
f"{PATH}/snarkjs_public_bn254.json",
),
(f"{PATH}/proof_risc0.json", f"{PATH}/vk_risc0.json", None),
],
)
def test_calldata_generation(
proof_path: str, vk_path: str, pub_inputs_path: str | None
):
import time

vk = Groth16VerifyingKey.from_json(vk_path)
proof = Groth16Proof.from_json(proof_path, pub_inputs_path)

start = time.time()
calldata = groth16_calldata_from_vk_and_proof(vk, proof, use_rust=False)
end = time.time()
print(f"Python time: {end - start}")

start = time.time()
calldata_rust = groth16_calldata_from_vk_and_proof(vk, proof, use_rust=True)
end = time.time()
print(f"Rust time: {end - start}")
assert calldata == calldata_rust
8 changes: 4 additions & 4 deletions tools/garaga_rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tools/garaga_rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "garaga_rs"
crate-type = ["cdylib"]
crate-type = ["cdylib", "rlib"]

[profile.release]
lto = true
Expand Down
20 changes: 13 additions & 7 deletions tools/garaga_rs/src/algebra/g1point.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::definitions::{CurveParamsProvider, FieldElement};
use lambdaworks_math::field::traits::IsPrimeField;
use num_bigint::{BigInt, BigUint, Sign};

#[derive(Debug, Clone)]
pub struct G1Point<F: IsPrimeField> {
pub x: FieldElement<F>,
Expand Down Expand Up @@ -48,15 +47,16 @@ impl<F: IsPrimeField + CurveParamsProvider<F>> G1Point<F> {

let lambda = if self.eq(other) {
let alpha = F::get_curve_params().a;

(FieldElement::<F>::from(3_u64) * self.x.square() + alpha)
/ (FieldElement::<F>::from(2_u64) * self.y.clone())
let numerator = FieldElement::<F>::from(3_u64) * &self.x.square() + alpha;
let denominator = FieldElement::<F>::from(2_u64) * &self.y;
numerator / denominator
} else {
(other.y.clone() - self.y.clone()) / (other.x.clone() - self.x.clone())
(&other.y - &self.y) / (&other.x - &self.x)
};

let x3 = lambda.square() - self.x.clone() - other.x.clone();
let y3 = lambda * (self.x.clone() - x3.clone()) - self.y.clone();
let x3 = &lambda.square() - &self.x - &other.x;

let y3 = &lambda * &(self.x.clone() - &x3) - &self.y;

G1Point::new_unchecked(x3, y3)
}
Expand Down Expand Up @@ -128,6 +128,12 @@ impl<F: IsPrimeField + CurveParamsProvider<F>> G1Point<F> {
self.y.representative().to_string()
);
}
pub fn generator() -> Self {
let curve_params = F::get_curve_params();
let generator_x = curve_params.g_x;
let generator_y = curve_params.g_y;
G1Point::new(generator_x, generator_y).unwrap()
}
}

impl<F: IsPrimeField> PartialEq for G1Point<F> {
Expand Down
33 changes: 16 additions & 17 deletions tools/garaga_rs/src/algebra/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,32 +229,31 @@ impl<F: IsPrimeField> std::ops::Add<&Polynomial<F>> for &Polynomial<F> {

fn add(self, a_polynomial: &Polynomial<F>) -> Self::Output {
let (pa, pb) = pad_with_zero_coefficients(self, a_polynomial);
let iter_coeff_pa = pa.coefficients.iter();
let iter_coeff_pb = pb.coefficients.iter();
let new_coefficients = iter_coeff_pa.zip(iter_coeff_pb).map(|(x, y)| x + y);
let new_coefficients_vec = new_coefficients.collect::<Vec<FieldElement<F>>>();
Polynomial::new(new_coefficients_vec)
let new_coefficients = pa
.coefficients
.iter()
.zip(pb.coefficients.iter())
.map(|(x, y)| x + y)
.collect();
Polynomial::new(new_coefficients)
}
}

impl<F: IsPrimeField> std::ops::Add for Polynomial<F> {
type Output = Polynomial<F>;

fn add(self, other: Polynomial<F>) -> Polynomial<F> {
let (ns, no) = (self.coefficients.len(), other.coefficients.len());
if ns >= no {
let mut coeffs = self.coefficients.clone();
for (i, coeff) in other.coefficients.iter().enumerate() {
coeffs[i] += coeff.clone();
}
Polynomial::new(coeffs)
let (mut longer, shorter) = if self.coefficients.len() >= other.coefficients.len() {
(self.coefficients, other.coefficients)
} else {
let mut coeffs = other.coefficients.clone();
for (i, coeff) in self.coefficients.iter().enumerate() {
coeffs[i] += coeff.clone();
}
Polynomial::new(coeffs)
(other.coefficients, self.coefficients)
};

for (i, coeff) in shorter.into_iter().enumerate() {
longer[i] += coeff;
}

Polynomial::new(longer)
}
}

Expand Down
1 change: 1 addition & 0 deletions tools/garaga_rs/src/calldata/full_proof_with_hints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod groth16;
Loading

0 comments on commit 0ea5289

Please sign in to comment.