Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: commit transaction #487

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions cairo/ethereum/cancun/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ from ethereum.cancun.trie import (
copy_TrieAddressOptionalAccount,
copy_trieBytes32U256,
)
from ethereum.cancun.vm.exceptions import ExceptionalHalt, IndexError
from ethereum_types.bytes import Bytes, Bytes32
from ethereum_types.numeric import U256, U256Struct, Bool, bool
from ethereum.utils.numeric import is_zero
from starkware.cairo.common.alloc import alloc

from src.utils.dict import hashdict_read, hashdict_write, hashdict_get, dict_new_empty

Expand Down Expand Up @@ -970,3 +972,116 @@ func set_account_balance{poseidon_ptr: PoseidonBuiltin*, state: State}(
set_account(address, new_account);
return ();
}

func commit_transaction{
poseidon_ptr: PoseidonBuiltin*, state: State, transient_storage: TransientStorage
}() -> ExceptionalHalt* {
alloc_locals;

// Handle State snapshot
let snapshots = state.value._snapshots;
if (snapshots.value.len == 0) {
tempvar err = new ExceptionalHalt(IndexError);
return err;
}

let (local new_snapshots_raw_ptr) = alloc();
tempvar end_len = snapshots.value.len - 1;
tempvar i = 0;

loop:
let new_snapshots_raw_ptr = cast([fp], felt*);
let i = [ap - 1];
let end_len = [ap - 2];
tempvar continue = 1 - is_zero(end_len - i);
jmp end if continue != 0;

assert new_snapshots_raw_ptr[i] = cast(snapshots.value.data[i].value, felt);
tempvar end_len = end_len;
tempvar i = i + 1;
jmp loop;

end:
let new_snapshots_ptr = cast(
[fp], TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct*
);

let new_snapshots_len = snapshots.value.len - 1;
tempvar new_snapshots = ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256(
new ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct(
data=new TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256(
new_snapshots_ptr
),
len=new_snapshots_len,
),
);

// Clear created_accounts if no snapshots remain
if (new_snapshots_len == 0) {
let (empty_dict_ptr) = dict_new_empty();
tempvar created_accounts = SetAddress(
new SetAddressStruct(
dict_ptr_start=cast(empty_dict_ptr, SetAddressDictAccess*),
dict_ptr=cast(empty_dict_ptr, SetAddressDictAccess*),
),
);
} else {
tempvar created_accounts = state.value.created_accounts;
}

tempvar state = State(
new StateStruct(
_main_trie=state.value._main_trie,
_storage_tries=state.value._storage_tries,
_snapshots=new_snapshots,
created_accounts=created_accounts,
original_storage_tries=state.value.original_storage_tries,
),
);

// Handle transient_storage snapshot
let transient_snapshots = transient_storage.value._snapshots;
if (transient_snapshots.value.len == 0) {
tempvar err = new ExceptionalHalt(IndexError);
return err;
}

let (local new_transient_snapshots_raw_ptr) = alloc();
tempvar transient_end_len = transient_snapshots.value.len - 1;
tempvar j = 0;

transient_loop:
let new_transient_snapshots_raw_ptr = cast([fp], felt*);
let j = [ap - 1];
let transient_end_len = [ap - 2];
tempvar continue = 1 - is_zero(transient_end_len - j);
jmp transient_end if continue != 0;

assert new_transient_snapshots_raw_ptr[j] = cast(transient_snapshots.value.data[j].value, felt);
tempvar transient_end_len = transient_end_len;
tempvar j = j + 1;
jmp transient_loop;

transient_end:
let new_transient_snapshots_ptr = cast(
new_transient_snapshots_raw_ptr, MappingAddressTrieBytes32U256Struct*
);
let new_transient_snapshots_len = transient_snapshots.value.len - 1;

tempvar new_transient_snapshots = TransientStorageSnapshots(
new TransientStorageSnapshotsStruct(
data=new MappingAddressTrieBytes32U256(new_transient_snapshots_ptr),
len=new_transient_snapshots_len,
),
);

// Update transient storage with new snapshots
tempvar transient_storage = TransientStorage(
new TransientStorageStruct(
_tries=transient_storage.value._tries, _snapshots=new_transient_snapshots
),
);

tempvar ok = cast(0, ExceptionalHalt*);
return ok;
}
1 change: 1 addition & 0 deletions cairo/ethereum/cancun/vm/exceptions.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ struct ExceptionalHalt {
const StackUnderflowError = 'StackUnderflowError';
const StackOverflowError = 'StackOverflowError';
const OutOfGasError = 'OutOfGasError';
const IndexError = 'IndexError';
const InvalidOpcode = 'InvalidOpcode';
const InvalidJumpDestError = 'InvalidJumpDestError';
const StackDepthLimitError = 'StackDepthLimitError';
Expand Down
27 changes: 25 additions & 2 deletions cairo/tests/ethereum/cancun/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from ethereum_types.numeric import U256
from hypothesis import given
from hypothesis import given, reproduce_failure
from hypothesis import strategies as st
from hypothesis.strategies import composite

Expand All @@ -12,6 +12,7 @@
account_exists_and_is_empty,
account_has_code_or_nonce,
begin_transaction,
commit_transaction,
destroy_account,
destroy_storage,
get_account,
Expand All @@ -29,6 +30,7 @@
set_transient_storage,
)
from tests.utils.args_gen import State, TransientStorage
from tests.utils.errors import strict_raises
from tests.utils.strategies import address, bytes32, code, state, transient_storage


Expand Down Expand Up @@ -271,7 +273,7 @@ def test_set_transient_storage(
assert transient_storage_cairo == transient_storage


class TestBeginTransaction:
class TestSnapshots:
@given(state=..., transient_storage=...)
def test_begin_transaction(
self, cairo_run, state: State, transient_storage: TransientStorage
Expand All @@ -284,3 +286,24 @@ def test_begin_transaction(
begin_transaction(state, transient_storage)
assert state_cairo == state
assert transient_storage_cairo == transient_storage

@reproduce_failure("6.123.17", b"AABBAAAAAA==")
@given(state=..., transient_storage=...)
def test_commit_transaction(
self, cairo_run_py, state: State, transient_storage: TransientStorage
):
try:
state_cairo, transient_storage_cairo, err = cairo_run_py(
"commit_transaction",
state,
transient_storage,
)
assert len(err) == 0
except Exception as cairo_error:
with strict_raises(type(cairo_error)):
commit_transaction(state, transient_storage)
return

commit_transaction(state, transient_storage)
assert state_cairo == state
assert transient_storage_cairo == transient_storage
Loading