Skip to content

Commit

Permalink
Wip : Arbitrary Msm
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Jul 23, 2024
1 parent 6504305 commit 4e6be8f
Show file tree
Hide file tree
Showing 8 changed files with 5,022 additions and 1,510 deletions.
10 changes: 4 additions & 6 deletions hydra/modulo_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,7 @@ def __init__(
self.generic_circuit = generic_circuit
self.compilation_mode = compilation_mode
self.exact_output_refs_needed = None
self.input_structs: list[Cairo1SerializableStruct] | None = (
[] if compilation_mode == 1 else None
)
self.input_structs: list[Cairo1SerializableStruct] = []

@property
def values_offset(self) -> int:
Expand Down Expand Up @@ -686,7 +684,7 @@ def eval_poly(
return acc

def extend_output(self, elmts: list[ModuloCircuitElement]):
assert isinstance(elmts, list)
assert isinstance(elmts, (list, tuple))
assert all(isinstance(x, ModuloCircuitElement) for x in elmts)
self.output.extend(elmts)
return
Expand Down Expand Up @@ -980,8 +978,8 @@ def compile_circuit_cairo_1(
"""
else:
code += f"""
let p = get_p(curve_index);
let modulus = TryInto::<_, CircuitModulus>::try_into([p.limb0, p.limb1, p.limb2, p.limb3])
let modulus = get_p(curve_index);
let modulus = TryInto::<_, CircuitModulus>::try_into([modulus.limb0, modulus.limb1, modulus.limb2, modulus.limb3])
.unwrap();
"""

Expand Down
119 changes: 103 additions & 16 deletions hydra/modulo_circuit_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Cairo1SerializableStruct(ABC):

def __post_init__(self):
assert type(self.name) == str
if isinstance(self.elmts, list):
if isinstance(self.elmts, (list, tuple)):
if isinstance(self.elmts[0], Cairo1SerializableStruct):
assert all(
isinstance(elmt, self.elmts[0].__class__) for elmt in self.elmts
Expand All @@ -25,7 +25,7 @@ def __post_init__(self):
for elmt in self.elmts
), f"All elements of {self.name} must be of type ModuloCircuitElement or PyFelt"
else:
assert self.elmts == None
assert self.elmts == None, f"Elmts must be a list or None, got {self.elmts}"

@property
def struct_name(self) -> str:
Expand Down Expand Up @@ -143,6 +143,40 @@ def __len__(self) -> int:
return None


class u384Span(Cairo1SerializableStruct):
def serialize(self, raw: bool = False) -> str:
raw_struct = f"{int_array_to_u384_array(self.elmts)}.span()"
if raw:
return raw_struct
else:
return f"let {self.name}:{self.struct_name} = {raw_struct};\n"

@property
def struct_name(self) -> str:
return "Span<u384>"

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 1
return f"let {self.name}:{self.struct_name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}].span();"

def dump_to_circuit_input(self) -> str:
code = f"""
let mut {self.name} = {self.name};
while let Option::Some(val) = {self.name}.pop_front() {{
circuit_inputs = circuit_inputs.next(*val);
}};
"""
return code

def __len__(self) -> int:
if self.elmts is not None:
return len(self.elmts)
else:
return None


class BLSProcessedPair(Cairo1SerializableStruct):
def __init__(self, name: str, elmts: list[ModuloCircuitElement]):
super().__init__(name, elmts)
Expand All @@ -156,9 +190,6 @@ def serialize(self) -> str:
assert len(self.elmts) == 2
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{yInv: {int_to_u384(self.elmts[0].value)}, xNegOverY: {int_to_u384(self.elmts[1].value)}}};"

def serialize_input_signature(self):
return f"{self.name}:{self.struct_name}"

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
Expand Down Expand Up @@ -200,9 +231,6 @@ def serialize(self) -> str:
f"let {self.name}:{self.struct_name} = {self.struct_name} {{{members}}};\n"
)

def serialize_input_signature(self):
return f"{self.name}:{self.struct_name}"

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
Expand Down Expand Up @@ -236,9 +264,6 @@ def serialize(self) -> str:
assert len(self.elmts) == 2
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{x: {int_to_u384(self.elmts[0].value)}, y: {int_to_u384(self.elmts[1].value)}}};\n"

def serialize_input_signature(self):
return f"{self.name}:{self.struct_name}"

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
Expand Down Expand Up @@ -272,8 +297,34 @@ def serialize(self) -> str:
assert len(self.elmts) == 4
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{x0: {int_to_u384(self.elmts[0].value)}, x1: {int_to_u384(self.elmts[1].value)}, y0: {int_to_u384(self.elmts[2].value)}, y1: {int_to_u384(self.elmts[3].value)}}};\n"

def serialize_input_signature(self):
return f"{self.name}:{self.struct_name}"
def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 4
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(4)])} }};"

def dump_to_circuit_input(self) -> str:
code = ""
for mem_name in self.members_names:
code += f"circuit_inputs = circuit_inputs.next({self.name}.{mem_name});\n"
return code

def __len__(self) -> int:
if self.elmts is not None:
assert len(self.elmts) == 4
return 4
else:
return 4


class FunctionFeltEvaluations(Cairo1SerializableStruct):
def __init__(self, name: str, elmts: list[ModuloCircuitElement]):
super().__init__(name, elmts)
self.members_names = ("a_num", "a_den", "b_num", "b_den")

def serialize(self) -> str:
assert len(self.elmts) == 4
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{a_num: {int_to_u384(self.elmts[0].value)}, a_den: {int_to_u384(self.elmts[1].value)}, b_num: {int_to_u384(self.elmts[2].value)}, b_den: {int_to_u384(self.elmts[3].value)}}};\n"

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
Expand Down Expand Up @@ -311,9 +362,6 @@ def serialize(self) -> str:
f"q: G2Point{{ x0:{int_to_u384(self.elmts[2].value)}, x1: {int_to_u384(self.elmts[3].value)}, y0: {int_to_u384(self.elmts[4].value)}, y1: {int_to_u384(self.elmts[5].value)}}}}};\n"
)

def serialize_input_signature(self):
return f"{self.name}:{self.struct_name}"

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
Expand Down Expand Up @@ -456,3 +504,42 @@ def __len__(self) -> int:
return 6
else:
return 6


class SlopeInterceptOutput(Cairo1SerializableStruct):
def __init__(self, name: str, elmts: list[ModuloCircuitElement]):
super().__init__(name, elmts)
self.members_names = ("m_A0", "b_A0", "x_A2", "y_A2", "coeff0", "coeff2")

def serialize(self, raw: bool = False) -> str:
assert len(self.elmts) == 6
raw_struct = f"{self.__class__.__name__}{{{','.join([f'{self.members_names[i]}: {int_to_u384(self.elmts[i].value)}' for i in range(len(self))])}}}"
if raw:
return raw_struct
else:
return f"let {self.name}:{self.__class__.__name__} = {raw_struct};\n"

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 6
code = (
f"let {self.name}:{self.__class__.__name__} = {self.__class__.__name__}{{\n"
)
for mem_name, elmt in zip(self.members_names, self.elmts):
code += f"{mem_name}: outputs.get_output({offset_to_reference_map[elmt.offset]}),\n"
code += "};"
return code

def dump_to_circuit_input(self) -> str:
code = ""
for mem_name in self.members_names:
code += f"circuit_inputs = circuit_inputs.next({self.name}.{mem_name});\n"
return code

def __len__(self) -> int:
if self.elmts is not None:
assert len(self.elmts) == 6
return 6
else:
return 6
Loading

0 comments on commit 4e6be8f

Please sign in to comment.