diff --git a/api/errors/errors.go b/api/errors/errors.go index 3f4e495b9d2..88ebeeec1c2 100644 --- a/api/errors/errors.go +++ b/api/errors/errors.go @@ -28,6 +28,9 @@ var ErrGetValueForKey = errors.New("get value for key error") // ErrGetKeyValuePairs signals an error in getting the key-value pairs of a key for an account var ErrGetKeyValuePairs = errors.New("get key-value pairs error") +// ErrIterateKeys signals an error in iterating over the keys of an account +var ErrIterateKeys = errors.New("iterate keys error") + // ErrGetESDTBalance signals an error in getting esdt balance for given address var ErrGetESDTBalance = errors.New("get esdt balance for account error") @@ -43,6 +46,12 @@ var ErrGetESDTNFTData = errors.New("get esdt nft data for account error") // ErrEmptyAddress signals that an empty address was provided var ErrEmptyAddress = errors.New("address is empty") +// ErrEmptyNumKeys signals that an empty numKeys was provided +var ErrEmptyNumKeys = errors.New("numKeys is empty") + +// ErrEmptyCheckpointId signals that an empty checkpointId was provided +var ErrEmptyCheckpointId = errors.New("checkpointId is empty") + // ErrEmptyKey signals that an empty key was provided var ErrEmptyKey = errors.New("key is empty") diff --git a/api/groups/addressGroup.go b/api/groups/addressGroup.go index 151b7f53372..a9a15957328 100644 --- a/api/groups/addressGroup.go +++ b/api/groups/addressGroup.go @@ -32,6 +32,7 @@ const ( getRegisteredNFTsPath = "/:address/registered-nfts" getESDTNFTDataPath = "/:address/nft/:tokenIdentifier/nonce/:nonce" getGuardianData = "/:address/guardian-data" + iterateKeysPath = "/iterate-keys" urlParamOnFinalBlock = "onFinalBlock" urlParamOnStartOfEpoch = "onStartOfEpoch" urlParamBlockNonce = "blockNonce" @@ -55,6 +56,7 @@ type addressFacadeHandler interface { GetESDTsWithRole(address string, role string, options api.AccountQueryOptions) ([]string, api.BlockInfo, error) GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) IsInterfaceNil() bool @@ -134,6 +136,11 @@ func NewAddressGroup(facade addressFacadeHandler) (*addressGroup, error) { Method: http.MethodGet, Handler: ag.getKeyValuePairs, }, + { + Path: iterateKeysPath, + Method: http.MethodPost, + Handler: ag.iterateKeys, + }, { Path: getESDTBalancePath, Method: http.MethodGet, @@ -327,7 +334,7 @@ func (ag *addressGroup) getGuardianData(c *gin.Context) { shared.RespondWithSuccess(c, gin.H{"guardianData": guardianData, "blockInfo": blockInfo}) } -// addressGroup returns all the key-value pairs for the given address +// getKeyValuePairs returns all the key-value pairs for the given address func (ag *addressGroup) getKeyValuePairs(c *gin.Context) { addr, options, err := extractBaseParams(c) if err != nil { @@ -344,6 +351,47 @@ func (ag *addressGroup) getKeyValuePairs(c *gin.Context) { shared.RespondWithSuccess(c, gin.H{"pairs": value, "blockInfo": blockInfo}) } +// IterateKeysRequest defines the request structure for iterating keys +type IterateKeysRequest struct { + Address string `json:"address"` + NumKeys uint `json:"numKeys"` + IteratorState [][]byte `json:"iteratorState"` +} + +// iterateKeys iterates keys for the given address +func (ag *addressGroup) iterateKeys(c *gin.Context) { + var iterateKeysRequest = &IterateKeysRequest{} + err := c.ShouldBindJSON(&iterateKeysRequest) + if err != nil { + shared.RespondWithValidationError(c, errors.ErrValidation, err) + return + } + + if len(iterateKeysRequest.Address) == 0 { + shared.RespondWithValidationError(c, errors.ErrValidation, errors.ErrEmptyAddress) + return + } + + options, err := extractAccountQueryOptions(c) + if err != nil { + shared.RespondWithValidationError(c, errors.ErrIterateKeys, err) + return + } + + value, newIteratorState, blockInfo, err := ag.getFacade().IterateKeys( + iterateKeysRequest.Address, + iterateKeysRequest.NumKeys, + iterateKeysRequest.IteratorState, + options, + ) + if err != nil { + shared.RespondWithInternalError(c, errors.ErrIterateKeys, err) + return + } + + shared.RespondWithSuccess(c, gin.H{"pairs": value, "newIteratorState": newIteratorState, "blockInfo": blockInfo}) +} + // getESDTBalance returns the balance for the given address and esdt token func (ag *addressGroup) getESDTBalance(c *gin.Context) { addr, tokenIdentifier, options, err := extractGetESDTBalanceParams(c) diff --git a/api/groups/addressGroup_test.go b/api/groups/addressGroup_test.go index bb19bb81d2c..03f4a1c5088 100644 --- a/api/groups/addressGroup_test.go +++ b/api/groups/addressGroup_test.go @@ -125,6 +125,16 @@ type keyValuePairsResponse struct { Code string } +type iterateKeysResponseData struct { + Pairs map[string]string `json:"pairs"` + NewIteratorState [][]byte `json:"newIteratorState"` +} +type iterateKeysResponse struct { + Data iterateKeysResponseData `json:"data"` + Error string `json:"error"` + Code string +} + type esdtRolesResponseData struct { Roles map[string][]string `json:"roles"` } @@ -662,6 +672,106 @@ func TestAddressGroup_getKeyValuePairs(t *testing.T) { }) } +func TestAddressGroup_iterateKeys(t *testing.T) { + t.Parallel() + + t.Run("invalid body should error", + testErrorScenario("/address/iterate-keys", "POST", bytes.NewBuffer([]byte("invalid body")), + formatExpectedErr(apiErrors.ErrValidation, errors.New("invalid character 'i' looking for beginning of value")))) + t.Run("empty address should error", func(t *testing.T) { + t.Parallel() + + body := &groups.IterateKeysRequest{ + Address: "", + } + bodyBytes, _ := json.Marshal(body) + testAddressGroup( + t, + &mock.FacadeStub{}, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), + http.StatusBadRequest, + formatExpectedErr(apiErrors.ErrValidation, apiErrors.ErrEmptyAddress), + ) + }) + t.Run("invalid query options should error", func(t *testing.T) { + t.Parallel() + + body := &groups.IterateKeysRequest{ + Address: "erd1", + } + bodyBytes, _ := json.Marshal(body) + testAddressGroup( + t, + &mock.FacadeStub{}, + "/address/iterate-keys?blockNonce=not-uint64", + "POST", + bytes.NewBuffer(bodyBytes), + http.StatusBadRequest, + formatExpectedErr(apiErrors.ErrIterateKeys, apiErrors.ErrBadUrlParams), + ) + }) + t.Run("with node fail should err", func(t *testing.T) { + t.Parallel() + + body := &groups.IterateKeysRequest{ + Address: "erd1", + } + bodyBytes, _ := json.Marshal(body) + facade := &mock.FacadeStub{ + IterateKeysCalled: func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + return nil, nil, api.BlockInfo{}, expectedErr + }, + } + testAddressGroup( + t, + facade, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), + http.StatusInternalServerError, + formatExpectedErr(apiErrors.ErrIterateKeys, expectedErr), + ) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + pairs := map[string]string{ + "k1": "v1", + "k2": "v2", + } + + body := &groups.IterateKeysRequest{ + Address: "erd1", + NumKeys: 10, + IteratorState: [][]byte{[]byte("starting"), []byte("state")}, + } + newIteratorState := [][]byte{[]byte("new"), []byte("state")} + bodyBytes, _ := json.Marshal(body) + facade := &mock.FacadeStub{ + IterateKeysCalled: func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + assert.Equal(t, body.Address, address) + assert.Equal(t, body.NumKeys, numKeys) + assert.Equal(t, body.IteratorState, iteratorState) + return pairs, newIteratorState, api.BlockInfo{}, nil + }, + } + + response := &iterateKeysResponse{} + loadAddressGroupResponse( + t, + facade, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), + response, + ) + assert.Equal(t, pairs, response.Data.Pairs) + assert.Equal(t, newIteratorState, response.Data.NewIteratorState) + }) +} + func TestAddressGroup_getESDTBalance(t *testing.T) { t.Parallel() @@ -1143,6 +1253,7 @@ func getAddressRoutesConfig() config.ApiRoutesConfig { {Name: "/:address/username", Open: true}, {Name: "/:address/code-hash", Open: true}, {Name: "/:address/keys", Open: true}, + {Name: "/iterate-keys", Open: true}, {Name: "/:address/key/:key", Open: true}, {Name: "/:address/esdt", Open: true}, {Name: "/:address/esdts/roles", Open: true}, diff --git a/api/mock/facadeStub.go b/api/mock/facadeStub.go index 62de2febc81..94bc0551c76 100644 --- a/api/mock/facadeStub.go +++ b/api/mock/facadeStub.go @@ -49,6 +49,7 @@ type FacadeStub struct { GetUsernameCalled func(address string, options api.AccountQueryOptions) (string, api.BlockInfo, error) GetCodeHashCalled func(address string, options api.AccountQueryOptions) ([]byte, api.BlockInfo, error) GetKeyValuePairsCalled func(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + IterateKeysCalled func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) SimulateTransactionExecutionHandler func(tx *transaction.Transaction) (*txSimData.SimulationResultsWithVMOutput, error) GetESDTDataCalled func(address string, key string, nonce uint64, options api.AccountQueryOptions) (*esdt.ESDigitalToken, api.BlockInfo, error) GetAllESDTTokensCalled func(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) @@ -241,6 +242,15 @@ func (f *FacadeStub) GetKeyValuePairs(address string, options api.AccountQueryOp return nil, api.BlockInfo{}, nil } +// IterateKeys - +func (f *FacadeStub) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + if f.IterateKeysCalled != nil { + return f.IterateKeysCalled(address, numKeys, iteratorState, options) + } + + return nil, nil, api.BlockInfo{}, nil +} + // GetGuardianData - func (f *FacadeStub) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) { if f.GetGuardianDataCalled != nil { diff --git a/api/shared/interface.go b/api/shared/interface.go index 206cea6ee30..adedd6642af 100644 --- a/api/shared/interface.go +++ b/api/shared/interface.go @@ -74,6 +74,7 @@ type FacadeHandler interface { GetESDTsWithRole(address string, role string, options api.AccountQueryOptions) ([]string, api.BlockInfo, error) GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) GetBlockByHash(hash string, options api.BlockQueryOptions) (*api.Block, error) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*api.Block, error) diff --git a/cmd/node/config/api.toml b/cmd/node/config/api.toml index fcf9cf7fc0b..378b4157e47 100644 --- a/cmd/node/config/api.toml +++ b/cmd/node/config/api.toml @@ -79,6 +79,9 @@ # /address/:address/keys will return all the key-value pairs of a given account { Name = "/:address/keys", Open = true }, + # address//iterate-keys will return the given num of key-value pairs for the given account. The iteration will start from the given starting state + { Name = "/iterate-keys", Open = true }, + # /address/:address/key/:key will return the value of a key for a given account { Name = "/:address/key/:key", Open = true }, diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 6e1205d5f7e..7d0ffeb57fe 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -663,6 +663,10 @@ MaxPeerTrieLevelInMemory = 5 StateStatisticsEnabled = false +[TrieLeavesRetrieverConfig] + Enabled = false + MaxSizeInBytes = 10485760 #10MB + [BlockSizeThrottleConfig] MinSizeInBytes = 104857 # 104857 is 10% from 1MB MaxSizeInBytes = 943718 # 943718 is 90% from 1MB diff --git a/common/interface.go b/common/interface.go index 696d4b0182c..b9734a357fe 100644 --- a/common/interface.go +++ b/common/interface.go @@ -80,6 +80,7 @@ type StorageMarker interface { type KeyBuilder interface { BuildKey(keyPart []byte) GetKey() ([]byte, error) + GetRawKey() []byte DeepClone() KeyBuilder ShallowClone() KeyBuilder Size() uint @@ -381,21 +382,19 @@ type TrieNodeData interface { GetData() []byte Size() uint64 IsLeaf() bool + GetVersion() core.TrieNodeVersion } // DfsIterator is used to iterate the trie nodes in a depth-first search manner type DfsIterator interface { - GetLeaves(numLeaves int, ctx context.Context) (map[string]string, error) - GetIteratorId() []byte - Clone() DfsIterator - FinishedIteration() bool - Size() uint64 + GetLeaves(numLeaves int, maxSize uint64, leavesParser TrieLeafParser, ctx context.Context) (map[string]string, error) + GetIteratorState() [][]byte IsInterfaceNil() bool } // TrieLeavesRetriever is used to retrieve the leaves from the trie. If there is a saved checkpoint for the iterator id, // it will continue to iterate from the checkpoint. type TrieLeavesRetriever interface { - GetLeaves(numLeaves int, rootHash []byte, iteratorID []byte, ctx context.Context) (map[string]string, []byte, error) + GetLeaves(numLeaves int, iteratorState [][]byte, leavesParser TrieLeafParser, ctx context.Context) (map[string]string, [][]byte, error) IsInterfaceNil() bool } diff --git a/config/config.go b/config/config.go index 49ef257c341..a9956dec5b2 100644 --- a/config/config.go +++ b/config/config.go @@ -161,12 +161,13 @@ type Config struct { BootstrapStorage StorageConfig MetaBlockStorage StorageConfig - AccountsTrieStorage StorageConfig - PeerAccountsTrieStorage StorageConfig - EvictionWaitingList EvictionWaitingListConfig - StateTriesConfig StateTriesConfig - TrieStorageManagerConfig TrieStorageManagerConfig - BadBlocksCache CacheConfig + AccountsTrieStorage StorageConfig + PeerAccountsTrieStorage StorageConfig + EvictionWaitingList EvictionWaitingListConfig + StateTriesConfig StateTriesConfig + TrieStorageManagerConfig TrieStorageManagerConfig + TrieLeavesRetrieverConfig TrieLeavesRetrieverConfig + BadBlocksCache CacheConfig TxBlockBodyDataPool CacheConfig PeerBlockBodyDataPool CacheConfig @@ -640,3 +641,9 @@ type PoolsCleanersConfig struct { type RedundancyConfig struct { MaxRoundsOfInactivityAccepted int } + +// TrieLeavesRetrieverConfig represents the config options to be used when setting up the trie leaves retriever +type TrieLeavesRetrieverConfig struct { + Enabled bool + MaxSizeInBytes uint64 +} diff --git a/errors/errors.go b/errors/errors.go index dd475327876..8071fffc219 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -598,3 +598,6 @@ var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") // ErrNilEpochSystemSCProcessor defines the error for setting a nil EpochSystemSCProcessor var ErrNilEpochSystemSCProcessor = errors.New("nil epoch system SC processor") + +// ErrNilTrieLeavesRetriever defines the error for setting a nil TrieLeavesRetriever +var ErrNilTrieLeavesRetriever = errors.New("nil trie leaves retriever") diff --git a/facade/initial/initialNodeFacade.go b/facade/initial/initialNodeFacade.go index d6043dbcd62..ea9268d0bde 100644 --- a/facade/initial/initialNodeFacade.go +++ b/facade/initial/initialNodeFacade.go @@ -346,6 +346,11 @@ func (inf *initialNodeFacade) GetKeyValuePairs(_ string, _ api.AccountQueryOptio return nil, api.BlockInfo{}, errNodeStarting } +// IterateKeys returns error +func (inf *initialNodeFacade) IterateKeys(_ string, _ uint, _ [][]byte, _ api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + return nil, nil, api.BlockInfo{}, errNodeStarting +} + // GetGuardianData returns error func (inf *initialNodeFacade) GetGuardianData(_ string, _ api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) { return api.GuardianData{}, api.BlockInfo{}, errNodeStarting diff --git a/facade/interface.go b/facade/interface.go index 309f6c98d6f..2dfa8b503bd 100644 --- a/facade/interface.go +++ b/facade/interface.go @@ -41,6 +41,9 @@ type NodeHandler interface { // GetKeyValuePairs returns the key-value pairs under a given address GetKeyValuePairs(address string, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, error) + // IterateKeys returns the key-value pairs under a given address starting from a given state + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) + // GetAllIssuedESDTs returns all the issued esdt tokens from esdt system smart contract GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]string, error) diff --git a/facade/mock/nodeStub.go b/facade/mock/nodeStub.go index 1e779e0ebce..e7b2817a32e 100644 --- a/facade/mock/nodeStub.go +++ b/facade/mock/nodeStub.go @@ -49,6 +49,7 @@ type NodeStub struct { GetESDTsWithRoleCalled func(address string, role string, options api.AccountQueryOptions, ctx context.Context) ([]string, api.BlockInfo, error) GetESDTsRolesCalled func(address string, options api.AccountQueryOptions, ctx context.Context) (map[string][]string, api.BlockInfo, error) GetKeyValuePairsCalled func(address string, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, error) + IterateKeysCalled func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) GetAllIssuedESDTsCalled func(tokenType string, ctx context.Context) ([]string, error) GetProofCalled func(rootHash string, key string) (*common.GetProofResponse, error) GetProofDataTrieCalled func(rootHash string, address string, key string) (*common.GetProofResponse, *common.GetProofResponse, error) @@ -112,6 +113,15 @@ func (ns *NodeStub) GetKeyValuePairs(address string, options api.AccountQueryOpt return nil, api.BlockInfo{}, nil } +// IterateKeys - +func (ns *NodeStub) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) { + if ns.IterateKeysCalled != nil { + return ns.IterateKeysCalled(address, numKeys, iteratorState, options, ctx) + } + + return nil, nil, api.BlockInfo{}, nil +} + // GetValueForKey - func (ns *NodeStub) GetValueForKey(address string, key string, options api.AccountQueryOptions) (string, api.BlockInfo, error) { if ns.GetValueForKeyCalled != nil { diff --git a/facade/nodeFacade.go b/facade/nodeFacade.go index c3a7f290edf..e516b506b52 100644 --- a/facade/nodeFacade.go +++ b/facade/nodeFacade.go @@ -229,6 +229,14 @@ func (nf *nodeFacade) GetKeyValuePairs(address string, options apiData.AccountQu return nf.node.GetKeyValuePairs(address, options, ctx) } +// IterateKeys starts from the given iteratorState and returns the next key-value pairs and the new iteratorState +func (nf *nodeFacade) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options apiData.AccountQueryOptions) (map[string]string, [][]byte, apiData.BlockInfo, error) { + ctx, cancel := nf.getContextForApiTrieRangeOperations() + defer cancel() + + return nf.node.IterateKeys(address, numKeys, iteratorState, options, ctx) +} + // GetGuardianData returns the guardian data for the provided address func (nf *nodeFacade) GetGuardianData(address string, options apiData.AccountQueryOptions) (apiData.GuardianData, apiData.BlockInfo, error) { return nf.node.GetGuardianData(address, options) diff --git a/factory/interface.go b/factory/interface.go index 0f1c237d0d9..a452d640320 100644 --- a/factory/interface.go +++ b/factory/interface.go @@ -335,6 +335,7 @@ type StateComponentsHolder interface { TriesContainer() common.TriesHolder TrieStorageManagers() map[string]common.StorageManager MissingTrieNodesNotifier() common.MissingTrieNodesNotifier + TrieLeavesRetriever() common.TrieLeavesRetriever Close() error IsInterfaceNil() bool } diff --git a/factory/mock/stateComponentsHolderStub.go b/factory/mock/stateComponentsHolderStub.go index c851fdc6dac..e6b6b6b86cf 100644 --- a/factory/mock/stateComponentsHolderStub.go +++ b/factory/mock/stateComponentsHolderStub.go @@ -14,6 +14,7 @@ type StateComponentsHolderStub struct { TriesContainerCalled func() common.TriesHolder TrieStorageManagersCalled func() map[string]common.StorageManager MissingTrieNodesNotifierCalled func() common.MissingTrieNodesNotifier + TrieLeavesRetrieverCalled func() common.TrieLeavesRetriever } // PeerAccounts - @@ -79,6 +80,14 @@ func (s *StateComponentsHolderStub) MissingTrieNodesNotifier() common.MissingTri return nil } +// TrieLeavesRetriever - +func (s *StateComponentsHolderStub) TrieLeavesRetriever() common.TrieLeavesRetriever { + if s.TrieLeavesRetrieverCalled != nil { + return s.TrieLeavesRetrieverCalled() + } + return nil +} + // Close - func (s *StateComponentsHolderStub) Close() error { return nil diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index 8da3251e230..e09aae7b1c9 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/state/syncer" trieFactory "github.com/multiversx/mx-chain-go/trie/factory" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever" ) // TODO: merge this with data components @@ -53,6 +54,7 @@ type stateComponents struct { triesContainer common.TriesHolder trieStorageManagers map[string]common.StorageManager missingTrieNodesNotifier common.MissingTrieNodesNotifier + trieLeavesRetriever common.TrieLeavesRetriever } // NewStateComponentsFactory will return a new instance of stateComponentsFactory @@ -100,6 +102,11 @@ func (scf *stateComponentsFactory) Create() (*stateComponents, error) { return nil, err } + trieLeavesRetriever, err := scf.createTrieLeavesRetriever(trieStorageManagers[dataRetriever.UserAccountsUnit.String()]) + if err != nil { + return nil, err + } + return &stateComponents{ peerAccounts: peerAdapter, accountsAdapter: accountsAdapter, @@ -108,9 +115,23 @@ func (scf *stateComponentsFactory) Create() (*stateComponents, error) { triesContainer: triesContainer, trieStorageManagers: trieStorageManagers, missingTrieNodesNotifier: syncer.NewMissingTrieNodesNotifier(), + trieLeavesRetriever: trieLeavesRetriever, }, nil } +func (scf *stateComponentsFactory) createTrieLeavesRetriever(trieStorage common.TrieStorageInteractor) (common.TrieLeavesRetriever, error) { + if !scf.config.TrieLeavesRetrieverConfig.Enabled { + return leavesRetriever.NewDisabledLeavesRetriever(), nil + } + + return leavesRetriever.NewLeavesRetriever( + trieStorage, + scf.core.InternalMarshalizer(), + scf.core.Hasher(), + scf.config.TrieLeavesRetrieverConfig.MaxSizeInBytes, + ) +} + func (scf *stateComponentsFactory) createSnapshotManager( accountFactory state.AccountFactory, stateMetrics state.StateMetrics, diff --git a/factory/state/stateComponentsHandler.go b/factory/state/stateComponentsHandler.go index 78271a28ffe..e84c1f8b3b5 100644 --- a/factory/state/stateComponentsHandler.go +++ b/factory/state/stateComponentsHandler.go @@ -93,6 +93,9 @@ func (msc *managedStateComponents) CheckSubcomponents() error { if check.IfNil(msc.missingTrieNodesNotifier) { return errors.ErrNilMissingTrieNodesNotifier } + if check.IfNil(msc.trieLeavesRetriever) { + return errors.ErrNilTrieLeavesRetriever + } return nil } @@ -214,6 +217,18 @@ func (msc *managedStateComponents) MissingTrieNodesNotifier() common.MissingTrie return msc.stateComponents.missingTrieNodesNotifier } +// TrieLeavesRetriever returns the trie leaves retriever +func (msc *managedStateComponents) TrieLeavesRetriever() common.TrieLeavesRetriever { + msc.mutStateComponents.RLock() + defer msc.mutStateComponents.RUnlock() + + if msc.stateComponents == nil { + return nil + } + + return msc.stateComponents.trieLeavesRetriever +} + // IsInterfaceNil returns true if the interface is nil func (msc *managedStateComponents) IsInterfaceNil() bool { return msc == nil diff --git a/integrationTests/interface.go b/integrationTests/interface.go index ad90ffbb6a3..2b78eec1f0f 100644 --- a/integrationTests/interface.go +++ b/integrationTests/interface.go @@ -69,6 +69,7 @@ type Facade interface { GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetESDTsRoles(address string, options api.AccountQueryOptions) (map[string][]string, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) GetBlockByHash(hash string, options api.BlockQueryOptions) (*dataApi.Block, error) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*dataApi.Block, error) diff --git a/node/chainSimulator/components/stateComponents.go b/node/chainSimulator/components/stateComponents.go index b3fddf55f40..998263a8d7a 100644 --- a/node/chainSimulator/components/stateComponents.go +++ b/node/chainSimulator/components/stateComponents.go @@ -29,6 +29,7 @@ type stateComponentsHolder struct { triesContainer common.TriesHolder triesStorageManager map[string]common.StorageManager missingTrieNodesNotifier common.MissingTrieNodesNotifier + trieLeavesRetriever common.TrieLeavesRetriever stateComponentsCloser io.Closer } @@ -70,6 +71,7 @@ func CreateStateComponents(args ArgsStateComponents) (*stateComponentsHolder, er triesContainer: stateComp.TriesContainer(), triesStorageManager: stateComp.TrieStorageManagers(), missingTrieNodesNotifier: stateComp.MissingTrieNodesNotifier(), + trieLeavesRetriever: stateComp.TrieLeavesRetriever(), stateComponentsCloser: stateComp, }, nil } @@ -109,6 +111,11 @@ func (s *stateComponentsHolder) MissingTrieNodesNotifier() common.MissingTrieNod return s.missingTrieNodesNotifier } +// TrieLeavesRetriever will return the trie leaves retriever +func (s *stateComponentsHolder) TrieLeavesRetriever() common.TrieLeavesRetriever { + return s.trieLeavesRetriever +} + // Close will close the state components func (s *stateComponentsHolder) Close() error { return s.stateComponentsCloser.Close() diff --git a/node/node.go b/node/node.go index a652e80be60..dac30c060cd 100644 --- a/node/node.go +++ b/node/node.go @@ -308,6 +308,43 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, return mapToReturn, blockInfo, nil } +type userAccountWithLeavesParser interface { + GetLeavesParser() common.TrieLeafParser +} + +// IterateKeys starts from the given iteratorState and returns the next key-value pairs and the new iteratorState +func (n *Node) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) { + userAccount, blockInfo, err := n.loadUserAccountHandlerByAddress(address, options) + if err != nil { + adaptedBlockInfo, isEmptyAccount := extractBlockInfoIfNewAccount(err) + if isEmptyAccount { + return make(map[string]string), nil, adaptedBlockInfo, nil + } + + return nil, nil, api.BlockInfo{}, err + } + + if check.IfNil(userAccount.DataTrie()) { + return map[string]string{}, nil, blockInfo, nil + } + + account, ok := userAccount.(userAccountWithLeavesParser) + if !ok { + return nil, nil, api.BlockInfo{}, fmt.Errorf("cannot cast user account to userAccountWithLeavesParser") + } + + if len(iteratorState) == 0 { + iteratorState = append(iteratorState, userAccount.GetRootHash()) + } + + mapToReturn, newIteratorState, err := n.stateComponents.TrieLeavesRetriever().GetLeaves(int(numKeys), iteratorState, account.GetLeavesParser(), ctx) + if err != nil { + return nil, nil, api.BlockInfo{}, err + } + + return mapToReturn, newIteratorState, blockInfo, nil +} + func (n *Node) getKeys(userAccount state.UserAccountHandler, ctx context.Context) (map[string]string, error) { chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), diff --git a/state/accounts/userAccount.go b/state/accounts/userAccount.go index d626f024559..4d7d280fdcf 100644 --- a/state/accounts/userAccount.go +++ b/state/accounts/userAccount.go @@ -210,6 +210,11 @@ func (a *userAccount) AccountDataHandler() vmcommon.AccountDataHandler { return a.dataTrieInteractor } +// GetLeavesParser returns the leaves parser +func (a *userAccount) GetLeavesParser() common.TrieLeafParser { + return a.dataTrieLeafParser +} + // IsInterfaceNil returns true if there is no value under the interface func (a *userAccount) IsInterfaceNil() bool { return a == nil diff --git a/testscommon/factory/stateComponentsMock.go b/testscommon/factory/stateComponentsMock.go index 5aa541dffa0..0adb3f3bc10 100644 --- a/testscommon/factory/stateComponentsMock.go +++ b/testscommon/factory/stateComponentsMock.go @@ -16,6 +16,7 @@ type StateComponentsMock struct { Tries common.TriesHolder StorageManagers map[string]common.StorageManager MissingNodesNotifier common.MissingTrieNodesNotifier + LeavesRetriever common.TrieLeavesRetriever } // NewStateComponentsMockFromRealComponent - @@ -28,6 +29,7 @@ func NewStateComponentsMockFromRealComponent(stateComponents factory.StateCompon Tries: stateComponents.TriesContainer(), StorageManagers: stateComponents.TrieStorageManagers(), MissingNodesNotifier: stateComponents.MissingTrieNodesNotifier(), + LeavesRetriever: stateComponents.TrieLeavesRetriever(), } } @@ -89,6 +91,11 @@ func (scm *StateComponentsMock) MissingTrieNodesNotifier() common.MissingTrieNod return scm.MissingNodesNotifier } +// TrieLeavesRetriever - +func (scm *StateComponentsMock) TrieLeavesRetriever() common.TrieLeavesRetriever { + return scm.LeavesRetriever +} + // IsInterfaceNil - func (scm *StateComponentsMock) IsInterfaceNil() bool { return scm == nil diff --git a/trie/errors.go b/trie/errors.go index 9cc2588e501..a879fd6c94c 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -123,3 +123,9 @@ var ErrNilTrieLeafParser = errors.New("nil trie leaf parser") // ErrInvalidNodeVersion signals that an invalid node version has been provided var ErrInvalidNodeVersion = errors.New("invalid node version provided") + +// ErrEmptyInitialIteratorState signals that an empty initial iterator state was provided +var ErrEmptyInitialIteratorState = errors.New("empty initial iterator state") + +// ErrInvalidIteratorState signals that an invalid iterator state was provided +var ErrInvalidIteratorState = errors.New("invalid iterator state") diff --git a/trie/keyBuilder/disabledKeyBuilder.go b/trie/keyBuilder/disabledKeyBuilder.go index 78a0350aa26..71c2022d372 100644 --- a/trie/keyBuilder/disabledKeyBuilder.go +++ b/trie/keyBuilder/disabledKeyBuilder.go @@ -22,6 +22,11 @@ func (dkb *disabledKeyBuilder) GetKey() ([]byte, error) { return []byte{}, nil } +// GetRawKey returns an empty byte array for this implementation +func (dkb *disabledKeyBuilder) GetRawKey() []byte { + return []byte{} +} + // ShallowClone returns a new disabled key builder func (dkb *disabledKeyBuilder) ShallowClone() common.KeyBuilder { return &disabledKeyBuilder{} diff --git a/trie/keyBuilder/keyBuilder.go b/trie/keyBuilder/keyBuilder.go index a455fd5c9e6..c1b7f78f62a 100644 --- a/trie/keyBuilder/keyBuilder.go +++ b/trie/keyBuilder/keyBuilder.go @@ -34,6 +34,11 @@ func (kb *keyBuilder) GetKey() ([]byte, error) { return hexToTrieKeyBytes(kb.key) } +// GetRawKey returns the key as it is, without transforming it +func (kb *keyBuilder) GetRawKey() []byte { + return kb.key +} + // ShallowClone returns a new KeyBuilder with the same key. The key slice points to the same memory location. func (kb *keyBuilder) ShallowClone() common.KeyBuilder { return &keyBuilder{ diff --git a/trie/leafNode.go b/trie/leafNode.go index 185160f8fc6..5cefe3754ff 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -568,10 +568,15 @@ func (ln *leafNode) getNodeData(keyBuilder common.KeyBuilder) ([]common.TrieNode return nil, fmt.Errorf("getNodeData error %w", err) } + version, err := ln.getVersion() + if err != nil { + return nil, err + } + data := make([]common.TrieNodeData, 1) clonedKeyBuilder := keyBuilder.DeepClone() clonedKeyBuilder.BuildKey(ln.Key) - nodeData, err := trieNodeData.NewLeafNodeData(clonedKeyBuilder, ln.Value) + nodeData, err := trieNodeData.NewLeafNodeData(clonedKeyBuilder, ln.Value, version) if err != nil { return nil, err } diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go index 5b47e2c1dd2..8f2b903d1d8 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go @@ -2,6 +2,7 @@ package dfsTrieIterator import ( "context" + "encoding/hex" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" @@ -9,19 +10,18 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/keyBuilder" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever/trieNodeData" ) type dfsIterator struct { nextNodes []common.TrieNodeData - rootHash []byte db common.TrieStorageInteractor marshaller marshal.Marshalizer hasher hashing.Hasher - size uint64 } // NewIterator creates a new DFS iterator for the trie. -func NewIterator(rootHash []byte, db common.TrieStorageInteractor, marshaller marshal.Marshalizer, hasher hashing.Hasher) (*dfsIterator, error) { +func NewIterator(initialState [][]byte, db common.TrieStorageInteractor, marshaller marshal.Marshalizer, hasher hashing.Hasher) (*dfsIterator, error) { if check.IfNil(db) { return nil, trie.ErrNilDatabase } @@ -31,34 +31,68 @@ func NewIterator(rootHash []byte, db common.TrieStorageInteractor, marshaller ma if check.IfNil(hasher) { return nil, trie.ErrNilHasher } + if len(initialState) == 0 { + return nil, trie.ErrEmptyInitialIteratorState + } - data, err := trie.GetNodeDataFromHash(rootHash, keyBuilder.NewKeyBuilder(), db, marshaller, hasher) + nextNodes, err := getNextNodesFromInitialState(initialState, uint(hasher.Size())) if err != nil { return nil, err } - size := uint64(0) - for _, node := range data { - size += node.Size() - } - return &dfsIterator{ - nextNodes: data, - rootHash: rootHash, + nextNodes: nextNodes, db: db, marshaller: marshaller, hasher: hasher, - size: size, }, nil } +func getNextNodesFromInitialState(initialState [][]byte, hashSize uint) ([]common.TrieNodeData, error) { + nextNodes := make([]common.TrieNodeData, len(initialState)) + for i, state := range initialState { + if len(state) < int(hashSize) { + return nil, trie.ErrInvalidIteratorState + } + + nodeHash := state[:hashSize] + key := state[hashSize:] + + kb := keyBuilder.NewKeyBuilder() + kb.BuildKey(key) + nodeData, err := trieNodeData.NewIntermediaryNodeData(kb, nodeHash) + if err != nil { + return nil, err + } + nextNodes[i] = nodeData + } + + return nextNodes, nil +} + +func getIteratorStateFromNextNodes(nextNodes []common.TrieNodeData) [][]byte { + iteratorState := make([][]byte, len(nextNodes)) + for i, node := range nextNodes { + nodeHash := node.GetData() + key := node.GetKeyBuilder().GetRawKey() + + iteratorState[i] = append(nodeHash, key...) + } + + return iteratorState +} + // GetLeaves retrieves leaves from the trie. It stops either when the number of leaves is reached or the context is done. -// TODO add a maxSize that will stop the iteration when the size is reached -func (it *dfsIterator) GetLeaves(numLeaves int, ctx context.Context) (map[string]string, error) { +func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, leavesParser common.TrieLeafParser, ctx context.Context) (map[string]string, error) { retrievedLeaves := make(map[string]string) + leavesSize := uint64(0) for { nextNodes := make([]common.TrieNodeData, 0) - if len(retrievedLeaves) >= numLeaves { + if leavesSize >= maxSize { + return retrievedLeaves, nil + } + + if len(retrievedLeaves) >= numLeaves && numLeaves != 0 { return retrievedLeaves, nil } @@ -77,7 +111,6 @@ func (it *dfsIterator) GetLeaves(numLeaves int, ctx context.Context) (map[string return nil, err } - childrenSize := uint64(0) for _, childNode := range childrenNodes { if childNode.IsLeaf() { key, err := childNode.GetKeyBuilder().GetKey() @@ -85,44 +118,32 @@ func (it *dfsIterator) GetLeaves(numLeaves int, ctx context.Context) (map[string return nil, err } - retrievedLeaves[string(key)] = string(childNode.GetData()) + keyValHolder, err := leavesParser.ParseLeaf(key, childNode.GetData(), childNode.GetVersion()) + if err != nil { + return nil, err + } + + hexKey := hex.EncodeToString(keyValHolder.Key()) + hexData := hex.EncodeToString(keyValHolder.Value()) + retrievedLeaves[hexKey] = hexData + leavesSize += uint64(len(hexKey) + len(hexData)) continue } nextNodes = append(nextNodes, childNode) - childrenSize += childNode.Size() } - it.size += childrenSize - it.size -= it.nextNodes[0].Size() it.nextNodes = append(nextNodes, it.nextNodes[1:]...) } } -// GetIteratorId returns the ID of the iterator. -func (it *dfsIterator) GetIteratorId() []byte { - if len(it.nextNodes) == 0 { +// GetIteratorState returns the state of the iterator from which it can be resumed by another call. +func (it *dfsIterator) GetIteratorState() [][]byte { + if it.FinishedIteration() { return nil } - nextNodeHash := it.nextNodes[0].GetData() - iteratorID := it.hasher.Compute(string(append(it.rootHash, nextNodeHash...))) - return iteratorID -} - -// Clone creates a copy of the iterator. -func (it *dfsIterator) Clone() common.DfsIterator { - nextNodes := make([]common.TrieNodeData, len(it.nextNodes)) - copy(nextNodes, it.nextNodes) - - return &dfsIterator{ - nextNodes: nextNodes, - rootHash: it.rootHash, - db: it.db, - marshaller: it.marshaller, - hasher: it.hasher, - size: it.size, - } + return getIteratorStateFromNextNodes(it.nextNodes) } // FinishedIteration checks if the iterator has finished the iteration. @@ -130,17 +151,11 @@ func (it *dfsIterator) FinishedIteration() bool { return len(it.nextNodes) == 0 } -// Size returns the size of the iterator. -func (it *dfsIterator) Size() uint64 { - return it.size + uint64(len(it.rootHash)) -} - // IsInterfaceNil returns true if there is no value under the interface func (it *dfsIterator) IsInterfaceNil() bool { return it == nil } -// TODO add context nil test func checkContextDone(ctx context.Context) bool { if ctx == nil { return false diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go index 4489a43a437..80699761eee 100644 --- a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go @@ -1,10 +1,13 @@ package dfsTrieIterator import ( + "bytes" "context" - "fmt" + "encoding/hex" + "math" "testing" + "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" @@ -14,52 +17,59 @@ import ( "github.com/stretchr/testify/assert" ) +var maxSize = uint64(math.MaxUint64) + func TestNewIterator(t *testing.T) { t.Parallel() t.Run("nil db", func(t *testing.T) { t.Parallel() - iterator, err := NewIterator([]byte("hash"), nil, &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, nil, &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) assert.Nil(t, iterator) assert.Equal(t, trie.ErrNilDatabase, err) }) t.Run("nil marshaller", func(t *testing.T) { t.Parallel() - iterator, err := NewIterator([]byte("hash"), testscommon.NewMemDbMock(), nil, &hashingMocks.HasherMock{}) + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, testscommon.NewMemDbMock(), nil, &hashingMocks.HasherMock{}) assert.Nil(t, iterator) assert.Equal(t, trie.ErrNilMarshalizer, err) }) t.Run("nil hasher", func(t *testing.T) { t.Parallel() - iterator, err := NewIterator([]byte("hash"), testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, nil) + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, nil) assert.Nil(t, iterator) assert.Equal(t, trie.ErrNilHasher, err) }) - t.Run("invalid hash", func(t *testing.T) { + t.Run("empty initial state", func(t *testing.T) { t.Parallel() - iterator, err := NewIterator([]byte("invalid hash"), testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + iterator, err := NewIterator([][]byte{}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) assert.Nil(t, iterator) - assert.NotNil(t, err) + assert.Equal(t, trie.ErrEmptyInitialIteratorState, err) }) - t.Run("initialize iterator with a valid hash", func(t *testing.T) { + t.Run("invalid initial state", func(t *testing.T) { t.Parallel() - tr := trieTest.GetNewTrie() - _ = tr.Update([]byte("key1"), []byte("value1")) - _ = tr.Commit() - rootHash, _ := tr.RootHash() + iterator, err := NewIterator([][]byte{[]byte("invalid state")}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + assert.Nil(t, iterator) + assert.Equal(t, trie.ErrInvalidIteratorState, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() - _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, err := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + initialState := [][]byte{ + bytes.Repeat([]byte{0}, 40), + bytes.Repeat([]byte{1}, 40), + } + + db, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, err := NewIterator(initialState, db, marshaller, hasher) assert.Nil(t, err) - assert.Equal(t, rootHash, iterator.rootHash) - assert.Equal(t, uint64(15), iterator.size) - assert.Equal(t, 1, len(iterator.nextNodes)) + assert.Equal(t, 2, len(iterator.nextNodes)) }) } @@ -92,9 +102,9 @@ func TestDfsIterator_GetLeaves(t *testing.T) { }, } _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, dbWrapper, marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, dbWrapper, marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, ctx) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), ctx) assert.Nil(t, err) assert.Equal(t, expectedNumLeaves, len(trieData)) }) @@ -107,9 +117,9 @@ func TestDfsIterator_GetLeaves(t *testing.T) { rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(numLeaves, context.Background()) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) assert.Equal(t, numLeaves, len(trieData)) }) @@ -123,12 +133,45 @@ func TestDfsIterator_GetLeaves(t *testing.T) { rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(17, context.Background()) + trieData, err := iterator.GetLeaves(17, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) }) + t.Run("num leaves 0 iterates until maxSize reached", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + trieData, err := iterator.GetLeaves(0, 200, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + assert.Equal(t, 8, len(trieData)) + assert.Equal(t, 8, len(iterator.nextNodes)) + }) + t.Run("max size reached returns retrieved leaves and saves iterator context", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + iteratorMaxSize := uint64(200) + trieData, err := iterator.GetLeaves(numLeaves, iteratorMaxSize, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + assert.Equal(t, 8, len(trieData)) + assert.Equal(t, 8, len(iterator.nextNodes)) + }) t.Run("retrieve all leaves in multiple calls", func(t *testing.T) { t.Parallel() @@ -137,12 +180,12 @@ func TestDfsIterator_GetLeaves(t *testing.T) { trieTest.AddDataToTrie(tr, numLeaves) rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) numRetrievedLeaves := 0 numIterations := 0 for numRetrievedLeaves < numLeaves { - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -152,50 +195,55 @@ func TestDfsIterator_GetLeaves(t *testing.T) { assert.Equal(t, numLeaves, numRetrievedLeaves) assert.Equal(t, 5, numIterations) }) -} - -func TestDfsIterator_GetIteratorId(t *testing.T) { - t.Parallel() + t.Run("retrieve leaves with nil context does not panic", func(t *testing.T) { + t.Parallel() - tr := trieTest.GetNewTrie() - numLeaves := 25 - trieTest.AddDataToTrie(tr, numLeaves) - rootHash, _ := tr.RootHash() - _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + tr := trieTest.GetNewTrie() + numLeaves := 25 + expectedNumRetrievedLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() - numRetrievedLeaves := 0 - for numRetrievedLeaves < numLeaves { - iteratorId := hasher.Compute(string(append(rootHash, iterator.nextNodes[0].GetData()...))) - assert.Equal(t, iteratorId, iterator.GetIteratorId()) + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), nil) assert.Nil(t, err) - - numRetrievedLeaves += len(trieData) - } - - assert.Equal(t, numLeaves, numRetrievedLeaves) - assert.Nil(t, iterator.GetIteratorId()) + assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) + }) } -func TestDfsIterator_Clone(t *testing.T) { +func TestDfsIterator_GetIteratorState(t *testing.T) { t.Parallel() tr := trieTest.GetNewTrie() - numLeaves := 25 - trieTest.AddDataToTrie(tr, numLeaves) + _ = tr.Update([]byte("doe"), []byte("reindeer")) + _ = tr.Update([]byte("dog"), []byte("puppy")) + _ = tr.Update([]byte("ddog"), []byte("cat")) + _ = tr.Commit() rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) - clonedIterator := iterator.Clone() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) - nextNodesMemAddr := fmt.Sprintf("%p", iterator.nextNodes) - clonedNextNodesMemAddr := fmt.Sprintf("%p", clonedIterator.(*dfsIterator).nextNodes) - assert.NotEqual(t, nextNodesMemAddr, clonedNextNodesMemAddr) - assert.Equal(t, iterator.rootHash, clonedIterator.(*dfsIterator).rootHash) - assert.Equal(t, iterator.size, clonedIterator.(*dfsIterator).size) + leaves, err := iterator.GetLeaves(2, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + assert.Equal(t, 2, len(leaves)) + val, ok := leaves[hex.EncodeToString([]byte("doe"))] + assert.True(t, ok) + assert.Equal(t, hex.EncodeToString([]byte("reindeer")), val) + val, ok = leaves[hex.EncodeToString([]byte("ddog"))] + assert.True(t, ok) + assert.Equal(t, hex.EncodeToString([]byte("cat")), val) + + iteratorState := iterator.GetIteratorState() + assert.Equal(t, 1, len(iteratorState)) + hash := iteratorState[0][:hasher.Size()] + key := iteratorState[0][hasher.Size():] + assert.Equal(t, []byte{0x7, 0x6, 0xf, 0x6, 0x4, 0x6, 0x10}, key) + leafBytes, err := tr.GetStorageManager().Get(hash) + assert.Nil(t, err) + assert.NotNil(t, leafBytes) } func TestDfsIterator_FinishedIteration(t *testing.T) { @@ -206,12 +254,12 @@ func TestDfsIterator_FinishedIteration(t *testing.T) { trieTest.AddDataToTrie(tr, numLeaves) rootHash, _ := tr.RootHash() _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) numRetrievedLeaves := 0 for numRetrievedLeaves < numLeaves { assert.False(t, iterator.FinishedIteration()) - trieData, err := iterator.GetLeaves(5, context.Background()) + trieData, err := iterator.GetLeaves(5, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) assert.Nil(t, err) numRetrievedLeaves += len(trieData) @@ -220,40 +268,3 @@ func TestDfsIterator_FinishedIteration(t *testing.T) { assert.Equal(t, numLeaves, numRetrievedLeaves) assert.True(t, iterator.FinishedIteration()) } - -func TestDfsIterator_Size(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - numLeaves := 25 - trieTest.AddDataToTrie(tr, numLeaves) - rootHash, _ := tr.RootHash() - _, marshaller, hasher := trieTest.GetDefaultTrieParameters() - - // branch node size = 33 - // root hash size = 32 - // extension nodes size = 34 - // leaf nodes size = 35 - iterator, _ := NewIterator(rootHash, tr.GetStorageManager(), marshaller, hasher) - assert.Equal(t, uint64(362), iterator.Size()) // 10 branch nodes + 1 root hash - - _, err := iterator.GetLeaves(5, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(331), iterator.Size()) // 8 branch nodes + 1 leaf node + 1 root hash - - _, err = iterator.GetLeaves(5, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(300), iterator.Size()) // 6 branch nodes + 2 leaf node + 1 root hash - - _, err = iterator.GetLeaves(5, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(197), iterator.Size()) // 5 branch nodes + 1 root hash - - _, err = iterator.GetLeaves(5, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(133), iterator.Size()) // 2 branch nodes + 1 leaf node + 1 root hash - - _, err = iterator.GetLeaves(5, context.Background()) - assert.Nil(t, err) - assert.Equal(t, uint64(32), iterator.Size()) // 1 root hash -} diff --git a/trie/leavesRetriever/disabledLeavesRetriever.go b/trie/leavesRetriever/disabledLeavesRetriever.go new file mode 100644 index 00000000000..8d3d33720ba --- /dev/null +++ b/trie/leavesRetriever/disabledLeavesRetriever.go @@ -0,0 +1,24 @@ +package leavesRetriever + +import ( + "context" + + "github.com/multiversx/mx-chain-go/common" +) + +type disabledLeavesRetriever struct{} + +// NewDisabledLeavesRetriever creates a new disabled leaves retriever +func NewDisabledLeavesRetriever() *disabledLeavesRetriever { + return &disabledLeavesRetriever{} +} + +// GetLeaves returns an empty map and a nil byte slice for this implementation +func (dlr *disabledLeavesRetriever) GetLeaves(_ int, _ [][]byte, _ common.TrieLeafParser, _ context.Context) (map[string]string, [][]byte, error) { + return make(map[string]string), [][]byte{}, nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (dlr *disabledLeavesRetriever) IsInterfaceNil() bool { + return dlr == nil +} diff --git a/trie/leavesRetriever/export_test.go b/trie/leavesRetriever/export_test.go deleted file mode 100644 index 3135262e01a..00000000000 --- a/trie/leavesRetriever/export_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package leavesRetriever - -import "github.com/multiversx/mx-chain-go/common" - -// GetIterators - -func (lr *leavesRetriever) GetIterators() map[string]common.DfsIterator { - return lr.iterators -} - -// GetLruIteratorIDs - -func (lr *leavesRetriever) GetLruIteratorIDs() [][]byte { - return lr.lruIteratorIDs -} - -// Size - -func (lr *leavesRetriever) Size() uint64 { - return lr.size -} diff --git a/trie/leavesRetriever/leavesRetriever.go b/trie/leavesRetriever/leavesRetriever.go index 89a11569bc0..a2822975f43 100644 --- a/trie/leavesRetriever/leavesRetriever.go +++ b/trie/leavesRetriever/leavesRetriever.go @@ -2,7 +2,7 @@ package leavesRetriever import ( "context" - "sync" + "fmt" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" @@ -12,14 +12,10 @@ import ( ) type leavesRetriever struct { - iterators map[string]common.DfsIterator - lruIteratorIDs [][]byte - db common.TrieStorageInteractor - marshaller marshal.Marshalizer - hasher hashing.Hasher - size uint64 - maxSize uint64 - mutex sync.RWMutex + db common.TrieStorageInteractor + marshaller marshal.Marshalizer + hasher hashing.Hasher + maxSize uint64 } // NewLeavesRetriever creates a new leaves retriever @@ -35,125 +31,31 @@ func NewLeavesRetriever(db common.TrieStorageInteractor, marshaller marshal.Mars } return &leavesRetriever{ - iterators: make(map[string]common.DfsIterator), - lruIteratorIDs: make([][]byte, 0), - db: db, - marshaller: marshaller, - hasher: hasher, - size: 0, - maxSize: maxSize, + db: db, + marshaller: marshaller, + hasher: hasher, + maxSize: maxSize, }, nil } -// GetLeaves retrieves the leaves from the trie. If there is a saved checkpoint for the iterator id, it will continue to iterate from the checkpoint. -func (lr *leavesRetriever) GetLeaves(numLeaves int, rootHash []byte, iteratorID []byte, ctx context.Context) (map[string]string, []byte, error) { - if len(iteratorID) == 0 { - return lr.getLeavesFromNewInstance(numLeaves, rootHash, ctx) +// GetLeaves retrieves leaves from the trie starting from the iterator state. It will also return the new iterator state +// from which one can continue the iteration. +func (lr *leavesRetriever) GetLeaves(numLeaves int, iteratorState [][]byte, leavesParser common.TrieLeafParser, ctx context.Context) (map[string]string, [][]byte, error) { + if check.IfNil(leavesParser) { + return nil, nil, fmt.Errorf("nil leaves parser") } - lr.mutex.RLock() - iterator, ok := lr.iterators[string(iteratorID)] - lr.mutex.RUnlock() - if !ok { - return nil, nil, ErrIteratorNotFound - } - - return lr.getLeavesFromCheckpoint(numLeaves, iterator, iteratorID, ctx) -} - -func (lr *leavesRetriever) getLeavesFromNewInstance(numLeaves int, rootHash []byte, ctx context.Context) (map[string]string, []byte, error) { - iterator, err := dfsTrieIterator.NewIterator(rootHash, lr.db, lr.marshaller, lr.hasher) + iterator, err := dfsTrieIterator.NewIterator(iteratorState, lr.db, lr.marshaller, lr.hasher) if err != nil { return nil, nil, err } - return lr.getLeavesFromIterator(iterator, numLeaves, ctx) -} - -func (lr *leavesRetriever) getLeavesFromCheckpoint(numLeaves int, iterator common.DfsIterator, iteratorID []byte, ctx context.Context) (map[string]string, []byte, error) { - lr.markIteratorAsRecentlyUsed(iteratorID) - clonedIterator := iterator.Clone() - - return lr.getLeavesFromIterator(clonedIterator, numLeaves, ctx) -} - -func (lr *leavesRetriever) getLeavesFromIterator(iterator common.DfsIterator, numLeaves int, ctx context.Context) (map[string]string, []byte, error) { - leaves, err := iterator.GetLeaves(numLeaves, ctx) + leavesData, err := iterator.GetLeaves(numLeaves, lr.maxSize, leavesParser, ctx) if err != nil { return nil, nil, err } - if iterator.FinishedIteration() { - return leaves, nil, nil - } - - iteratorId := iterator.GetIteratorId() - if len(iteratorId) == 0 { - return leaves, nil, nil - } - - lr.manageIterators(iteratorId, iterator) - return leaves, iteratorId, nil -} - -func (lr *leavesRetriever) manageIterators(iteratorId []byte, iterator common.DfsIterator) { - lr.mutex.Lock() - defer lr.mutex.Unlock() - - lr.saveIterator(iteratorId, iterator) - lr.removeIteratorsIfMaxSizeIsExceeded() -} - -func (lr *leavesRetriever) saveIterator(iteratorId []byte, iterator common.DfsIterator) { - _, isPresent := lr.iterators[string(iteratorId)] - if isPresent { - return - } - - lr.lruIteratorIDs = append(lr.lruIteratorIDs, iteratorId) - lr.iterators[string(iteratorId)] = iterator - lr.size += iterator.Size() + uint64(len(iteratorId)) -} - -func (lr *leavesRetriever) markIteratorAsRecentlyUsed(iteratorId []byte) { - lr.mutex.Lock() - defer lr.mutex.Unlock() - - for i, id := range lr.lruIteratorIDs { - if string(id) == string(iteratorId) { - lr.lruIteratorIDs = append(lr.lruIteratorIDs[:i], lr.lruIteratorIDs[i+1:]...) - lr.lruIteratorIDs = append(lr.lruIteratorIDs, id) - return - } - } -} - -func (lr *leavesRetriever) removeIteratorsIfMaxSizeIsExceeded() { - if lr.size <= lr.maxSize { - return - } - - idsToRemove := make([][]byte, 0) - sizeOfRemoved := uint64(0) - numOfRemoved := 0 - - for i := 0; i < len(lr.lruIteratorIDs); i++ { - id := lr.lruIteratorIDs[i] - idsToRemove = append(idsToRemove, id) - iterator := lr.iterators[string(id)] - sizeOfRemoved += iterator.Size() + uint64(len(id)) - numOfRemoved++ - - if lr.size-sizeOfRemoved <= lr.maxSize { - break - } - } - - for _, id := range idsToRemove { - delete(lr.iterators, string(id)) - } - lr.lruIteratorIDs = lr.lruIteratorIDs[numOfRemoved:] - lr.size -= sizeOfRemoved + return leavesData, iterator.GetIteratorState(), nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/trie/leavesRetriever/leavesRetriever_test.go b/trie/leavesRetriever/leavesRetriever_test.go index 28dd6131475..8fd376de439 100644 --- a/trie/leavesRetriever/leavesRetriever_test.go +++ b/trie/leavesRetriever/leavesRetriever_test.go @@ -2,20 +2,17 @@ package leavesRetriever_test import ( "context" - "crypto/rand" - "encoding/hex" - "fmt" - "sync" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/keyValStorage" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" trieTest "github.com/multiversx/mx-chain-go/testscommon/state" - "github.com/multiversx/mx-chain-go/trie" + trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie/leavesRetriever" "github.com/stretchr/testify/assert" ) @@ -59,213 +56,23 @@ func TestNewLeavesRetriever(t *testing.T) { func TestLeavesRetriever_GetLeaves(t *testing.T) { t.Parallel() - t.Run("get leaves from new instance", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) - leaves, iteratorId, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves)) - assert.Equal(t, 32, len(iteratorId)) - }) - t.Run("get leaves from existing instance", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 10000000) - leaves1, iteratorId1, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves1)) - assert.Equal(t, 32, len(iteratorId1)) - assert.Equal(t, 1, len(lr.GetIterators())) - assert.Equal(t, 1, len(lr.GetLruIteratorIDs())) - - leaves2, iteratorId2, err := lr.GetLeaves(10, rootHash, iteratorId1, context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves2)) - assert.Equal(t, 32, len(iteratorId2)) - assert.Equal(t, 2, len(lr.GetIterators())) - assert.Equal(t, 2, len(lr.GetLruIteratorIDs())) - - assert.NotEqual(t, leaves1, leaves2) - assert.NotEqual(t, iteratorId1, iteratorId2) - }) - t.Run("traversing a trie saves all iterator instances", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 10000000) - leaves1, iteratorId1, err := lr.GetLeaves(10, rootHash, []byte(""), context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves1)) - assert.Equal(t, 32, len(iteratorId1)) - assert.Equal(t, 1, len(lr.GetIterators())) - assert.Equal(t, 1, len(lr.GetLruIteratorIDs())) - - leaves2, iteratorId2, err := lr.GetLeaves(10, rootHash, iteratorId1, context.Background()) - assert.Nil(t, err) - assert.Equal(t, 10, len(leaves2)) - assert.Equal(t, 32, len(iteratorId2)) - assert.Equal(t, 2, len(lr.GetIterators())) - assert.Equal(t, 2, len(lr.GetLruIteratorIDs())) - - leaves3, iteratorId3, err := lr.GetLeaves(10, rootHash, iteratorId2, context.Background()) - assert.Nil(t, err) - assert.Equal(t, 5, len(leaves3)) - assert.Equal(t, 0, len(iteratorId3)) - assert.Equal(t, 2, len(lr.GetIterators())) - assert.Equal(t, 2, len(lr.GetLruIteratorIDs())) - }) - t.Run("iterator instances are evicted in a lru manner", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - maxSize := uint64(1000) - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - iterators := make([][]byte, 0) - _, id1, _ := lr.GetLeaves(5, rootHash, []byte(""), context.Background()) - iterators = append(iterators, id1) - _, id2, _ := lr.GetLeaves(5, rootHash, id1, context.Background()) - iterators = append(iterators, id2) - _, id3, _ := lr.GetLeaves(5, rootHash, id2, context.Background()) - iterators = append(iterators, id3) - - assert.Equal(t, 3, len(lr.GetIterators())) - for i, id := range lr.GetLruIteratorIDs() { - assert.Equal(t, iterators[i], id) - } - - _, id4, _ := lr.GetLeaves(5, rootHash, id3, context.Background()) - iterators = append(iterators, id4) - assert.Equal(t, 3, len(lr.GetIterators())) - for i, id := range lr.GetLruIteratorIDs() { - assert.Equal(t, iterators[i+1], id) - } - }) - t.Run("when an iterator instance is used, it is moved in the front of the eviction queue", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - maxSize := uint64(100000) - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - iterators := make([][]byte, 0) - _, id1, _ := lr.GetLeaves(5, rootHash, []byte(""), context.Background()) - iterators = append(iterators, id1) - leaves1, id2, _ := lr.GetLeaves(5, rootHash, id1, context.Background()) - iterators = append(iterators, id2) - _, id3, _ := lr.GetLeaves(5, rootHash, id2, context.Background()) - iterators = append(iterators, id3) - - assert.Equal(t, 3, len(lr.GetIterators())) - for i, id := range lr.GetLruIteratorIDs() { - assert.Equal(t, iterators[i], id) - } - - leaves2, id4, _ := lr.GetLeaves(5, rootHash, id1, context.Background()) - assert.Equal(t, leaves1, leaves2) - assert.Equal(t, id2, id4) - - assert.Equal(t, 3, len(lr.GetIterators())) - retrievedIterators := lr.GetLruIteratorIDs() - assert.Equal(t, 3, len(retrievedIterators)) - assert.Equal(t, id2, retrievedIterators[0]) - assert.Equal(t, id3, retrievedIterators[1]) - assert.Equal(t, id1, retrievedIterators[2]) - }) - t.Run("iterator not found", func(t *testing.T) { - t.Parallel() - - tr := trieTest.GetNewTrie() - trieTest.AddDataToTrie(tr, 25) - rootHash, _ := tr.RootHash() - maxSize := uint64(100000) - - lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - leaves, id, err := lr.GetLeaves(5, rootHash, []byte("invalid iterator"), context.Background()) - assert.Nil(t, leaves) - assert.Equal(t, 0, len(id)) - assert.Equal(t, leavesRetriever.ErrIteratorNotFound, err) - }) -} - -func TestLeavesRetriever_Concurrency(t *testing.T) { - t.Parallel() - - numTries := 10 - numLeaves := 1000 - tries := buildTries(numTries, numLeaves) - - rootHashes := make([][]byte, 0) - for _, tr := range tries { - rootHash, _ := tr.RootHash() - rootHashes = append(rootHashes, rootHash) - + tr := trieTest.GetNewTrie() + trieTest.AddDataToTrie(tr, 25) + rootHash, _ := tr.RootHash() + leafParser := &trieMock.TrieLeafParserStub{ + ParseLeafCalled: func(key []byte, val []byte, version core.TrieNodeVersion) (core.KeyValueHolder, error) { + return keyValStorage.NewKeyValStorage(key, val), nil + }, } - - maxSize := uint64(1000000) - lr, _ := leavesRetriever.NewLeavesRetriever(tries[0].GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, maxSize) - - wg := &sync.WaitGroup{} - wg.Add(numTries) - for i := 0; i < numTries; i++ { - go retrieveTrieLeaves(t, lr, rootHashes[i], numLeaves, wg) - } - wg.Wait() -} - -func retrieveTrieLeaves(t *testing.T, lr common.TrieLeavesRetriever, rootHash []byte, numLeaves int, wg *sync.WaitGroup) { - iteratorId := []byte("") - numRetrievedLeaves := 0 - for { - leaves, newId, err := lr.GetLeaves(100, rootHash, iteratorId, context.Background()) - assert.Nil(t, err) - iteratorId = newId - numRetrievedLeaves += len(leaves) - fmt.Println("Retrieved leaves: ", numRetrievedLeaves, " for root hash: ", hex.EncodeToString(rootHash)) - if len(iteratorId) == 0 { - break - } - } - assert.Equal(t, numLeaves, numRetrievedLeaves) - wg.Done() -} - -func buildTries(numTries int, numLeaves int) []common.Trie { - tries := make([]common.Trie, 0) - tsm, marshaller, hasher := trieTest.GetDefaultTrieParameters() - for i := 0; i < numTries; i++ { - tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) - addDataToTrie(tr, numLeaves) - tries = append(tries, tr) - } - return tries -} - -func addDataToTrie(tr common.Trie, numLeaves int) { - for i := 0; i < numLeaves; i++ { - _ = tr.Update(generateRandomByteArray(32), generateRandomByteArray(32)) - } - _ = tr.Commit() -} - -func generateRandomByteArray(size int) []byte { - r := make([]byte, size) - _, _ = rand.Read(r) - return r + lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) + leaves, newIteratorState, err := lr.GetLeaves(10, [][]byte{rootHash}, leafParser, context.Background()) + assert.Nil(t, err) + assert.Equal(t, 10, len(leaves)) + assert.Equal(t, 8, len(newIteratorState)) + + newLr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) + leaves, newIteratorState, err = newLr.GetLeaves(10, newIteratorState, leafParser, context.Background()) + assert.Nil(t, err) + assert.Equal(t, 10, len(leaves)) + assert.Equal(t, 3, len(newIteratorState)) } diff --git a/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go b/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go index bd6d029868f..10a18a856fa 100644 --- a/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go +++ b/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go @@ -1,6 +1,7 @@ package trieNodeData import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" ) @@ -28,6 +29,11 @@ func (ind *intermediaryNodeData) IsLeaf() bool { return false } +// GetVersion returns NotSpecified +func (ind *intermediaryNodeData) GetVersion() core.TrieNodeVersion { + return core.NotSpecified +} + // IsInterfaceNil returns true if there is no value under the interface func (ind *intermediaryNodeData) IsInterfaceNil() bool { return ind == nil diff --git a/trie/leavesRetriever/trieNodeData/leafNodeData.go b/trie/leavesRetriever/trieNodeData/leafNodeData.go index 5f5ef574c70..08a80f2d3f8 100644 --- a/trie/leavesRetriever/trieNodeData/leafNodeData.go +++ b/trie/leavesRetriever/trieNodeData/leafNodeData.go @@ -1,16 +1,18 @@ package trieNodeData import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" ) type leafNodeData struct { *baseNodeData + version core.TrieNodeVersion } // NewLeafNodeData creates a new leaf node data -func NewLeafNodeData(key common.KeyBuilder, data []byte) (*leafNodeData, error) { +func NewLeafNodeData(key common.KeyBuilder, data []byte, version core.TrieNodeVersion) (*leafNodeData, error) { if check.IfNil(key) { return nil, ErrNilKeyBuilder } @@ -20,6 +22,7 @@ func NewLeafNodeData(key common.KeyBuilder, data []byte) (*leafNodeData, error) keyBuilder: key, data: data, }, + version: version, }, nil } @@ -28,6 +31,11 @@ func (lnd *leafNodeData) IsLeaf() bool { return true } +// GetVersion returns the version of the leaf +func (lnd *leafNodeData) GetVersion() core.TrieNodeVersion { + return lnd.version +} + // IsInterfaceNil returns true if there is no value under the interface func (lnd *leafNodeData) IsInterfaceNil() bool { return lnd == nil diff --git a/trie/leavesRetriever/trieNodeData/leafNodeData_test.go b/trie/leavesRetriever/trieNodeData/leafNodeData_test.go index 730ca9836a7..dc4b4ab656b 100644 --- a/trie/leavesRetriever/trieNodeData/leafNodeData_test.go +++ b/trie/leavesRetriever/trieNodeData/leafNodeData_test.go @@ -3,6 +3,7 @@ package trieNodeData import ( "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/stretchr/testify/assert" @@ -14,11 +15,11 @@ func TestNewLeafNodeData(t *testing.T) { var lnd *leafNodeData assert.True(t, check.IfNil(lnd)) - lnd, err := NewLeafNodeData(nil, nil) + lnd, err := NewLeafNodeData(nil, nil, core.NotSpecified) assert.Equal(t, ErrNilKeyBuilder, err) assert.True(t, check.IfNil(lnd)) - lnd, err = NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data")) + lnd, err = NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data"), core.NotSpecified) assert.Nil(t, err) assert.False(t, check.IfNil(lnd)) } @@ -26,6 +27,6 @@ func TestNewLeafNodeData(t *testing.T) { func TestLeafNodeData(t *testing.T) { t.Parallel() - lnd, _ := NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data")) + lnd, _ := NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data"), core.NotSpecified) assert.True(t, lnd.IsLeaf()) } diff --git a/trie/mock/keyBuilderStub.go b/trie/mock/keyBuilderStub.go index 1c68a95384d..7fec902d542 100644 --- a/trie/mock/keyBuilderStub.go +++ b/trie/mock/keyBuilderStub.go @@ -6,6 +6,7 @@ import "github.com/multiversx/mx-chain-go/common" type KeyBuilderStub struct { BuildKeyCalled func(keyPart []byte) GetKeyCalled func() ([]byte, error) + GetRawKeyCalled func() []byte ShallowCloneCalled func() common.KeyBuilder DeepCloneCalled func() common.KeyBuilder SizeCalled func() uint @@ -27,6 +28,15 @@ func (stub *KeyBuilderStub) GetKey() ([]byte, error) { return []byte{}, nil } +// GetRawKey - +func (stub *KeyBuilderStub) GetRawKey() []byte { + if stub.GetRawKeyCalled != nil { + return stub.GetRawKeyCalled() + } + + return []byte{} +} + // ShallowClone - func (stub *KeyBuilderStub) ShallowClone() common.KeyBuilder { if stub.ShallowCloneCalled != nil {