diff --git a/cmd/gossamer/main.go b/cmd/gossamer/main.go index 764b93de83a..6f17ac322f7 100644 --- a/cmd/gossamer/main.go +++ b/cmd/gossamer/main.go @@ -154,6 +154,10 @@ func main() { } } +var ( + errStateVersionNotSet = errors.New("state version not set") +) + func importStateAction(ctx *cli.Context) error { var ( stateFP, headerFP string diff --git a/dot/build_spec.go b/dot/build_spec.go index fb877196a26..a8e59d45158 100644 --- a/dot/build_spec.go +++ b/dot/build_spec.go @@ -112,8 +112,19 @@ func BuildFromDB(path string) (*BuildSpec, error) { if err != nil { return nil, fmt.Errorf("cannot start state service: %w", err) } + + bestBlockStateRoot, err := stateSrvc.Block.BestBlockStateRoot() + if err != nil { + return nil, fmt.Errorf("getting best block state root: %w", err) + } + + stateVersion, err := stateSrvc.Block.GetRuntimeStateVersion(bestBlockStateRoot) + if err != nil { + return nil, fmt.Errorf("getting runtime state version for block state root: %w", err) + } + // set genesis fields data - ent, err := stateSrvc.Storage.Entries(nil) + ent, err := stateSrvc.Storage.Entries(&bestBlockStateRoot, stateVersion) if err != nil { return nil, fmt.Errorf("failed to get storage trie entries: %w", err) } diff --git a/dot/core/helpers_test.go b/dot/core/helpers_test.go index 371183a9a94..c317d8ccde7 100644 --- a/dot/core/helpers_test.go +++ b/dot/core/helpers_test.go @@ -149,7 +149,8 @@ func getGssmrRuntimeCode(t *testing.T) (code []byte) { gssmrGenesis, err := genesis.NewGenesisFromJSONRaw(path) require.NoError(t, err) - trie, err := genesis.NewTrieFromGenesis(gssmrGenesis) + const stateVersion = trie.V0 + trie, err := genesis.NewTrieFromGenesis(gssmrGenesis, stateVersion) require.NoError(t, err) trieState := rtstorage.NewTrieState(trie) diff --git a/dot/core/interface.go b/dot/core/interface.go index aadf9b908ad..bdc1678a173 100644 --- a/dot/core/interface.go +++ b/dot/core/interface.go @@ -15,6 +15,7 @@ import ( "github.com/ChainSafe/gossamer/lib/runtime" rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" "github.com/ChainSafe/gossamer/lib/transaction" + "github.com/ChainSafe/gossamer/lib/trie" ) //go:generate mockgen -destination=mock_core_test.go -package $GOPACKAGE . BlockState,StorageState,TransactionState,Network,EpochState,CodeSubstitutedState @@ -42,7 +43,7 @@ type BlockState interface { SubChain(start, end common.Hash) ([]common.Hash, error) GetBlockBody(hash common.Hash) (*types.Body, error) HandleRuntimeChanges(newState *rtstorage.TrieState, in runtime.Instance, bHash common.Hash) error - GetRuntime(*common.Hash) (runtime.Instance, error) + GetRuntime(blockHash *common.Hash) (instance runtime.Instance, err error) StoreRuntime(common.Hash, runtime.Instance) } @@ -50,11 +51,13 @@ type BlockState interface { type StorageState interface { LoadCode(root *common.Hash) ([]byte, error) LoadCodeHash(root *common.Hash) (common.Hash, error) - TrieState(root *common.Hash) (*rtstorage.TrieState, error) - StoreTrie(*rtstorage.TrieState, *types.Header) error + TrieState(root *common.Hash, stateVersion trie.Version) (*rtstorage.TrieState, error) + StoreTrie(trieState *rtstorage.TrieState, blockHeader *types.Header, + stateVersion trie.Version) error GetStateRootFromBlock(bhash *common.Hash) (*common.Hash, error) GetStorage(root *common.Hash, key []byte) ([]byte, error) - GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ([][]byte, error) + GenerateTrieProof(stateRoot common.Hash, keys [][]byte, + stateVersion trie.Version) (encodedNodes [][]byte, err error) sync.Locker } diff --git a/dot/core/messages.go b/dot/core/messages.go index 839e4df9962..2cb7f312a1c 100644 --- a/dot/core/messages.go +++ b/dot/core/messages.go @@ -12,6 +12,7 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/runtime" "github.com/ChainSafe/gossamer/lib/transaction" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/libp2p/go-libp2p-core/peer" ) @@ -20,7 +21,15 @@ func (s *Service) validateTransaction(peerID peer.ID, head *types.Header, rt run tx types.Extrinsic) (validity *transaction.Validity, valid bool, err error) { s.storageState.Lock() - ts, err := s.storageState.TrieState(&head.StateRoot) + // Note this is a cheap call getting the runtime cached version + // so we can call this in this function and not pass it as argument. + coreVersion := rt.Version() + stateVersion, err := trie.ParseVersion(coreVersion.StateVersion) + if err != nil { + return nil, false, fmt.Errorf("parsing state version: %w", err) + } + + ts, err := s.storageState.TrieState(&head.StateRoot, stateVersion) s.storageState.Unlock() if err != nil { return nil, false, fmt.Errorf("cannot get trie state from storage for root %s: %w", head.StateRoot, err) diff --git a/dot/core/service.go b/dot/core/service.go index 2776a2de36a..f812b75fcc1 100644 --- a/dot/core/service.go +++ b/dot/core/service.go @@ -22,6 +22,7 @@ import ( "github.com/ChainSafe/gossamer/lib/runtime/wasmer" "github.com/ChainSafe/gossamer/lib/services" "github.com/ChainSafe/gossamer/lib/transaction" + "github.com/ChainSafe/gossamer/lib/trie" cscale "github.com/centrifuge/go-substrate-rpc-client/v4/scale" ctypes "github.com/centrifuge/go-substrate-rpc-client/v4/types" ) @@ -60,6 +61,7 @@ type Service struct { // Config holds the configuration for the core Service. type Config struct { + // TODO add state version field here LogLvl log.Level BlockState BlockState @@ -102,7 +104,7 @@ func NewService(cfg *Config) (*Service, error) { blockAddCh := make(chan *types.Block, 256) ctx, cancel := context.WithCancel(context.Background()) - srv := &Service{ + return &Service{ ctx: ctx, cancel: cancel, keys: cfg.Keystore, @@ -114,9 +116,7 @@ func NewService(cfg *Config) (*Service, error) { blockAddCh: blockAddCh, codeSubstitute: cfg.CodeSubstitutes, codeSubstitutedState: cfg.CodeSubstitutedState, - } - - return srv, nil + }, nil } // Start starts the core service @@ -135,18 +135,21 @@ func (s *Service) Stop() error { return nil } -// StorageRoot returns the hash of the storage root -func (s *Service) StorageRoot() (common.Hash, error) { +// StorageRoot returns the hash of the storage root. +// It is only used by tests, but has to exported because +// internal fields are not exported to other packages. +func (s *Service) StorageRoot(stateVersion trie.Version) ( + rootHash common.Hash, err error) { if s.storageState == nil { - return common.Hash{}, ErrNilStorageState + return rootHash, ErrNilStorageState } - ts, err := s.storageState.TrieState(nil) + ts, err := s.storageState.TrieState(nil, stateVersion) if err != nil { - return common.Hash{}, err + return rootHash, err } - return ts.Root() + return ts.Root(stateVersion) } // HandleBlockImport handles a block that was imported via the network @@ -188,8 +191,19 @@ func (s *Service) handleBlock(block *types.Block, state *rtstorage.TrieState) er return ErrNilBlockHandlerParameter } + rt, err := s.blockState.GetRuntime(&block.Header.ParentHash) + if err != nil { + return fmt.Errorf("getting runtime: %w", err) + } + + version := rt.Version() + stateVersion, err := trie.ParseVersion(version.StateVersion) + if err != nil { + return fmt.Errorf("parsing state version: %w", err) + } + // store updates state trie nodes in database - err := s.storageState.StoreTrie(state, &block.Header) + err = s.storageState.StoreTrie(state, &block.Header, stateVersion) if err != nil { logger.Warnf("failed to store state trie for imported block %s: %s", block.Header.Hash(), err) @@ -208,12 +222,7 @@ func (s *Service) handleBlock(block *types.Block, state *rtstorage.TrieState) er } logger.Debugf("imported block %s and stored state trie with root %s", - block.Header.Hash(), state.MustRoot()) - - rt, err := s.blockState.GetRuntime(&block.Header.ParentHash) - if err != nil { - return err - } + block.Header.Hash(), state.MustRoot(stateVersion)) // check for runtime changes if err := s.blockState.HandleRuntimeChanges(state, rt, block.Header.Hash()); err != nil { @@ -467,14 +476,22 @@ func (s *Service) GetRuntimeVersion(bhash *common.Hash) ( } } - ts, err := s.storageState.TrieState(stateRootHash) + rt, err := s.blockState.GetRuntime(bhash) if err != nil { return version, err } - rt, err := s.blockState.GetRuntime(bhash) + coreVersion := rt.Version() + stateVersion, err := trie.ParseVersion(coreVersion.StateVersion) if err != nil { - return version, err + return version, fmt.Errorf("parsing state version: %w", err) + } + + // Note: not too sure why this trie state call is needed, + // but some RPC tests fail without it. + ts, err := s.storageState.TrieState(stateRootHash, stateVersion) + if err != nil { + return version, fmt.Errorf("getting trie state: %w", err) } rt.SetContextStorage(ts) @@ -498,15 +515,20 @@ func (s *Service) HandleSubmittedExtrinsic(ext types.Extrinsic) error { return fmt.Errorf("could not get state root from block %s: %w", bestBlockHash, err) } - ts, err := s.storageState.TrieState(stateRoot) + rt, err := s.blockState.GetRuntime(&bestBlockHash) if err != nil { - return err + return fmt.Errorf("getting runtime: %w", err) } - rt, err := s.blockState.GetRuntime(&bestBlockHash) + coreVersion := rt.Version() + stateVersion, err := trie.ParseVersion(coreVersion.StateVersion) if err != nil { - logger.Critical("failed to get runtime") - return err + return fmt.Errorf("parsing state version: %w", err) + } + + ts, err := s.storageState.TrieState(stateRoot, stateVersion) + if err != nil { + return fmt.Errorf("computing trie state: %w", err) } rt.SetContextStorage(ts) @@ -528,11 +550,8 @@ func (s *Service) HandleSubmittedExtrinsic(ext types.Extrinsic) error { } //GetMetadata calls runtime Metadata_metadata function -func (s *Service) GetMetadata(bhash *common.Hash) ([]byte, error) { - var ( - stateRootHash *common.Hash - err error - ) +func (s *Service) GetMetadata(bhash *common.Hash) (metadata []byte, err error) { + var stateRootHash *common.Hash // If block hash is not nil then fetch the state root corresponding to the block. if bhash != nil { @@ -541,12 +560,19 @@ func (s *Service) GetMetadata(bhash *common.Hash) ([]byte, error) { return nil, err } } - ts, err := s.storageState.TrieState(stateRootHash) + + rt, err := s.blockState.GetRuntime(bhash) if err != nil { return nil, err } - rt, err := s.blockState.GetRuntime(bhash) + coreVersion := rt.Version() + stateVersion, err := trie.ParseVersion(coreVersion.StateVersion) + if err != nil { + return nil, fmt.Errorf("parsing state version: %w", err) + } + + ts, err := s.storageState.TrieState(stateRootHash, stateVersion) if err != nil { return nil, err } @@ -563,12 +589,23 @@ func (s *Service) GetReadProofAt(block common.Hash, keys [][]byte) ( block = s.blockState.BestBlockHash() } + instance, err := s.blockState.GetRuntime(&block) + if err != nil { + return hash, nil, fmt.Errorf("getting block runtime: %w", err) + } + + coreVersion := instance.Version() + stateVersion, err := trie.ParseVersion(coreVersion.StateVersion) + if err != nil { + return hash, nil, fmt.Errorf("parsing state version: %w", err) + } + stateRoot, err := s.blockState.GetBlockStateRoot(block) if err != nil { return hash, nil, err } - proofForKeys, err = s.storageState.GenerateTrieProof(stateRoot, keys) + proofForKeys, err = s.storageState.GenerateTrieProof(stateRoot, keys, stateVersion) if err != nil { return hash, nil, err } diff --git a/dot/import.go b/dot/import.go index af352575098..cc61d067bb2 100644 --- a/dot/import.go +++ b/dot/import.go @@ -13,6 +13,7 @@ import ( "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/runtime/wasmer" "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/pkg/scale" @@ -21,11 +22,6 @@ import ( // ImportState imports the state in the given files to the database with the given path. func ImportState(basepath, stateFP, headerFP string, firstSlot uint64) error { - tr, err := newTrieFromPairs(stateFP) - if err != nil { - return err - } - header, err := newHeaderFromFile(headerFP) if err != nil { return err @@ -38,10 +34,32 @@ func ImportState(basepath, stateFP, headerFP string, firstSlot uint64) error { LogLevel: log.Info, } srv := state.NewService(config) - return srv.Import(header, tr, firstSlot) + + blockHash := header.Hash() + runtimeCode, err := srv.Storage.LoadCode(&blockHash) + if err != nil { + return fmt.Errorf("loading code from storage: %w", err) + } + + coreVersion, err := wasmer.GetRuntimeVersion(runtimeCode) + if err != nil { + return fmt.Errorf("getting runtime version: %w", err) + } + + stateVersion, err := trie.ParseVersion(coreVersion.StateVersion) + if err != nil { + return fmt.Errorf("parsing state version: %w", err) + } + + trie, err := newTrieFromPairs(stateFP, stateVersion) + if err != nil { + return fmt.Errorf("creating trie from pairs: %w", err) + } + + return srv.Import(header, trie, firstSlot, stateVersion) } -func newTrieFromPairs(filename string) (*trie.Trie, error) { +func newTrieFromPairs(filename string, stateVersion trie.Version) (*trie.Trie, error) { data, err := os.ReadFile(filepath.Clean(filename)) if err != nil { return nil, err @@ -63,7 +81,7 @@ func newTrieFromPairs(filename string) (*trie.Trie, error) { } tr := trie.NewEmptyTrie() - err = tr.LoadFromMap(entries) + err = tr.LoadFromMap(entries, stateVersion) if err != nil { return nil, err } diff --git a/dot/node.go b/dot/node.go index edb21cd1374..ac9d6a9cfa4 100644 --- a/dot/node.go +++ b/dot/node.go @@ -32,7 +32,9 @@ import ( "github.com/ChainSafe/gossamer/lib/grandpa" "github.com/ChainSafe/gossamer/lib/keystore" "github.com/ChainSafe/gossamer/lib/runtime" + "github.com/ChainSafe/gossamer/lib/runtime/wasmer" "github.com/ChainSafe/gossamer/lib/services" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/lib/utils" ) @@ -143,14 +145,29 @@ func (*nodeBuilder) initNode(cfg *Config) error { } } + runtimeCode, err := gen.RuntimeCode() + if err != nil { + return fmt.Errorf("getting runtime code from genesis: %w", err) + } + + coreVersion, err := wasmer.GetRuntimeVersion(runtimeCode) + if err != nil { + return fmt.Errorf("getting runtime version from genesis: %w", err) + } + + stateVersion, err := trie.ParseVersion(coreVersion.StateVersion) + if err != nil { + return fmt.Errorf("parsing state version from genesis runtime: %w", err) + } + // create trie from genesis - t, err := genesis.NewTrieFromGenesis(gen) + t, err := genesis.NewTrieFromGenesis(gen, stateVersion) if err != nil { return fmt.Errorf("failed to create trie from genesis: %w", err) } // create genesis block from trie - header, err := genesis.NewGenesisBlockFromTrie(t) + header, err := genesis.NewGenesisBlockFromTrie(t, stateVersion) if err != nil { return fmt.Errorf("failed to create genesis block from trie: %w", err) } diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index 5b06a8cba36..dc973639643 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -21,13 +21,13 @@ import ( // StorageAPI is the interface for the storage state type StorageAPI interface { GetStorage(root *common.Hash, key []byte) ([]byte, error) - GetStorageChild(root *common.Hash, keyToChild []byte) (*trie.Trie, error) - GetStorageFromChild(root *common.Hash, keyToChild, key []byte) ([]byte, error) + GetStorageChild(root *common.Hash, keyToChild []byte, stateVersion trie.Version) (*trie.Trie, error) + GetStorageFromChild(root *common.Hash, keyToChild, key []byte, stateVersion trie.Version) ([]byte, error) GetStorageByBlockHash(bhash *common.Hash, key []byte) ([]byte, error) - Entries(root *common.Hash) (map[string][]byte, error) + Entries(root *common.Hash, stateVersion trie.Version) (map[string][]byte, error) GetStateRootFromBlock(bhash *common.Hash) (*common.Hash, error) - GetKeysWithPrefix(root *common.Hash, prefix []byte) ([][]byte, error) - RegisterStorageObserver(observer state.Observer) + GetKeysWithPrefix(root *common.Hash, prefix []byte, stateVersion trie.Version) ([][]byte, error) + RegisterStorageObserver(observer state.Observer) (err error) UnregisterStorageObserver(observer state.Observer) } @@ -101,7 +101,8 @@ type CoreAPI interface { HandleSubmittedExtrinsic(types.Extrinsic) error GetMetadata(bhash *common.Hash) ([]byte, error) DecodeSessionKeys(enc []byte) ([]byte, error) - GetReadProofAt(block common.Hash, keys [][]byte) (common.Hash, [][]byte, error) + GetReadProofAt(block common.Hash, keys [][]byte) ( + blockHash common.Hash, encodedNodes [][]byte, err error) } //go:generate mockery --name RPCAPI --structname RPCAPI --case underscore --keeptree diff --git a/dot/rpc/modules/childstate.go b/dot/rpc/modules/childstate.go index 270e2cb54b8..833e4ee6788 100644 --- a/dot/rpc/modules/childstate.go +++ b/dot/rpc/modules/childstate.go @@ -4,6 +4,7 @@ package modules import ( + "fmt" "net/http" "github.com/ChainSafe/gossamer/lib/common" @@ -66,7 +67,12 @@ func (cs *ChildStateModule) GetKeys(_ *http.Request, req *GetKeysRequest, res *[ return err } - trie, err := cs.storageAPI.GetStorageChild(stateRoot, req.Key) + stateVersion, err := getStateVersion(cs.blockAPI, hash) + if err != nil { + return fmt.Errorf("getting state version: %w", err) + } + + trie, err := cs.storageAPI.GetStorageChild(stateRoot, req.Key, stateVersion) if err != nil { return err } @@ -96,7 +102,13 @@ func (cs *ChildStateModule) GetStorageSize(_ *http.Request, req *GetChildStorage return err } - item, err := cs.storageAPI.GetStorageFromChild(stateRoot, req.KeyChild, req.EntryKey) + stateVersion, err := getStateVersion(cs.blockAPI, hash) + if err != nil { + return fmt.Errorf("getting state version: %w", err) + } + + item, err := cs.storageAPI.GetStorageFromChild(stateRoot, + req.KeyChild, req.EntryKey, stateVersion) if err != nil { return err } @@ -123,7 +135,13 @@ func (cs *ChildStateModule) GetStorageHash(_ *http.Request, req *GetStorageHash, return err } - item, err := cs.storageAPI.GetStorageFromChild(stateRoot, req.KeyChild, req.EntryKey) + stateVersion, err := getStateVersion(cs.blockAPI, hash) + if err != nil { + return fmt.Errorf("getting state version: %w", err) + } + + item, err := cs.storageAPI.GetStorageFromChild(stateRoot, + req.KeyChild, req.EntryKey, stateVersion) if err != nil { return err } @@ -155,7 +173,13 @@ func (cs *ChildStateModule) GetStorage( return err } - item, err = cs.storageAPI.GetStorageFromChild(stateRoot, req.ChildStorageKey, req.Key) + stateVersion, err := getStateVersion(cs.blockAPI, hash) + if err != nil { + return fmt.Errorf("getting state version: %w", err) + } + + item, err = cs.storageAPI.GetStorageFromChild(stateRoot, + req.ChildStorageKey, req.Key, stateVersion) if err != nil { return err } diff --git a/dot/rpc/modules/helpers.go b/dot/rpc/modules/helpers.go new file mode 100644 index 00000000000..68b55e64a1b --- /dev/null +++ b/dot/rpc/modules/helpers.go @@ -0,0 +1,25 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package modules + +import ( + "fmt" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" +) + +func getStateVersion(blockState BlockAPI, blockHash common.Hash) (stateVersion trie.Version, err error) { + runtimeInstance, err := blockState.GetRuntime(&blockHash) + if err != nil { + return stateVersion, fmt.Errorf("getting runtime: %w", err) + } + + stateVersion, err = trie.ParseVersion(runtimeInstance.Version().StateVersion) + if err != nil { + return stateVersion, fmt.Errorf("parsing state version: %w", err) + } + + return stateVersion, nil +} diff --git a/dot/rpc/modules/state.go b/dot/rpc/modules/state.go index 4a742e6e754..7b806e7ec0b 100644 --- a/dot/rpc/modules/state.go +++ b/dot/rpc/modules/state.go @@ -14,7 +14,7 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) -//StateGetReadProofRequest json fields +// StateGetReadProofRequest json fields type StateGetReadProofRequest struct { Keys []string Hash common.Hash @@ -189,19 +189,28 @@ func NewStateModule(net NetworkAPI, storage StorageAPI, core CoreAPI, blockAPI B // GetPairs returns the keys with prefix, leave empty to get all the keys. func (sm *StateModule) GetPairs(_ *http.Request, req *StatePairRequest, res *StatePairResponse) error { var ( + blockHash common.Hash stateRootHash *common.Hash err error ) if req.Bhash != nil { + blockHash = *req.Bhash stateRootHash, err = sm.storageAPI.GetStateRootFromBlock(req.Bhash) if err != nil { return err } + } else { + blockHash = sm.blockAPI.BestBlockHash() + } + + stateVersion, err := getStateVersion(sm.blockAPI, blockHash) + if err != nil { + return fmt.Errorf("getting state version: %w", err) } if req.Prefix == nil || *req.Prefix == "" || *req.Prefix == "0x" { - pairs, err := sm.storageAPI.Entries(stateRootHash) + pairs, err := sm.storageAPI.Entries(stateRootHash, stateVersion) if err != nil { return err } @@ -216,7 +225,7 @@ func (sm *StateModule) GetPairs(_ *http.Request, req *StatePairRequest, res *Sta if err != nil { return fmt.Errorf("cannot convert hex prefix %s to bytes: %w", *req.Prefix, err) } - keys, err := sm.storageAPI.GetKeysWithPrefix(stateRootHash, reqBytes) + keys, err := sm.storageAPI.GetKeysWithPrefix(stateRootHash, reqBytes, stateVersion) if err != nil { return err } @@ -255,7 +264,20 @@ func (sm *StateModule) GetKeysPaged(_ *http.Request, req *StateStorageKeyRequest if err != nil { return err } - keys, err := sm.storageAPI.GetKeysWithPrefix(req.Block, hPrefix) + + var blockHash common.Hash + if req.Block != nil { + blockHash = *req.Block + } else { + blockHash = sm.blockAPI.BestBlockHash() + } + + stateVersion, err := getStateVersion(sm.blockAPI, blockHash) + if err != nil { + return fmt.Errorf("getting state version: %w", err) + } + + keys, err := sm.storageAPI.GetKeysWithPrefix(req.Block, hPrefix, stateVersion) if err != nil { return fmt.Errorf("cannot get keys with prefix %s: %w", hPrefix, err) } diff --git a/dot/rpc/modules/sync_state.go b/dot/rpc/modules/sync_state.go index 5d08a87a2cb..755c71eee9c 100644 --- a/dot/rpc/modules/sync_state.go +++ b/dot/rpc/modules/sync_state.go @@ -8,6 +8,7 @@ import ( "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/genesis" + "github.com/ChainSafe/gossamer/lib/trie" ) // GenSyncSpecRequest represents request to get chain specification. @@ -43,21 +44,17 @@ type syncState struct { } // NewStateSync creates an instance of SyncStateAPI given a chain specification. -func NewStateSync(gData *genesis.Data, storageAPI StorageAPI) (SyncStateAPI, error) { +func NewStateSync(gData *genesis.Data, storageAPI StorageAPI, + stateVersion trie.Version) (stateSync SyncStateAPI, err error) { tmpGen := &genesis.Genesis{ - Name: "", - ID: "", - Bootnodes: nil, - ProtocolID: "", Genesis: genesis.Fields{ - Runtime: nil, + Raw: make(map[string]map[string]string), + Runtime: make(map[string]map[string]interface{}), }, } - tmpGen.Genesis.Raw = make(map[string]map[string]string) - tmpGen.Genesis.Runtime = make(map[string]map[string]interface{}) // set genesis fields data - ent, err := storageAPI.Entries(nil) + ent, err := storageAPI.Entries(nil, stateVersion) if err != nil { return nil, err } diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index 2c1fcc5c432..baf6c13dd15 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -203,7 +203,11 @@ func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (L c.mu.Unlock() - c.StorageAPI.RegisterStorageObserver(stgobs) + err := c.StorageAPI.RegisterStorageObserver(stgobs) + if err != nil { + return nil, fmt.Errorf("registering storage observer: %w", err) + } + initRes := NewSubscriptionResponseJSON(stgobs.id, reqID) c.safeSend(initRes) diff --git a/dot/services.go b/dot/services.go index e2a18404398..42e03d58f33 100644 --- a/dot/services.go +++ b/dot/services.go @@ -31,6 +31,7 @@ import ( "github.com/ChainSafe/gossamer/lib/keystore" "github.com/ChainSafe/gossamer/lib/runtime" "github.com/ChainSafe/gossamer/lib/runtime/wasmer" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/lib/utils" ) @@ -121,7 +122,18 @@ func createRuntime(cfg *Config, ns runtime.NodeStorage, st *state.Service, code = common.MustHexToBytes(codeString) } - ts, err := st.Storage.TrieState(nil) + // Use the runtime state version to load the trie state in storage + version, err := wasmer.GetRuntimeVersion(code) + if err != nil { + return nil, fmt.Errorf("getting runtime version: %w", err) + } + + stateVersion, err := trie.ParseVersion(version.StateVersion) + if err != nil { + return nil, fmt.Errorf("parsing state version: %w", err) + } + + ts, err := st.Storage.TrieState(nil, stateVersion) if err != nil { return nil, err } @@ -321,12 +333,24 @@ func (nodeBuilder) createRPCService(params rpcServiceSettings) (*rpc.HTTPServer, ) rpcService := rpc.NewService() + runtime, err := params.state.Block.GetRuntime(nil) + if err != nil { + return nil, fmt.Errorf("getting runtime: %w", err) + } + + stateVersion, err := trie.ParseVersion(runtime.Version().StateVersion) + if err != nil { + return nil, fmt.Errorf("parsing state version: %w", err) + } + genesisData, err := params.state.Base.LoadGenesisData() if err != nil { return nil, fmt.Errorf("failed to load genesis data: %s", err) } - syncStateSrvc, err := modules.NewStateSync(genesisData, params.state.Storage) + // TODO should we inject the state version from the latest runtime? + // Or somehow find the genesis state version??? + syncStateSrvc, err := modules.NewStateSync(genesisData, params.state.Storage, stateVersion) if err != nil { return nil, fmt.Errorf("failed to create sync state service: %s", err) } diff --git a/dot/state/block.go b/dot/state/block.go index 0135e6bdd75..54351563025 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -17,6 +17,7 @@ import ( "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/runtime" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -651,6 +652,26 @@ func (bs *BlockState) GetRuntime(hash *common.Hash) (runtime.Instance, error) { return bs.bt.GetBlockRuntime(*hash) } +// GetRuntimeStateVersion returns the state version of the runtime +// corresponding to the block hash given. +func (bs *BlockState) GetRuntimeStateVersion(blockHash common.Hash) ( + stateVersion trie.Version, err error) { + // TODO I don't think we should have instantiated instances in memory, we should just fetch + // the wasm code from disk and load it up when needed. + instance, err := bs.GetRuntime(&blockHash) + if err != nil { + return stateVersion, fmt.Errorf("getting runtime instance: %w", err) + } + + version := instance.Version() + stateVersion, err = trie.ParseVersion(version.StateVersion) + if err != nil { + return stateVersion, fmt.Errorf("parsing state version: %w", err) + } + + return stateVersion, nil +} + // StoreRuntime stores the runtime for corresponding block hash. func (bs *BlockState) StoreRuntime(hash common.Hash, rt runtime.Instance) { bs.bt.StoreRuntime(hash, rt) diff --git a/dot/state/initialize.go b/dot/state/initialize.go index 465bc190104..b3f9837dfe3 100644 --- a/dot/state/initialize.go +++ b/dot/state/initialize.go @@ -52,6 +52,12 @@ func (s *Service) Initialise(gen *genesis.Genesis, header *types.Header, t *trie return err } + version := rt.Version() + stateVersion, err := trie.ParseVersion(version.StateVersion) + if err != nil { + return fmt.Errorf("parsing state version: %w", err) + } + babeCfg, err := s.loadBabeConfigurationFromRuntime(rt) if err != nil { return err @@ -64,7 +70,7 @@ func (s *Service) Initialise(gen *genesis.Genesis, header *types.Header, t *trie } tries := NewTries() - tries.SetTrie(t) + tries.SetTrie(t, stateVersion) // create block state from genesis block blockState, err := NewBlockStateFromGenesis(db, tries, header, s.Telemetry) diff --git a/dot/state/offline_pruner.go b/dot/state/offline_pruner.go index 1a59c471e23..b4c28473881 100644 --- a/dot/state/offline_pruner.go +++ b/dot/state/offline_pruner.go @@ -116,7 +116,8 @@ func (p *OfflinePruner) SetBloomFilter() (err error) { // loop from latest to last `retainBlockNum` blocks for blockNum := header.Number; blockNum > 0 && blockNum >= latestBlockNum-uint(p.retainBlockNum); { var tr *trie.Trie - tr, err = p.storageState.LoadFromDB(header.StateRoot) + // TODO do we need this pruner? + tr, err = p.storageState.LoadFromDB(header.StateRoot, trie.V0) if err != nil { return err } diff --git a/dot/state/service.go b/dot/state/service.go index 0d9fa51b7f4..15bb81c3785 100644 --- a/dot/state/service.go +++ b/dot/state/service.go @@ -142,8 +142,14 @@ func (s *Service) Start() (err error) { return fmt.Errorf("failed to create storage state: %w", err) } + bestBlockHash := s.Block.BestBlockHash() + stateVersion, err := s.Block.GetRuntimeStateVersion(bestBlockHash) + if err != nil { + return fmt.Errorf("getting runtime state version: %w", err) + } + // load current storage state trie into memory - _, err = s.Storage.LoadFromDB(stateRoot) + _, err = s.Storage.LoadFromDB(stateRoot, stateVersion) if err != nil { return fmt.Errorf("failed to load storage trie from database: %w", err) } @@ -263,7 +269,8 @@ func (s *Service) Stop() error { // Import imports the given state corresponding to the given header and sets the head of the chain // to it. Additionally, it uses the first slot to correctly set the epoch number of the block. -func (s *Service) Import(header *types.Header, t *trie.Trie, firstSlot uint64) error { +func (s *Service) Import(header *types.Header, t *trie.Trie, + firstSlot uint64, stateVersion trie.Version) error { var err error // initialise database using data directory if !s.isMemDB { @@ -308,7 +315,7 @@ func (s *Service) Import(header *types.Header, t *trie.Trie, firstSlot uint64) e return err } - root := t.MustHash() + root := t.MustHash(stateVersion) if root != header.StateRoot { return fmt.Errorf("trie state root does not equal header state root") } diff --git a/dot/state/storage.go b/dot/state/storage.go index a07468380db..aa4ab56053b 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -73,8 +73,9 @@ func NewStorageState(db chaindb.Database, blockState *BlockState, } // StoreTrie stores the given trie in the StorageState and writes it to the database -func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header) error { - root := ts.MustRoot() +func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header, + stateVersion trie.Version) error { + root := ts.MustRoot(stateVersion) s.tries.softSet(root, ts.Trie()) @@ -102,13 +103,13 @@ func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header) return err } - go s.notifyAll(root) + go s.notifyAll(root, stateVersion) return nil } // TrieState returns the TrieState for a given state root. // If no state root is provided, it returns the TrieState for the current chain head. -func (s *StorageState) TrieState(root *common.Hash) (*rtstorage.TrieState, error) { +func (s *StorageState) TrieState(root *common.Hash, version trie.Version) (*rtstorage.TrieState, error) { if root == nil { sr, err := s.blockState.BestBlockStateRoot() if err != nil { @@ -120,13 +121,13 @@ func (s *StorageState) TrieState(root *common.Hash) (*rtstorage.TrieState, error t := s.tries.get(*root) if t == nil { var err error - t, err = s.LoadFromDB(*root) + t, err = s.LoadFromDB(*root, version) if err != nil { return nil, err } s.tries.softSet(*root, t) - } else if t.MustHash() != *root { + } else if t.MustHash(version) != *root { panic("trie does not have expected root") } @@ -138,18 +139,18 @@ func (s *StorageState) TrieState(root *common.Hash) (*rtstorage.TrieState, error } // LoadFromDB loads an encoded trie from the DB where the key is `root` -func (s *StorageState) LoadFromDB(root common.Hash) (*trie.Trie, error) { +func (s *StorageState) LoadFromDB(root common.Hash, version trie.Version) (*trie.Trie, error) { t := trie.NewEmptyTrie() - err := t.Load(s.db, root) + err := t.Load(s.db, root, version) if err != nil { return nil, err } - s.tries.softSet(t.MustHash(), t) + s.tries.softSet(t.MustHash(version), t) return t, nil } -func (s *StorageState) loadTrie(root *common.Hash) (*trie.Trie, error) { +func (s *StorageState) loadTrie(root *common.Hash, version trie.Version) (*trie.Trie, error) { if root == nil { sr, err := s.blockState.BestBlockStateRoot() if err != nil { @@ -163,7 +164,7 @@ func (s *StorageState) loadTrie(root *common.Hash) (*trie.Trie, error) { return t, nil } - tr, err := s.LoadFromDB(*root) + tr, err := s.LoadFromDB(*root, version) if err != nil { return nil, fmt.Errorf("trie does not exist at root %s: %w", *root, err) } @@ -242,9 +243,11 @@ func (s *StorageState) StorageRoot() (common.Hash, error) { return s.blockState.BestBlockStateRoot() } -// Entries returns Entries from the trie with the given state root -func (s *StorageState) Entries(root *common.Hash) (map[string][]byte, error) { - tr, err := s.loadTrie(root) +// Entries returns the entries from the trie corresponding to the given state +// root as a map of key (string of LE encoded bytes) to value byte slice. +func (s *StorageState) Entries(root *common.Hash, version trie.Version) ( + entries map[string][]byte, err error) { + tr, err := s.loadTrie(root, version) if err != nil { return nil, err } @@ -254,8 +257,9 @@ func (s *StorageState) Entries(root *common.Hash) (map[string][]byte, error) { // GetKeysWithPrefix returns all that match the given prefix for the given hash // (or best block state root if hash is nil) in lexicographic order -func (s *StorageState) GetKeysWithPrefix(root *common.Hash, prefix []byte) ([][]byte, error) { - tr, err := s.loadTrie(root) +func (s *StorageState) GetKeysWithPrefix(root *common.Hash, prefix []byte, + version trie.Version) ([][]byte, error) { + tr, err := s.loadTrie(root, version) if err != nil { return nil, err } @@ -264,8 +268,9 @@ func (s *StorageState) GetKeysWithPrefix(root *common.Hash, prefix []byte) ([][] } // GetStorageChild returns a child trie, if it exists -func (s *StorageState) GetStorageChild(root *common.Hash, keyToChild []byte) (*trie.Trie, error) { - tr, err := s.loadTrie(root) +func (s *StorageState) GetStorageChild(root *common.Hash, keyToChild []byte, + version trie.Version) (*trie.Trie, error) { + tr, err := s.loadTrie(root, version) if err != nil { return nil, err } @@ -274,8 +279,9 @@ func (s *StorageState) GetStorageChild(root *common.Hash, keyToChild []byte) (*t } // GetStorageFromChild get a value from a child trie -func (s *StorageState) GetStorageFromChild(root *common.Hash, keyToChild, key []byte) ([]byte, error) { - tr, err := s.loadTrie(root) +func (s *StorageState) GetStorageFromChild(root *common.Hash, keyToChild, + key []byte, version trie.Version) (value []byte, err error) { + tr, err := s.loadTrie(root, version) if err != nil { return nil, err } @@ -299,7 +305,7 @@ func (s *StorageState) LoadCodeHash(hash *common.Hash) (common.Hash, error) { } // GenerateTrieProof returns the proofs related to the keys on the state root trie -func (s *StorageState) GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ( - encodedProofNodes [][]byte, err error) { - return proof.Generate(stateRoot[:], keys, s.db) +func (s *StorageState) GenerateTrieProof(stateRoot common.Hash, keys [][]byte, + version trie.Version) (encodedProofNodes [][]byte, err error) { + return proof.Generate(stateRoot[:], keys, s.db, version) } diff --git a/dot/state/storage_notify.go b/dot/state/storage_notify.go index 8160813a3b5..0f1456a0bc8 100644 --- a/dot/state/storage_notify.go +++ b/dot/state/storage_notify.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" ) // KeyValue struct to hold key value pairs @@ -47,21 +48,27 @@ type Observer interface { } // RegisterStorageObserver to add abserver to notification list -func (s *StorageState) RegisterStorageObserver(o Observer) { +func (s *StorageState) RegisterStorageObserver(o Observer) (err error) { s.observerList = append(s.observerList, o) // notifyObserver here to send storage value of current state sr, err := s.blockState.BestBlockStateRoot() if err != nil { - logger.Debugf("error registering storage change channel: %s", err) - return + return fmt.Errorf("getting best block state root: %w", err) } + + stateVersion, err := s.blockState.GetRuntimeStateVersion(sr) + if err != nil { + return fmt.Errorf("getting runtime state version: %w", err) + } + go func() { - if err := s.notifyObserver(sr, o); err != nil { + if err := s.notifyObserver(sr, o, stateVersion); err != nil { logger.Warnf("failed to notify storage subscriptions: %s", err) } }() + return nil } // UnregisterStorageObserver removes observer from notification list @@ -69,19 +76,20 @@ func (s *StorageState) UnregisterStorageObserver(o Observer) { s.observerList = s.removeFromSlice(s.observerList, o) } -func (s *StorageState) notifyAll(root common.Hash) { +func (s *StorageState) notifyAll(root common.Hash, stateVersion trie.Version) { s.changedLock.RLock() defer s.changedLock.RUnlock() for _, observer := range s.observerList { - err := s.notifyObserver(root, observer) + err := s.notifyObserver(root, observer, stateVersion) if err != nil { logger.Warnf("failed to notify storage subscriptions: %s", err) } } } -func (s *StorageState) notifyObserver(root common.Hash, o Observer) error { - t, err := s.TrieState(&root) +func (s *StorageState) notifyObserver(root common.Hash, o Observer, + stateVersion trie.Version) (err error) { + t, err := s.TrieState(&root, stateVersion) if err != nil { return err } diff --git a/dot/state/test_helpers.go b/dot/state/test_helpers.go index cd02d9b358f..691c6459ce1 100644 --- a/dot/state/test_helpers.go +++ b/dot/state/test_helpers.go @@ -236,17 +236,18 @@ func AddBlocksToStateWithFixedBranches(t *testing.T, blockState *BlockState, dep } func generateBlockWithRandomTrie(t *testing.T, serv *Service, - parent *common.Hash, bNum uint) (*types.Block, *runtime.TrieState) { - trieState, err := serv.Storage.TrieState(nil) + parent *common.Hash, bNum uint, stateVersion trie.Version) ( + block *types.Block, trieState *runtime.TrieState) { + trieState, err := serv.Storage.TrieState(nil, stateVersion) require.NoError(t, err) // Generate random data for trie state. rand := time.Now().UnixNano() key := []byte("testKey" + fmt.Sprint(rand)) value := []byte("testValue" + fmt.Sprint(rand)) - trieState.Set(key, value) + trieState.Set(key, value, stateVersion) - trieStateRoot, err := trieState.Root() + trieStateRoot, err := trieState.Root(stateVersion) require.NoError(t, err) if parent == nil { @@ -257,7 +258,7 @@ func generateBlockWithRandomTrie(t *testing.T, serv *Service, body, err := types.NewBodyFromBytes([]byte{}) require.NoError(t, err) - block := &types.Block{ + block = &types.Block{ Header: types.Header{ ParentHash: *parent, Number: bNum, diff --git a/dot/state/tries.go b/dot/state/tries.go index 5800d33e12c..bc461cfb461 100644 --- a/dot/state/tries.go +++ b/dot/state/tries.go @@ -55,8 +55,19 @@ func (t *Tries) SetEmptyTrie() { t.softSet(trie.EmptyHash, trie.NewEmptyTrie()) } -func (t *Tries) SetTrie(trie *trie.Trie) { - t.softSet(trie.MustHash(), trie) +func (t *Tries) SetTrie(trie *trie.Trie, stateVersion trie.Version) { + t.softSet(trie.MustHash(stateVersion), trie) +} + +func NewTriesWithEmptyTrie() (trs *Tries) { + return &Tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + trie.EmptyHash: trie.NewEmptyTrie(), + }, + triesGauge: triesGauge, + setCounter: setCounter, + deleteCounter: deleteCounter, + } } // softSet sets the given trie at the given root hash diff --git a/dot/sync/chain_processor.go b/dot/sync/chain_processor.go index 0e03b0426eb..10555ebd990 100644 --- a/dot/sync/chain_processor.go +++ b/dot/sync/chain_processor.go @@ -4,7 +4,6 @@ package sync import ( - "bytes" "context" "errors" "fmt" @@ -12,6 +11,7 @@ import ( "github.com/ChainSafe/gossamer/dot/telemetry" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/blocktree" + "github.com/ChainSafe/gossamer/lib/trie" ) //go:generate mockgen -destination=mock_chain_processor_test.go -package=$GOPACKAGE . ChainProcessor @@ -139,10 +139,20 @@ func (s *chainProcessor) processBlockData(bd *types.BlockData) error { s.handleJustification(&block.Header, *bd.Justification) } + runtimeInstance, err := s.blockState.GetRuntime(&block.Header.ParentHash) + if err != nil { + return fmt.Errorf("getting runtime for parent hash: %w", err) + } + + stateVersion, err := trie.ParseVersion(runtimeInstance.Version().StateVersion) + if err != nil { + return fmt.Errorf("parsing state version: %w", err) + } + // TODO: this is probably unnecessary, since the state is already in the database // however, this case shouldn't be hit often, since it's only hit if the node state // is rewinded or if the node shuts down unexpectedly (#1784) - state, err := s.storageState.TrieState(&block.Header.StateRoot) + state, err := s.storageState.TrieState(&block.Header.StateRoot, stateVersion) if err != nil { logger.Warnf("failed to load state for block with hash %s: %s", block.Header.Hash(), err) return err @@ -196,32 +206,39 @@ func (s *chainProcessor) handleBody(body *types.Body) { } } -// handleHeader handles blocks (header+body) included in BlockResponses +// handleBlock handles blocks (header+body) included in BlockResponses func (s *chainProcessor) handleBlock(block *types.Block) error { parent, err := s.blockState.GetHeader(block.Header.ParentHash) if err != nil { return fmt.Errorf("%w: %s", errFailedToGetParent, err) } + hash := parent.Hash() + rt, err := s.blockState.GetRuntime(&hash) + if err != nil { + return fmt.Errorf("getting runtime for parent hash: %w", err) + } + + stateVersion, err := trie.ParseVersion(rt.Version().StateVersion) + if err != nil { + return fmt.Errorf("parsing state version: %w", err) + } + s.storageState.Lock() defer s.storageState.Unlock() - ts, err := s.storageState.TrieState(&parent.StateRoot) + ts, err := s.storageState.TrieState(&parent.StateRoot, stateVersion) if err != nil { return err } - root := ts.MustRoot() - if !bytes.Equal(parent.StateRoot[:], root[:]) { + // TODO shall we remove this? Both are coming from the internal state + rootV0 := ts.MustRoot(trie.V0) + if !rootV0.Equal(parent.StateRoot) { + // TODO add extra check with rootV1 when v1 is supported panic("parent state root does not match snapshot state root") } - hash := parent.Hash() - rt, err := s.blockState.GetRuntime(&hash) - if err != nil { - return err - } - rt.SetContextStorage(ts) _, err = rt.ExecuteBlock(block) diff --git a/dot/sync/interface.go b/dot/sync/interface.go index e1a2a6d4125..fcf7dcff825 100644 --- a/dot/sync/interface.go +++ b/dot/sync/interface.go @@ -12,6 +12,7 @@ import ( "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/runtime" rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/libp2p/go-libp2p-core/peer" ) @@ -50,7 +51,7 @@ type BlockState interface { // StorageState is the interface for the storage state type StorageState interface { - TrieState(root *common.Hash) (*rtstorage.TrieState, error) + TrieState(root *common.Hash, stateVersion trie.Version) (*rtstorage.TrieState, error) LoadCodeHash(*common.Hash) (common.Hash, error) sync.Locker } diff --git a/lib/babe/babe.go b/lib/babe/babe.go index d0dd5fa4e72..e87183b0ae4 100644 --- a/lib/babe/babe.go +++ b/lib/babe/babe.go @@ -15,6 +15,7 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/log" "github.com/ChainSafe/gossamer/lib/crypto/sr25519" + "github.com/ChainSafe/gossamer/lib/trie" ethmetrics "github.com/ethereum/go-ethereum/metrics" ) @@ -489,9 +490,7 @@ func (b *Service) handleEpoch(epoch uint64) (next uint64, err error) { } func (b *Service) handleSlot(epoch, slotNum uint64, - authorityIndex uint32, - preRuntimeDigest *types.PreRuntimeDigest, -) error { + authorityIndex uint32, preRuntimeDigest *types.PreRuntimeDigest) error { parentHeader, err := b.blockState.BestBlockHeader() if err != nil { return err @@ -517,20 +516,22 @@ func (b *Service) handleSlot(epoch, slotNum uint64, b.storageState.Lock() defer b.storageState.Unlock() + hash := parent.Hash() + rt, err := b.blockState.GetRuntime(&hash) + if err != nil { + return fmt.Errorf("getting runtime: %w", err) + } + + stateVersion := trie.Version(rt.Version().StateVersion) + // set runtime trie before building block // if block building is successful, store the resulting trie in the storage state - ts, err := b.storageState.TrieState(&parent.StateRoot) + ts, err := b.storageState.TrieState(&parent.StateRoot, stateVersion) if err != nil || ts == nil { logger.Errorf("failed to get parent trie with parent state root %s: %s", parent.StateRoot, err) return err } - hash := parent.Hash() - rt, err := b.blockState.GetRuntime(&hash) - if err != nil { - return err - } - rt.SetContextStorage(ts) block, err := b.buildBlock(parent, currentSlot, rt, authorityIndex, preRuntimeDigest) diff --git a/lib/babe/epoch_handler.go b/lib/babe/epoch_handler.go index 52e75ba3dd8..2f860431e54 100644 --- a/lib/babe/epoch_handler.go +++ b/lib/babe/epoch_handler.go @@ -14,7 +14,8 @@ import ( "github.com/ChainSafe/gossamer/lib/crypto/sr25519" ) -type handleSlotFunc = func(epoch, slotNum uint64, authorityIndex uint32, preRuntimeDigest *types.PreRuntimeDigest) error +type handleSlotFunc = func(epoch, slotNum uint64, authorityIndex uint32, + preRuntimeDigest *types.PreRuntimeDigest) error var ( errEpochPast = errors.New("cannot run epoch that has already passed") @@ -135,7 +136,8 @@ func (h *epochHandler) run(ctx context.Context, errCh chan<- error) { panic(fmt.Sprintf("no VRF proof for authoring slot! slot=%d", swt.slotNum)) } - err := h.handleSlot(h.epochNumber, swt.slotNum, h.epochData.authorityIndex, h.slotToPreRuntimeDigest[swt.slotNum]) + err := h.handleSlot(h.epochNumber, swt.slotNum, h.epochData.authorityIndex, + h.slotToPreRuntimeDigest[swt.slotNum]) if err != nil { logger.Warnf("failed to handle slot %d: %s", swt.slotNum, err) continue diff --git a/lib/babe/state.go b/lib/babe/state.go index 897f06bd35c..d9e9b5e5849 100644 --- a/lib/babe/state.go +++ b/lib/babe/state.go @@ -12,6 +12,7 @@ import ( "github.com/ChainSafe/gossamer/lib/runtime" rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" "github.com/ChainSafe/gossamer/lib/transaction" + "github.com/ChainSafe/gossamer/lib/trie" ) //go:generate mockgen -destination=./mock_state_test.go -package $GOPACKAGE . BlockState,ImportedBlockNotifierManager,StorageState,TransactionState,EpochState,DigestHandler,BlockImportHandler @@ -47,7 +48,7 @@ type ImportedBlockNotifierManager interface { // StorageState interface for storage state methods type StorageState interface { - TrieState(hash *common.Hash) (*rtstorage.TrieState, error) + TrieState(hash *common.Hash, stateVersion trie.Version) (*rtstorage.TrieState, error) sync.Locker } diff --git a/lib/genesis/genesis.go b/lib/genesis/genesis.go index 9d27cd07993..4cf7b8e7f14 100644 --- a/lib/genesis/genesis.go +++ b/lib/genesis/genesis.go @@ -4,6 +4,9 @@ package genesis import ( + "errors" + "fmt" + "github.com/ChainSafe/gossamer/lib/common" ) @@ -94,6 +97,23 @@ func (g *Genesis) ToRaw() error { return nil } +var ErrRuntimeCodeNotFound = errors.New("runtime code not found") + +func (g *Genesis) RuntimeCode() (runtimeCode []byte, err error) { + const hexCodeTrieKey = "0x3a636f6465" // :code in hexadecimal + hexRuntimeCode, ok := g.Genesis.Raw["top"][hexCodeTrieKey] + if !ok { + return nil, ErrRuntimeCodeNotFound + } + + runtimeCode, err = common.HexToBytes(hexRuntimeCode) + if err != nil { + return nil, fmt.Errorf("converting runtime code hex to bytes: %w", err) + } + + return runtimeCode, nil +} + func interfaceToTelemetryEndpoint(endpoints []interface{}) []*TelemetryEndpoint { var res []*TelemetryEndpoint for _, v := range endpoints { diff --git a/lib/genesis/helpers.go b/lib/genesis/helpers.go index 5907b26b42c..47948612d72 100644 --- a/lib/genesis/helpers.go +++ b/lib/genesis/helpers.go @@ -57,12 +57,12 @@ func NewGenesisFromJSONRaw(file string) (*Genesis, error) { } // NewTrieFromGenesis creates a new trie from the raw genesis data -func NewTrieFromGenesis(g *Genesis) (*trie.Trie, error) { +func NewTrieFromGenesis(g *Genesis, stateVersion trie.Version) (*trie.Trie, error) { t := trie.NewEmptyTrie() r := g.GenesisFields().Raw["top"] - err := t.LoadFromMap(r) + err := t.LoadFromMap(r, stateVersion) if err != nil { return nil, fmt.Errorf("failed to create trie from genesis: %s", err) } @@ -71,10 +71,9 @@ func NewTrieFromGenesis(g *Genesis) (*trie.Trie, error) { } // NewGenesisBlockFromTrie creates a genesis block from the provided trie -func NewGenesisBlockFromTrie(t *trie.Trie) (*types.Header, error) { - +func NewGenesisBlockFromTrie(t *trie.Trie, version trie.Version) (*types.Header, error) { // create state root from trie hash - stateRoot, err := t.Hash() + stateRoot, err := t.Hash(version) if err != nil { return nil, fmt.Errorf("failed to create state root from trie hash: %s", err) } diff --git a/lib/genesis/test_utils.go b/lib/genesis/test_utils.go index f0bfe01ec37..32a5cac878e 100644 --- a/lib/genesis/test_utils.go +++ b/lib/genesis/test_utils.go @@ -76,32 +76,32 @@ func CreateTestGenesisJSONFile(t *testing.T, fields Fields) (filename string) { } // NewTestGenesisWithTrieAndHeader generates genesis, genesis trie and genesis header -func NewTestGenesisWithTrieAndHeader(t *testing.T) (*Genesis, *trie.Trie, *types.Header) { +func NewTestGenesisWithTrieAndHeader(t *testing.T, version trie.Version) (*Genesis, *trie.Trie, *types.Header) { genesisPath := utils.GetGssmrV3SubstrateGenesisRawPathTest(t) gen, err := NewGenesisFromJSONRaw(genesisPath) require.NoError(t, err) - tr, h := newGenesisTrieAndHeader(t, gen) + tr, h := newGenesisTrieAndHeader(t, gen, version) return gen, tr, h } // NewDevGenesisWithTrieAndHeader generates test dev genesis, genesis trie and genesis header -func NewDevGenesisWithTrieAndHeader(t *testing.T) (*Genesis, *trie.Trie, *types.Header) { +func NewDevGenesisWithTrieAndHeader(t *testing.T, version trie.Version) (*Genesis, *trie.Trie, *types.Header) { genesisPath := utils.GetDevV3SubstrateGenesisPath(t) gen, err := NewGenesisFromJSONRaw(genesisPath) require.NoError(t, err) - tr, h := newGenesisTrieAndHeader(t, gen) + tr, h := newGenesisTrieAndHeader(t, gen, version) return gen, tr, h } -func newGenesisTrieAndHeader(t *testing.T, gen *Genesis) (*trie.Trie, *types.Header) { - genTrie, err := NewTrieFromGenesis(gen) +func newGenesisTrieAndHeader(t *testing.T, gen *Genesis, version trie.Version) (*trie.Trie, *types.Header) { + genTrie, err := NewTrieFromGenesis(gen, version) require.NoError(t, err) genesisHeader, err := types.NewHeader(common.NewHash([]byte{0}), - genTrie.MustHash(), trie.EmptyHash, 0, types.NewDigest()) + genTrie.MustHash(version), trie.EmptyHash, 0, types.NewDigest()) require.NoError(t, err) return genTrie, genesisHeader diff --git a/lib/runtime/allocator.go b/lib/runtime/allocator.go index 0e84431822f..5ffba46c5e1 100644 --- a/lib/runtime/allocator.go +++ b/lib/runtime/allocator.go @@ -82,9 +82,11 @@ func (fbha *FreeingBumpHeapAllocator) growHeap(numPages uint32) error { return nil } -// Allocate determines if there is space available in WASM heap to grow the heap by 'size'. If there is space -// available it grows the heap to fit give 'size'. The heap grows is chunks of Powers of 2, so the growth becomes -// the next highest power of 2 of the requested size. +// Allocate determines if there is space available in WASM heap to grow +// the heap by 'size' bytes. If there is space available it grows the heap +// to fit the given 'size' number of bytes. The heap grows is chunks of +// powers of 2, so the next growed heap size becomes the next highest power +// of 2 of the requested size. func (fbha *FreeingBumpHeapAllocator) Allocate(size uint32) (uint32, error) { // test for space allocation if size > MaxPossibleAllocation { diff --git a/lib/runtime/interface.go b/lib/runtime/interface.go index 6fc9f3a01a5..b577c5af171 100644 --- a/lib/runtime/interface.go +++ b/lib/runtime/interface.go @@ -50,11 +50,11 @@ type Instance interface { // Storage interface type Storage interface { - Set(key []byte, value []byte) + Set(key []byte, value []byte, version trie.Version) Get(key []byte) []byte - Root() (common.Hash, error) - SetChild(keyToChild []byte, child *trie.Trie) error - SetChildStorage(keyToChild, key, value []byte) error + Root(version trie.Version) (common.Hash, error) + SetChild(keyToChild []byte, child *trie.Trie, version trie.Version) error + SetChildStorage(keyToChild, key, value []byte, version trie.Version) error GetChildStorage(keyToChild, key []byte) ([]byte, error) Delete(key []byte) DeleteChild(keyToChild []byte) diff --git a/lib/runtime/storage/trie.go b/lib/runtime/storage/trie.go index aace0977b4e..defc02453a1 100644 --- a/lib/runtime/storage/trie.go +++ b/lib/runtime/storage/trie.go @@ -68,10 +68,10 @@ func (s *TrieState) RollbackStorageTransaction() { } // Set sets a key-value pair in the trie -func (s *TrieState) Set(key, value []byte) { +func (s *TrieState) Set(key, value []byte, version trie.Version) { s.lock.Lock() defer s.lock.Unlock() - s.t.Put(key, value) + s.t.Put(key, value, version) } // Get gets a value from the trie @@ -82,13 +82,13 @@ func (s *TrieState) Get(key []byte) []byte { } // MustRoot returns the trie's root hash. It panics if it fails to compute the root. -func (s *TrieState) MustRoot() common.Hash { - return s.t.MustHash() +func (s *TrieState) MustRoot(version trie.Version) common.Hash { + return s.t.MustHash(version) } // Root returns the trie's root hash -func (s *TrieState) Root() (common.Hash, error) { - return s.t.Hash() +func (s *TrieState) Root(version trie.Version) (common.Hash, error) { + return s.t.Hash(version) } // Has returns whether or not a key exists @@ -139,17 +139,19 @@ func (s *TrieState) TrieEntries() map[string][]byte { } // SetChild sets the child trie at the given key -func (s *TrieState) SetChild(keyToChild []byte, child *trie.Trie) error { +func (s *TrieState) SetChild(keyToChild []byte, child *trie.Trie, + version trie.Version) (err error) { s.lock.Lock() defer s.lock.Unlock() - return s.t.PutChild(keyToChild, child) + return s.t.PutChild(keyToChild, child, version) } // SetChildStorage sets a key-value pair in a child trie -func (s *TrieState) SetChildStorage(keyToChild, key, value []byte) error { +func (s *TrieState) SetChildStorage(keyToChild, key, value []byte, + version trie.Version) (err error) { s.lock.Lock() defer s.lock.Unlock() - return s.t.PutIntoChild(keyToChild, key, value) + return s.t.PutIntoChild(keyToChild, key, value, version) } // GetChild returns the child trie at the given key diff --git a/lib/runtime/wasmer/imports.go b/lib/runtime/wasmer/imports.go index d6e9c21b99b..6d0c3a26f85 100644 --- a/lib/runtime/wasmer/imports.go +++ b/lib/runtime/wasmer/imports.go @@ -825,7 +825,7 @@ func ext_trie_blake2_256_root_version_1(context unsafe.Pointer, dataSpan C.int64 } for _, kv := range kvs { - t.Put(kv.Key, kv.Value) + t.Put(kv.Key, kv.Value, trie.V0) } // allocate memory for value and copy value to memory @@ -835,7 +835,7 @@ func ext_trie_blake2_256_root_version_1(context unsafe.Pointer, dataSpan C.int64 return 0 } - hash, err := t.Hash() + hash, err := t.Hash(trie.V0) if err != nil { logger.Errorf("failed computing trie Merkle root hash: %s", err) return 0 @@ -873,7 +873,7 @@ func ext_trie_blake2_256_ordered_root_version_1(context unsafe.Pointer, dataSpan "put key=0x%x and value=0x%x", key, val) - t.Put(key, val) + t.Put(key, val, trie.V0) } // allocate memory for value and copy value to memory @@ -883,7 +883,7 @@ func ext_trie_blake2_256_ordered_root_version_1(context unsafe.Pointer, dataSpan return 0 } - hash, err := t.Hash() + hash, err := t.Hash(trie.V0) if err != nil { logger.Errorf("failed computing trie Merkle root hash: %s", err) return 0 @@ -922,7 +922,7 @@ func ext_trie_blake2_256_verify_proof_version_1(context unsafe.Pointer, mem := instanceContext.Memory().Data() trieRoot := mem[rootSpan : rootSpan+32] - err = proof.Verify(encodedProofNodes, trieRoot, key, value) + err = proof.Verify(encodedProofNodes, trieRoot, key, value, trie.V0) if err != nil { logger.Errorf("failed proof verification: %s", err) return C.int32_t(0) @@ -1142,7 +1142,7 @@ func ext_default_child_storage_root_version_1(context unsafe.Pointer, return 0 } - childRoot, err := child.Hash() + childRoot, err := child.Hash(trie.V0) if err != nil { logger.Errorf("failed to encode child root: %s", err) return 0 @@ -1173,7 +1173,8 @@ func ext_default_child_storage_set_version_1(context unsafe.Pointer, cp := make([]byte, len(value)) copy(cp, value) - err := storage.SetChildStorage(childStorageKey, key, cp) + // TODO obtain version from context + err := storage.SetChildStorage(childStorageKey, key, cp, trie.V0) if err != nil { logger.Errorf("failed to set value in child storage: %s", err) return @@ -1809,7 +1810,7 @@ func ext_offchain_http_request_add_header_version_1(context unsafe.Pointer, return C.int64_t(ptr) } -func storageAppend(storage runtime.Storage, key, valueToAppend []byte) error { +func storageAppend(storage runtime.Storage, key, valueToAppend []byte, version trie.Version) error { nextLength := big.NewInt(1) var valueRes []byte @@ -1825,7 +1826,7 @@ func storageAppend(storage runtime.Storage, key, valueToAppend []byte) error { if err != nil { logger.Tracef( "item in storage is not SCALE encoded, overwriting at key 0x%x", key) - storage.Set(key, append([]byte{4}, valueToAppend...)) + storage.Set(key, append([]byte{4}, valueToAppend...), version) return nil //nolint:nilerr } @@ -1850,7 +1851,7 @@ func storageAppend(storage runtime.Storage, key, valueToAppend []byte) error { // append new length prefix to start of items array lengthEnc = append(lengthEnc, valueRes...) logger.Debugf("resulting value: 0x%x", lengthEnc) - storage.Set(key, lengthEnc) + storage.Set(key, lengthEnc, version) return nil } @@ -1870,7 +1871,8 @@ func ext_storage_append_version_1(context unsafe.Pointer, keySpan, valueSpan C.i cp := make([]byte, len(valueAppend)) copy(cp, valueAppend) - err := storageAppend(storage, key, cp) + // TODO obtain version from context + err := storageAppend(storage, key, cp, trie.V0) if err != nil { logger.Errorf("failed appending to storage: %s", err) } @@ -2071,7 +2073,7 @@ func ext_storage_root_version_1(context unsafe.Pointer) C.int64_t { instanceContext := wasm.IntoInstanceContext(context) storage := instanceContext.Data().(*runtime.Context).Storage - root, err := storage.Root() + root, err := storage.Root(trie.V0) if err != nil { logger.Errorf("failed to get storage root: %s", err) return 0 @@ -2111,7 +2113,8 @@ func ext_storage_set_version_1(context unsafe.Pointer, keySpan, valueSpan C.int6 logger.Debugf( "key 0x%x has value 0x%x", key, value) - storage.Set(key, cp) + // TODO obtain version from context + storage.Set(key, cp, trie.V0) } //export ext_storage_start_transaction_version_1 diff --git a/lib/trie/child_storage.go b/lib/trie/child_storage.go index fc80da533f3..0f8161478f9 100644 --- a/lib/trie/child_storage.go +++ b/lib/trie/child_storage.go @@ -18,8 +18,8 @@ var ErrChildTrieDoesNotExist = errors.New("child trie does not exist") // PutChild inserts a child trie into the main trie at key :child_storage:[keyToChild] // A child trie is added as a node (K, V) in the main trie. K is the child storage key // associated to the child trie, and V is the root hash of the child trie. -func (t *Trie) PutChild(keyToChild []byte, child *Trie) error { - childHash, err := child.Hash() +func (t *Trie) PutChild(keyToChild []byte, child *Trie, version Version) error { + childHash, err := child.Hash(version) if err != nil { return err } @@ -28,7 +28,7 @@ func (t *Trie) PutChild(keyToChild []byte, child *Trie) error { copy(key, ChildStorageKeyPrefix) copy(key[len(ChildStorageKeyPrefix):], keyToChild) - t.Put(key, childHash.ToBytes()) + t.Put(key, childHash.ToBytes(), version) t.childTries[childHash] = child return nil } @@ -48,19 +48,19 @@ func (t *Trie) GetChild(keyToChild []byte) (*Trie, error) { } // PutIntoChild puts a key-value pair into the child trie located in the main trie at key :child_storage:[keyToChild] -func (t *Trie) PutIntoChild(keyToChild, key, value []byte) error { +func (t *Trie) PutIntoChild(keyToChild, key, value []byte, version Version) error { child, err := t.GetChild(keyToChild) if err != nil { return err } - origChildHash, err := child.Hash() + origChildHash, err := child.Hash(version) if err != nil { return err } - child.Put(key, value) - childHash, err := child.Hash() + child.Put(key, value, version) + childHash, err := child.Hash(version) if err != nil { return err } @@ -68,7 +68,7 @@ func (t *Trie) PutIntoChild(keyToChild, key, value []byte) error { delete(t.childTries, origChildHash) t.childTries[childHash] = child - return t.PutChild(keyToChild, child) + return t.PutChild(keyToChild, child, version) } // GetFromChild retrieves a key-value pair from the child trie located diff --git a/lib/trie/database.go b/lib/trie/database.go index 35753fc303e..75a7829e373 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -83,7 +83,8 @@ func (t *Trie) storeNode(db chaindb.Batch, n *Node) (err error) { // Load reconstructs the trie from the database from the given root hash. // It is used when restarting the node to load the current state trie. -func (t *Trie) Load(db Database, rootHash common.Hash) error { +func (t *Trie) Load(db Database, rootHash common.Hash, + version Version) (err error) { if rootHash == EmptyHash { t.root = nil return nil @@ -106,10 +107,10 @@ func (t *Trie) Load(db Database, rootHash common.Hash) error { t.root.Encoding = encodedNode t.root.MerkleValue = rootHashBytes - return t.loadNode(db, t.root) + return t.loadNode(db, t.root, version) } -func (t *Trie) loadNode(db Database, n *Node) error { +func (t *Trie) loadNode(db Database, n *Node, version Version) error { if n.Kind() != node.Branch { return nil } @@ -149,7 +150,7 @@ func (t *Trie) loadNode(db Database, n *Node) error { decodedNode.MerkleValue = merkleValue branch.Children[i] = decodedNode - err = t.loadNode(db, decodedNode) + err = t.loadNode(db, decodedNode, version) if err != nil { return fmt.Errorf("loading child at index %d with Merkle value 0x%x: %w", i, merkleValue, err) } @@ -170,12 +171,12 @@ func (t *Trie) loadNode(db Database, n *Node) error { childTrie := NewEmptyTrie() value := t.Get(key) rootHash := common.BytesToHash(value) - err := childTrie.Load(db, rootHash) + err := childTrie.Load(db, rootHash, version) if err != nil { return fmt.Errorf("failed to load child trie with root hash=%s: %w", rootHash, err) } - hash, err := childTrie.Hash() + hash, err := childTrie.Hash(version) if err != nil { return fmt.Errorf("cannot hash chilld trie at key 0x%x: %w", key, err) } @@ -208,8 +209,8 @@ func (t *Trie) PopulateNodeHashes(n *Node, hashesSet map[common.Hash]struct{}) { // PutInDB inserts a value in the trie at the key given. // It writes the updated nodes from the changed node up to the root node // to the database in a batch operation. -func (t *Trie) PutInDB(db chaindb.Database, key, value []byte) error { - t.Put(key, value) +func (t *Trie) PutInDB(db chaindb.Database, key, value []byte, version Version) error { + t.Put(key, value, version) return t.WriteDirty(db) } diff --git a/lib/trie/proof/generate.go b/lib/trie/proof/generate.go index eb8075de9d9..27b545b4ded 100644 --- a/lib/trie/proof/generate.go +++ b/lib/trie/proof/generate.go @@ -29,10 +29,11 @@ type Database interface { // for the trie corresponding to the root hash given, and for // the slice of (Little Endian) full keys given. The database given // is used to load the trie using the root hash given. -func Generate(rootHash []byte, fullKeys [][]byte, database Database) ( - encodedProofNodes [][]byte, err error) { +func Generate(rootHash []byte, fullKeys [][]byte, database Database, + version trie.Version) (encodedProofNodes [][]byte, err error) { trie := trie.NewEmptyTrie() - if err := trie.Load(database, common.BytesToHash(rootHash)); err != nil { + err = trie.Load(database, common.BytesToHash(rootHash), version) + if err != nil { return nil, fmt.Errorf("loading trie: %w", err) } rootNode := trie.RootNode() diff --git a/lib/trie/proof/verify.go b/lib/trie/proof/verify.go index 93aff2054ef..a53bb318c80 100644 --- a/lib/trie/proof/verify.go +++ b/lib/trie/proof/verify.go @@ -25,7 +25,8 @@ var ( // A nil error is returned on success. // Note this is exported because it is imported and used by: // https://github.com/ComposableFi/ibc-go/blob/6d62edaa1a3cb0768c430dab81bb195e0b0c72db/modules/light-clients/11-beefy/types/client_state.go#L78 -func Verify(encodedProofNodes [][]byte, rootHash, key, value []byte) (err error) { +func Verify(encodedProofNodes [][]byte, rootHash, key, value []byte, + version trie.Version) (err error) { proofTrie, err := buildTrie(encodedProofNodes, rootHash) if err != nil { return fmt.Errorf("building trie from proof encoded nodes: %w", err) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 72d661751c9..99075761356 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -13,8 +13,11 @@ import ( "github.com/ChainSafe/gossamer/lib/common" ) -// EmptyHash is the empty trie hash. -var EmptyHash, _ = NewEmptyTrie().Hash() +var ( + // EmptyHash is the empty trie hash, which is the same for + // both the V0 and V1 state trie versions + EmptyHash, _ = NewEmptyTrie().Hash(V0) +) // Trie is a base 16 modified Merkle Patricia trie. type Trie struct { @@ -167,8 +170,8 @@ func encodeRoot(root *Node, buffer node.Buffer) (err error) { // MustHash returns the hashed root of the trie. // It panics if it fails to hash the root node. -func (t *Trie) MustHash() common.Hash { - h, err := t.Hash() +func (t *Trie) MustHash(version Version) common.Hash { + h, err := t.Hash(version) if err != nil { panic(err) } @@ -177,7 +180,7 @@ func (t *Trie) MustHash() common.Hash { } // Hash returns the hashed root of the trie. -func (t *Trie) Hash() (rootHash common.Hash, err error) { +func (t *Trie) Hash(version Version) (rootHash common.Hash, err error) { buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() defer pools.EncodingBuffers.Put(buffer) @@ -317,14 +320,15 @@ func findNextKeyChild(children []*Node, startIndex byte, // Put inserts a value into the trie at the // key specified in little Endian format. -func (t *Trie) Put(keyLE, value []byte) { +func (t *Trie) Put(keyLE, value []byte, version Version) { nibblesKey := codec.KeyLEToNibbles(keyLE) - t.root, _ = t.insert(t.root, nibblesKey, value) + t.root, _ = t.insert(t.root, nibblesKey, value, version) } // insert inserts a value in the trie at the key specified. // It may create one or more new nodes or update an existing node. -func (t *Trie) insert(parent *Node, key, value []byte) (newParent *Node, nodesCreated uint32) { +func (t *Trie) insert(parent *Node, key, value []byte, version Version) ( + newParent *Node, nodesCreated uint32) { if parent == nil { const nodesCreated = 1 return &Node{ @@ -338,12 +342,12 @@ func (t *Trie) insert(parent *Node, key, value []byte) (newParent *Node, nodesCr // TODO ensure all values have dirty set to true if parent.Kind() == node.Branch { - return t.insertInBranch(parent, key, value) + return t.insertInBranch(parent, key, value, version) } - return t.insertInLeaf(parent, key, value) + return t.insertInLeaf(parent, key, value, version) } -func (t *Trie) insertInLeaf(parentLeaf *Node, key, value []byte) ( +func (t *Trie) insertInLeaf(parentLeaf *Node, key, value []byte, version Version) ( newParent *Node, nodesCreated uint32) { if bytes.Equal(parentLeaf.Key, key) { nodesCreated = 0 @@ -413,7 +417,7 @@ func (t *Trie) insertInLeaf(parentLeaf *Node, key, value []byte) ( return newBranchParent, nodesCreated } -func (t *Trie) insertInBranch(parentBranch *Node, key, value []byte) ( +func (t *Trie) insertInBranch(parentBranch *Node, key, value []byte, version Version) ( newParent *Node, nodesCreated uint32) { copySettings := node.DefaultCopySettings parentBranch = t.prepBranchForMutation(parentBranch, copySettings) @@ -439,7 +443,7 @@ func (t *Trie) insertInBranch(parentBranch *Node, key, value []byte) ( } nodesCreated = 1 } else { - child, nodesCreated = t.insert(child, remainingKey, value) + child, nodesCreated = t.insert(child, remainingKey, value, version) } parentBranch.Children[childIndex] = child @@ -471,7 +475,7 @@ func (t *Trie) insertInBranch(parentBranch *Node, key, value []byte) ( childIndex := key[commonPrefixLength] remainingKey := key[commonPrefixLength+1:] var additionalNodesCreated uint32 - newParentBranch.Children[childIndex], additionalNodesCreated = t.insert(nil, remainingKey, value) + newParentBranch.Children[childIndex], additionalNodesCreated = t.insert(nil, remainingKey, value, version) nodesCreated += additionalNodesCreated newParentBranch.Descendants += additionalNodesCreated } @@ -482,7 +486,7 @@ func (t *Trie) insertInBranch(parentBranch *Node, key, value []byte) ( // LoadFromMap loads the given data mapping of key to value into the trie. // The keys are in hexadecimal little Endian encoding and the values // are hexadecimal encoded. -func (t *Trie) LoadFromMap(data map[string]string) (err error) { +func (t *Trie) LoadFromMap(data map[string]string, version Version) (err error) { for key, value := range data { keyLEBytes, err := common.HexToBytes(key) if err != nil { @@ -494,7 +498,7 @@ func (t *Trie) LoadFromMap(data map[string]string) (err error) { return fmt.Errorf("cannot convert value hex to bytes: %w", err) } - t.Put(keyLEBytes, valueBytes) + t.Put(keyLEBytes, valueBytes, version) } return nil diff --git a/lib/trie/version.go b/lib/trie/version.go index 23527890c81..0b8c8b7f6c7 100644 --- a/lib/trie/version.go +++ b/lib/trie/version.go @@ -30,15 +30,24 @@ func (v Version) String() string { } } -var ErrParseVersion = errors.New("parsing version failed") +var ErrVersionNotValid = errors.New("version not valid") + +// ParseVersion parses a state trie version string or uint32. +func ParseVersion[T string | uint32](x T) (version Version, err error) { + var s string + switch value := any(x).(type) { + case string: + s = value + case uint32: + s = fmt.Sprintf("V%d", value) + default: + panic(fmt.Sprintf("unsupported type %T", x)) + } -// ParseVersion parses a state trie version string. -func ParseVersion(s string) (version Version, err error) { switch { case strings.EqualFold(s, V0.String()): return V0, nil default: - return version, fmt.Errorf("%w: %q must be %s", - ErrParseVersion, s, V0) + return version, fmt.Errorf("%w: %s", ErrVersionNotValid, s) } } diff --git a/lib/trie/version_test.go b/lib/trie/version_test.go index ab2ac03ebed..0eea1c8e01f 100644 --- a/lib/trie/version_test.go +++ b/lib/trie/version_test.go @@ -4,6 +4,7 @@ package trie import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -49,23 +50,32 @@ func Test_ParseVersion(t *testing.T) { t.Parallel() testCases := map[string]struct { - s string + x any version Version errWrapped error errMessage string }{ "v0": { - s: "v0", + x: "v0", version: V0, }, "V0": { - s: "V0", + x: "V0", version: V0, }, - "invalid": { - s: "xyz", - errWrapped: ErrParseVersion, - errMessage: "parsing version failed: \"xyz\" must be v0", + "invalid string": { + x: "xyz", + errWrapped: ErrVersionNotValid, + errMessage: "parsing version failed: xyz", + }, + "0 uint32": { + x: uint32(0), + version: V0, + }, + "invalid uint32": { + x: uint32(100), + errWrapped: ErrVersionNotValid, + errMessage: "parsing version failed: v100", }, } @@ -74,7 +84,16 @@ func Test_ParseVersion(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - version, err := ParseVersion(testCase.s) + var version Version + var err error + switch typedX := testCase.x.(type) { + case string: + version, err = ParseVersion(typedX) + case uint32: + version, err = ParseVersion(typedX) + default: + panic(fmt.Sprintf("unsupported type %T", testCase.x)) + } assert.Equal(t, testCase.version, version) assert.ErrorIs(t, err, testCase.errWrapped)