Skip to content

Commit

Permalink
hydra: small fixes/checks
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Jul 4, 2024
1 parent 3a2b7a5 commit b24e630
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 39 deletions.
2 changes: 1 addition & 1 deletion hydra/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def one(self) -> PyFelt:
return PyFelt(1, self.p)

def random(self) -> PyFelt:
return PyFelt(random.randint(0, self.p), self.p)
return PyFelt(random.randint(0, self.p - 1), self.p)


@dataclass(slots=True, frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion hydra/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def gen_random_point(curve_id: CurveID) -> "G2Point":
from tools.gnark_cli import GnarkCLI

cli = GnarkCLI(curve_id)
ng1ng2 = cli.nG1nG2_operation(scalar, 1, raw=True)
ng1ng2 = cli.nG1nG2_operation(1, scalar, raw=True)
return G2Point((ng1ng2[2], ng1ng2[3]), (ng1ng2[4], ng1ng2[5]), curve_id)
else:
raise NotImplementedError(
Expand Down
1 change: 1 addition & 0 deletions hydra/hints/bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def get_root_and_scaling_factor_bls(f: E12) -> tuple[E12, E12]:
w_full = wp_shift * w27_shift
f_shifted = f * w_full
root = find_nth_root(f_shifted, lam)
assert f_shifted == root**lam
return root, w_full


Expand Down
1 change: 1 addition & 0 deletions hydra/modulo_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def write_elements(
return vals

def write_cairo_native_felt(self, native_felt: PyFelt):
assert type(native_felt) == PyFelt, f"Expected PyFelt, got {type(native_felt)}"
assert 0 <= native_felt.value < STARK
res = self.write_element(elmt=native_felt, write_source=WriteOps.FELT)
return res
Expand Down
6 changes: 5 additions & 1 deletion hydra/precompiled_circuits/multi_miller_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ def __init__(
n_pairs: int,
hash_input: bool = True,
init_hash: int = None,
compilation_mode: int = 0,
):
super().__init__(
name=name,
curve_id=curve_id,
extension_degree=12,
hash_input=hash_input,
init_hash=init_hash,
compilation_mode=compilation_mode,
)
self.curve = CURVES[curve_id]
self.line_sparsity: list[int] = self.curve.line_function_sparsity
Expand All @@ -40,6 +42,8 @@ def __init__(
self.set_or_get_constant(self.field(-9))
self.P = []
self.Q = []
self.yInv = []
self.xNegOverY = []
self.loop_counter = CURVES[self.curve_id].loop_counter
self.ops_counter.update(
{
Expand Down Expand Up @@ -140,7 +144,7 @@ def build_sparse_line(
yInv: ModuloCircuitElement,
xNegOverY: ModuloCircuitElement,
) -> list[ModuloCircuitElement]:
ZERO, ONE = self.get_constant(0), self.get_constant(1)
ZERO, ONE = self.set_or_get_constant(0), self.set_or_get_constant(1)

if self.curve_id == BN254_ID:
return [
Expand Down
65 changes: 29 additions & 36 deletions hydra/precompiled_circuits/multi_pairing_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_sparsity,
CurveID,
generate_frobenius_maps,
CURVES,
)
from hydra.hints.multi_miller_witness import get_final_exp_witness
from hydra.modulo_circuit import WriteOps, ModuloCircuitElement, PyFelt
Expand Down Expand Up @@ -53,6 +54,8 @@ def get_root_and_scaling_factor(
)
c.write_p_and_q(c_input)
f = E12.from_direct(c.miller_loop(len(P)), curve_id)
h = (CURVES[curve_id].p ** 12 - 1) // CURVES[curve_id].n
assert f**h == E12.one(curve_id)
lambda_root_e12, scaling_factor_e12 = get_final_exp_witness(curve_id, f)

lambda_root: list[PyFelt]
Expand All @@ -63,13 +66,9 @@ def get_root_and_scaling_factor(
scaling_factor_e12.to_direct(),
)

print(f"here")
e6_subfield = E12([E6.random(curve_id), E6.zero(curve_id)], curve_id)
print(f"e6_subfield: {e6_subfield.value_coeffs}")
print(f"e6 direct: {e6_subfield.to_direct()}")
scaling_factor_sparsity = get_sparsity(e6_subfield.to_direct())

print(f"done: {scaling_factor_sparsity}")
# Assert sparsity is correct: for every index where the sparsity is 0, the coefficient must 0 in scaling factor
for i in range(len(scaling_factor_sparsity)):
if scaling_factor_sparsity[i] == 0:
Expand All @@ -87,6 +86,7 @@ def __init__(
n_pairs: int,
hash_input: bool = True,
init_hash: int = None,
compilation_mode: int = 0,
):
assert n_pairs >= 2, "n_pairs must be >= 2 for pairing checks"
super().__init__(
Expand All @@ -95,6 +95,7 @@ def __init__(
n_pairs=n_pairs,
hash_input=hash_input,
init_hash=init_hash,
compilation_mode=compilation_mode,
)
self.frobenius_maps = {}
for i in [1, 2, 3]:
Expand Down Expand Up @@ -300,36 +301,23 @@ def multi_pairing_check(
f = self.extf_mul([f, w], 12, Ps_sparsities=[None, scaling_factor_sparsity])
c_inv_frob_1 = self.frobenius(c_inv, 1)
f = self.extf_mul([f, c_inv_frob_1], 12)

# Conjugate f
f = [
f[0],
self.neg(f[1]),
f[2],
self.neg(f[3]),
f[4],
self.neg(f[5]),
f[6],
self.neg(f[7]),
f[8],
self.neg(f[9]),
f[10],
self.neg(f[11]),
]
else:
raise NotImplementedError(f"Curve {self.curve_id} not implemented")

assert [fi.value for fi in f] == [1] + [0] * 11, f"f: {f}"
return f


def get_pairing_check_input(curve_id: CurveID) -> list[PyFelt]:
def get_pairing_check_input(curve_id: CurveID, n_pairs: int) -> list[PyFelt]:
assert n_pairs >= 2, "n_pairs must be >= 2 for pairing checks"
field = get_base_field(curve_id.value)
p = G1Point.gen_random_point(curve_id)
q = G2Point.gen_random_point(curve_id)

P = [p, -p]
Q = [q, q]
P = [p] * n_pairs
Q = [q] * n_pairs

P[-1] = p.scalar_mul(-(n_pairs - 1))
c_input = []
for p, q in zip(P, Q):
c_input.append(field(p.x))
Expand All @@ -342,16 +330,21 @@ def get_pairing_check_input(curve_id: CurveID) -> list[PyFelt]:


if __name__ == "__main__":
c = MultiPairingCheckCircuit(name="mock", curve_id=CurveID.BN254.value, n_pairs=2)
c.write_p_and_q(get_pairing_check_input(CurveID.BN254))
c.multi_pairing_check(2)
c.finalize_circuit()
print(c.summarize())

c = MultiPairingCheckCircuit(
name="mock2", curve_id=CurveID.BLS12_381.value, n_pairs=2
)
c.write_p_and_q(get_pairing_check_input(CurveID.BLS12_381))
c.multi_pairing_check(2)
c.finalize_circuit()
print(c.summarize())

def test_mpcheck(curve_id: CurveID, n_pairs: int):
c = MultiPairingCheckCircuit(
name="mock", curve_id=curve_id.value, n_pairs=n_pairs
)
c.write_p_and_q(get_pairing_check_input(curve_id, n_pairs))
c.multi_pairing_check(n_pairs)
c.finalize_circuit()
print(c.summarize())
print(f"Test {curve_id.name} {n_pairs=} passed")

test_mpcheck(CurveID.BN254, 2)
test_mpcheck(CurveID.BN254, 3)
test_mpcheck(CurveID.BN254, 4)
test_mpcheck(CurveID.BN254, 5)
test_mpcheck(CurveID.BLS12_381, 2)
test_mpcheck(CurveID.BLS12_381, 3)
test_mpcheck(CurveID.BLS12_381, 4)

0 comments on commit b24e630

Please sign in to comment.