From c0b6fa8032662c4e272eeb4fd585ac70f69bda1c Mon Sep 17 00:00:00 2001 From: Oba Date: Fri, 17 Jan 2025 15:21:10 +0100 Subject: [PATCH] feat: commit transaction --- cairo/ethereum/cancun/state.cairo | 115 ++++++++++++++++++++++ cairo/ethereum/cancun/vm/exceptions.cairo | 1 + cairo/tests/ethereum/cancun/test_state.py | 27 ++++- 3 files changed, 141 insertions(+), 2 deletions(-) diff --git a/cairo/ethereum/cancun/state.cairo b/cairo/ethereum/cancun/state.cairo index 8dbe2c4d..63d22325 100644 --- a/cairo/ethereum/cancun/state.cairo +++ b/cairo/ethereum/cancun/state.cairo @@ -30,9 +30,11 @@ from ethereum.cancun.trie import ( copy_TrieAddressOptionalAccount, copy_trieBytes32U256, ) +from ethereum.cancun.vm.exceptions import ExceptionalHalt, IndexError from ethereum_types.bytes import Bytes, Bytes32 from ethereum_types.numeric import U256, U256Struct, Bool, bool from ethereum.utils.numeric import is_zero +from starkware.cairo.common.alloc import alloc from src.utils.dict import hashdict_read, hashdict_write, hashdict_get, dict_new_empty @@ -970,3 +972,116 @@ func set_account_balance{poseidon_ptr: PoseidonBuiltin*, state: State}( set_account(address, new_account); return (); } + +func commit_transaction{ + poseidon_ptr: PoseidonBuiltin*, state: State, transient_storage: TransientStorage +}() -> ExceptionalHalt* { + alloc_locals; + + // Handle State snapshot + let snapshots = state.value._snapshots; + if (snapshots.value.len == 0) { + tempvar err = new ExceptionalHalt(IndexError); + return err; + } + + let (local new_snapshots_raw_ptr) = alloc(); + tempvar end_len = snapshots.value.len - 1; + tempvar i = 0; + + loop: + let new_snapshots_raw_ptr = cast([fp], felt*); + let i = [ap - 1]; + let end_len = [ap - 2]; + tempvar continue = 1 - is_zero(end_len - i); + jmp end if continue != 0; + + assert new_snapshots_raw_ptr[i] = cast(snapshots.value.data[i].value, felt); + tempvar end_len = end_len; + tempvar i = i + 1; + jmp loop; + + end: + let new_snapshots_ptr = cast( + [fp], TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct* + ); + + let new_snapshots_len = snapshots.value.len - 1; + tempvar new_snapshots = ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256( + new ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct( + data=new TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256( + new_snapshots_ptr + ), + len=new_snapshots_len, + ), + ); + + // Clear created_accounts if no snapshots remain + if (new_snapshots_len == 0) { + let (empty_dict_ptr) = dict_new_empty(); + tempvar created_accounts = SetAddress( + new SetAddressStruct( + dict_ptr_start=cast(empty_dict_ptr, SetAddressDictAccess*), + dict_ptr=cast(empty_dict_ptr, SetAddressDictAccess*), + ), + ); + } else { + tempvar created_accounts = state.value.created_accounts; + } + + tempvar state = State( + new StateStruct( + _main_trie=state.value._main_trie, + _storage_tries=state.value._storage_tries, + _snapshots=new_snapshots, + created_accounts=created_accounts, + original_storage_tries=state.value.original_storage_tries, + ), + ); + + // Handle transient_storage snapshot + let transient_snapshots = transient_storage.value._snapshots; + if (transient_snapshots.value.len == 0) { + tempvar err = new ExceptionalHalt(IndexError); + return err; + } + + let (local new_transient_snapshots_raw_ptr) = alloc(); + tempvar transient_end_len = transient_snapshots.value.len - 1; + tempvar j = 0; + + transient_loop: + let new_transient_snapshots_raw_ptr = cast([fp], felt*); + let j = [ap - 1]; + let transient_end_len = [ap - 2]; + tempvar continue = 1 - is_zero(transient_end_len - j); + jmp transient_end if continue != 0; + + assert new_transient_snapshots_raw_ptr[j] = cast(transient_snapshots.value.data[j].value, felt); + tempvar transient_end_len = transient_end_len; + tempvar j = j + 1; + jmp transient_loop; + + transient_end: + let new_transient_snapshots_ptr = cast( + new_transient_snapshots_raw_ptr, MappingAddressTrieBytes32U256Struct* + ); + let new_transient_snapshots_len = transient_snapshots.value.len - 1; + + tempvar new_transient_snapshots = TransientStorageSnapshots( + new TransientStorageSnapshotsStruct( + data=new MappingAddressTrieBytes32U256(new_transient_snapshots_ptr), + len=new_transient_snapshots_len, + ), + ); + + // Update transient storage with new snapshots + tempvar transient_storage = TransientStorage( + new TransientStorageStruct( + _tries=transient_storage.value._tries, _snapshots=new_transient_snapshots + ), + ); + + tempvar ok = cast(0, ExceptionalHalt*); + return ok; +} diff --git a/cairo/ethereum/cancun/vm/exceptions.cairo b/cairo/ethereum/cancun/vm/exceptions.cairo index b2cf7bfb..8da8b980 100644 --- a/cairo/ethereum/cancun/vm/exceptions.cairo +++ b/cairo/ethereum/cancun/vm/exceptions.cairo @@ -5,6 +5,7 @@ struct ExceptionalHalt { const StackUnderflowError = 'StackUnderflowError'; const StackOverflowError = 'StackOverflowError'; const OutOfGasError = 'OutOfGasError'; +const IndexError = 'IndexError'; const InvalidOpcode = 'InvalidOpcode'; const InvalidJumpDestError = 'InvalidJumpDestError'; const StackDepthLimitError = 'StackDepthLimitError'; diff --git a/cairo/tests/ethereum/cancun/test_state.py b/cairo/tests/ethereum/cancun/test_state.py index d26ac552..d8d37c91 100644 --- a/cairo/tests/ethereum/cancun/test_state.py +++ b/cairo/tests/ethereum/cancun/test_state.py @@ -2,7 +2,7 @@ import pytest from ethereum_types.numeric import U256 -from hypothesis import given +from hypothesis import given, reproduce_failure from hypothesis import strategies as st from hypothesis.strategies import composite @@ -12,6 +12,7 @@ account_exists_and_is_empty, account_has_code_or_nonce, begin_transaction, + commit_transaction, destroy_account, destroy_storage, get_account, @@ -29,6 +30,7 @@ set_transient_storage, ) from tests.utils.args_gen import State, TransientStorage +from tests.utils.errors import strict_raises from tests.utils.strategies import address, bytes32, code, state, transient_storage @@ -271,7 +273,7 @@ def test_set_transient_storage( assert transient_storage_cairo == transient_storage -class TestBeginTransaction: +class TestSnapshots: @given(state=..., transient_storage=...) def test_begin_transaction( self, cairo_run, state: State, transient_storage: TransientStorage @@ -284,3 +286,24 @@ def test_begin_transaction( begin_transaction(state, transient_storage) assert state_cairo == state assert transient_storage_cairo == transient_storage + + @reproduce_failure("6.123.17", b"AABBAAAAAA==") + @given(state=..., transient_storage=...) + def test_commit_transaction( + self, cairo_run_py, state: State, transient_storage: TransientStorage + ): + try: + state_cairo, transient_storage_cairo, err = cairo_run_py( + "commit_transaction", + state, + transient_storage, + ) + assert len(err) == 0 + except Exception as cairo_error: + with strict_raises(type(cairo_error)): + commit_transaction(state, transient_storage) + return + + commit_transaction(state, transient_storage) + assert state_cairo == state + assert transient_storage_cairo == transient_storage