From 78b58dbd8cbd94639ae4699fffd3d3157abb3db9 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 9 Jan 2025 12:54:25 +0200 Subject: [PATCH 1/8] resolve TODOs: add maxSize check for each dfsTrieIterator --- common/interface.go | 2 +- testscommon/state/testTrie.go | 16 ++++++ .../dfsTrieIterator/dfsTrieIterator.go | 8 ++- .../dfsTrieIterator/dfsTrieIterator_test.go | 56 +++++++++++++++---- trie/leavesRetriever/leavesRetriever.go | 24 ++++++-- trie/leavesRetriever/leavesRetriever_test.go | 14 +++++ 6 files changed, 99 insertions(+), 21 deletions(-) diff --git a/common/interface.go b/common/interface.go index 696d4b0182c..efa6b5116fd 100644 --- a/common/interface.go +++ b/common/interface.go @@ -385,7 +385,7 @@ type TrieNodeData interface { // DfsIterator is used to iterate the trie nodes in a depth-first search manner type DfsIterator interface { - GetLeaves(numLeaves int, ctx context.Context) (map[string]string, error) + GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) GetIteratorId() []byte Clone() DfsIterator FinishedIteration() bool diff --git a/testscommon/state/testTrie.go b/testscommon/state/testTrie.go index 8744009aa18..bc33a5e2b6b 100644 --- a/testscommon/state/testTrie.go +++ b/testscommon/state/testTrie.go @@ -53,3 +53,19 @@ func AddDataToTrie(tr common.Trie, numLeaves int) { } _ = tr.Commit() } + +// GetTrieWithData returns a trie with some data. +// The added data builds a rootNode that is a branch with 2 leaves and 1 extension node which will have 4 leaves when traversed; +// this way the size of the iterator will be highest when the extension node is reached but 2 leaves will +// have already been retrieved +func GetTrieWithData() common.Trie { + tr := GetNewTrie() + _ = tr.Update([]byte("key1"), []byte("value1")) + _ = tr.Update([]byte("key2"), []byte("value2")) + _ = tr.Update([]byte("key13"), []byte("value3")) + _ = tr.Update([]byte("key23"), []byte("value4")) + _ = tr.Update([]byte("key33"), []byte("value4")) + _ = tr.Update([]byte("key43"), []byte("value4")) + _ = tr.Commit() + return tr +} diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go index 5b47e2c1dd2..2224416e282 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go @@ -53,11 +53,14 @@ func NewIterator(rootHash []byte, db common.TrieStorageInteractor, marshaller ma } // GetLeaves retrieves leaves from the trie. It stops either when the number of leaves is reached or the context is done. -// TODO add a maxSize that will stop the iteration when the size is reached -func (it *dfsIterator) GetLeaves(numLeaves int, ctx context.Context) (map[string]string, error) { +func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) { retrievedLeaves := make(map[string]string) for { nextNodes := make([]common.TrieNodeData, 0) + if it.size >= maxSize { + return retrievedLeaves, nil + } + if len(retrievedLeaves) >= numLeaves { return retrievedLeaves, nil } @@ -140,7 +143,6 @@ func (it *dfsIterator) IsInterfaceNil() bool { return it == nil } -// TODO add context nil test func checkContextDone(ctx context.Context) bool { if ctx == nil { return false diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go index 4489a43a437..b8d71b40173 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go @@ -3,6 +3,7 @@ package dfsTrieIterator import ( "context" "fmt" + "math" "testing" "github.com/multiversx/mx-chain-go/testscommon" @@ -14,6 +15,8 @@ import ( "github.com/stretchr/testify/assert" ) +var maxSize = uint64(math.MaxUint64) + func TestNewIterator(t *testing.T) { t.Parallel() @@ -94,7 +97,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator(rootHash, dbWrapper, marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, ctx) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, ctx) assert.Nil(t, err) assert.Equal(t, expectedNumLeaves, len(trieData)) }) @@ -109,7 +112,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, context.Background()) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, numLeaves, len(trieData)) }) @@ -125,7 +128,22 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(17, context.Background()) + trieData, err := iterator.GetLeaves(17, maxSize, context.Background()) + assert.Nil(t, err) + assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) + }) + t.Run("max size reached returns retrieved leaves and saves iterator context", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetTrieWithData() + expectedNumRetrievedLeaves := 2 + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + + iteratorMaxSize := uint64(100) + trieData, err := iterator.GetLeaves(5, iteratorMaxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) }) @@ -142,7 +160,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { numRetrievedLeaves := 0 numIterations := 0 for numRetrievedLeaves < numLeaves { - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -152,6 +170,22 @@ func TestDfsIterator_GetLeaves(t *testing.T) { assert.Equal(t, numLeaves, numRetrievedLeaves) assert.Equal(t, 5, numIterations) }) + t.Run("retrieve leaves with nil iterator does not panic", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + expectedNumRetrievedLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + + trieData, err := iterator.GetLeaves(numLeaves, maxSize, nil) + assert.Nil(t, err) + assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) + }) } func TestDfsIterator_GetIteratorId(t *testing.T) { @@ -169,7 +203,7 @@ func TestDfsIterator_GetIteratorId(t *testing.T) { iteratorId := hasher.Compute(string(append(rootHash, iterator.nextNodes[0].GetData()...))) assert.Equal(t, iteratorId, iterator.GetIteratorId()) - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -211,7 +245,7 @@ func TestDfsIterator_FinishedIteration(t *testing.T) { numRetrievedLeaves := 0 for numRetrievedLeaves < numLeaves { assert.False(t, iterator.FinishedIteration()) - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -237,23 +271,23 @@ func TestDfsIterator_Size(t *testing.T) { iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) assert.Equal(t, uint64(362), iterator.Size()) // 10 branch nodes + 1 root hash - _, err := iterator.GetLeaves(5, context.Background()) + _, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(331), iterator.Size()) // 8 branch nodes + 1 leaf node + 1 root hash - _, err = iterator.GetLeaves(5, context.Background()) + _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(300), iterator.Size()) // 6 branch nodes + 2 leaf node + 1 root hash - _, err = iterator.GetLeaves(5, context.Background()) + _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(197), iterator.Size()) // 5 branch nodes + 1 root hash - _, err = iterator.GetLeaves(5, context.Background()) + _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(133), iterator.Size()) // 2 branch nodes + 1 leaf node + 1 root hash - _, err = iterator.GetLeaves(5, context.Background()) + _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) assert.Equal(t, uint64(32), iterator.Size()) // 1 root hash } diff --git a/trie/leavesRetriever/leavesRetriever.go b/trie/leavesRetriever/leavesRetriever.go index 89a11569bc0..5630a3ce7e1 100644 --- a/trie/leavesRetriever/leavesRetriever.go +++ b/trie/leavesRetriever/leavesRetriever.go @@ -78,7 +78,7 @@ func (lr *leavesRetriever) getLeavesFromCheckpoint(numLeaves int, iterator commo } func (lr *leavesRetriever) getLeavesFromIterator(iterator common.DfsIterator, numLeaves int, ctx context.Context) (map[string]string, []byte, error) { - leaves, err := iterator.GetLeaves(numLeaves, ctx) + leaves, err := iterator.GetLeaves(numLeaves, lr.maxSize, ctx) if err != nil { return nil, nil, err } @@ -92,27 +92,39 @@ func (lr *leavesRetriever) getLeavesFromIterator(iterator common.DfsIterator, nu return leaves, nil, nil } - lr.manageIterators(iteratorId, iterator) + shouldReturnId := lr.manageIterators(iteratorId, iterator) + if !shouldReturnId { + return leaves, nil, nil + } return leaves, iteratorId, nil } -func (lr *leavesRetriever) manageIterators(iteratorId []byte, iterator common.DfsIterator) { +func (lr *leavesRetriever) manageIterators(iteratorId []byte, iterator common.DfsIterator) bool { lr.mutex.Lock() defer lr.mutex.Unlock() - lr.saveIterator(iteratorId, iterator) + newIteratorPresent := lr.saveIterator(iteratorId, iterator) + if !newIteratorPresent { + return false + } lr.removeIteratorsIfMaxSizeIsExceeded() + return true } -func (lr *leavesRetriever) saveIterator(iteratorId []byte, iterator common.DfsIterator) { +func (lr *leavesRetriever) saveIterator(iteratorId []byte, iterator common.DfsIterator) bool { _, isPresent := lr.iterators[string(iteratorId)] if isPresent { - return + return true + } + + if iterator.Size() >= lr.maxSize { + return false } lr.lruIteratorIDs = append(lr.lruIteratorIDs, iteratorId) lr.iterators[string(iteratorId)] = iterator lr.size += iterator.Size() + uint64(len(iteratorId)) + return true } func (lr *leavesRetriever) markIteratorAsRecentlyUsed(iteratorId []byte) { diff --git a/trie/leavesRetriever/leavesRetriever_test.go b/trie/leavesRetriever/leavesRetriever_test.go index 28dd6131475..1605aaf6fc4 100644 --- a/trie/leavesRetriever/leavesRetriever_test.go +++ b/trie/leavesRetriever/leavesRetriever_test.go @@ -202,6 +202,20 @@ func TestLeavesRetriever_GetLeaves(t *testing.T) { assert.Equal(t, 0, len(id)) assert.Equal(t, leavesRetriever.ErrIteratorNotFound, err) }) + t.Run("max size reached on the first iteration", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetTrieWithData() + rootHash, _ := tr.RootHash() + maxSize := uint64(100) + + lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) + leaves, id1, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) + assert.Nil(t, err) + assert.Equal(t, 2, len(leaves)) + assert.Equal(t, 0, len(id1)) + assert.Equal(t, 0, len(lr.GetIterators())) + }) } func TestLeavesRetriever_Concurrency(t *testing.T) { From 34cae3298b707683f52e149afc0b8ec472eb97d9 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 9 Jan 2025 15:39:29 +0200 Subject: [PATCH 2/8] add the trieLeavesRetriever to the stateComponents --- cmd/node/config/config.toml | 4 ++++ config/config.go | 19 +++++++++++------ errors/errors.go | 3 +++ factory/interface.go | 1 + factory/mock/stateComponentsHolderStub.go | 9 ++++++++ factory/state/stateComponents.go | 21 +++++++++++++++++++ factory/state/stateComponentsHandler.go | 15 +++++++++++++ .../components/stateComponents.go | 7 +++++++ testscommon/factory/stateComponentsMock.go | 7 +++++++ .../disabledLeavesRetriever.go | 20 ++++++++++++++++++ 10 files changed, 100 insertions(+), 6 deletions(-) create mode 100644 trie/leavesRetriever/disabledLeavesRetriever.go diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 6e1205d5f7e..7e40d31dbd8 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -663,6 +663,10 @@ MaxPeerTrieLevelInMemory = 5 StateStatisticsEnabled = false +[TrieLeavesRetrieverConfig] + Enabled = false + MaxSizeInBytes = 104857600 #100MB + [BlockSizeThrottleConfig] MinSizeInBytes = 104857 # 104857 is 10% from 1MB MaxSizeInBytes = 943718 # 943718 is 90% from 1MB diff --git a/config/config.go b/config/config.go index 49ef257c341..a9956dec5b2 100644 --- a/config/config.go +++ b/config/config.go @@ -161,12 +161,13 @@ type Config struct { BootstrapStorage StorageConfig MetaBlockStorage StorageConfig - AccountsTrieStorage StorageConfig - PeerAccountsTrieStorage StorageConfig - EvictionWaitingList EvictionWaitingListConfig - StateTriesConfig StateTriesConfig - TrieStorageManagerConfig TrieStorageManagerConfig - BadBlocksCache CacheConfig + AccountsTrieStorage StorageConfig + PeerAccountsTrieStorage StorageConfig + EvictionWaitingList EvictionWaitingListConfig + StateTriesConfig StateTriesConfig + TrieStorageManagerConfig TrieStorageManagerConfig + TrieLeavesRetrieverConfig TrieLeavesRetrieverConfig + BadBlocksCache CacheConfig TxBlockBodyDataPool CacheConfig PeerBlockBodyDataPool CacheConfig @@ -640,3 +641,9 @@ type PoolsCleanersConfig struct { type RedundancyConfig struct { MaxRoundsOfInactivityAccepted int } + +// TrieLeavesRetrieverConfig represents the config options to be used when setting up the trie leaves retriever +type TrieLeavesRetrieverConfig struct { + Enabled bool + MaxSizeInBytes uint64 +} diff --git a/errors/errors.go b/errors/errors.go index dd475327876..8071fffc219 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -598,3 +598,6 @@ var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") // ErrNilEpochSystemSCProcessor defines the error for setting a nil EpochSystemSCProcessor var ErrNilEpochSystemSCProcessor = errors.New("nil epoch system SC processor") + +// ErrNilTrieLeavesRetriever defines the error for setting a nil TrieLeavesRetriever +var ErrNilTrieLeavesRetriever = errors.New("nil trie leaves retriever") diff --git a/factory/interface.go b/factory/interface.go index 0f1c237d0d9..a452d640320 100644 --- a/factory/interface.go +++ b/factory/interface.go @@ -335,6 +335,7 @@ type StateComponentsHolder interface { TriesContainer() common.TriesHolder TrieStorageManagers() map[string]common.StorageManager MissingTrieNodesNotifier() common.MissingTrieNodesNotifier + TrieLeavesRetriever() common.TrieLeavesRetriever Close() error IsInterfaceNil() bool } diff --git a/factory/mock/stateComponentsHolderStub.go b/factory/mock/stateComponentsHolderStub.go index c851fdc6dac..e6b6b6b86cf 100644 --- a/factory/mock/stateComponentsHolderStub.go +++ b/factory/mock/stateComponentsHolderStub.go @@ -14,6 +14,7 @@ type StateComponentsHolderStub struct { TriesContainerCalled func() common.TriesHolder TrieStorageManagersCalled func() map[string]common.StorageManager MissingTrieNodesNotifierCalled func() common.MissingTrieNodesNotifier + TrieLeavesRetrieverCalled func() common.TrieLeavesRetriever } // PeerAccounts - @@ -79,6 +80,14 @@ func (s *StateComponentsHolderStub) MissingTrieNodesNotifier() common.MissingTri return nil } +// TrieLeavesRetriever - +func (s *StateComponentsHolderStub) TrieLeavesRetriever() common.TrieLeavesRetriever { + if s.TrieLeavesRetrieverCalled != nil { + return s.TrieLeavesRetrieverCalled() + } + return nil +} + // Close - func (s *StateComponentsHolderStub) Close() error { return nil diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index 8da3251e230..e09aae7b1c9 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/state/syncer" trieFactory "github.com/multiversx/mx-chain-go/trie/factory" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever" ) // TODO: merge this with data components @@ -53,6 +54,7 @@ type stateComponents struct { triesContainer common.TriesHolder trieStorageManagers map[string]common.StorageManager missingTrieNodesNotifier common.MissingTrieNodesNotifier + trieLeavesRetriever common.TrieLeavesRetriever } // NewStateComponentsFactory will return a new instance of stateComponentsFactory @@ -100,6 +102,11 @@ func (scf *stateComponentsFactory) Create() (*stateComponents, error) { return nil, err } + trieLeavesRetriever, err := scf.createTrieLeavesRetriever(trieStorageManagers[dataRetriever.UserAccountsUnit.String()]) + if err != nil { + return nil, err + } + return &stateComponents{ peerAccounts: peerAdapter, accountsAdapter: accountsAdapter, @@ -108,9 +115,23 @@ func (scf *stateComponentsFactory) Create() (*stateComponents, error) { triesContainer: triesContainer, trieStorageManagers: trieStorageManagers, missingTrieNodesNotifier: syncer.NewMissingTrieNodesNotifier(), + trieLeavesRetriever: trieLeavesRetriever, }, nil } +func (scf *stateComponentsFactory) createTrieLeavesRetriever(trieStorage common.TrieStorageInteractor) (common.TrieLeavesRetriever, error) { + if !scf.config.TrieLeavesRetrieverConfig.Enabled { + return leavesRetriever.NewDisabledLeavesRetriever(), nil + } + + return leavesRetriever.NewLeavesRetriever( + trieStorage, + scf.core.InternalMarshalizer(), + scf.core.Hasher(), + scf.config.TrieLeavesRetrieverConfig.MaxSizeInBytes, + ) +} + func (scf *stateComponentsFactory) createSnapshotManager( accountFactory state.AccountFactory, stateMetrics state.StateMetrics, diff --git a/factory/state/stateComponentsHandler.go b/factory/state/stateComponentsHandler.go index 78271a28ffe..e84c1f8b3b5 100644 --- a/factory/state/stateComponentsHandler.go +++ b/factory/state/stateComponentsHandler.go @@ -93,6 +93,9 @@ func (msc *managedStateComponents) CheckSubcomponents() error { if check.IfNil(msc.missingTrieNodesNotifier) { return errors.ErrNilMissingTrieNodesNotifier } + if check.IfNil(msc.trieLeavesRetriever) { + return errors.ErrNilTrieLeavesRetriever + } return nil } @@ -214,6 +217,18 @@ func (msc *managedStateComponents) MissingTrieNodesNotifier() common.MissingTrie return msc.stateComponents.missingTrieNodesNotifier } +// TrieLeavesRetriever returns the trie leaves retriever +func (msc *managedStateComponents) TrieLeavesRetriever() common.TrieLeavesRetriever { + msc.mutStateComponents.RLock() + defer msc.mutStateComponents.RUnlock() + + if msc.stateComponents == nil { + return nil + } + + return msc.stateComponents.trieLeavesRetriever +} + // IsInterfaceNil returns true if the interface is nil func (msc *managedStateComponents) IsInterfaceNil() bool { return msc == nil diff --git a/node/chainSimulator/components/stateComponents.go b/node/chainSimulator/components/stateComponents.go index b3fddf55f40..998263a8d7a 100644 --- a/node/chainSimulator/components/stateComponents.go +++ b/node/chainSimulator/components/stateComponents.go @@ -29,6 +29,7 @@ type stateComponentsHolder struct { triesContainer common.TriesHolder triesStorageManager map[string]common.StorageManager missingTrieNodesNotifier common.MissingTrieNodesNotifier + trieLeavesRetriever common.TrieLeavesRetriever stateComponentsCloser io.Closer } @@ -70,6 +71,7 @@ func CreateStateComponents(args ArgsStateComponents) (*stateComponentsHolder, er triesContainer: stateComp.TriesContainer(), triesStorageManager: stateComp.TrieStorageManagers(), missingTrieNodesNotifier: stateComp.MissingTrieNodesNotifier(), + trieLeavesRetriever: stateComp.TrieLeavesRetriever(), stateComponentsCloser: stateComp, }, nil } @@ -109,6 +111,11 @@ func (s *stateComponentsHolder) MissingTrieNodesNotifier() common.MissingTrieNod return s.missingTrieNodesNotifier } +// TrieLeavesRetriever will return the trie leaves retriever +func (s *stateComponentsHolder) TrieLeavesRetriever() common.TrieLeavesRetriever { + return s.trieLeavesRetriever +} + // Close will close the state components func (s *stateComponentsHolder) Close() error { return s.stateComponentsCloser.Close() diff --git a/testscommon/factory/stateComponentsMock.go b/testscommon/factory/stateComponentsMock.go index 5aa541dffa0..0adb3f3bc10 100644 --- a/testscommon/factory/stateComponentsMock.go +++ b/testscommon/factory/stateComponentsMock.go @@ -16,6 +16,7 @@ type StateComponentsMock struct { Tries common.TriesHolder StorageManagers map[string]common.StorageManager MissingNodesNotifier common.MissingTrieNodesNotifier + LeavesRetriever common.TrieLeavesRetriever } // NewStateComponentsMockFromRealComponent - @@ -28,6 +29,7 @@ func NewStateComponentsMockFromRealComponent(stateComponents factory.StateCompon Tries: stateComponents.TriesContainer(), StorageManagers: stateComponents.TrieStorageManagers(), MissingNodesNotifier: stateComponents.MissingTrieNodesNotifier(), + LeavesRetriever: stateComponents.TrieLeavesRetriever(), } } @@ -89,6 +91,11 @@ func (scm *StateComponentsMock) MissingTrieNodesNotifier() common.MissingTrieNod return scm.MissingNodesNotifier } +// TrieLeavesRetriever - +func (scm *StateComponentsMock) TrieLeavesRetriever() common.TrieLeavesRetriever { + return scm.LeavesRetriever +} + // IsInterfaceNil - func (scm *StateComponentsMock) IsInterfaceNil() bool { return scm == nil diff --git a/trie/leavesRetriever/disabledLeavesRetriever.go b/trie/leavesRetriever/disabledLeavesRetriever.go new file mode 100644 index 00000000000..b3143e377ff --- /dev/null +++ b/trie/leavesRetriever/disabledLeavesRetriever.go @@ -0,0 +1,20 @@ +package leavesRetriever + +import "context" + +type disabledLeavesRetriever struct{} + +// NewDisabledLeavesRetriever creates a new disabled leaves retriever +func NewDisabledLeavesRetriever() *disabledLeavesRetriever { + return &disabledLeavesRetriever{} +} + +// GetLeaves returns an empty map and a nil byte slice for this implementation +func (dlr *disabledLeavesRetriever) GetLeaves(_ int, _ []byte, _ []byte, _ context.Context) (map[string]string, []byte, error) { + return make(map[string]string), []byte{}, nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (dlr *disabledLeavesRetriever) IsInterfaceNil() bool { + return dlr == nil +} From 8bb0a039c2b6c87633c8ea49a274daf931583800 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Thu, 9 Jan 2025 16:58:08 +0200 Subject: [PATCH 3/8] add new API endpoint which uses the leavesRetriever --- api/errors/errors.go | 9 ++++ api/groups/addressGroup.go | 51 ++++++++++++++++++- api/mock/facadeStub.go | 10 ++++ api/shared/interface.go | 1 + facade/initial/initialNodeFacade.go | 5 ++ facade/interface.go | 3 ++ facade/mock/nodeStub.go | 10 ++++ facade/nodeFacade.go | 9 ++++ integrationTests/interface.go | 1 + node/node.go | 34 +++++++++++++ .../dfsTrieIterator/dfsTrieIterator.go | 3 +- 11 files changed, 134 insertions(+), 2 deletions(-) diff --git a/api/errors/errors.go b/api/errors/errors.go index 3f4e495b9d2..104413cb682 100644 --- a/api/errors/errors.go +++ b/api/errors/errors.go @@ -28,6 +28,9 @@ var ErrGetValueForKey = errors.New("get value for key error") // ErrGetKeyValuePairs signals an error in getting the key-value pairs of a key for an account var ErrGetKeyValuePairs = errors.New("get key-value pairs error") +// ErrGetKeyValuePairsWithCheckpoint signals an error in getting the key-value pairs of a key for an account with a checkpoint +var ErrGetKeyValuePairsWithCheckpoint = errors.New("get key-value pairs with checkpoint error") + // ErrGetESDTBalance signals an error in getting esdt balance for given address var ErrGetESDTBalance = errors.New("get esdt balance for account error") @@ -43,6 +46,12 @@ var ErrGetESDTNFTData = errors.New("get esdt nft data for account error") // ErrEmptyAddress signals that an empty address was provided var ErrEmptyAddress = errors.New("address is empty") +// ErrEmptyNumKeys signals that an empty numKeys was provided +var ErrEmptyNumKeys = errors.New("numKeys is empty") + +// ErrEmptyCheckpointId signals that an empty checkpointId was provided +var ErrEmptyCheckpointId = errors.New("checkpointId is empty") + // ErrEmptyKey signals that an empty key was provided var ErrEmptyKey = errors.New("key is empty") diff --git a/api/groups/addressGroup.go b/api/groups/addressGroup.go index 151b7f53372..a91a62f6757 100644 --- a/api/groups/addressGroup.go +++ b/api/groups/addressGroup.go @@ -5,6 +5,7 @@ import ( "fmt" "math/big" "net/http" + "strconv" "sync" "github.com/gin-gonic/gin" @@ -23,6 +24,7 @@ const ( getUsernamePath = "/:address/username" getCodeHashPath = "/:address/code-hash" getKeysPath = "/:address/keys" + getKeysWithCheckpointPath = "/:address/num-keys/:numKeys/checkpoint-id/:checkpointId" getKeyPath = "/:address/key/:key" getDataTrieMigrationStatusPath = "/:address/is-data-trie-migrated" getESDTTokensPath = "/:address/esdt" @@ -55,6 +57,7 @@ type addressFacadeHandler interface { GetESDTsWithRole(address string, role string, options api.AccountQueryOptions) ([]string, api.BlockInfo, error) GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) IsInterfaceNil() bool @@ -134,6 +137,11 @@ func NewAddressGroup(facade addressFacadeHandler) (*addressGroup, error) { Method: http.MethodGet, Handler: ag.getKeyValuePairs, }, + { + Path: getKeysWithCheckpointPath, + Method: http.MethodGet, + Handler: ag.getKeyValuePairsWithCheckpoint, + }, { Path: getESDTBalancePath, Method: http.MethodGet, @@ -327,7 +335,7 @@ func (ag *addressGroup) getGuardianData(c *gin.Context) { shared.RespondWithSuccess(c, gin.H{"guardianData": guardianData, "blockInfo": blockInfo}) } -// addressGroup returns all the key-value pairs for the given address +// getKeyValuePairs returns all the key-value pairs for the given address func (ag *addressGroup) getKeyValuePairs(c *gin.Context) { addr, options, err := extractBaseParams(c) if err != nil { @@ -344,6 +352,47 @@ func (ag *addressGroup) getKeyValuePairs(c *gin.Context) { shared.RespondWithSuccess(c, gin.H{"pairs": value, "blockInfo": blockInfo}) } +// getKeysWithCheckpoint returns all the key-value pairs for the given address +func (ag *addressGroup) getKeyValuePairsWithCheckpoint(c *gin.Context) { + addr := c.Param("address") + if addr == "" { + shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyAddress) + return + } + + options, err := extractAccountQueryOptions(c) + if err != nil { + shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) + return + } + + numLeavesAsString := c.Param("num-keys") + if numLeavesAsString == "" { + shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyNumKeys) + return + } + + numLeaves, err := strconv.Atoi(numLeavesAsString) + if err != nil { + shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) + return + } + + checkpointId := c.Param("checkpoint-id") + if checkpointId == "" { + shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyCheckpointId) + return + } + + value, blockInfo, newCheckpointId, err := ag.getFacade().GetKeyValuePairsWithCheckpoint(addr, checkpointId, numLeaves, options) + if err != nil { + shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairs, err) + return + } + + shared.RespondWithSuccess(c, gin.H{"pairs": value, "newCheckpointId": newCheckpointId, "blockInfo": blockInfo}) +} + // getESDTBalance returns the balance for the given address and esdt token func (ag *addressGroup) getESDTBalance(c *gin.Context) { addr, tokenIdentifier, options, err := extractGetESDTBalanceParams(c) diff --git a/api/mock/facadeStub.go b/api/mock/facadeStub.go index 62de2febc81..c471ccf21c2 100644 --- a/api/mock/facadeStub.go +++ b/api/mock/facadeStub.go @@ -49,6 +49,7 @@ type FacadeStub struct { GetUsernameCalled func(address string, options api.AccountQueryOptions) (string, api.BlockInfo, error) GetCodeHashCalled func(address string, options api.AccountQueryOptions) ([]byte, api.BlockInfo, error) GetKeyValuePairsCalled func(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + GetKeyValuePairsWithCheckpointCalled func(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) SimulateTransactionExecutionHandler func(tx *transaction.Transaction) (*txSimData.SimulationResultsWithVMOutput, error) GetESDTDataCalled func(address string, key string, nonce uint64, options api.AccountQueryOptions) (*esdt.ESDigitalToken, api.BlockInfo, error) GetAllESDTTokensCalled func(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) @@ -241,6 +242,15 @@ func (f *FacadeStub) GetKeyValuePairs(address string, options api.AccountQueryOp return nil, api.BlockInfo{}, nil } +// GetKeyValuePairsWithCheckpoint - +func (f *FacadeStub) GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) { + if f.GetKeyValuePairsWithCheckpointCalled != nil { + return f.GetKeyValuePairsWithCheckpointCalled(address, checkpointId, numLeaves, options) + } + + return nil, api.BlockInfo{}, "", nil +} + // GetGuardianData - func (f *FacadeStub) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) { if f.GetGuardianDataCalled != nil { diff --git a/api/shared/interface.go b/api/shared/interface.go index 206cea6ee30..56a1cc70e19 100644 --- a/api/shared/interface.go +++ b/api/shared/interface.go @@ -74,6 +74,7 @@ type FacadeHandler interface { GetESDTsWithRole(address string, role string, options api.AccountQueryOptions) ([]string, api.BlockInfo, error) GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) GetBlockByHash(hash string, options api.BlockQueryOptions) (*api.Block, error) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*api.Block, error) diff --git a/facade/initial/initialNodeFacade.go b/facade/initial/initialNodeFacade.go index d6043dbcd62..626f77db816 100644 --- a/facade/initial/initialNodeFacade.go +++ b/facade/initial/initialNodeFacade.go @@ -346,6 +346,11 @@ func (inf *initialNodeFacade) GetKeyValuePairs(_ string, _ api.AccountQueryOptio return nil, api.BlockInfo{}, errNodeStarting } +// GetKeyValuePairsWithCheckpoint returns error +func (inf *initialNodeFacade) GetKeyValuePairsWithCheckpoint(_ string, _ string, _ int, _ api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) { + return nil, api.BlockInfo{}, "", errNodeStarting +} + // GetGuardianData returns error func (inf *initialNodeFacade) GetGuardianData(_ string, _ api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) { return api.GuardianData{}, api.BlockInfo{}, errNodeStarting diff --git a/facade/interface.go b/facade/interface.go index 309f6c98d6f..413389cb1be 100644 --- a/facade/interface.go +++ b/facade/interface.go @@ -41,6 +41,9 @@ type NodeHandler interface { // GetKeyValuePairs returns the key-value pairs under a given address GetKeyValuePairs(address string, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, error) + // GetKeyValuePairsWithCheckpoint returns the key-value pairs under a given address with a checkpoint + GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, string, error) + // GetAllIssuedESDTs returns all the issued esdt tokens from esdt system smart contract GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]string, error) diff --git a/facade/mock/nodeStub.go b/facade/mock/nodeStub.go index 1e779e0ebce..a9f289a6ff8 100644 --- a/facade/mock/nodeStub.go +++ b/facade/mock/nodeStub.go @@ -49,6 +49,7 @@ type NodeStub struct { GetESDTsWithRoleCalled func(address string, role string, options api.AccountQueryOptions, ctx context.Context) ([]string, api.BlockInfo, error) GetESDTsRolesCalled func(address string, options api.AccountQueryOptions, ctx context.Context) (map[string][]string, api.BlockInfo, error) GetKeyValuePairsCalled func(address string, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, error) + GetKeyValuePairsWithCheckpointCalled func(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, string, error) GetAllIssuedESDTsCalled func(tokenType string, ctx context.Context) ([]string, error) GetProofCalled func(rootHash string, key string) (*common.GetProofResponse, error) GetProofDataTrieCalled func(rootHash string, address string, key string) (*common.GetProofResponse, *common.GetProofResponse, error) @@ -112,6 +113,15 @@ func (ns *NodeStub) GetKeyValuePairs(address string, options api.AccountQueryOpt return nil, api.BlockInfo{}, nil } +// GetKeyValuePairsWithCheckpoint - +func (ns *NodeStub) GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, string, error) { + if ns.GetKeyValuePairsWithCheckpointCalled != nil { + return ns.GetKeyValuePairsWithCheckpointCalled(address, checkpointId, numLeaves, options, ctx) + } + + return nil, api.BlockInfo{}, "", nil +} + // GetValueForKey - func (ns *NodeStub) GetValueForKey(address string, key string, options api.AccountQueryOptions) (string, api.BlockInfo, error) { if ns.GetValueForKeyCalled != nil { diff --git a/facade/nodeFacade.go b/facade/nodeFacade.go index c3a7f290edf..89ab8c99813 100644 --- a/facade/nodeFacade.go +++ b/facade/nodeFacade.go @@ -229,6 +229,15 @@ func (nf *nodeFacade) GetKeyValuePairs(address string, options apiData.AccountQu return nf.node.GetKeyValuePairs(address, options, ctx) } +// GetKeyValuePairsWithCheckpoint returns the given number of key-value pairs under the provided address. +// The iteration starts from the given checkpoint, and returns a new checkpoint for the next iteration. +func (nf *nodeFacade) GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options apiData.AccountQueryOptions) (map[string]string, apiData.BlockInfo, string, error) { + ctx, cancel := nf.getContextForApiTrieRangeOperations() + defer cancel() + + return nf.node.GetKeyValuePairsWithCheckpoint(address, checkpointId, numLeaves, options, ctx) +} + // GetGuardianData returns the guardian data for the provided address func (nf *nodeFacade) GetGuardianData(address string, options apiData.AccountQueryOptions) (apiData.GuardianData, apiData.BlockInfo, error) { return nf.node.GetGuardianData(address, options) diff --git a/integrationTests/interface.go b/integrationTests/interface.go index ad90ffbb6a3..c63438fbc15 100644 --- a/integrationTests/interface.go +++ b/integrationTests/interface.go @@ -69,6 +69,7 @@ type Facade interface { GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetESDTsRoles(address string, options api.AccountQueryOptions) (map[string][]string, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) GetBlockByHash(hash string, options api.BlockQueryOptions) (*dataApi.Block, error) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*dataApi.Block, error) diff --git a/node/node.go b/node/node.go index a652e80be60..72731c5bf78 100644 --- a/node/node.go +++ b/node/node.go @@ -308,6 +308,40 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, return mapToReturn, blockInfo, nil } +// GetKeyValuePairsWithCheckpoint returns the given number of key-value pairs under the provided address. +// The iteration starts from the given checkpoint, and returns a new checkpoint for the next iteration. +func (n *Node) GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, string, error) { + userAccount, blockInfo, err := n.loadUserAccountHandlerByAddress(address, options) + if err != nil { + adaptedBlockInfo, isEmptyAccount := extractBlockInfoIfNewAccount(err) + if isEmptyAccount { + return make(map[string]string), adaptedBlockInfo, "", nil + } + + return nil, api.BlockInfo{}, "", err + } + + if check.IfNil(userAccount.DataTrie()) { + return map[string]string{}, blockInfo, "", nil + } + + checkpointIdBytes, err := hex.DecodeString(checkpointId) + if err != nil { + return nil, api.BlockInfo{}, "", fmt.Errorf("invalid checkpointId: %w", err) + } + + mapToReturn, newCheckpoint, err := n.stateComponents.TrieLeavesRetriever().GetLeaves(numLeaves, userAccount.GetRootHash(), checkpointIdBytes, ctx) + if err != nil { + return nil, api.BlockInfo{}, "", err + } + + if common.IsContextDone(ctx) { + return nil, api.BlockInfo{}, "", ErrTrieOperationsTimeout + } + + return mapToReturn, blockInfo, hex.EncodeToString(newCheckpoint), nil +} + func (n *Node) getKeys(userAccount state.UserAccountHandler, ctx context.Context) (map[string]string, error) { chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go index 2224416e282..7932e9a2ce4 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go @@ -2,6 +2,7 @@ package dfsTrieIterator import ( "context" + "encoding/hex" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" @@ -88,7 +89,7 @@ func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Cont return nil, err } - retrievedLeaves[string(key)] = string(childNode.GetData()) + retrievedLeaves[hex.EncodeToString(key)] = hex.EncodeToString(childNode.GetData()) continue } From 54b255eb9c2be9d88c509dcdeefb51cc181b9a9b Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Mon, 13 Jan 2025 10:41:50 +0200 Subject: [PATCH 4/8] add logging --- cmd/node/config/api.toml | 3 +++ trie/leavesRetriever/leavesRetriever.go | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/cmd/node/config/api.toml b/cmd/node/config/api.toml index fcf9cf7fc0b..af7ccd8a877 100644 --- a/cmd/node/config/api.toml +++ b/cmd/node/config/api.toml @@ -79,6 +79,9 @@ # /address/:address/keys will return all the key-value pairs of a given account { Name = "/:address/keys", Open = true }, + # address/:address/num-keys/:numKeys/checkpoint-id/:checkpointId will return the given num of key-value pairs for the given account + { Name = "/:address/num-keys/:numKeys/checkpoint-id/:checkpointId", Open = true }, + # /address/:address/key/:key will return the value of a key for a given account { Name = "/:address/key/:key", Open = true }, diff --git a/trie/leavesRetriever/leavesRetriever.go b/trie/leavesRetriever/leavesRetriever.go index 5630a3ce7e1..544f9a0ff79 100644 --- a/trie/leavesRetriever/leavesRetriever.go +++ b/trie/leavesRetriever/leavesRetriever.go @@ -9,8 +9,11 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/trie/leavesRetriever/dfsTrieIterator" + logger "github.com/multiversx/mx-chain-logger-go" ) +var log = logger.GetOrCreate("trie/leavesRetriever") + type leavesRetriever struct { iterators map[string]common.DfsIterator lruIteratorIDs [][]byte @@ -47,6 +50,7 @@ func NewLeavesRetriever(db common.TrieStorageInteractor, marshaller marshal.Mars // GetLeaves retrieves the leaves from the trie. If there is a saved checkpoint for the iterator id, it will continue to iterate from the checkpoint. func (lr *leavesRetriever) GetLeaves(numLeaves int, rootHash []byte, iteratorID []byte, ctx context.Context) (map[string]string, []byte, error) { + defer log.Trace("leaves retriever stats", "size", lr.size, "numIterators", len(lr.iterators)) if len(iteratorID) == 0 { return lr.getLeavesFromNewInstance(numLeaves, rootHash, ctx) } @@ -62,6 +66,7 @@ func (lr *leavesRetriever) GetLeaves(numLeaves int, rootHash []byte, iteratorID } func (lr *leavesRetriever) getLeavesFromNewInstance(numLeaves int, rootHash []byte, ctx context.Context) (map[string]string, []byte, error) { + log.Trace("get leaves from new instance", "numLeaves", numLeaves, "rootHash", rootHash) iterator, err := dfsTrieIterator.NewIterator(rootHash, lr.db, lr.marshaller, lr.hasher) if err != nil { return nil, nil, err @@ -71,6 +76,7 @@ func (lr *leavesRetriever) getLeavesFromNewInstance(numLeaves int, rootHash []by } func (lr *leavesRetriever) getLeavesFromCheckpoint(numLeaves int, iterator common.DfsIterator, iteratorID []byte, ctx context.Context) (map[string]string, []byte, error) { + log.Trace("get leaves from checkpoint", "numLeaves", numLeaves, "iteratorID", iteratorID) lr.markIteratorAsRecentlyUsed(iteratorID) clonedIterator := iterator.Clone() From 143cc00a721d7a515836d18a400cfee9e5ef8d72 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Mon, 13 Jan 2025 14:04:49 +0200 Subject: [PATCH 5/8] added unit tests and small fixes --- api/groups/addressGroup.go | 19 +++----- api/groups/addressGroup_test.go | 78 +++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 12 deletions(-) diff --git a/api/groups/addressGroup.go b/api/groups/addressGroup.go index a91a62f6757..3e577810adc 100644 --- a/api/groups/addressGroup.go +++ b/api/groups/addressGroup.go @@ -356,37 +356,32 @@ func (ag *addressGroup) getKeyValuePairs(c *gin.Context) { func (ag *addressGroup) getKeyValuePairsWithCheckpoint(c *gin.Context) { addr := c.Param("address") if addr == "" { - shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyAddress) + shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyAddress) return } options, err := extractAccountQueryOptions(c) if err != nil { - shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) + shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) return } - numLeavesAsString := c.Param("num-keys") + numLeavesAsString := c.Param("numKeys") if numLeavesAsString == "" { - shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyNumKeys) + shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyNumKeys) return } numLeaves, err := strconv.Atoi(numLeavesAsString) if err != nil { - shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) - return - } - - checkpointId := c.Param("checkpoint-id") - if checkpointId == "" { - shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyCheckpointId) + shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) return } + checkpointId := c.Param("checkpointId") value, blockInfo, newCheckpointId, err := ag.getFacade().GetKeyValuePairsWithCheckpoint(addr, checkpointId, numLeaves, options) if err != nil { - shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairs, err) + shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) return } diff --git a/api/groups/addressGroup_test.go b/api/groups/addressGroup_test.go index bb19bb81d2c..8e0b8d8bc61 100644 --- a/api/groups/addressGroup_test.go +++ b/api/groups/addressGroup_test.go @@ -125,6 +125,16 @@ type keyValuePairsResponse struct { Code string } +type keyValuePairsWithCheckpointResponseData struct { + Pairs map[string]string `json:"pairs"` + NewCheckpointId string `json:"newCheckpointId"` +} +type keyValuePairsWithCheckpointResponse struct { + Data keyValuePairsWithCheckpointResponseData `json:"data"` + Error string `json:"error"` + Code string +} + type esdtRolesResponseData struct { Roles map[string][]string `json:"roles"` } @@ -662,6 +672,73 @@ func TestAddressGroup_getKeyValuePairs(t *testing.T) { }) } +func TestAddressGroup_getKeyValuePairsWithCheckpoint(t *testing.T) { + t.Parallel() + + t.Run("empty address should error", + testErrorScenario("/address//num-keys/10/checkpoint-id/abc", "GET", nil, + formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, apiErrors.ErrEmptyAddress))) + t.Run("invalid query options should error", + testErrorScenario("/address/erd1alice/num-keys/10/checkpoint-id/abc?blockNonce=not-uint64", "GET", nil, + formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, apiErrors.ErrBadUrlParams))) + t.Run("empty num-keys should error", + testErrorScenario("/address/erd1alice/num-keys//checkpoint-id/abc", "GET", nil, + formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, apiErrors.ErrEmptyNumKeys))) + t.Run("invalid num-keys should error", + testErrorScenario("/address/erd1alice/num-keys/not-uint64/checkpoint-id/abc", "GET", nil, + formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, errors.New("strconv.Atoi: parsing \"not-uint64\": invalid syntax")))) + t.Run("with node fail should err", func(t *testing.T) { + t.Parallel() + + facade := &mock.FacadeStub{ + GetKeyValuePairsWithCheckpointCalled: func(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) { + return nil, api.BlockInfo{}, "", expectedErr + }, + } + testAddressGroup( + t, + facade, + "/address/erd1alice/num-keys/10/checkpoint-id/abc", + "GET", + nil, + http.StatusInternalServerError, + formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, expectedErr), + ) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + pairs := map[string]string{ + "k1": "v1", + "k2": "v2", + } + originalCheckpointId := "abc" + newCheckpointId := "def" + numKeys := "10" + addr := "erd1alice" + facade := &mock.FacadeStub{ + GetKeyValuePairsWithCheckpointCalled: func(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) { + assert.Equal(t, addr, address) + assert.Equal(t, 10, numLeaves) + assert.Equal(t, originalCheckpointId, checkpointId) + return pairs, api.BlockInfo{}, newCheckpointId, nil + }, + } + + response := &keyValuePairsWithCheckpointResponse{} + loadAddressGroupResponse( + t, + facade, + "/address/"+addr+"/num-keys/"+numKeys+"/checkpoint-id/"+originalCheckpointId, + "GET", + nil, + response, + ) + assert.Equal(t, pairs, response.Data.Pairs) + assert.Equal(t, newCheckpointId, response.Data.NewCheckpointId) + }) +} + func TestAddressGroup_getESDTBalance(t *testing.T) { t.Parallel() @@ -1143,6 +1220,7 @@ func getAddressRoutesConfig() config.ApiRoutesConfig { {Name: "/:address/username", Open: true}, {Name: "/:address/code-hash", Open: true}, {Name: "/:address/keys", Open: true}, + {Name: "/:address/num-keys/:numKeys/checkpoint-id/:checkpointId", Open: true}, {Name: "/:address/key/:key", Open: true}, {Name: "/:address/esdt", Open: true}, {Name: "/:address/esdts/roles", Open: true}, From 65d9e3d0086403912b671a8df001dbdf3a7a4b0f Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 15 Jan 2025 12:07:30 +0200 Subject: [PATCH 6/8] refactor: make leaves retriever stateless --- api/errors/errors.go | 4 +- api/groups/addressGroup.go | 52 ++-- api/groups/addressGroup_test.go | 113 +++++--- api/mock/facadeStub.go | 12 +- api/shared/interface.go | 2 +- cmd/node/config/api.toml | 4 +- cmd/node/config/config.toml | 2 +- common/interface.go | 8 +- facade/initial/initialNodeFacade.go | 6 +- facade/interface.go | 4 +- facade/mock/nodeStub.go | 12 +- facade/nodeFacade.go | 7 +- integrationTests/interface.go | 2 +- node/node.go | 26 +- trie/errors.go | 6 + trie/keyBuilder/disabledKeyBuilder.go | 5 + trie/keyBuilder/keyBuilder.go | 5 + .../dfsTrieIterator/dfsTrieIterator.go | 75 ++++-- .../dfsTrieIterator/dfsTrieIterator_test.go | 141 +++++----- .../disabledLeavesRetriever.go | 4 +- trie/leavesRetriever/export_test.go | 30 +-- trie/leavesRetriever/leavesRetriever.go | 146 +---------- trie/leavesRetriever/leavesRetriever_test.go | 244 ++---------------- trie/mock/keyBuilderStub.go | 10 + 24 files changed, 328 insertions(+), 592 deletions(-) diff --git a/api/errors/errors.go b/api/errors/errors.go index 104413cb682..88ebeeec1c2 100644 --- a/api/errors/errors.go +++ b/api/errors/errors.go @@ -28,8 +28,8 @@ var ErrGetValueForKey = errors.New("get value for key error") // ErrGetKeyValuePairs signals an error in getting the key-value pairs of a key for an account var ErrGetKeyValuePairs = errors.New("get key-value pairs error") -// ErrGetKeyValuePairsWithCheckpoint signals an error in getting the key-value pairs of a key for an account with a checkpoint -var ErrGetKeyValuePairsWithCheckpoint = errors.New("get key-value pairs with checkpoint error") +// ErrIterateKeys signals an error in iterating over the keys of an account +var ErrIterateKeys = errors.New("iterate keys error") // ErrGetESDTBalance signals an error in getting esdt balance for given address var ErrGetESDTBalance = errors.New("get esdt balance for account error") diff --git a/api/groups/addressGroup.go b/api/groups/addressGroup.go index 3e577810adc..a9a15957328 100644 --- a/api/groups/addressGroup.go +++ b/api/groups/addressGroup.go @@ -5,7 +5,6 @@ import ( "fmt" "math/big" "net/http" - "strconv" "sync" "github.com/gin-gonic/gin" @@ -24,7 +23,6 @@ const ( getUsernamePath = "/:address/username" getCodeHashPath = "/:address/code-hash" getKeysPath = "/:address/keys" - getKeysWithCheckpointPath = "/:address/num-keys/:numKeys/checkpoint-id/:checkpointId" getKeyPath = "/:address/key/:key" getDataTrieMigrationStatusPath = "/:address/is-data-trie-migrated" getESDTTokensPath = "/:address/esdt" @@ -34,6 +32,7 @@ const ( getRegisteredNFTsPath = "/:address/registered-nfts" getESDTNFTDataPath = "/:address/nft/:tokenIdentifier/nonce/:nonce" getGuardianData = "/:address/guardian-data" + iterateKeysPath = "/iterate-keys" urlParamOnFinalBlock = "onFinalBlock" urlParamOnStartOfEpoch = "onStartOfEpoch" urlParamBlockNonce = "blockNonce" @@ -57,7 +56,7 @@ type addressFacadeHandler interface { GetESDTsWithRole(address string, role string, options api.AccountQueryOptions) ([]string, api.BlockInfo, error) GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) - GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) IsInterfaceNil() bool @@ -138,9 +137,9 @@ func NewAddressGroup(facade addressFacadeHandler) (*addressGroup, error) { Handler: ag.getKeyValuePairs, }, { - Path: getKeysWithCheckpointPath, - Method: http.MethodGet, - Handler: ag.getKeyValuePairsWithCheckpoint, + Path: iterateKeysPath, + Method: http.MethodPost, + Handler: ag.iterateKeys, }, { Path: getESDTBalancePath, @@ -352,40 +351,45 @@ func (ag *addressGroup) getKeyValuePairs(c *gin.Context) { shared.RespondWithSuccess(c, gin.H{"pairs": value, "blockInfo": blockInfo}) } -// getKeysWithCheckpoint returns all the key-value pairs for the given address -func (ag *addressGroup) getKeyValuePairsWithCheckpoint(c *gin.Context) { - addr := c.Param("address") - if addr == "" { - shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyAddress) - return - } +// IterateKeysRequest defines the request structure for iterating keys +type IterateKeysRequest struct { + Address string `json:"address"` + NumKeys uint `json:"numKeys"` + IteratorState [][]byte `json:"iteratorState"` +} - options, err := extractAccountQueryOptions(c) +// iterateKeys iterates keys for the given address +func (ag *addressGroup) iterateKeys(c *gin.Context) { + var iterateKeysRequest = &IterateKeysRequest{} + err := c.ShouldBindJSON(&iterateKeysRequest) if err != nil { - shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) + shared.RespondWithValidationError(c, errors.ErrValidation, err) return } - numLeavesAsString := c.Param("numKeys") - if numLeavesAsString == "" { - shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairsWithCheckpoint, errors.ErrEmptyNumKeys) + if len(iterateKeysRequest.Address) == 0 { + shared.RespondWithValidationError(c, errors.ErrValidation, errors.ErrEmptyAddress) return } - numLeaves, err := strconv.Atoi(numLeavesAsString) + options, err := extractAccountQueryOptions(c) if err != nil { - shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) + shared.RespondWithValidationError(c, errors.ErrIterateKeys, err) return } - checkpointId := c.Param("checkpointId") - value, blockInfo, newCheckpointId, err := ag.getFacade().GetKeyValuePairsWithCheckpoint(addr, checkpointId, numLeaves, options) + value, newIteratorState, blockInfo, err := ag.getFacade().IterateKeys( + iterateKeysRequest.Address, + iterateKeysRequest.NumKeys, + iterateKeysRequest.IteratorState, + options, + ) if err != nil { - shared.RespondWithInternalError(c, errors.ErrGetKeyValuePairsWithCheckpoint, err) + shared.RespondWithInternalError(c, errors.ErrIterateKeys, err) return } - shared.RespondWithSuccess(c, gin.H{"pairs": value, "newCheckpointId": newCheckpointId, "blockInfo": blockInfo}) + shared.RespondWithSuccess(c, gin.H{"pairs": value, "newIteratorState": newIteratorState, "blockInfo": blockInfo}) } // getESDTBalance returns the balance for the given address and esdt token diff --git a/api/groups/addressGroup_test.go b/api/groups/addressGroup_test.go index 8e0b8d8bc61..03f4a1c5088 100644 --- a/api/groups/addressGroup_test.go +++ b/api/groups/addressGroup_test.go @@ -125,13 +125,13 @@ type keyValuePairsResponse struct { Code string } -type keyValuePairsWithCheckpointResponseData struct { - Pairs map[string]string `json:"pairs"` - NewCheckpointId string `json:"newCheckpointId"` +type iterateKeysResponseData struct { + Pairs map[string]string `json:"pairs"` + NewIteratorState [][]byte `json:"newIteratorState"` } -type keyValuePairsWithCheckpointResponse struct { - Data keyValuePairsWithCheckpointResponseData `json:"data"` - Error string `json:"error"` +type iterateKeysResponse struct { + Data iterateKeysResponseData `json:"data"` + Error string `json:"error"` Code string } @@ -672,37 +672,66 @@ func TestAddressGroup_getKeyValuePairs(t *testing.T) { }) } -func TestAddressGroup_getKeyValuePairsWithCheckpoint(t *testing.T) { +func TestAddressGroup_iterateKeys(t *testing.T) { t.Parallel() - t.Run("empty address should error", - testErrorScenario("/address//num-keys/10/checkpoint-id/abc", "GET", nil, - formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, apiErrors.ErrEmptyAddress))) - t.Run("invalid query options should error", - testErrorScenario("/address/erd1alice/num-keys/10/checkpoint-id/abc?blockNonce=not-uint64", "GET", nil, - formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, apiErrors.ErrBadUrlParams))) - t.Run("empty num-keys should error", - testErrorScenario("/address/erd1alice/num-keys//checkpoint-id/abc", "GET", nil, - formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, apiErrors.ErrEmptyNumKeys))) - t.Run("invalid num-keys should error", - testErrorScenario("/address/erd1alice/num-keys/not-uint64/checkpoint-id/abc", "GET", nil, - formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, errors.New("strconv.Atoi: parsing \"not-uint64\": invalid syntax")))) + t.Run("invalid body should error", + testErrorScenario("/address/iterate-keys", "POST", bytes.NewBuffer([]byte("invalid body")), + formatExpectedErr(apiErrors.ErrValidation, errors.New("invalid character 'i' looking for beginning of value")))) + t.Run("empty address should error", func(t *testing.T) { + t.Parallel() + + body := &groups.IterateKeysRequest{ + Address: "", + } + bodyBytes, _ := json.Marshal(body) + testAddressGroup( + t, + &mock.FacadeStub{}, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), + http.StatusBadRequest, + formatExpectedErr(apiErrors.ErrValidation, apiErrors.ErrEmptyAddress), + ) + }) + t.Run("invalid query options should error", func(t *testing.T) { + t.Parallel() + + body := &groups.IterateKeysRequest{ + Address: "erd1", + } + bodyBytes, _ := json.Marshal(body) + testAddressGroup( + t, + &mock.FacadeStub{}, + "/address/iterate-keys?blockNonce=not-uint64", + "POST", + bytes.NewBuffer(bodyBytes), + http.StatusBadRequest, + formatExpectedErr(apiErrors.ErrIterateKeys, apiErrors.ErrBadUrlParams), + ) + }) t.Run("with node fail should err", func(t *testing.T) { t.Parallel() + body := &groups.IterateKeysRequest{ + Address: "erd1", + } + bodyBytes, _ := json.Marshal(body) facade := &mock.FacadeStub{ - GetKeyValuePairsWithCheckpointCalled: func(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) { - return nil, api.BlockInfo{}, "", expectedErr + IterateKeysCalled: func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + return nil, nil, api.BlockInfo{}, expectedErr }, } testAddressGroup( t, facade, - "/address/erd1alice/num-keys/10/checkpoint-id/abc", - "GET", - nil, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), http.StatusInternalServerError, - formatExpectedErr(apiErrors.ErrGetKeyValuePairsWithCheckpoint, expectedErr), + formatExpectedErr(apiErrors.ErrIterateKeys, expectedErr), ) }) t.Run("should work", func(t *testing.T) { @@ -712,30 +741,34 @@ func TestAddressGroup_getKeyValuePairsWithCheckpoint(t *testing.T) { "k1": "v1", "k2": "v2", } - originalCheckpointId := "abc" - newCheckpointId := "def" - numKeys := "10" - addr := "erd1alice" + + body := &groups.IterateKeysRequest{ + Address: "erd1", + NumKeys: 10, + IteratorState: [][]byte{[]byte("starting"), []byte("state")}, + } + newIteratorState := [][]byte{[]byte("new"), []byte("state")} + bodyBytes, _ := json.Marshal(body) facade := &mock.FacadeStub{ - GetKeyValuePairsWithCheckpointCalled: func(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) { - assert.Equal(t, addr, address) - assert.Equal(t, 10, numLeaves) - assert.Equal(t, originalCheckpointId, checkpointId) - return pairs, api.BlockInfo{}, newCheckpointId, nil + IterateKeysCalled: func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + assert.Equal(t, body.Address, address) + assert.Equal(t, body.NumKeys, numKeys) + assert.Equal(t, body.IteratorState, iteratorState) + return pairs, newIteratorState, api.BlockInfo{}, nil }, } - response := &keyValuePairsWithCheckpointResponse{} + response := &iterateKeysResponse{} loadAddressGroupResponse( t, facade, - "/address/"+addr+"/num-keys/"+numKeys+"/checkpoint-id/"+originalCheckpointId, - "GET", - nil, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), response, ) assert.Equal(t, pairs, response.Data.Pairs) - assert.Equal(t, newCheckpointId, response.Data.NewCheckpointId) + assert.Equal(t, newIteratorState, response.Data.NewIteratorState) }) } @@ -1220,7 +1253,7 @@ func getAddressRoutesConfig() config.ApiRoutesConfig { {Name: "/:address/username", Open: true}, {Name: "/:address/code-hash", Open: true}, {Name: "/:address/keys", Open: true}, - {Name: "/:address/num-keys/:numKeys/checkpoint-id/:checkpointId", Open: true}, + {Name: "/iterate-keys", Open: true}, {Name: "/:address/key/:key", Open: true}, {Name: "/:address/esdt", Open: true}, {Name: "/:address/esdts/roles", Open: true}, diff --git a/api/mock/facadeStub.go b/api/mock/facadeStub.go index c471ccf21c2..94bc0551c76 100644 --- a/api/mock/facadeStub.go +++ b/api/mock/facadeStub.go @@ -49,7 +49,7 @@ type FacadeStub struct { GetUsernameCalled func(address string, options api.AccountQueryOptions) (string, api.BlockInfo, error) GetCodeHashCalled func(address string, options api.AccountQueryOptions) ([]byte, api.BlockInfo, error) GetKeyValuePairsCalled func(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) - GetKeyValuePairsWithCheckpointCalled func(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) + IterateKeysCalled func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) SimulateTransactionExecutionHandler func(tx *transaction.Transaction) (*txSimData.SimulationResultsWithVMOutput, error) GetESDTDataCalled func(address string, key string, nonce uint64, options api.AccountQueryOptions) (*esdt.ESDigitalToken, api.BlockInfo, error) GetAllESDTTokensCalled func(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) @@ -242,13 +242,13 @@ func (f *FacadeStub) GetKeyValuePairs(address string, options api.AccountQueryOp return nil, api.BlockInfo{}, nil } -// GetKeyValuePairsWithCheckpoint - -func (f *FacadeStub) GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) { - if f.GetKeyValuePairsWithCheckpointCalled != nil { - return f.GetKeyValuePairsWithCheckpointCalled(address, checkpointId, numLeaves, options) +// IterateKeys - +func (f *FacadeStub) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + if f.IterateKeysCalled != nil { + return f.IterateKeysCalled(address, numKeys, iteratorState, options) } - return nil, api.BlockInfo{}, "", nil + return nil, nil, api.BlockInfo{}, nil } // GetGuardianData - diff --git a/api/shared/interface.go b/api/shared/interface.go index 56a1cc70e19..adedd6642af 100644 --- a/api/shared/interface.go +++ b/api/shared/interface.go @@ -74,7 +74,7 @@ type FacadeHandler interface { GetESDTsWithRole(address string, role string, options api.AccountQueryOptions) ([]string, api.BlockInfo, error) GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) - GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) GetBlockByHash(hash string, options api.BlockQueryOptions) (*api.Block, error) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*api.Block, error) diff --git a/cmd/node/config/api.toml b/cmd/node/config/api.toml index af7ccd8a877..378b4157e47 100644 --- a/cmd/node/config/api.toml +++ b/cmd/node/config/api.toml @@ -79,8 +79,8 @@ # /address/:address/keys will return all the key-value pairs of a given account { Name = "/:address/keys", Open = true }, - # address/:address/num-keys/:numKeys/checkpoint-id/:checkpointId will return the given num of key-value pairs for the given account - { Name = "/:address/num-keys/:numKeys/checkpoint-id/:checkpointId", Open = true }, + # address//iterate-keys will return the given num of key-value pairs for the given account. The iteration will start from the given starting state + { Name = "/iterate-keys", Open = true }, # /address/:address/key/:key will return the value of a key for a given account { Name = "/:address/key/:key", Open = true }, diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 7e40d31dbd8..7d0ffeb57fe 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -665,7 +665,7 @@ [TrieLeavesRetrieverConfig] Enabled = false - MaxSizeInBytes = 104857600 #100MB + MaxSizeInBytes = 10485760 #10MB [BlockSizeThrottleConfig] MinSizeInBytes = 104857 # 104857 is 10% from 1MB diff --git a/common/interface.go b/common/interface.go index efa6b5116fd..a72be45b1f4 100644 --- a/common/interface.go +++ b/common/interface.go @@ -80,6 +80,7 @@ type StorageMarker interface { type KeyBuilder interface { BuildKey(keyPart []byte) GetKey() ([]byte, error) + GetRawKey() []byte DeepClone() KeyBuilder ShallowClone() KeyBuilder Size() uint @@ -386,16 +387,13 @@ type TrieNodeData interface { // DfsIterator is used to iterate the trie nodes in a depth-first search manner type DfsIterator interface { GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) - GetIteratorId() []byte - Clone() DfsIterator - FinishedIteration() bool - Size() uint64 + GetIteratorState() [][]byte IsInterfaceNil() bool } // TrieLeavesRetriever is used to retrieve the leaves from the trie. If there is a saved checkpoint for the iterator id, // it will continue to iterate from the checkpoint. type TrieLeavesRetriever interface { - GetLeaves(numLeaves int, rootHash []byte, iteratorID []byte, ctx context.Context) (map[string]string, []byte, error) + GetLeaves(numLeaves int, iteratorState [][]byte, ctx context.Context) (map[string]string, [][]byte, error) IsInterfaceNil() bool } diff --git a/facade/initial/initialNodeFacade.go b/facade/initial/initialNodeFacade.go index 626f77db816..ea9268d0bde 100644 --- a/facade/initial/initialNodeFacade.go +++ b/facade/initial/initialNodeFacade.go @@ -346,9 +346,9 @@ func (inf *initialNodeFacade) GetKeyValuePairs(_ string, _ api.AccountQueryOptio return nil, api.BlockInfo{}, errNodeStarting } -// GetKeyValuePairsWithCheckpoint returns error -func (inf *initialNodeFacade) GetKeyValuePairsWithCheckpoint(_ string, _ string, _ int, _ api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) { - return nil, api.BlockInfo{}, "", errNodeStarting +// IterateKeys returns error +func (inf *initialNodeFacade) IterateKeys(_ string, _ uint, _ [][]byte, _ api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + return nil, nil, api.BlockInfo{}, errNodeStarting } // GetGuardianData returns error diff --git a/facade/interface.go b/facade/interface.go index 413389cb1be..2dfa8b503bd 100644 --- a/facade/interface.go +++ b/facade/interface.go @@ -41,8 +41,8 @@ type NodeHandler interface { // GetKeyValuePairs returns the key-value pairs under a given address GetKeyValuePairs(address string, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, error) - // GetKeyValuePairsWithCheckpoint returns the key-value pairs under a given address with a checkpoint - GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, string, error) + // IterateKeys returns the key-value pairs under a given address starting from a given state + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) // GetAllIssuedESDTs returns all the issued esdt tokens from esdt system smart contract GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]string, error) diff --git a/facade/mock/nodeStub.go b/facade/mock/nodeStub.go index a9f289a6ff8..e7b2817a32e 100644 --- a/facade/mock/nodeStub.go +++ b/facade/mock/nodeStub.go @@ -49,7 +49,7 @@ type NodeStub struct { GetESDTsWithRoleCalled func(address string, role string, options api.AccountQueryOptions, ctx context.Context) ([]string, api.BlockInfo, error) GetESDTsRolesCalled func(address string, options api.AccountQueryOptions, ctx context.Context) (map[string][]string, api.BlockInfo, error) GetKeyValuePairsCalled func(address string, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, error) - GetKeyValuePairsWithCheckpointCalled func(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, string, error) + IterateKeysCalled func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) GetAllIssuedESDTsCalled func(tokenType string, ctx context.Context) ([]string, error) GetProofCalled func(rootHash string, key string) (*common.GetProofResponse, error) GetProofDataTrieCalled func(rootHash string, address string, key string) (*common.GetProofResponse, *common.GetProofResponse, error) @@ -113,13 +113,13 @@ func (ns *NodeStub) GetKeyValuePairs(address string, options api.AccountQueryOpt return nil, api.BlockInfo{}, nil } -// GetKeyValuePairsWithCheckpoint - -func (ns *NodeStub) GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, string, error) { - if ns.GetKeyValuePairsWithCheckpointCalled != nil { - return ns.GetKeyValuePairsWithCheckpointCalled(address, checkpointId, numLeaves, options, ctx) +// IterateKeys - +func (ns *NodeStub) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) { + if ns.IterateKeysCalled != nil { + return ns.IterateKeysCalled(address, numKeys, iteratorState, options, ctx) } - return nil, api.BlockInfo{}, "", nil + return nil, nil, api.BlockInfo{}, nil } // GetValueForKey - diff --git a/facade/nodeFacade.go b/facade/nodeFacade.go index 89ab8c99813..e516b506b52 100644 --- a/facade/nodeFacade.go +++ b/facade/nodeFacade.go @@ -229,13 +229,12 @@ func (nf *nodeFacade) GetKeyValuePairs(address string, options apiData.AccountQu return nf.node.GetKeyValuePairs(address, options, ctx) } -// GetKeyValuePairsWithCheckpoint returns the given number of key-value pairs under the provided address. -// The iteration starts from the given checkpoint, and returns a new checkpoint for the next iteration. -func (nf *nodeFacade) GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options apiData.AccountQueryOptions) (map[string]string, apiData.BlockInfo, string, error) { +// IterateKeys starts from the given iteratorState and returns the next key-value pairs and the new iteratorState +func (nf *nodeFacade) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options apiData.AccountQueryOptions) (map[string]string, [][]byte, apiData.BlockInfo, error) { ctx, cancel := nf.getContextForApiTrieRangeOperations() defer cancel() - return nf.node.GetKeyValuePairsWithCheckpoint(address, checkpointId, numLeaves, options, ctx) + return nf.node.IterateKeys(address, numKeys, iteratorState, options, ctx) } // GetGuardianData returns the guardian data for the provided address diff --git a/integrationTests/interface.go b/integrationTests/interface.go index c63438fbc15..2b78eec1f0f 100644 --- a/integrationTests/interface.go +++ b/integrationTests/interface.go @@ -69,7 +69,7 @@ type Facade interface { GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetESDTsRoles(address string, options api.AccountQueryOptions) (map[string][]string, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) - GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, string, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) GetBlockByHash(hash string, options api.BlockQueryOptions) (*dataApi.Block, error) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*dataApi.Block, error) diff --git a/node/node.go b/node/node.go index 72731c5bf78..73c1cc88da8 100644 --- a/node/node.go +++ b/node/node.go @@ -308,38 +308,32 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, return mapToReturn, blockInfo, nil } -// GetKeyValuePairsWithCheckpoint returns the given number of key-value pairs under the provided address. -// The iteration starts from the given checkpoint, and returns a new checkpoint for the next iteration. -func (n *Node) GetKeyValuePairsWithCheckpoint(address string, checkpointId string, numLeaves int, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, string, error) { +// IterateKeys starts from the given iteratorState and returns the next key-value pairs and the new iteratorState +func (n *Node) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) { userAccount, blockInfo, err := n.loadUserAccountHandlerByAddress(address, options) if err != nil { adaptedBlockInfo, isEmptyAccount := extractBlockInfoIfNewAccount(err) if isEmptyAccount { - return make(map[string]string), adaptedBlockInfo, "", nil + return make(map[string]string), nil, adaptedBlockInfo, nil } - return nil, api.BlockInfo{}, "", err + return nil, nil, api.BlockInfo{}, err } if check.IfNil(userAccount.DataTrie()) { - return map[string]string{}, blockInfo, "", nil + return map[string]string{}, nil, blockInfo, nil } - checkpointIdBytes, err := hex.DecodeString(checkpointId) - if err != nil { - return nil, api.BlockInfo{}, "", fmt.Errorf("invalid checkpointId: %w", err) + if len(iteratorState) == 0 { + iteratorState = append(iteratorState, userAccount.GetRootHash()) } - mapToReturn, newCheckpoint, err := n.stateComponents.TrieLeavesRetriever().GetLeaves(numLeaves, userAccount.GetRootHash(), checkpointIdBytes, ctx) + mapToReturn, newIteratorState, err := n.stateComponents.TrieLeavesRetriever().GetLeaves(int(numKeys), iteratorState, ctx) if err != nil { - return nil, api.BlockInfo{}, "", err - } - - if common.IsContextDone(ctx) { - return nil, api.BlockInfo{}, "", ErrTrieOperationsTimeout + return nil, nil, api.BlockInfo{}, err } - return mapToReturn, blockInfo, hex.EncodeToString(newCheckpoint), nil + return mapToReturn, newIteratorState, blockInfo, nil } func (n *Node) getKeys(userAccount state.UserAccountHandler, ctx context.Context) (map[string]string, error) { diff --git a/trie/errors.go b/trie/errors.go index 9cc2588e501..a879fd6c94c 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -123,3 +123,9 @@ var ErrNilTrieLeafParser = errors.New("nil trie leaf parser") // ErrInvalidNodeVersion signals that an invalid node version has been provided var ErrInvalidNodeVersion = errors.New("invalid node version provided") + +// ErrEmptyInitialIteratorState signals that an empty initial iterator state was provided +var ErrEmptyInitialIteratorState = errors.New("empty initial iterator state") + +// ErrInvalidIteratorState signals that an invalid iterator state was provided +var ErrInvalidIteratorState = errors.New("invalid iterator state") diff --git a/trie/keyBuilder/disabledKeyBuilder.go b/trie/keyBuilder/disabledKeyBuilder.go index 78a0350aa26..71c2022d372 100644 --- a/trie/keyBuilder/disabledKeyBuilder.go +++ b/trie/keyBuilder/disabledKeyBuilder.go @@ -22,6 +22,11 @@ func (dkb *disabledKeyBuilder) GetKey() ([]byte, error) { return []byte{}, nil } +// GetRawKey returns an empty byte array for this implementation +func (dkb *disabledKeyBuilder) GetRawKey() []byte { + return []byte{} +} + // ShallowClone returns a new disabled key builder func (dkb *disabledKeyBuilder) ShallowClone() common.KeyBuilder { return &disabledKeyBuilder{} diff --git a/trie/keyBuilder/keyBuilder.go b/trie/keyBuilder/keyBuilder.go index a455fd5c9e6..c1b7f78f62a 100644 --- a/trie/keyBuilder/keyBuilder.go +++ b/trie/keyBuilder/keyBuilder.go @@ -34,6 +34,11 @@ func (kb *keyBuilder) GetKey() ([]byte, error) { return hexToTrieKeyBytes(kb.key) } +// GetRawKey returns the key as it is, without transforming it +func (kb *keyBuilder) GetRawKey() []byte { + return kb.key +} + // ShallowClone returns a new KeyBuilder with the same key. The key slice points to the same memory location. func (kb *keyBuilder) ShallowClone() common.KeyBuilder { return &keyBuilder{ diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go index 7932e9a2ce4..b7dc34329a4 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go @@ -10,11 +10,11 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/keyBuilder" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever/trieNodeData" ) type dfsIterator struct { nextNodes []common.TrieNodeData - rootHash []byte db common.TrieStorageInteractor marshaller marshal.Marshalizer hasher hashing.Hasher @@ -22,7 +22,7 @@ type dfsIterator struct { } // NewIterator creates a new DFS iterator for the trie. -func NewIterator(rootHash []byte, db common.TrieStorageInteractor, marshaller marshal.Marshalizer, hasher hashing.Hasher) (*dfsIterator, error) { +func NewIterator(initialState [][]byte, db common.TrieStorageInteractor, marshaller marshal.Marshalizer, hasher hashing.Hasher) (*dfsIterator, error) { if check.IfNil(db) { return nil, trie.ErrNilDatabase } @@ -32,20 +32,22 @@ func NewIterator(rootHash []byte, db common.TrieStorageInteractor, marshaller ma if check.IfNil(hasher) { return nil, trie.ErrNilHasher } + if len(initialState) == 0 { + return nil, trie.ErrEmptyInitialIteratorState + } - data, err := trie.GetNodeDataFromHash(rootHash, keyBuilder.NewKeyBuilder(), db, marshaller, hasher) + nextNodes, err := getNextNodesFromInitialState(initialState, uint(hasher.Size())) if err != nil { return nil, err } size := uint64(0) - for _, node := range data { + for _, node := range nextNodes { size += node.Size() } return &dfsIterator{ - nextNodes: data, - rootHash: rootHash, + nextNodes: nextNodes, db: db, marshaller: marshaller, hasher: hasher, @@ -53,6 +55,40 @@ func NewIterator(rootHash []byte, db common.TrieStorageInteractor, marshaller ma }, nil } +func getNextNodesFromInitialState(initialState [][]byte, hashSize uint) ([]common.TrieNodeData, error) { + nextNodes := make([]common.TrieNodeData, len(initialState)) + for i, state := range initialState { + if len(state) < int(hashSize) { + return nil, trie.ErrInvalidIteratorState + } + + nodeHash := state[:hashSize] + key := state[hashSize:] + + kb := keyBuilder.NewKeyBuilder() + kb.BuildKey(key) + nodeData, err := trieNodeData.NewIntermediaryNodeData(kb, nodeHash) + if err != nil { + return nil, err + } + nextNodes[i] = nodeData + } + + return nextNodes, nil +} + +func getIteratorStateFromNextNodes(nextNodes []common.TrieNodeData) [][]byte { + iteratorState := make([][]byte, len(nextNodes)) + for i, node := range nextNodes { + nodeHash := node.GetData() + key := node.GetKeyBuilder().GetRawKey() + + iteratorState[i] = append(nodeHash, key...) + } + + return iteratorState +} + // GetLeaves retrieves leaves from the trie. It stops either when the number of leaves is reached or the context is done. func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) { retrievedLeaves := make(map[string]string) @@ -103,30 +139,13 @@ func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Cont } } -// GetIteratorId returns the ID of the iterator. -func (it *dfsIterator) GetIteratorId() []byte { - if len(it.nextNodes) == 0 { +// GetIteratorState returns the state of the iterator from which it can be resumed by another call. +func (it *dfsIterator) GetIteratorState() [][]byte { + if it.FinishedIteration() { return nil } - nextNodeHash := it.nextNodes[0].GetData() - iteratorID := it.hasher.Compute(string(append(it.rootHash, nextNodeHash...))) - return iteratorID -} - -// Clone creates a copy of the iterator. -func (it *dfsIterator) Clone() common.DfsIterator { - nextNodes := make([]common.TrieNodeData, len(it.nextNodes)) - copy(nextNodes, it.nextNodes) - - return &dfsIterator{ - nextNodes: nextNodes, - rootHash: it.rootHash, - db: it.db, - marshaller: it.marshaller, - hasher: it.hasher, - size: it.size, - } + return getIteratorStateFromNextNodes(it.nextNodes) } // FinishedIteration checks if the iterator has finished the iteration. @@ -136,7 +155,7 @@ func (it *dfsIterator) FinishedIteration() bool { // Size returns the size of the iterator. func (it *dfsIterator) Size() uint64 { - return it.size + uint64(len(it.rootHash)) + return it.size } // IsInterfaceNil returns true if there is no value under the interface diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go index b8d71b40173..657dc302bd2 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go @@ -1,8 +1,9 @@ package dfsTrieIterator import ( + "bytes" "context" - "fmt" + "encoding/hex" "math" "testing" @@ -23,46 +24,52 @@ func TestNewIterator(t *testing.T) { t.Run("nil db", func(t *testing.T) { t.Parallel() - iterator, err := NewIterator([]byte("hash"), nil, &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, nil, &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) assert.Nil(t, iterator) assert.Equal(t, trie.ErrNilDatabase, err) }) t.Run("nil marshaller", func(t *testing.T) { t.Parallel() - iterator, err := NewIterator([]byte("hash"), testscommon.NewMemDbMock(), nil, &hashingMocks.HasherMock{}) + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, testscommon.NewMemDbMock(), nil, &hashingMocks.HasherMock{}) assert.Nil(t, iterator) assert.Equal(t, trie.ErrNilMarshalizer, err) }) t.Run("nil hasher", func(t *testing.T) { t.Parallel() - iterator, err := NewIterator([]byte("hash"), testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, nil) + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, nil) assert.Nil(t, iterator) assert.Equal(t, trie.ErrNilHasher, err) }) - t.Run("invalid hash", func(t *testing.T) { + t.Run("empty initial state", func(t *testing.T) { t.Parallel() - iterator, err := NewIterator([]byte("invalid hash"), testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + iterator, err := NewIterator([][]byte{}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) assert.Nil(t, iterator) - assert.NotNil(t, err) + assert.Equal(t, trie.ErrEmptyInitialIteratorState, err) }) - t.Run("initialize iterator with a valid hash", func(t *testing.T) { + t.Run("invalid initial state", func(t *testing.T) { t.Parallel() - tr := trieTest.GetNewTrie() - _ = tr.Update([]byte("key1"), []byte("value1")) - _ = tr.Commit() - rootHash, _ := tr.RootHash() + iterator, err := NewIterator([][]byte{[]byte("invalid state")}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + assert.Nil(t, iterator) + assert.Equal(t, trie.ErrInvalidIteratorState, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() - _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, err := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + initialState := [][]byte{ + bytes.Repeat([]byte{0}, 40), + bytes.Repeat([]byte{1}, 40), + } + + db, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, err := NewIterator(initialState, db, marshaller, hasher) assert.Nil(t, err) - assert.Equal(t, rootHash, iterator.rootHash) - assert.Equal(t, uint64(15), iterator.size) - assert.Equal(t, 1, len(iterator.nextNodes)) + assert.Equal(t, uint64(80), iterator.size) + assert.Equal(t, 2, len(iterator.nextNodes)) }) } @@ -95,7 +102,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { }, } _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, dbWrapper, marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, dbWrapper, marshaller, hasher) trieData, err := iterator.GetLeaves(numLeaves, maxSize, ctx) assert.Nil(t, err) @@ -110,7 +117,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) trieData, err := iterator.GetLeaves(numLeaves, maxSize, context.Background()) assert.Nil(t, err) @@ -126,7 +133,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) trieData, err := iterator.GetLeaves(17, maxSize, context.Background()) assert.Nil(t, err) @@ -140,7 +147,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) iteratorMaxSize := uint64(100) trieData, err := iterator.GetLeaves(5, iteratorMaxSize, context.Background()) @@ -155,7 +162,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { trieTest.AddDataToTrie(tr, numLeaves) rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) numRetrievedLeaves := 0 numIterations := 0 @@ -170,7 +177,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { assert.Equal(t, numLeaves, numRetrievedLeaves) assert.Equal(t, 5, numIterations) }) - t.Run("retrieve leaves with nil iterator does not panic", func(t *testing.T) { + t.Run("retrieve leaves with nil context does not panic", func(t *testing.T) { t.Parallel() tr := trieTest.GetNewTrie() @@ -180,7 +187,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) trieData, err := iterator.GetLeaves(numLeaves, maxSize, nil) assert.Nil(t, err) @@ -188,48 +195,36 @@ func TestDfsIterator_GetLeaves(t *testing.T) { }) } -func TestDfsIterator_GetIteratorId(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - numLeaves := 25 - trieTest.AddDataToTrie(tr, numLeaves) - rootHash, _ := tr.RootHash() - _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) - - numRetrievedLeaves := 0 - for numRetrievedLeaves < numLeaves { - iteratorId := hasher.Compute(string(append(rootHash, iterator.nextNodes[0].GetData()...))) - assert.Equal(t, iteratorId, iterator.GetIteratorId()) - - trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) - assert.Nil(t, err) - - numRetrievedLeaves += len(trieData) - } - - assert.Equal(t, numLeaves, numRetrievedLeaves) - assert.Nil(t, iterator.GetIteratorId()) -} - -func TestDfsIterator_Clone(t *testing.T) { +func TestDfsIterator_GetIteratorState(t *testing.T) { t.Parallel() tr := trieTest.GetNewTrie() - numLeaves := 25 - trieTest.AddDataToTrie(tr, numLeaves) + _ = tr.Update([]byte("doe"), []byte("reindeer")) + _ = tr.Update([]byte("dog"), []byte("puppy")) + _ = tr.Update([]byte("ddog"), []byte("cat")) + _ = tr.Commit() rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - clonedIterator := iterator.Clone() - - nextNodesMemAddr := fmt.Sprintf("%p", iterator.nextNodes) - clonedNextNodesMemAddr := fmt.Sprintf("%p", clonedIterator.(*dfsIterator).nextNodes) - assert.NotEqual(t, nextNodesMemAddr, clonedNextNodesMemAddr) - assert.Equal(t, iterator.rootHash, clonedIterator.(*dfsIterator).rootHash) - assert.Equal(t, iterator.size, clonedIterator.(*dfsIterator).size) + leaves, err := iterator.GetLeaves(2, maxSize, context.Background()) + assert.Nil(t, err) + assert.Equal(t, 2, len(leaves)) + val, ok := leaves[hex.EncodeToString([]byte("doe"))] + assert.True(t, ok) + assert.Equal(t, hex.EncodeToString([]byte("reindeer")), val) + val, ok = leaves[hex.EncodeToString([]byte("ddog"))] + assert.True(t, ok) + assert.Equal(t, hex.EncodeToString([]byte("cat")), val) + + iteratorState := iterator.GetIteratorState() + assert.Equal(t, 1, len(iteratorState)) + hash := iteratorState[0][:hasher.Size()] + key := iteratorState[0][hasher.Size():] + assert.Equal(t, []byte{0x7, 0x6, 0xf, 0x6, 0x4, 0x6, 0x10}, key) + leafBytes, err := tr.GetStorageManager().Get(hash) + assert.Nil(t, err) + assert.NotNil(t, leafBytes) } func TestDfsIterator_FinishedIteration(t *testing.T) { @@ -240,7 +235,7 @@ func TestDfsIterator_FinishedIteration(t *testing.T) { trieTest.AddDataToTrie(tr, numLeaves) rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) numRetrievedLeaves := 0 for numRetrievedLeaves < numLeaves { @@ -264,30 +259,32 @@ func TestDfsIterator_Size(t *testing.T) { rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - // branch node size = 33 - // root hash size = 32 - // extension nodes size = 34 - // leaf nodes size = 35 - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) - assert.Equal(t, uint64(362), iterator.Size()) // 10 branch nodes + 1 root hash + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + assert.Equal(t, uint64(32), iterator.Size()) // root hash + assert.False(t, iterator.FinishedIteration()) _, err := iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) - assert.Equal(t, uint64(331), iterator.Size()) // 8 branch nodes + 1 leaf node + 1 root hash + assert.Equal(t, uint64(299), iterator.Size()) // 9 hashes + leaf key(3) + 8 x intermediary nodes key(8 * 1) + assert.False(t, iterator.FinishedIteration()) _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) - assert.Equal(t, uint64(300), iterator.Size()) // 6 branch nodes + 2 leaf node + 1 root hash + assert.Equal(t, uint64(268), iterator.Size()) // 8 hashes + 2 x leaf keys(2 * 3) + 6 x intermediary nodes key(6*1) + assert.False(t, iterator.FinishedIteration()) _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) - assert.Equal(t, uint64(197), iterator.Size()) // 5 branch nodes + 1 root hash + assert.Equal(t, uint64(165), iterator.Size()) // 5 hashes + 5 x intermediary nodes key(5*1) + assert.False(t, iterator.FinishedIteration()) _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) - assert.Equal(t, uint64(133), iterator.Size()) // 2 branch nodes + 1 leaf node + 1 root hash + assert.Equal(t, uint64(101), iterator.Size()) // 3 hashes + leaf key(3) + 2 x intermediary nodes key(2*1) + assert.False(t, iterator.FinishedIteration()) _, err = iterator.GetLeaves(5, maxSize, context.Background()) assert.Nil(t, err) - assert.Equal(t, uint64(32), iterator.Size()) // 1 root hash + assert.Equal(t, uint64(0), iterator.Size()) + assert.True(t, iterator.FinishedIteration()) } diff --git a/trie/leavesRetriever/disabledLeavesRetriever.go b/trie/leavesRetriever/disabledLeavesRetriever.go index b3143e377ff..df087d7b47b 100644 --- a/trie/leavesRetriever/disabledLeavesRetriever.go +++ b/trie/leavesRetriever/disabledLeavesRetriever.go @@ -10,8 +10,8 @@ func NewDisabledLeavesRetriever() *disabledLeavesRetriever { } // GetLeaves returns an empty map and a nil byte slice for this implementation -func (dlr *disabledLeavesRetriever) GetLeaves(_ int, _ []byte, _ []byte, _ context.Context) (map[string]string, []byte, error) { - return make(map[string]string), []byte{}, nil +func (dlr *disabledLeavesRetriever) GetLeaves(_ int, _ [][]byte, _ context.Context) (map[string]string, [][]byte, error) { + return make(map[string]string), [][]byte{}, nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/trie/leavesRetriever/export_test.go b/trie/leavesRetriever/export_test.go index 3135262e01a..3d9ce098b3e 100644 --- a/trie/leavesRetriever/export_test.go +++ b/trie/leavesRetriever/export_test.go @@ -1,18 +1,16 @@ package leavesRetriever -import "github.com/multiversx/mx-chain-go/common" - -// GetIterators - -func (lr *leavesRetriever) GetIterators() map[string]common.DfsIterator { - return lr.iterators -} - -// GetLruIteratorIDs - -func (lr *leavesRetriever) GetLruIteratorIDs() [][]byte { - return lr.lruIteratorIDs -} - -// Size - -func (lr *leavesRetriever) Size() uint64 { - return lr.size -} +//// GetIterators - +//func (lr *leavesRetriever) GetIterators() map[string]common.DfsIterator { +// return lr.iterators +//} +// +//// GetLruIteratorIDs - +//func (lr *leavesRetriever) GetLruIteratorIDs() [][]byte { +// return lr.lruIteratorIDs +//} +// +//// Size - +//func (lr *leavesRetriever) Size() uint64 { +// return lr.size +//} diff --git a/trie/leavesRetriever/leavesRetriever.go b/trie/leavesRetriever/leavesRetriever.go index 544f9a0ff79..ee0937647c7 100644 --- a/trie/leavesRetriever/leavesRetriever.go +++ b/trie/leavesRetriever/leavesRetriever.go @@ -2,7 +2,6 @@ package leavesRetriever import ( "context" - "sync" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" @@ -15,14 +14,10 @@ import ( var log = logger.GetOrCreate("trie/leavesRetriever") type leavesRetriever struct { - iterators map[string]common.DfsIterator - lruIteratorIDs [][]byte - db common.TrieStorageInteractor - marshaller marshal.Marshalizer - hasher hashing.Hasher - size uint64 - maxSize uint64 - mutex sync.RWMutex + db common.TrieStorageInteractor + marshaller marshal.Marshalizer + hasher hashing.Hasher + maxSize uint64 } // NewLeavesRetriever creates a new leaves retriever @@ -38,140 +33,27 @@ func NewLeavesRetriever(db common.TrieStorageInteractor, marshaller marshal.Mars } return &leavesRetriever{ - iterators: make(map[string]common.DfsIterator), - lruIteratorIDs: make([][]byte, 0), - db: db, - marshaller: marshaller, - hasher: hasher, - size: 0, - maxSize: maxSize, + db: db, + marshaller: marshaller, + hasher: hasher, + maxSize: maxSize, }, nil } -// GetLeaves retrieves the leaves from the trie. If there is a saved checkpoint for the iterator id, it will continue to iterate from the checkpoint. -func (lr *leavesRetriever) GetLeaves(numLeaves int, rootHash []byte, iteratorID []byte, ctx context.Context) (map[string]string, []byte, error) { - defer log.Trace("leaves retriever stats", "size", lr.size, "numIterators", len(lr.iterators)) - if len(iteratorID) == 0 { - return lr.getLeavesFromNewInstance(numLeaves, rootHash, ctx) - } - - lr.mutex.RLock() - iterator, ok := lr.iterators[string(iteratorID)] - lr.mutex.RUnlock() - if !ok { - return nil, nil, ErrIteratorNotFound - } - - return lr.getLeavesFromCheckpoint(numLeaves, iterator, iteratorID, ctx) -} - -func (lr *leavesRetriever) getLeavesFromNewInstance(numLeaves int, rootHash []byte, ctx context.Context) (map[string]string, []byte, error) { - log.Trace("get leaves from new instance", "numLeaves", numLeaves, "rootHash", rootHash) - iterator, err := dfsTrieIterator.NewIterator(rootHash, lr.db, lr.marshaller, lr.hasher) +// GetLeaves retrieves leaves from the trie starting from the iterator state. It will also return the new iterator state +// from which one can continue the iteration. +func (lr *leavesRetriever) GetLeaves(numLeaves int, iteratorState [][]byte, ctx context.Context) (map[string]string, [][]byte, error) { + iterator, err := dfsTrieIterator.NewIterator(iteratorState, lr.db, lr.marshaller, lr.hasher) if err != nil { return nil, nil, err } - return lr.getLeavesFromIterator(iterator, numLeaves, ctx) -} - -func (lr *leavesRetriever) getLeavesFromCheckpoint(numLeaves int, iterator common.DfsIterator, iteratorID []byte, ctx context.Context) (map[string]string, []byte, error) { - log.Trace("get leaves from checkpoint", "numLeaves", numLeaves, "iteratorID", iteratorID) - lr.markIteratorAsRecentlyUsed(iteratorID) - clonedIterator := iterator.Clone() - - return lr.getLeavesFromIterator(clonedIterator, numLeaves, ctx) -} - -func (lr *leavesRetriever) getLeavesFromIterator(iterator common.DfsIterator, numLeaves int, ctx context.Context) (map[string]string, []byte, error) { - leaves, err := iterator.GetLeaves(numLeaves, lr.maxSize, ctx) + leavesData, err := iterator.GetLeaves(numLeaves, lr.maxSize, ctx) if err != nil { return nil, nil, err } - if iterator.FinishedIteration() { - return leaves, nil, nil - } - - iteratorId := iterator.GetIteratorId() - if len(iteratorId) == 0 { - return leaves, nil, nil - } - - shouldReturnId := lr.manageIterators(iteratorId, iterator) - if !shouldReturnId { - return leaves, nil, nil - } - return leaves, iteratorId, nil -} - -func (lr *leavesRetriever) manageIterators(iteratorId []byte, iterator common.DfsIterator) bool { - lr.mutex.Lock() - defer lr.mutex.Unlock() - - newIteratorPresent := lr.saveIterator(iteratorId, iterator) - if !newIteratorPresent { - return false - } - lr.removeIteratorsIfMaxSizeIsExceeded() - return true -} - -func (lr *leavesRetriever) saveIterator(iteratorId []byte, iterator common.DfsIterator) bool { - _, isPresent := lr.iterators[string(iteratorId)] - if isPresent { - return true - } - - if iterator.Size() >= lr.maxSize { - return false - } - - lr.lruIteratorIDs = append(lr.lruIteratorIDs, iteratorId) - lr.iterators[string(iteratorId)] = iterator - lr.size += iterator.Size() + uint64(len(iteratorId)) - return true -} - -func (lr *leavesRetriever) markIteratorAsRecentlyUsed(iteratorId []byte) { - lr.mutex.Lock() - defer lr.mutex.Unlock() - - for i, id := range lr.lruIteratorIDs { - if string(id) == string(iteratorId) { - lr.lruIteratorIDs = append(lr.lruIteratorIDs[:i], lr.lruIteratorIDs[i+1:]...) - lr.lruIteratorIDs = append(lr.lruIteratorIDs, id) - return - } - } -} - -func (lr *leavesRetriever) removeIteratorsIfMaxSizeIsExceeded() { - if lr.size <= lr.maxSize { - return - } - - idsToRemove := make([][]byte, 0) - sizeOfRemoved := uint64(0) - numOfRemoved := 0 - - for i := 0; i < len(lr.lruIteratorIDs); i++ { - id := lr.lruIteratorIDs[i] - idsToRemove = append(idsToRemove, id) - iterator := lr.iterators[string(id)] - sizeOfRemoved += iterator.Size() + uint64(len(id)) - numOfRemoved++ - - if lr.size-sizeOfRemoved <= lr.maxSize { - break - } - } - - for _, id := range idsToRemove { - delete(lr.iterators, string(id)) - } - lr.lruIteratorIDs = lr.lruIteratorIDs[numOfRemoved:] - lr.size -= sizeOfRemoved + return leavesData, iterator.GetIteratorState(), nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/trie/leavesRetriever/leavesRetriever_test.go b/trie/leavesRetriever/leavesRetriever_test.go index 1605aaf6fc4..e8dade186b6 100644 --- a/trie/leavesRetriever/leavesRetriever_test.go +++ b/trie/leavesRetriever/leavesRetriever_test.go @@ -2,20 +2,14 @@ package leavesRetriever_test import ( "context" - "crypto/rand" - "encoding/hex" - "fmt" - "sync" "testing" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" trieTest "github.com/multiversx/mx-chain-go/testscommon/state" - "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/leavesRetriever" "github.com/stretchr/testify/assert" ) @@ -59,227 +53,19 @@ func TestNewLeavesRetriever(t *testing.T) { func TestLeavesRetriever_GetLeaves(t *testing.T) { t.Parallel() - t.Run("get leaves from new instance", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) - leaves, iteratorId, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves)) - assert.Equal(t, 32, len(iteratorId)) - }) - t.Run("get leaves from existing instance", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 10000000) - leaves1, iteratorId1, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves1)) - assert.Equal(t, 32, len(iteratorId1)) - assert.Equal(t, 1, len(lr.GetIterators())) - assert.Equal(t, 1, len(lr.GetLruIteratorIDs())) - - leaves2, iteratorId2, err := lr.GetLeaves(10, rootHash, iteratorId1, context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves2)) - assert.Equal(t, 32, len(iteratorId2)) - assert.Equal(t, 2, len(lr.GetIterators())) - assert.Equal(t, 2, len(lr.GetLruIteratorIDs())) - - assert.NotEqual(t, leaves1, leaves2) - assert.NotEqual(t, iteratorId1, iteratorId2) - }) - t.Run("traversing a trie saves all iterator instances", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 10000000) - leaves1, iteratorId1, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves1)) - assert.Equal(t, 32, len(iteratorId1)) - assert.Equal(t, 1, len(lr.GetIterators())) - assert.Equal(t, 1, len(lr.GetLruIteratorIDs())) - - leaves2, iteratorId2, err := lr.GetLeaves(10, rootHash, iteratorId1, context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves2)) - assert.Equal(t, 32, len(iteratorId2)) - assert.Equal(t, 2, len(lr.GetIterators())) - assert.Equal(t, 2, len(lr.GetLruIteratorIDs())) - - leaves3, iteratorId3, err := lr.GetLeaves(10, rootHash, iteratorId2, context.Background()) - assert.Nil(t, err) - assert.Equal(t, 5, len(leaves3)) - assert.Equal(t, 0, len(iteratorId3)) - assert.Equal(t, 2, len(lr.GetIterators())) - assert.Equal(t, 2, len(lr.GetLruIteratorIDs())) - }) - t.Run("iterator instances are evicted in a lru manner", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - maxSize := uint64(1000) - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - iterators := make([][]byte, 0) - _, id1, _ := lr.GetLeaves(5, rootHash, []byte(""), context.Background()) - iterators = append(iterators, id1) - _, id2, _ := lr.GetLeaves(5, rootHash, id1, context.Background()) - iterators = append(iterators, id2) - _, id3, _ := lr.GetLeaves(5, rootHash, id2, context.Background()) - iterators = append(iterators, id3) - - assert.Equal(t, 3, len(lr.GetIterators())) - for i, id := range lr.GetLruIteratorIDs() { - assert.Equal(t, iterators[i], id) - } - - _, id4, _ := lr.GetLeaves(5, rootHash, id3, context.Background()) - iterators = append(iterators, id4) - assert.Equal(t, 3, len(lr.GetIterators())) - for i, id := range lr.GetLruIteratorIDs() { - assert.Equal(t, iterators[i+1], id) - } - }) - t.Run("when an iterator instance is used, it is moved in the front of the eviction queue", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - maxSize := uint64(100000) - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - iterators := make([][]byte, 0) - _, id1, _ := lr.GetLeaves(5, rootHash, []byte(""), context.Background()) - iterators = append(iterators, id1) - leaves1, id2, _ := lr.GetLeaves(5, rootHash, id1, context.Background()) - iterators = append(iterators, id2) - _, id3, _ := lr.GetLeaves(5, rootHash, id2, context.Background()) - iterators = append(iterators, id3) - - assert.Equal(t, 3, len(lr.GetIterators())) - for i, id := range lr.GetLruIteratorIDs() { - assert.Equal(t, iterators[i], id) - } - - leaves2, id4, _ := lr.GetLeaves(5, rootHash, id1, context.Background()) - assert.Equal(t, leaves1, leaves2) - assert.Equal(t, id2, id4) - - assert.Equal(t, 3, len(lr.GetIterators())) - retrievedIterators := lr.GetLruIteratorIDs() - assert.Equal(t, 3, len(retrievedIterators)) - assert.Equal(t, id2, retrievedIterators[0]) - assert.Equal(t, id3, retrievedIterators[1]) - assert.Equal(t, id1, retrievedIterators[2]) - }) - t.Run("iterator not found", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - maxSize := uint64(100000) - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - leaves, id, err := lr.GetLeaves(5, rootHash, []byte("invalid iterator"), context.Background()) - assert.Nil(t, leaves) - assert.Equal(t, 0, len(id)) - assert.Equal(t, leavesRetriever.ErrIteratorNotFound, err) - }) - t.Run("max size reached on the first iteration", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetTrieWithData() - rootHash, _ := tr.RootHash() - maxSize := uint64(100) - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - leaves, id1, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) - assert.Nil(t, err) - assert.Equal(t, 2, len(leaves)) - assert.Equal(t, 0, len(id1)) - assert.Equal(t, 0, len(lr.GetIterators())) - }) -} - -func TestLeavesRetriever_Concurrency(t *testing.T) { - t.Parallel() - - numTries := 10 - numLeaves := 1000 - tries := buildTries(numTries, numLeaves) - - rootHashes := make([][]byte, 0) - for _, tr := range tries { - rootHash, _ := tr.RootHash() - rootHashes = append(rootHashes, rootHash) - - } - - maxSize := uint64(1000000) - lr, _ := leavesRetriever.NewLeavesRetriever(tries[0].GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - - wg := &sync.WaitGroup{} - wg.Add(numTries) - for i := 0; i < numTries; i++ { - go retrieveTrieLeaves(t, lr, rootHashes[i], numLeaves, wg) - } - wg.Wait() -} - -func retrieveTrieLeaves(t *testing.T, lr common.TrieLeavesRetriever, rootHash []byte, numLeaves int, wg *sync.WaitGroup) { - iteratorId := []byte("") - numRetrievedLeaves := 0 - for { - leaves, newId, err := lr.GetLeaves(100, rootHash, iteratorId, context.Background()) - assert.Nil(t, err) - iteratorId = newId - numRetrievedLeaves += len(leaves) - fmt.Println("Retrieved leaves: ", numRetrievedLeaves, " for root hash: ", hex.EncodeToString(rootHash)) - if len(iteratorId) == 0 { - break - } - } - assert.Equal(t, numLeaves, numRetrievedLeaves) - wg.Done() -} - -func buildTries(numTries int, numLeaves int) []common.Trie { - tries := make([]common.Trie, 0) - tsm, marshaller, hasher := trieTest.GetDefaultTrieParameters() - for i := 0; i < numTries; i++ { - tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) - addDataToTrie(tr, numLeaves) - tries = append(tries, tr) - } - return tries -} - -func addDataToTrie(tr common.Trie, numLeaves int) { - for i := 0; i < numLeaves; i++ { - _ = tr.Update(generateRandomByteArray(32), generateRandomByteArray(32)) - } - _ = tr.Commit() -} - -func generateRandomByteArray(size int) []byte { - r := make([]byte, size) - _, _ = rand.Read(r) - return r + tr := trieTest.GetNewTrie() + trieTest.AddDataToTrie(tr, 25) + rootHash, _ := tr.RootHash() + + lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) + leaves, newIteratorState, err := lr.GetLeaves(10, [][]byte{rootHash}, context.Background()) + assert.Nil(t, err) + assert.Equal(t, 10, len(leaves)) + assert.Equal(t, 8, len(newIteratorState)) + + newLr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) + leaves, newIteratorState, err = newLr.GetLeaves(10, newIteratorState, context.Background()) + assert.Nil(t, err) + assert.Equal(t, 10, len(leaves)) + assert.Equal(t, 3, len(newIteratorState)) } diff --git a/trie/mock/keyBuilderStub.go b/trie/mock/keyBuilderStub.go index 1c68a95384d..7fec902d542 100644 --- a/trie/mock/keyBuilderStub.go +++ b/trie/mock/keyBuilderStub.go @@ -6,6 +6,7 @@ import "github.com/multiversx/mx-chain-go/common" type KeyBuilderStub struct { BuildKeyCalled func(keyPart []byte) GetKeyCalled func() ([]byte, error) + GetRawKeyCalled func() []byte ShallowCloneCalled func() common.KeyBuilder DeepCloneCalled func() common.KeyBuilder SizeCalled func() uint @@ -27,6 +28,15 @@ func (stub *KeyBuilderStub) GetKey() ([]byte, error) { return []byte{}, nil } +// GetRawKey - +func (stub *KeyBuilderStub) GetRawKey() []byte { + if stub.GetRawKeyCalled != nil { + return stub.GetRawKeyCalled() + } + + return []byte{} +} + // ShallowClone - func (stub *KeyBuilderStub) ShallowClone() common.KeyBuilder { if stub.ShallowCloneCalled != nil { From 65d088e7d51fab25f4feee739d57f07659417d9a Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 15 Jan 2025 12:56:44 +0200 Subject: [PATCH 7/8] use maxSize to control the collected leaves total size --- testscommon/state/testTrie.go | 16 ----- .../dfsTrieIterator/dfsTrieIterator.go | 26 ++----- .../dfsTrieIterator/dfsTrieIterator_test.go | 68 +++++++------------ trie/leavesRetriever/export_test.go | 16 ----- trie/leavesRetriever/leavesRetriever.go | 3 - 5 files changed, 30 insertions(+), 99 deletions(-) delete mode 100644 trie/leavesRetriever/export_test.go diff --git a/testscommon/state/testTrie.go b/testscommon/state/testTrie.go index bc33a5e2b6b..8744009aa18 100644 --- a/testscommon/state/testTrie.go +++ b/testscommon/state/testTrie.go @@ -53,19 +53,3 @@ func AddDataToTrie(tr common.Trie, numLeaves int) { } _ = tr.Commit() } - -// GetTrieWithData returns a trie with some data. -// The added data builds a rootNode that is a branch with 2 leaves and 1 extension node which will have 4 leaves when traversed; -// this way the size of the iterator will be highest when the extension node is reached but 2 leaves will -// have already been retrieved -func GetTrieWithData() common.Trie { - tr := GetNewTrie() - _ = tr.Update([]byte("key1"), []byte("value1")) - _ = tr.Update([]byte("key2"), []byte("value2")) - _ = tr.Update([]byte("key13"), []byte("value3")) - _ = tr.Update([]byte("key23"), []byte("value4")) - _ = tr.Update([]byte("key33"), []byte("value4")) - _ = tr.Update([]byte("key43"), []byte("value4")) - _ = tr.Commit() - return tr -} diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go index b7dc34329a4..d79568cf3eb 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go @@ -18,7 +18,6 @@ type dfsIterator struct { db common.TrieStorageInteractor marshaller marshal.Marshalizer hasher hashing.Hasher - size uint64 } // NewIterator creates a new DFS iterator for the trie. @@ -41,17 +40,11 @@ func NewIterator(initialState [][]byte, db common.TrieStorageInteractor, marshal return nil, err } - size := uint64(0) - for _, node := range nextNodes { - size += node.Size() - } - return &dfsIterator{ nextNodes: nextNodes, db: db, marshaller: marshaller, hasher: hasher, - size: size, }, nil } @@ -92,13 +85,14 @@ func getIteratorStateFromNextNodes(nextNodes []common.TrieNodeData) [][]byte { // GetLeaves retrieves leaves from the trie. It stops either when the number of leaves is reached or the context is done. func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) { retrievedLeaves := make(map[string]string) + leavesSize := uint64(0) for { nextNodes := make([]common.TrieNodeData, 0) - if it.size >= maxSize { + if leavesSize >= maxSize { return retrievedLeaves, nil } - if len(retrievedLeaves) >= numLeaves { + if len(retrievedLeaves) >= numLeaves && numLeaves != 0 { return retrievedLeaves, nil } @@ -117,7 +111,6 @@ func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Cont return nil, err } - childrenSize := uint64(0) for _, childNode := range childrenNodes { if childNode.IsLeaf() { key, err := childNode.GetKeyBuilder().GetKey() @@ -125,16 +118,16 @@ func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Cont return nil, err } - retrievedLeaves[hex.EncodeToString(key)] = hex.EncodeToString(childNode.GetData()) + hexKey := hex.EncodeToString(key) + hexData := hex.EncodeToString(childNode.GetData()) + retrievedLeaves[hexKey] = hexData + leavesSize += uint64(len(hexKey) + len(hexData)) continue } nextNodes = append(nextNodes, childNode) - childrenSize += childNode.Size() } - it.size += childrenSize - it.size -= it.nextNodes[0].Size() it.nextNodes = append(nextNodes, it.nextNodes[1:]...) } } @@ -153,11 +146,6 @@ func (it *dfsIterator) FinishedIteration() bool { return len(it.nextNodes) == 0 } -// Size returns the size of the iterator. -func (it *dfsIterator) Size() uint64 { - return it.size -} - // IsInterfaceNil returns true if there is no value under the interface func (it *dfsIterator) IsInterfaceNil() bool { return it == nil diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go index 657dc302bd2..ede367a40ac 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go @@ -68,7 +68,6 @@ func TestNewIterator(t *testing.T) { iterator, err := NewIterator(initialState, db, marshaller, hasher) assert.Nil(t, err) - assert.Equal(t, uint64(80), iterator.size) assert.Equal(t, 2, len(iterator.nextNodes)) }) } @@ -139,20 +138,38 @@ func TestDfsIterator_GetLeaves(t *testing.T) { assert.Nil(t, err) assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) }) + t.Run("num leaves 0 iterates until maxSize reached", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + trieData, err := iterator.GetLeaves(0, 200, context.Background()) + assert.Nil(t, err) + assert.Equal(t, 8, len(trieData)) + assert.Equal(t, 8, len(iterator.nextNodes)) + }) t.Run("max size reached returns retrieved leaves and saves iterator context", func(t *testing.T) { t.Parallel() - tr := trieTest.GetTrieWithData() - expectedNumRetrievedLeaves := 2 + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - iteratorMaxSize := uint64(100) - trieData, err := iterator.GetLeaves(5, iteratorMaxSize, context.Background()) + iteratorMaxSize := uint64(200) + trieData, err := iterator.GetLeaves(numLeaves, iteratorMaxSize, context.Background()) assert.Nil(t, err) - assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) + assert.Equal(t, 8, len(trieData)) + assert.Equal(t, 8, len(iterator.nextNodes)) }) t.Run("retrieve all leaves in multiple calls", func(t *testing.T) { t.Parallel() @@ -249,42 +266,3 @@ func TestDfsIterator_FinishedIteration(t *testing.T) { assert.Equal(t, numLeaves, numRetrievedLeaves) assert.True(t, iterator.FinishedIteration()) } - -func TestDfsIterator_Size(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - numLeaves := 25 - trieTest.AddDataToTrie(tr, numLeaves) - rootHash, _ := tr.RootHash() - _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - - iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - assert.Equal(t, uint64(32), iterator.Size()) // root hash - assert.False(t, iterator.FinishedIteration()) - - _, err := iterator.GetLeaves(5, maxSize, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(299), iterator.Size()) // 9 hashes + leaf key(3) + 8 x intermediary nodes key(8 * 1) - assert.False(t, iterator.FinishedIteration()) - - _, err = iterator.GetLeaves(5, maxSize, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(268), iterator.Size()) // 8 hashes + 2 x leaf keys(2 * 3) + 6 x intermediary nodes key(6*1) - assert.False(t, iterator.FinishedIteration()) - - _, err = iterator.GetLeaves(5, maxSize, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(165), iterator.Size()) // 5 hashes + 5 x intermediary nodes key(5*1) - assert.False(t, iterator.FinishedIteration()) - - _, err = iterator.GetLeaves(5, maxSize, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(101), iterator.Size()) // 3 hashes + leaf key(3) + 2 x intermediary nodes key(2*1) - assert.False(t, iterator.FinishedIteration()) - - _, err = iterator.GetLeaves(5, maxSize, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(0), iterator.Size()) - assert.True(t, iterator.FinishedIteration()) -} diff --git a/trie/leavesRetriever/export_test.go b/trie/leavesRetriever/export_test.go deleted file mode 100644 index 3d9ce098b3e..00000000000 --- a/trie/leavesRetriever/export_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package leavesRetriever - -//// GetIterators - -//func (lr *leavesRetriever) GetIterators() map[string]common.DfsIterator { -// return lr.iterators -//} -// -//// GetLruIteratorIDs - -//func (lr *leavesRetriever) GetLruIteratorIDs() [][]byte { -// return lr.lruIteratorIDs -//} -// -//// Size - -//func (lr *leavesRetriever) Size() uint64 { -// return lr.size -//} diff --git a/trie/leavesRetriever/leavesRetriever.go b/trie/leavesRetriever/leavesRetriever.go index ee0937647c7..ce199d1caa4 100644 --- a/trie/leavesRetriever/leavesRetriever.go +++ b/trie/leavesRetriever/leavesRetriever.go @@ -8,11 +8,8 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/trie/leavesRetriever/dfsTrieIterator" - logger "github.com/multiversx/mx-chain-logger-go" ) -var log = logger.GetOrCreate("trie/leavesRetriever") - type leavesRetriever struct { db common.TrieStorageInteractor marshaller marshal.Marshalizer From 9246e1a94bfb7e70c6d2f190bfac02c706369f37 Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Wed, 15 Jan 2025 15:45:40 +0200 Subject: [PATCH 8/8] use a leafParser to transform trie entries --- common/interface.go | 5 +++-- node/node.go | 11 +++++++++- state/accounts/userAccount.go | 5 +++++ trie/leafNode.go | 7 ++++++- .../dfsTrieIterator/dfsTrieIterator.go | 11 +++++++--- .../dfsTrieIterator/dfsTrieIterator_test.go | 20 ++++++++++--------- .../disabledLeavesRetriever.go | 8 ++++++-- trie/leavesRetriever/leavesRetriever.go | 9 +++++++-- trie/leavesRetriever/leavesRetriever_test.go | 13 +++++++++--- .../trieNodeData/intermediaryNodeData.go | 6 ++++++ .../trieNodeData/leafNodeData.go | 10 +++++++++- .../trieNodeData/leafNodeData_test.go | 7 ++++--- 12 files changed, 85 insertions(+), 27 deletions(-) diff --git a/common/interface.go b/common/interface.go index a72be45b1f4..b9734a357fe 100644 --- a/common/interface.go +++ b/common/interface.go @@ -382,11 +382,12 @@ type TrieNodeData interface { GetData() []byte Size() uint64 IsLeaf() bool + GetVersion() core.TrieNodeVersion } // DfsIterator is used to iterate the trie nodes in a depth-first search manner type DfsIterator interface { - GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) + GetLeaves(numLeaves int, maxSize uint64, leavesParser TrieLeafParser, ctx context.Context) (map[string]string, error) GetIteratorState() [][]byte IsInterfaceNil() bool } @@ -394,6 +395,6 @@ type DfsIterator interface { // TrieLeavesRetriever is used to retrieve the leaves from the trie. If there is a saved checkpoint for the iterator id, // it will continue to iterate from the checkpoint. type TrieLeavesRetriever interface { - GetLeaves(numLeaves int, iteratorState [][]byte, ctx context.Context) (map[string]string, [][]byte, error) + GetLeaves(numLeaves int, iteratorState [][]byte, leavesParser TrieLeafParser, ctx context.Context) (map[string]string, [][]byte, error) IsInterfaceNil() bool } diff --git a/node/node.go b/node/node.go index 73c1cc88da8..dac30c060cd 100644 --- a/node/node.go +++ b/node/node.go @@ -308,6 +308,10 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, return mapToReturn, blockInfo, nil } +type userAccountWithLeavesParser interface { + GetLeavesParser() common.TrieLeafParser +} + // IterateKeys starts from the given iteratorState and returns the next key-value pairs and the new iteratorState func (n *Node) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) { userAccount, blockInfo, err := n.loadUserAccountHandlerByAddress(address, options) @@ -324,11 +328,16 @@ func (n *Node) IterateKeys(address string, numKeys uint, iteratorState [][]byte, return map[string]string{}, nil, blockInfo, nil } + account, ok := userAccount.(userAccountWithLeavesParser) + if !ok { + return nil, nil, api.BlockInfo{}, fmt.Errorf("cannot cast user account to userAccountWithLeavesParser") + } + if len(iteratorState) == 0 { iteratorState = append(iteratorState, userAccount.GetRootHash()) } - mapToReturn, newIteratorState, err := n.stateComponents.TrieLeavesRetriever().GetLeaves(int(numKeys), iteratorState, ctx) + mapToReturn, newIteratorState, err := n.stateComponents.TrieLeavesRetriever().GetLeaves(int(numKeys), iteratorState, account.GetLeavesParser(), ctx) if err != nil { return nil, nil, api.BlockInfo{}, err } diff --git a/state/accounts/userAccount.go b/state/accounts/userAccount.go index d626f024559..4d7d280fdcf 100644 --- a/state/accounts/userAccount.go +++ b/state/accounts/userAccount.go @@ -210,6 +210,11 @@ func (a *userAccount) AccountDataHandler() vmcommon.AccountDataHandler { return a.dataTrieInteractor } +// GetLeavesParser returns the leaves parser +func (a *userAccount) GetLeavesParser() common.TrieLeafParser { + return a.dataTrieLeafParser +} + // IsInterfaceNil returns true if there is no value under the interface func (a *userAccount) IsInterfaceNil() bool { return a == nil diff --git a/trie/leafNode.go b/trie/leafNode.go index 185160f8fc6..5cefe3754ff 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -568,10 +568,15 @@ func (ln *leafNode) getNodeData(keyBuilder common.KeyBuilder) ([]common.TrieNode return nil, fmt.Errorf("getNodeData error %w", err) } + version, err := ln.getVersion() + if err != nil { + return nil, err + } + data := make([]common.TrieNodeData, 1) clonedKeyBuilder := keyBuilder.DeepClone() clonedKeyBuilder.BuildKey(ln.Key) - nodeData, err := trieNodeData.NewLeafNodeData(clonedKeyBuilder, ln.Value) + nodeData, err := trieNodeData.NewLeafNodeData(clonedKeyBuilder, ln.Value, version) if err != nil { return nil, err } diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go index d79568cf3eb..8f2b903d1d8 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go @@ -83,7 +83,7 @@ func getIteratorStateFromNextNodes(nextNodes []common.TrieNodeData) [][]byte { } // GetLeaves retrieves leaves from the trie. It stops either when the number of leaves is reached or the context is done. -func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Context) (map[string]string, error) { +func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, leavesParser common.TrieLeafParser, ctx context.Context) (map[string]string, error) { retrievedLeaves := make(map[string]string) leavesSize := uint64(0) for { @@ -118,8 +118,13 @@ func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, ctx context.Cont return nil, err } - hexKey := hex.EncodeToString(key) - hexData := hex.EncodeToString(childNode.GetData()) + keyValHolder, err := leavesParser.ParseLeaf(key, childNode.GetData(), childNode.GetVersion()) + if err != nil { + return nil, err + } + + hexKey := hex.EncodeToString(keyValHolder.Key()) + hexData := hex.EncodeToString(keyValHolder.Value()) retrievedLeaves[hexKey] = hexData leavesSize += uint64(len(hexKey) + len(hexData)) continue diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go index ede367a40ac..80699761eee 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go @@ -7,6 +7,7 @@ import ( "math" "testing" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" @@ -103,7 +104,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator([][]byte{rootHash}, dbWrapper, marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, maxSize, ctx) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), ctx) assert.Nil(t, err) assert.Equal(t, expectedNumLeaves, len(trieData)) }) @@ -118,7 +119,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, maxSize, context.Background()) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) assert.Equal(t, numLeaves, len(trieData)) }) @@ -134,7 +135,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(17, maxSize, context.Background()) + trieData, err := iterator.GetLeaves(17, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) }) @@ -149,7 +150,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(0, 200, context.Background()) + trieData, err := iterator.GetLeaves(0, 200, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) assert.Equal(t, 8, len(trieData)) assert.Equal(t, 8, len(iterator.nextNodes)) @@ -166,7 +167,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) iteratorMaxSize := uint64(200) - trieData, err := iterator.GetLeaves(numLeaves, iteratorMaxSize, context.Background()) + trieData, err := iterator.GetLeaves(numLeaves, iteratorMaxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) assert.Equal(t, 8, len(trieData)) assert.Equal(t, 8, len(iterator.nextNodes)) @@ -184,7 +185,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { numRetrievedLeaves := 0 numIterations := 0 for numRetrievedLeaves < numLeaves { - trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -206,7 +207,7 @@ func TestDfsIterator_GetLeaves(t *testing.T) { _, marshaller, hasher := trieTest.GetDefaultTrieParameters() iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, maxSize, nil) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), nil) assert.Nil(t, err) assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) }) @@ -222,9 +223,10 @@ func TestDfsIterator_GetIteratorState(t *testing.T) { _ = tr.Commit() rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - leaves, err := iterator.GetLeaves(2, maxSize, context.Background()) + leaves, err := iterator.GetLeaves(2, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) assert.Equal(t, 2, len(leaves)) val, ok := leaves[hex.EncodeToString([]byte("doe"))] @@ -257,7 +259,7 @@ func TestDfsIterator_FinishedIteration(t *testing.T) { numRetrievedLeaves := 0 for numRetrievedLeaves < numLeaves { assert.False(t, iterator.FinishedIteration()) - trieData, err := iterator.GetLeaves(5, maxSize, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) diff --git a/trie/leavesRetriever/disabledLeavesRetriever.go b/trie/leavesRetriever/disabledLeavesRetriever.go index df087d7b47b..8d3d33720ba 100644 --- a/trie/leavesRetriever/disabledLeavesRetriever.go +++ b/trie/leavesRetriever/disabledLeavesRetriever.go @@ -1,6 +1,10 @@ package leavesRetriever -import "context" +import ( + "context" + + "github.com/multiversx/mx-chain-go/common" +) type disabledLeavesRetriever struct{} @@ -10,7 +14,7 @@ func NewDisabledLeavesRetriever() *disabledLeavesRetriever { } // GetLeaves returns an empty map and a nil byte slice for this implementation -func (dlr *disabledLeavesRetriever) GetLeaves(_ int, _ [][]byte, _ context.Context) (map[string]string, [][]byte, error) { +func (dlr *disabledLeavesRetriever) GetLeaves(_ int, _ [][]byte, _ common.TrieLeafParser, _ context.Context) (map[string]string, [][]byte, error) { return make(map[string]string), [][]byte{}, nil } diff --git a/trie/leavesRetriever/leavesRetriever.go b/trie/leavesRetriever/leavesRetriever.go index ce199d1caa4..a2822975f43 100644 --- a/trie/leavesRetriever/leavesRetriever.go +++ b/trie/leavesRetriever/leavesRetriever.go @@ -2,6 +2,7 @@ package leavesRetriever import ( "context" + "fmt" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" @@ -39,13 +40,17 @@ func NewLeavesRetriever(db common.TrieStorageInteractor, marshaller marshal.Mars // GetLeaves retrieves leaves from the trie starting from the iterator state. It will also return the new iterator state // from which one can continue the iteration. -func (lr *leavesRetriever) GetLeaves(numLeaves int, iteratorState [][]byte, ctx context.Context) (map[string]string, [][]byte, error) { +func (lr *leavesRetriever) GetLeaves(numLeaves int, iteratorState [][]byte, leavesParser common.TrieLeafParser, ctx context.Context) (map[string]string, [][]byte, error) { + if check.IfNil(leavesParser) { + return nil, nil, fmt.Errorf("nil leaves parser") + } + iterator, err := dfsTrieIterator.NewIterator(iteratorState, lr.db, lr.marshaller, lr.hasher) if err != nil { return nil, nil, err } - leavesData, err := iterator.GetLeaves(numLeaves, lr.maxSize, ctx) + leavesData, err := iterator.GetLeaves(numLeaves, lr.maxSize, leavesParser, ctx) if err != nil { return nil, nil, err } diff --git a/trie/leavesRetriever/leavesRetriever_test.go b/trie/leavesRetriever/leavesRetriever_test.go index e8dade186b6..8fd376de439 100644 --- a/trie/leavesRetriever/leavesRetriever_test.go +++ b/trie/leavesRetriever/leavesRetriever_test.go @@ -4,12 +4,15 @@ import ( "context" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/keyValStorage" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" trieTest "github.com/multiversx/mx-chain-go/testscommon/state" + trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie/leavesRetriever" "github.com/stretchr/testify/assert" ) @@ -56,15 +59,19 @@ func TestLeavesRetriever_GetLeaves(t *testing.T) { tr := trieTest.GetNewTrie() trieTest.AddDataToTrie(tr, 25) rootHash, _ := tr.RootHash() - + leafParser := &trieMock.TrieLeafParserStub{ + ParseLeafCalled: func(key []byte, val []byte, version core.TrieNodeVersion) (core.KeyValueHolder, error) { + return keyValStorage.NewKeyValStorage(key, val), nil + }, + } lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) - leaves, newIteratorState, err := lr.GetLeaves(10, [][]byte{rootHash}, context.Background()) + leaves, newIteratorState, err := lr.GetLeaves(10, [][]byte{rootHash}, leafParser, context.Background()) assert.Nil(t, err) assert.Equal(t, 10, len(leaves)) assert.Equal(t, 8, len(newIteratorState)) newLr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) - leaves, newIteratorState, err = newLr.GetLeaves(10, newIteratorState, context.Background()) + leaves, newIteratorState, err = newLr.GetLeaves(10, newIteratorState, leafParser, context.Background()) assert.Nil(t, err) assert.Equal(t, 10, len(leaves)) assert.Equal(t, 3, len(newIteratorState)) diff --git a/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go b/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go index bd6d029868f..10a18a856fa 100644 --- a/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go +++ b/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go @@ -1,6 +1,7 @@ package trieNodeData import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" ) @@ -28,6 +29,11 @@ func (ind *intermediaryNodeData) IsLeaf() bool { return false } +// GetVersion returns NotSpecified +func (ind *intermediaryNodeData) GetVersion() core.TrieNodeVersion { + return core.NotSpecified +} + // IsInterfaceNil returns true if there is no value under the interface func (ind *intermediaryNodeData) IsInterfaceNil() bool { return ind == nil diff --git a/trie/leavesRetriever/trieNodeData/leafNodeData.go b/trie/leavesRetriever/trieNodeData/leafNodeData.go index 5f5ef574c70..08a80f2d3f8 100644 --- a/trie/leavesRetriever/trieNodeData/leafNodeData.go +++ b/trie/leavesRetriever/trieNodeData/leafNodeData.go @@ -1,16 +1,18 @@ package trieNodeData import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" ) type leafNodeData struct { *baseNodeData + version core.TrieNodeVersion } // NewLeafNodeData creates a new leaf node data -func NewLeafNodeData(key common.KeyBuilder, data []byte) (*leafNodeData, error) { +func NewLeafNodeData(key common.KeyBuilder, data []byte, version core.TrieNodeVersion) (*leafNodeData, error) { if check.IfNil(key) { return nil, ErrNilKeyBuilder } @@ -20,6 +22,7 @@ func NewLeafNodeData(key common.KeyBuilder, data []byte) (*leafNodeData, error) keyBuilder: key, data: data, }, + version: version, }, nil } @@ -28,6 +31,11 @@ func (lnd *leafNodeData) IsLeaf() bool { return true } +// GetVersion returns the version of the leaf +func (lnd *leafNodeData) GetVersion() core.TrieNodeVersion { + return lnd.version +} + // IsInterfaceNil returns true if there is no value under the interface func (lnd *leafNodeData) IsInterfaceNil() bool { return lnd == nil diff --git a/trie/leavesRetriever/trieNodeData/leafNodeData_test.go b/trie/leavesRetriever/trieNodeData/leafNodeData_test.go index 730ca9836a7..dc4b4ab656b 100644 --- a/trie/leavesRetriever/trieNodeData/leafNodeData_test.go +++ b/trie/leavesRetriever/trieNodeData/leafNodeData_test.go @@ -3,6 +3,7 @@ package trieNodeData import ( "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/stretchr/testify/assert" @@ -14,11 +15,11 @@ func TestNewLeafNodeData(t *testing.T) { var lnd *leafNodeData assert.True(t, check.IfNil(lnd)) - lnd, err := NewLeafNodeData(nil, nil) + lnd, err := NewLeafNodeData(nil, nil, core.NotSpecified) assert.Equal(t, ErrNilKeyBuilder, err) assert.True(t, check.IfNil(lnd)) - lnd, err = NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data")) + lnd, err = NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data"), core.NotSpecified) assert.Nil(t, err) assert.False(t, check.IfNil(lnd)) } @@ -26,6 +27,6 @@ func TestNewLeafNodeData(t *testing.T) { func TestLeafNodeData(t *testing.T) { t.Parallel() - lnd, _ := NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data")) + lnd, _ := NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data"), core.NotSpecified) assert.True(t, lnd.IsLeaf()) }