Skip to content

Commit

Permalink
wip: cairo1 compiler^3
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Jun 27, 2024
1 parent 14b51c4 commit 183834a
Show file tree
Hide file tree
Showing 19 changed files with 208,108 additions and 207,133 deletions.
147 changes: 141 additions & 6 deletions hydra/extension_field_modulo_circuit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from hydra.modulo_circuit import ModuloCircuit, ModuloCircuitElement, WriteOps
from hydra.modulo_circuit import (
ModuloCircuit,
ModuloCircuitElement,
WriteOps,
ModBuiltinOps,
)
from hydra.algebra import BaseField, PyFelt, Polynomial
from hydra.poseidon_transcript import CairoPoseidonTranscript
from hydra.hints.extf_mul import (
Expand Down Expand Up @@ -79,8 +84,11 @@ def __init__(
extension_degree: int,
init_hash: int = None,
hash_input: bool = True,
compilation_mode: int = 0,
) -> None:
super().__init__(name, curve_id)
super().__init__(
name=name, curve_id=curve_id, compilation_mode=compilation_mode
)
self.class_name = "ExtensionFieldModuloCircuit"
self.extension_degree = extension_degree
self.z_powers: list[ModuloCircuitElement] = []
Expand Down Expand Up @@ -502,7 +510,9 @@ def finalize_circuit(
assert (
lhs.value == rhs.value
), f"{lhs.value} != {rhs.value}, {acc_index}"
self.sub_and_assert(lhs, rhs, self.get_constant(0))
self.sub_and_assert(
lhs, rhs, self.set_or_get_constant(self.field.zero())
)
return True

def summarize(self):
Expand All @@ -523,7 +533,14 @@ def summarize(self):

return summary

def compile_circuit(
def compile_circuit(self, function_name: str = None):
self.values_segment = self.values_segment.non_interactive_transform()
if self.compilation_mode == 0:
return self.compile_circuit_cairo_zero(function_name)
elif self.compilation_mode == 1:
return self.compile_circuit_cairo_1(function_name)

def compile_circuit_cairo_zero(
self,
function_name: str = None,
returns: dict[str] = {
Expand Down Expand Up @@ -627,6 +644,120 @@ def compile_circuit(
code += "}\n"
return code

def compile_circuit_cairo_1(
self,
function_name: str = None,
) -> str:
name = function_name or self.values_segment.name
function_name = f"get_{name}_circuit"
if self.generic_circuit:
code = (
f"fn {function_name}(mut input: Array<u384>, curve_index:usize)->Array<u384>"
+ "{"
+ "\n"
)
else:
code = (
f"fn {function_name}(mut input: Array<u384>)->Array<u384>" + "{" + "\n"
)

def write_stack(
write_ops: WriteOps,
code: str,
offset_to_reference_map: dict[int, str],
start_index: int,
) -> tuple:
if len(self.values_segment.segment_stacks[write_ops]) > 0:
code += f"\n // {write_ops.name} stack\n"
for i, offset in enumerate(
self.values_segment.segment_stacks[write_ops].keys()
):

code += f"\t let in{start_index+i} = CircuitElement::<CircuitInput<{start_index+i}>> {{}};\n"
offset_to_reference_map[offset] = f"in{start_index+i}"
return (
code,
offset_to_reference_map,
start_index + len(self.values_segment.segment_stacks[write_ops]),
)
else:
return code, offset_to_reference_map, start_index

code, offset_to_reference_map, start_index = write_stack(
WriteOps.CONSTANT, code, {}, 0
)

code, offset_to_reference_map, commit_start_index = write_stack(
WriteOps.INPUT, code, offset_to_reference_map, start_index
)
code, offset_to_reference_map, commit_end_index = write_stack(
WriteOps.COMMIT, code, offset_to_reference_map, commit_start_index
)
code, offset_to_reference_map, start_index = write_stack(
WriteOps.WITNESS, code, offset_to_reference_map, commit_end_index
)
code, offset_to_reference_map, start_index = write_stack(
WriteOps.FELT, code, offset_to_reference_map, start_index
)
for i, (offset, vs_item) in enumerate(
self.values_segment.segment_stacks[WriteOps.BUILTIN].items()
):
op = vs_item.instruction.operation
left_offset = vs_item.instruction.left_offset
right_offset = vs_item.instruction.right_offset
result_offset = vs_item.instruction.result_offset
# print(op, offset_to_reference_map, left_offset, right_offset, result_offset)
match op:
case ModBuiltinOps.ADD:
if right_offset > result_offset:
# Case sub
code += f"let t{i} = circuit_sub({offset_to_reference_map[result_offset]}, {offset_to_reference_map[left_offset]});\n"
offset_to_reference_map[offset] = f"t{i}"
assert offset == right_offset
else:
code += f"let t{i} = circuit_add({offset_to_reference_map[left_offset]}, {offset_to_reference_map[right_offset]});\n"
offset_to_reference_map[offset] = f"t{i}"
assert offset == result_offset

case ModBuiltinOps.MUL:
if right_offset == result_offset == offset:
# Case inv
# print(f"\t INV {left_offset} {right_offset} {result_offset}")
code += f"let t{i} = circuit_inverse({offset_to_reference_map[left_offset]});\n"
offset_to_reference_map[offset] = f"t{i}"
else:
# print(f"MUL {left_offset} {right_offset} {result_offset}")
code += f"let t{i} = circuit_mul({offset_to_reference_map[left_offset]}, {offset_to_reference_map[right_offset]});\n"
offset_to_reference_map[offset] = f"t{i}"
assert offset == result_offset

outputs_refs = [offset_to_reference_map[out.offset] for out in self.output]

code += f"// {commit_start_index=}, {commit_end_index-1=}"
code += f"""
let p = get_p(curve_index);
let modulus = TryInto::<_, CircuitModulus>::try_into([p.limb0, p.limb1, p.limb2, p.limb3])
.unwrap();
let mut circuit_inputs = ({','.join(outputs_refs)},).new_inputs();
while let Option::Some(val) = input.pop_front() {{
circuit_inputs = circuit_inputs.next(val);
}};
let outputs = match circuit_inputs.done().eval(modulus) {{
EvalCircuitResult::Success(outputs) => {{ outputs }},
EvalCircuitResult::Failure((_, _)) => {{ panic!("Expected success") }}
}};
"""
for i, ref in enumerate(outputs_refs):
code += f"\t let o{i} = outputs.get_output({ref});\n"
code += "\n"
code += f"let res=array![{','.join(['o'+str(i) for i, _ in enumerate(outputs_refs)])}];\n"
code += "return res;\n"
code += "}\n"
return code


if __name__ == "__main__":
from hydra.definitions import CURVES, CurveID
Expand All @@ -638,7 +769,9 @@ def init_z_circuit(z: int = 2):

def test_eval():
c = init_z_circuit()
X = c.write_elements([PyFelt(1, c.field.p) for _ in range(6)])
X = c.write_elements(
[PyFelt(1, c.field.p) for _ in range(6)], operation=WriteOps.INPUT
)
print("X(z)", [x.value for x in X])
X = c.eval_poly_in_precomputed_Z(X)
print("X(z)", X.value)
Expand All @@ -649,7 +782,9 @@ def test_eval():

def test_eval_sparse():
c = init_z_circuit()
X = c.write_elements([c.field.one(), c.field.zero(), c.field.one()])
X = c.write_elements(
[c.field.one(), c.field.zero(), c.field.one()], operation=WriteOps.INPUT
)
X = c.eval_poly_in_precomputed_Z(X, sparsity=[1, 0, 1])
print("X(z)", X.value)
c.print_value_segment()
Expand Down
61 changes: 33 additions & 28 deletions hydra/modulo_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,9 @@ def write_sparse_elements(
elements.append(self.get_constant(elmt.value))
return elements, sparsity

def set_or_get_constant(self, val: PyFelt) -> None:
def set_or_get_constant(self, val: PyFelt | int) -> ModuloCircuitElement:
if type(val) == int:
val = self.field(val)
if val.value in self.constants:
# print((f"/!\ Constant '{hex(val.value)}' already exists."))
return self.constants[val.value]
Expand All @@ -378,21 +380,21 @@ def add(
b: ModuloCircuitElement,
) -> ModuloCircuitElement:

# if a is None:
# return b
# elif b is None:
# return a
# else:
assert (
type(a) == type(b) == ModuloCircuitElement
), f"Expected ModuloElement, got {type(a)}, {a} and {type(b)}, {b}"
if a is None and type(b) == ModuloCircuitElement:
return b
elif b is None and type(a) == ModuloCircuitElement:
return a
else:
assert (
type(a) == type(b) == ModuloCircuitElement
), f"Expected ModuloElement, got {type(a)}, {a} and {type(b)}, {b}"

instruction = ModuloCircuitInstruction(
ModBuiltinOps.ADD, a.offset, b.offset, self.values_offset
)
return self.write_element(
a.emulated_felt + b.emulated_felt, WriteOps.BUILTIN, instruction
)
instruction = ModuloCircuitInstruction(
ModBuiltinOps.ADD, a.offset, b.offset, self.values_offset
)
return self.write_element(
a.emulated_felt + b.emulated_felt, WriteOps.BUILTIN, instruction
)

def double(self, a: ModuloCircuitElement) -> ModuloCircuitElement:
return self.add(a, a)
Expand All @@ -411,7 +413,7 @@ def mul(
)

def neg(self, a: ModuloCircuitElement) -> ModuloCircuitElement:
res = self.sub(self.set_or_get_constant(0), a)
res = self.sub(self.set_or_get_constant(self.field.zero()), a)
return res

def sub(self, a: ModuloCircuitElement, b: ModuloCircuitElement):
Expand All @@ -422,11 +424,15 @@ def sub(self, a: ModuloCircuitElement, b: ModuloCircuitElement):

def inv(self, a: ModuloCircuitElement):
if self.compilation_mode == 0:
one = self.set_or_get_constant(
1
) # Write one before accessing its offset so self.values_offset is correctly updated.

instruction = ModuloCircuitInstruction(
ModBuiltinOps.MUL,
a.offset,
self.values_offset,
self.set_or_get_constant(1).offset,
one.offset,
)
elif self.compilation_mode == 1:
instruction = ModuloCircuitInstruction(
Expand Down Expand Up @@ -695,29 +701,27 @@ def write_stack(
code, offset_to_reference_map, start_index = write_stack(
WriteOps.CONSTANT, code, {}, 0
)
print(offset_to_reference_map)

code, offset_to_reference_map, start_index = write_stack(
code, offset_to_reference_map, commit_start_index = write_stack(
WriteOps.INPUT, code, offset_to_reference_map, start_index
)
code, offset_to_reference_map, start_index = write_stack(
WriteOps.COMMIT, code, offset_to_reference_map, start_index
code, offset_to_reference_map, commit_end_index = write_stack(
WriteOps.COMMIT, code, offset_to_reference_map, commit_start_index
)
code, offset_to_reference_map, start_index = write_stack(
WriteOps.WITNESS, code, offset_to_reference_map, start_index
WriteOps.WITNESS, code, offset_to_reference_map, commit_end_index
)
code, offset_to_reference_map, start_index = write_stack(
WriteOps.FELT, code, offset_to_reference_map, start_index
)
print(offset_to_reference_map)
for i, (offset, vs_item) in enumerate(
self.values_segment.segment_stacks[WriteOps.BUILTIN].items()
):
op = vs_item.instruction.operation
left_offset = vs_item.instruction.left_offset
right_offset = vs_item.instruction.right_offset
result_offset = vs_item.instruction.result_offset
print(op, offset_to_reference_map, left_offset, right_offset, result_offset)
# print(op, offset_to_reference_map, left_offset, right_offset, result_offset)
match op:
case ModBuiltinOps.ADD:
if right_offset > result_offset:
Expand All @@ -733,22 +737,24 @@ def write_stack(
case ModBuiltinOps.MUL:
if right_offset == result_offset == offset:
# Case inv
print(f"\t INV {left_offset} {right_offset} {result_offset}")
# print(f"\t INV {left_offset} {right_offset} {result_offset}")
code += f"let t{i} = circuit_inverse({offset_to_reference_map[left_offset]});\n"
offset_to_reference_map[offset] = f"t{i}"
else:
print(f"MUL {left_offset} {right_offset} {result_offset}")
# print(f"MUL {left_offset} {right_offset} {result_offset}")
code += f"let t{i} = circuit_mul({offset_to_reference_map[left_offset]}, {offset_to_reference_map[right_offset]});\n"
offset_to_reference_map[offset] = f"t{i}"
assert offset == result_offset

outputs_refs = [offset_to_reference_map[out.offset] for out in self.output]

last_t_index = len(self.values_segment.segment_stacks[WriteOps.BUILTIN]) - 1
code += f"""
let p = get_p(curve_index);
let modulus = TryInto::<_, CircuitModulus>::try_into([p.limb0, p.limb1, p.limb2, p.limb3])
.unwrap();
let mut circuit_inputs = (t{last_t_index},).new_inputs();
let mut circuit_inputs = ({','.join(outputs_refs)},).new_inputs();
while let Option::Some(val) = input.pop_front() {{
circuit_inputs = circuit_inputs.next(val);
Expand All @@ -759,7 +765,6 @@ def write_stack(
EvalCircuitResult::Failure((_, _)) => {{ panic!("Expected success") }}
}};
"""
outputs_refs = [offset_to_reference_map[out.offset] for out in self.output]
for i, ref in enumerate(outputs_refs):
code += f"\t let o{i} = outputs.get_output({ref});\n"
code += "\n"
Expand Down
Loading

0 comments on commit 183834a

Please sign in to comment.