Skip to content

Commit

Permalink
bootloader gets actual the job
Browse files Browse the repository at this point in the history
  • Loading branch information
Okm165 committed Apr 23, 2024
1 parent 23da5fc commit 2e09710
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 146 deletions.
32 changes: 25 additions & 7 deletions cairo/bootloader/hash_program.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import argparse
import json

from starkware.cairo.common.hash_chain import compute_hash_chain
from starkware.cairo.lang.compiler.program import Program, ProgramBase
from starkware.cairo.lang.version import __version__
from starkware.cairo.lang.vm.crypto import get_crypto_lib_context_manager, poseidon_hash_many
from starkware.cairo.lang.vm.crypto import (
get_crypto_lib_context_manager,
poseidon_hash_many,
)
from starkware.python.utils import from_bytes


def compute_program_hash_chain(program: ProgramBase, use_poseidon: bool, bootloader_version=0):
def compute_program_hash_chain(
program: ProgramBase, use_poseidon: bool, bootloader_version=0
):
"""
Computes a hash chain over a program, including the length of the data chain.
"""
builtin_list = [from_bytes(builtin.encode("ascii")) for builtin in program.builtins]
# The program header below is missing the data length, which is later added to the data_chain.
program_header = [bootloader_version, program.main, len(program.builtins)] + builtin_list
program_header = [
bootloader_version,
program.main,
len(program.builtins),
] + builtin_list
data_chain = program_header + program.data

if use_poseidon:
Expand All @@ -23,8 +31,12 @@ def compute_program_hash_chain(program: ProgramBase, use_poseidon: bool, bootloa


def main():
parser = argparse.ArgumentParser(description="A tool to compute the hash of a cairo program")
parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}")
parser = argparse.ArgumentParser(
description="A tool to compute the hash of a cairo program"
)
parser.add_argument(
"-v", "--version", action="version", version=f"%(prog)s {__version__}"
)
parser.add_argument(
"--program",
type=argparse.FileType("r"),
Expand All @@ -48,7 +60,13 @@ def main():

with get_crypto_lib_context_manager(args.flavor):
program = Program.Schema().load(json.load(args.program))
print(hex(compute_program_hash_chain(program=program, use_poseidon=args.use_poseidon)))
print(
hex(
compute_program_hash_chain(
program=program, use_poseidon=args.use_poseidon
)
)
)


if __name__ == "__main__":
Expand Down
37 changes: 18 additions & 19 deletions cairo/bootloader/objects.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import dataclasses
from abc import abstractmethod
from typing import Optional

from typing import List, Optional
import marshmallow_dataclass

from starkware.cairo.lang.compiler.program import ProgramBase, StrippedProgram
from starkware.cairo.lang.vm.cairo_pie import CairoPie
from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass


class TaskSpec(ValidatedMarshmallowDataclass):
"""
Contains task's specification.
"""

@abstractmethod
def load_task(self, memory=None, args_start=None, args_len=None) -> "Task":
def load_task(self) -> "Task":
"""
Returns the corresponding task.
"""
Expand All @@ -39,27 +33,32 @@ def get_program(self) -> StrippedProgram:


@dataclasses.dataclass(frozen=True)
class Job(Task):
class JobData(Task):
reward: int
num_of_steps: int
cairo_pie: bytearray
registry_address: bytearray
public_key: bytearray
signature: bytearray
cairo_pie_compressed: List[int]
registry_address: str

def load_task(self) -> "CairoPieTask":
"""
Loads the PIE to memory.
"""
return CairoPieTask(
cairo_pie=CairoPie.deserialize(self.cairo_pie),
use_poseidon=self.use_poseidon,
cairo_pie=CairoPie.deserialize(bytes(self.cairo_pie_compressed)),
use_poseidon=True,
)


@dataclasses.dataclass(frozen=True)
class Job(Task):
job_data: JobData
public_key: List[int]
signature: List[int]

def load_task(self) -> "CairoPieTask":
return self.job_data.load_task()


@marshmallow_dataclass.dataclass(frozen=True)
class SimpleBootloaderInput(ValidatedMarshmallowDataclass):
identity: bytearray
identity: str
job: Job

fact_topologies_path: Optional[str]
Expand Down
2 changes: 1 addition & 1 deletion cairo/bootloader/recursive_with_poseidon/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
BITWISE_BUILTIN,
POSEIDON_BUILTIN,
]
)
)
16 changes: 2 additions & 14 deletions cairo/bootloader/recursive_with_poseidon/execute_task.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,17 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
%{
from bootloader.objects import (
CairoPieTask,
RunProgramTask,
Task,
)
from bootloader.utils import (
load_cairo_pie,
prepare_output_runner,
)
assert isinstance(task, Task)
n_builtins = len(task.get_program().builtins)
new_task_locals = {}
if isinstance(task, RunProgramTask):
new_task_locals['program_input'] = task.program_input
new_task_locals['WITH_BOOTLOADER'] = True
vm_load_program(task.program, program_address)
elif isinstance(task, CairoPieTask):
if isinstance(task, CairoPieTask):
ret_pc = ids.ret_pc_label.instruction_offset_ - ids.call_task.instruction_offset_ + pc
load_cairo_pie(
task=task.cairo_pie, memory=memory, segments=segments,
Expand All @@ -169,10 +163,6 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
else:
raise NotImplementedError(f'Unexpected task type: {type(task).__name__}.')
output_runner_data = prepare_output_runner(
task=task,
output_builtin=output_builtin,
output_ptr=ids.pre_execution_builtin_ptrs.output)
vm_enter_scope(new_task_locals)
%}

Expand Down Expand Up @@ -243,8 +233,6 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
fact_topologies.append(get_task_fact_topology(
output_size=output_end - output_start,
task=task,
output_builtin=output_builtin,
output_runner_data=output_runner_data,
))
%}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func run_simple_bootloader{
local task_range_check_ptr;

%{
n_tasks = len(simple_bootloader_input.tasks)
n_tasks = 1
memory[ids.output_ptr] = n_tasks

# Task range checks are located right after simple bootloader validation range checks, and
Expand Down Expand Up @@ -65,7 +65,6 @@ func run_simple_bootloader{
// Call execute_tasks.
let (__fp__, _) = get_fp_and_pc();

%{ tasks = simple_bootloader_input.tasks %}
let builtin_ptrs = &builtin_ptrs_before;
let self_range_check_ptr = range_check_ptr;
with builtin_ptrs, self_range_check_ptr {
Expand Down Expand Up @@ -141,9 +140,7 @@ func execute_tasks{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
from bootloader.objects import Task

# Pass current task to execute_task.
task_id = len(simple_bootloader_input.tasks) - ids.n_tasks
task_obj = simple_bootloader_input.tasks[task_id]
task = task_obj.load_task()
task = simple_bootloader_input.job.load_task()
%}
tempvar use_poseidon = nondet %{ 1 if task.use_poseidon else 0 %};
// Call execute_task to execute the current task.
Expand Down
Loading

0 comments on commit 2e09710

Please sign in to comment.