diff --git a/core/state/state_prove.go b/core/state/state_prove.go index 95c54988dc18..668e20eb74b4 100644 --- a/core/state/state_prove.go +++ b/core/state/state_prove.go @@ -74,15 +74,18 @@ func (s *StateDB) GetStorageTrieForProof(addr common.Address) (Trie, error) { // GetSecureTrieProof handle any interface with Prove (should be a Trie in most case) and // deliver the proof in bytes -func (s *StateDB) GetSecureTrieProof(trieProve TrieProve, key common.Hash) ([][]byte, error) { +func (s *StateDB) GetSecureTrieProof(trieProve TrieProve, key common.Hash) (FullProofList, common.Hash, error) { - var proof proofList + var proof FullProofList + var hash common.Hash var err error if s.IsZktrie() { key_s, _ := zkt.ToSecureKeyBytes(key.Bytes()) - err = trieProve.Prove(key_s.Bytes(), 0, &proof) + hash = common.BytesToHash(key_s.Bytes()) + err = trieProve.Prove(hash.Bytes(), 0, &proof) } else { - err = trieProve.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) + hash = common.BytesToHash(crypto.Keccak256(key.Bytes())) + err = trieProve.Prove(hash.Bytes(), 0, &proof) } - return proof, err + return proof, hash, err } diff --git a/core/state/statedb.go b/core/state/statedb.go index 6629a50eae57..ec876bc7882a 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -54,6 +54,33 @@ func (n *proofList) Delete(key []byte) error { panic("not supported") } +type fullProof struct { + Key []byte + Value []byte +} + +type FullProofList []fullProof + +func (n *FullProofList) Put(key []byte, value []byte) error { + *n = append(*n, fullProof{ + Key: key, + Value: value, + }) + return nil +} + +func (n *FullProofList) Delete(key []byte) error { + panic("not supported") +} + +func (n FullProofList) GetData() (out [][]byte) { + out = make([][]byte, 0, len(n)) + for _, i := range n { + out = append(out, i.Value) + } + return +} + // StateDB structs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: @@ -343,6 +370,21 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { return proof, err } +// GetFullProof returns the Merkle proof for a given account, with both node data and key +// also the key for address is provided +func (s *StateDB) GetFullProof(addr common.Address) (FullProofList, common.Hash, error) { + var hash common.Hash + if s.IsZktrie() { + addr_s, _ := zkt.ToSecureKeyBytes(addr.Bytes()) + hash = common.BytesToHash(addr_s.Bytes()) + } else { + hash = crypto.Keccak256Hash(addr.Bytes()) + } + var proof FullProofList + err := s.trie.Prove(hash[:], 0, &proof) + return proof, hash, err +} + func (s *StateDB) GetLiveStateAccount(addr common.Address) *types.StateAccount { obj, ok := s.stateObjects[addr] if !ok { @@ -361,7 +403,11 @@ func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, if trie == nil { return nil, errors.New("storage trie for requested address does not exist") } - return s.GetSecureTrieProof(trie, key) + proof, _, err := s.GetSecureTrieProof(trie, key) + if err != nil { + return nil, err + } + return proof.GetData(), nil } // GetCommittedState retrieves a value from the given account's committed storage trie. diff --git a/core/types/l2trace.go b/core/types/l2trace.go index b5a1ebd8d053..6c125def6f64 100644 --- a/core/types/l2trace.go +++ b/core/types/l2trace.go @@ -45,11 +45,37 @@ type StorageTrace struct { // All storage proofs BEFORE execution StorageProofs map[string]map[string][]hexutil.Bytes `json:"storageProofs,omitempty"` + // The "flatten" db nodes + FlattenProofs map[common.Hash]hexutil.Bytes `json:"flattenProofs,omitempty"` + + // The hash of secured addresses + AddressHashes map[common.Address]common.Hash `json:"addressHashes,omitempty"` + // The hash of secured store key + StoreKeyHashes map[common.Hash]common.Hash `json:"storeKeyHashes,omitempty"` + // Node entries for deletion, no need to distinguish what it is from, just read them // into the partial db DeletionProofs []hexutil.Bytes `json:"deletionProofs,omitempty"` } +func (tr *StorageTrace) ApplyFilter(legacy bool) { + if legacy { + tr.FlattenProofs = nil + tr.AddressHashes = nil + tr.StoreKeyHashes = nil + } else { + for k := range tr.Proofs { + tr.Proofs[k] = []hexutil.Bytes{} + } + for _, st := range tr.StorageProofs { + for k := range st { + st[k] = []hexutil.Bytes{} + } + } + tr.DeletionProofs = []hexutil.Bytes{} + } +} + // ExecutionResult groups all structured logs emitted by the EVM // while replaying a transaction in debug mode as well as transaction // execution status, the amount of gas used and the return value diff --git a/core/vm/logger.go b/core/vm/logger.go index 740c7b93cc05..156aba21f140 100644 --- a/core/vm/logger.go +++ b/core/vm/logger.go @@ -50,12 +50,15 @@ func (s Storage) Copy() Storage { // LogConfig are the configuration options for structured logger the EVM type LogConfig struct { - EnableMemory bool // enable memory capture - DisableStack bool // disable stack capture - DisableStorage bool // disable storage capture - EnableReturnData bool // enable return data capture - Debug bool // print output during capture end - Limit int // maximum length of output, but zero means unlimited + EnableMemory bool // enable memory capture + DisableStack bool // disable stack capture + DisableStorage bool // disable storage capture + EnableReturnData bool // enable return data capture + Debug bool // print output during capture end + Limit int // maximum length of output, but zero means unlimited + StorageProofFormat *string // format of storage proofs, can be + // "legacy" (use the legacy proof format) or + // "union" (output both flatten and legacy proof) // Chain overrides, can be used to execute a trace using future fork rules Overrides *params.ChainConfig `json:"overrides,omitempty"` } diff --git a/eth/tracers/api_blocktrace.go b/eth/tracers/api_blocktrace.go index a52daa29983d..c69b9f9fbec2 100644 --- a/eth/tracers/api_blocktrace.go +++ b/eth/tracers/api_blocktrace.go @@ -22,8 +22,14 @@ type TraceBlock interface { GetTxBlockTraceOnTopOfBlock(ctx context.Context, tx *types.Transaction, blockNrOrHash rpc.BlockNumberOrHash, config *TraceConfig) (*types.BlockTrace, error) } +type TracerEnv interface { + ResetForPartialTrace(*types.Block) error + GetBlockTrace(*types.Block) (*types.BlockTrace, error) +} + type scrollTracerWrapper interface { CreateTraceEnvAndGetBlockTrace(*params.ChainConfig, core.ChainContext, consensus.Engine, ethdb.Database, *state.StateDB, *types.Block, *types.Block, bool) (*types.BlockTrace, error) + CreateTraceEnv(*params.ChainConfig, core.ChainContext, consensus.Engine, ethdb.Database, *state.StateDB, *types.Block, *types.Block, bool) (TracerEnv, error) } // GetBlockTraceByNumberOrHash replays the block and returns the structured BlockTrace by hash or number. @@ -79,8 +85,82 @@ func (api *API) GetTxBlockTraceOnTopOfBlock(ctx context.Context, tx *types.Trans return api.createTraceEnvAndGetBlockTrace(ctx, config, block) } +func (api *API) GetTxByTxBlockTrace(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash, config *TraceConfig) ([]*types.BlockTrace, error) { + if api.scrollTracerWrapper == nil { + return nil, errNoScrollTracerWrapper + } + + // Try to retrieve the specified block + var ( + err error + block *types.Block + ) + if number, ok := blockNrOrHash.Number(); ok { + block, err = api.blockByNumber(ctx, number) + } else if hash, ok := blockNrOrHash.Hash(); ok { + block, err = api.blockByHash(ctx, hash) + } else { + return nil, errors.New("invalid arguments; neither block number nor hash specified") + } + if err != nil { + return nil, err + } + if block.NumberU64() == 0 { + return nil, errors.New("genesis is not traceable") + } + + if config == nil { + config = &TraceConfig{ + LogConfig: &vm.LogConfig{ + DisableStorage: true, + DisableStack: true, + EnableMemory: false, + EnableReturnData: true, + }, + } + } else if config.Tracer != nil { + config.Tracer = nil + log.Warn("Tracer params is unsupported") + } + + parent, err := api.blockByNumberAndHash(ctx, rpc.BlockNumber(block.NumberU64()-1), block.ParentHash()) + if err != nil { + return nil, err + } + reexec := defaultTraceReexec + if config != nil && config.Reexec != nil { + reexec = *config.Reexec + } + statedb, err := api.backend.StateAtBlock(ctx, parent, reexec, nil, true, true) + if err != nil { + return nil, err + } + + chaindb := api.backend.ChainDb() + traces := []*types.BlockTrace{} + traceEnv, err := api.scrollTracerWrapper.CreateTraceEnv(api.backend.ChainConfig(), api.chainContext(ctx), api.backend.Engine(), chaindb, statedb, parent, block, true) + if err != nil { + return nil, err + } + for _, tx := range block.Transactions() { + singleTxBlock := types.NewBlockWithHeader(block.Header()).WithBody([]*types.Transaction{tx}, nil) + if err := traceEnv.ResetForPartialTrace(singleTxBlock); err != nil { + return nil, err + } + trace, err := traceEnv.GetBlockTrace(singleTxBlock) + if err != nil { + return nil, err + } + // trace.StorageTrace.ApplyFilter(false) + traces = append(traces, trace) + } + return traces, nil +} + // Make trace environment for current block, and then get the trace for the block. func (api *API) createTraceEnvAndGetBlockTrace(ctx context.Context, config *TraceConfig, block *types.Block) (*types.BlockTrace, error) { + legacyStorageTrace := true + unionStorageTrace := false if config == nil { config = &TraceConfig{ LogConfig: &vm.LogConfig{ @@ -95,6 +175,14 @@ func (api *API) createTraceEnvAndGetBlockTrace(ctx context.Context, config *Trac log.Warn("Tracer params is unsupported") } + if config.LogConfig != nil && config.StorageProofFormat != nil { + if *config.StorageProofFormat == "flatten" { + legacyStorageTrace = false + } else if *config.StorageProofFormat == "union" { + unionStorageTrace = true + } + } + parent, err := api.blockByNumberAndHash(ctx, rpc.BlockNumber(block.NumberU64()-1), block.ParentHash()) if err != nil { return nil, err @@ -109,5 +197,15 @@ func (api *API) createTraceEnvAndGetBlockTrace(ctx context.Context, config *Trac } chaindb := api.backend.ChainDb() - return api.scrollTracerWrapper.CreateTraceEnvAndGetBlockTrace(api.backend.ChainConfig(), api.chainContext(ctx), api.backend.Engine(), chaindb, statedb, parent, block, true) + l2Trace, err := api.scrollTracerWrapper.CreateTraceEnvAndGetBlockTrace(api.backend.ChainConfig(), api.chainContext(ctx), api.backend.Engine(), chaindb, statedb, parent, block, true) + if err != nil { + return nil, err + } + if !unionStorageTrace { + l2Trace.StorageTrace.ApplyFilter(legacyStorageTrace) + for _, st := range l2Trace.TxStorageTraces { + st.ApplyFilter(legacyStorageTrace) + } + } + return l2Trace, nil } diff --git a/rollup/tracing/tracing.go b/rollup/tracing/tracing.go index 1e667a4a9c29..b6693bedf490 100644 --- a/rollup/tracing/tracing.go +++ b/rollup/tracing/tracing.go @@ -8,6 +8,8 @@ import ( "sync" "time" + zktrie "github.com/scroll-tech/zktrie/types" + "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/common/hexutil" "github.com/scroll-tech/go-ethereum/consensus" @@ -40,11 +42,19 @@ var ( // TracerWrapper implements ScrollTracerWrapper interface type TracerWrapper struct{} +// alias for proof list +type proofList = state.FullProofList + // NewTracerWrapper TracerWrapper creates a new TracerWrapper func NewTracerWrapper() *TracerWrapper { return &TracerWrapper{} } +func (tw *TracerWrapper) CreateTraceEnv(chainConfig *params.ChainConfig, chainContext core.ChainContext, engine consensus.Engine, chaindb ethdb.Database, statedb *state.StateDB, parent *types.Block, block *types.Block, commitAfterApply bool) (tracers.TracerEnv, error) { + traceEnv, err := CreateTraceEnv(chainConfig, chainContext, engine, chaindb, statedb, parent, block, commitAfterApply) + return traceEnv, err +} + // CreateTraceEnvAndGetBlockTrace wraps the whole block tracing logic for a block func (tw *TracerWrapper) CreateTraceEnvAndGetBlockTrace(chainConfig *params.ChainConfig, chainContext core.ChainContext, engine consensus.Engine, chaindb ethdb.Database, statedb *state.StateDB, parent *types.Block, block *types.Block, commitAfterApply bool) (*types.BlockTrace, error) { traceEnv, err := CreateTraceEnv(chainConfig, chainContext, engine, chaindb, statedb, parent, block, commitAfterApply) @@ -69,7 +79,7 @@ type TraceEnv struct { // The following Mutexes are used to protect against parallel read/write, // since txs are executed in parallel. pMu sync.Mutex // for `TraceEnv.StorageTrace.Proofs` - sMu sync.Mutex // for `TraceEnv.state`` + sMu sync.Mutex // for `TraceEnv.state` cMu sync.Mutex // for `TraceEnv.Codes` *types.StorageTrace @@ -109,10 +119,13 @@ func CreateTraceEnvHelper(chainConfig *params.ChainConfig, logConfig *vm.LogConf state: statedb, blockCtx: blockCtx, StorageTrace: &types.StorageTrace{ - RootBefore: rootBefore, - RootAfter: block.Root(), - Proofs: make(map[string][]hexutil.Bytes), - StorageProofs: make(map[string]map[string][]hexutil.Bytes), + RootBefore: rootBefore, + RootAfter: block.Root(), + Proofs: make(map[string][]hexutil.Bytes), + StorageProofs: make(map[string]map[string][]hexutil.Bytes), + FlattenProofs: make(map[common.Hash]hexutil.Bytes), + AddressHashes: make(map[common.Address]common.Hash), + StoreKeyHashes: make(map[common.Hash]common.Hash), }, Codes: make(map[common.Hash]vm.CodeInfo), ZkTrieTracer: make(map[string]state.ZktrieProofTracer), @@ -170,17 +183,55 @@ func CreateTraceEnv(chainConfig *params.ChainConfig, chainContext core.ChainCont key := coinbase.String() if _, exist := env.Proofs[key]; !exist { - proof, err := env.state.GetProof(coinbase) + proof, addrHash, err := env.state.GetFullProof(coinbase) if err != nil { log.Error("Proof for coinbase not available", "coinbase", coinbase, "error", err) // but we still mark the proofs map with nil array } - env.Proofs[key] = types.WrapProof(proof) + // TODO: + env.AddressHashes[coinbase] = addrHash + env.fillFlattenStorageProof(nil, proof) + env.Proofs[key] = types.WrapProof(proof.GetData()) } return env, nil } +func (env *TraceEnv) ResetForPartialTrace(partialBlk *types.Block) error { + + if env.StorageTrace == nil { + return fmt.Errorf("not init") + } + + // TODO: can we chained the RootBefore / After? + oldStorage := env.StorageTrace + + // only reset which can be reset + env.signer = types.MakeSigner(env.chainConfig, partialBlk.Number()) + env.StorageTrace = &types.StorageTrace{ + RootBefore: oldStorage.RootBefore, + RootAfter: partialBlk.Root(), + Proofs: make(map[string][]hexutil.Bytes), + StorageProofs: make(map[string]map[string][]hexutil.Bytes), + FlattenProofs: make(map[common.Hash]hexutil.Bytes), + AddressHashes: make(map[common.Address]common.Hash), + StoreKeyHashes: make(map[common.Hash]common.Hash), + } + env.Codes = make(map[common.Hash]vm.CodeInfo) + env.ExecutionResults = make([]*types.ExecutionResult, partialBlk.Transactions().Len()) + env.TxStorageTraces = make([]*types.StorageTrace, partialBlk.Transactions().Len()) + + // still need to restore coinbase's proof .... + proof, addrHash, err := env.state.GetFullProof(env.coinbase) + if err == nil { + env.AddressHashes[env.coinbase] = addrHash + env.fillFlattenStorageProof(nil, proof) + env.Proofs[env.coinbase.String()] = types.WrapProof(proof.GetData()) + } + + return nil +} + func (env *TraceEnv) GetBlockTrace(block *types.Block) (*types.BlockTrace, error) { // Execute all the transaction contained within the block concurrently var ( @@ -249,21 +300,32 @@ func (env *TraceEnv) GetBlockTrace(block *types.Block) (*types.BlockTrace, error pend.Wait() // after all tx has been traced, collect "deletion proof" for zktrie + deleteionProofs := make(map[common.Hash]hexutil.Bytes) + for _, tracer := range env.ZkTrieTracer { delProofs, err := tracer.GetDeletionProofs() if err != nil { log.Error("deletion proof failure", "error", err) } else { - for _, proof := range delProofs { + for key, proof := range delProofs { + deleteionProofs[common.BytesToHash(key.Bytes())] = proof env.DeletionProofs = append(env.DeletionProofs, proof) } } } + //TODO: merge deletion proof + for k, v := range deleteionProofs { + env.FlattenProofs[k] = v + } + // build dummy per-tx deletion proof for _, txStorageTrace := range env.TxStorageTraces { if txStorageTrace != nil { txStorageTrace.DeletionProofs = env.DeletionProofs + for k, v := range deleteionProofs { + txStorageTrace.FlattenProofs[k] = v + } } } @@ -372,8 +434,11 @@ func (env *TraceEnv) getTxResult(state *state.StateDB, index int, block *types.B } txStorageTrace := &types.StorageTrace{ - Proofs: make(map[string][]hexutil.Bytes), - StorageProofs: make(map[string]map[string][]hexutil.Bytes), + Proofs: make(map[string][]hexutil.Bytes), + StorageProofs: make(map[string]map[string][]hexutil.Bytes), + FlattenProofs: make(map[common.Hash]hexutil.Bytes), + AddressHashes: make(map[common.Address]common.Hash), + StoreKeyHashes: make(map[common.Hash]common.Hash), } // still we have no state root for per tx, only set the head and tail if index == 0 { @@ -407,13 +472,17 @@ func (env *TraceEnv) getTxResult(state *state.StateDB, index int, block *types.B if existed { continue } - proof, err := state.GetProof(addr) + proof, addrHash, err := state.GetFullProof(addr) if err != nil { log.Error("Proof not available", "address", addrStr, "error", err) // but we still mark the proofs map with nil array } - wrappedProof := types.WrapProof(proof) + wrappedProof := types.WrapProof(proof.GetData()) env.pMu.Lock() + // TODO: + env.fillFlattenStorageProof(txStorageTrace, proof) + txStorageTrace.AddressHashes[addr] = addrHash + env.AddressHashes[addr] = addrHash env.Proofs[addrStr] = wrappedProof txStorageTrace.Proofs[addrStr] = wrappedProof env.pMu.Unlock() @@ -465,18 +534,27 @@ func (env *TraceEnv) getTxResult(state *state.StateDB, index int, block *types.B } env.sMu.Unlock() - var proof [][]byte + var proof proofList + var keyHash common.Hash var err error if zktrieTracer.Available() { - proof, err = state.GetSecureTrieProof(zktrieTracer, key) + proof, keyHash, err = state.GetSecureTrieProof(zktrieTracer, key) } else { - proof, err = state.GetSecureTrieProof(trie, key) + proof, keyHash, err = state.GetSecureTrieProof(trie, key) } if err != nil { log.Error("Storage proof not available", "error", err, "address", addrStr, "key", keyStr) // but we still mark the proofs map with nil array } - wrappedProof := types.WrapProof(proof) + + env.pMu.Lock() + // TODO: + env.fillFlattenStorageProof(txStorageTrace, proof) + txStorageTrace.StoreKeyHashes[key] = keyHash + env.StoreKeyHashes[key] = keyHash + env.pMu.Unlock() + + wrappedProof := types.WrapProof(proof.GetData()) env.sMu.Lock() txm[keyStr] = wrappedProof m[keyStr] = wrappedProof @@ -515,6 +593,18 @@ func (env *TraceEnv) getTxResult(state *state.StateDB, index int, block *types.B return nil } +func (env *TraceEnv) fillFlattenStorageProof(trace *types.StorageTrace, proof proofList) { + for _, i := range proof { + // the "raw key" is in fact a zktrie.Hash (bytes stored with little-endian) + // we need to convert it into big-endian + hash := common.BytesToHash(zktrie.NewHashFromBytes(i.Key)[:]) + env.FlattenProofs[hash] = i.Value + if trace != nil { + trace.FlattenProofs[hash] = i.Value + } + } +} + // fillBlockTrace content after all the txs are finished running. func (env *TraceEnv) fillBlockTrace(block *types.Block) (*types.BlockTrace, error) { defer func(t time.Time) { @@ -543,10 +633,13 @@ func (env *TraceEnv) fillBlockTrace(block *types.Block) (*types.BlockTrace, erro for addr, storages := range intrinsicStorageProofs { if _, existed := env.Proofs[addr.String()]; !existed { - if proof, err := statedb.GetProof(addr); err != nil { + if proof, addrHash, err := statedb.GetFullProof(addr); err != nil { log.Error("Proof for intrinstic address not available", "error", err, "address", addr) } else { - env.Proofs[addr.String()] = types.WrapProof(proof) + // TODO: + env.fillFlattenStorageProof(nil, proof) + env.AddressHashes[addr] = addrHash + env.Proofs[addr.String()] = types.WrapProof(proof.GetData()) } } @@ -558,10 +651,13 @@ func (env *TraceEnv) fillBlockTrace(block *types.Block) (*types.BlockTrace, erro if _, existed := env.StorageProofs[addr.String()][slot.String()]; !existed { if trie, err := statedb.GetStorageTrieForProof(addr); err != nil { log.Error("Storage proof for intrinstic address not available", "error", err, "address", addr) - } else if proof, err := statedb.GetSecureTrieProof(trie, slot); err != nil { + } else if proof, keyHash, err := statedb.GetSecureTrieProof(trie, slot); err != nil { log.Error("Get storage proof for intrinstic address failed", "error", err, "address", addr, "slot", slot) } else { - env.StorageProofs[addr.String()][slot.String()] = types.WrapProof(proof) + // TODO: + env.fillFlattenStorageProof(nil, proof) + env.StoreKeyHashes[slot] = keyHash + env.StorageProofs[addr.String()][slot.String()] = types.WrapProof(proof.GetData()) } } } diff --git a/trie/zk_trie_proof_test.go b/trie/zk_trie_proof_test.go index aec28fde5aad..f494bbf64a88 100644 --- a/trie/zk_trie_proof_test.go +++ b/trie/zk_trie_proof_test.go @@ -243,12 +243,12 @@ func TestProofWithDeletion(t *testing.T) { assert.NoError(t, err) //assert.Equal(t, len(sibling1), len(delTracer.GetProofs())) - siblings, err := proofTracer.GetDeletionProofs() + siblings, err := proofTracer.GetDeletionProofNodes() assert.NoError(t, err) assert.Equal(t, 0, len(siblings)) proofTracer.MarkDeletion(s_key1.Bytes()) - siblings, err = proofTracer.GetDeletionProofs() + siblings, err = proofTracer.GetDeletionProofNodes() assert.NoError(t, err) assert.Equal(t, 1, len(siblings)) l := len(siblings[0]) @@ -259,7 +259,7 @@ func TestProofWithDeletion(t *testing.T) { // Marking a key that is currently not hit (but terminated by an empty node) // also causes it to be added to the deletion proof proofTracer.MarkDeletion(s_key2.Bytes()) - siblings, err = proofTracer.GetDeletionProofs() + siblings, err = proofTracer.GetDeletionProofNodes() assert.NoError(t, err) assert.Equal(t, 2, len(siblings)) @@ -277,12 +277,12 @@ func TestProofWithDeletion(t *testing.T) { assert.NoError(t, err) proofTracer.MarkDeletion(s_key1.Bytes()) - siblings, err = proofTracer.GetDeletionProofs() + siblings, err = proofTracer.GetDeletionProofNodes() assert.NoError(t, err) assert.Equal(t, 1, len(siblings)) proofTracer.MarkDeletion(s_key2.Bytes()) - siblings, err = proofTracer.GetDeletionProofs() + siblings, err = proofTracer.GetDeletionProofNodes() assert.NoError(t, err) assert.Equal(t, 2, len(siblings)) diff --git a/trie/zktrie_deletionproof.go b/trie/zktrie_deletionproof.go index 7ae2a11ff87b..ff221674b772 100644 --- a/trie/zktrie_deletionproof.go +++ b/trie/zktrie_deletionproof.go @@ -51,13 +51,28 @@ func (t *ProofTracer) Merge(another *ProofTracer) *ProofTracer { return t } +// GetDeletionProofNodes extract the value part from deletion proofs +func (t *ProofTracer) GetDeletionProofNodes() ([][]byte, error) { + + retMap, err := t.GetDeletionProofs() + if err != nil { + return nil, err + } + + var ret [][]byte + for _, bt := range retMap { + ret = append(ret, bt) + } + return ret, nil +} + // GetDeletionProofs generate current deletionTracer and collect deletion proofs // which is possible to be used from all rawPaths, which enabling witness generator // to predict the final state root after executing any deletion // along any of the rawpath, no matter of the deletion occurs in any position of the mpt ops // Note the collected sibling node has no key along with it since witness generator would // always decode the node for its purpose -func (t *ProofTracer) GetDeletionProofs() ([][]byte, error) { +func (t *ProofTracer) GetDeletionProofs() (map[zkt.Hash][]byte, error) { retMap := map[zkt.Hash][]byte{} @@ -93,12 +108,7 @@ func (t *ProofTracer) GetDeletionProofs() ([][]byte, error) { } } - var ret [][]byte - for _, bt := range retMap { - ret = append(ret, bt) - } - - return ret, nil + return retMap, nil }