diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index b9b9dc61f7..9f2e48250b 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -6,7 +6,6 @@ package node import ( "bytes" "fmt" - "hash" "io" "runtime" @@ -147,56 +146,15 @@ func encodeChild(child *Node, buffer io.Writer) (err error) { // and then SCALE encodes it. This is used to encode children // nodes of branches. func scaleEncodeHash(node *Node) (encoding []byte, err error) { - buffer := pools.DigestBuffers.Get().(*bytes.Buffer) - buffer.Reset() - defer pools.DigestBuffers.Put(buffer) - - err = hashNode(node, buffer) + _, merkleValue, err := node.EncodeAndHash() if err != nil { - return nil, fmt.Errorf("cannot hash %s: %w", node.Kind(), err) + return nil, fmt.Errorf("encoding and hashing %s: %w", node.Kind(), err) } - encoding, err = scale.Marshal(buffer.Bytes()) + encoding, err = scale.Marshal(merkleValue) if err != nil { return nil, fmt.Errorf("cannot scale encode hashed %s: %w", node.Kind(), err) } return encoding, nil } - -func hashNode(node *Node, digestWriter io.Writer) (err error) { - encodingBuffer := pools.EncodingBuffers.Get().(*bytes.Buffer) - encodingBuffer.Reset() - defer pools.EncodingBuffers.Put(encodingBuffer) - - err = node.Encode(encodingBuffer) - if err != nil { - return fmt.Errorf("cannot encode %s: %w", node.Kind(), err) - } - - // if length of encoded leaf is less than 32 bytes, do not hash - if encodingBuffer.Len() < 32 { - _, err = digestWriter.Write(encodingBuffer.Bytes()) - if err != nil { - return fmt.Errorf("cannot write encoded %s to buffer: %w", node.Kind(), err) - } - return nil - } - - // otherwise, hash encoded node - hasher := pools.Hashers.Get().(hash.Hash) - hasher.Reset() - defer pools.Hashers.Put(hasher) - - // Note: using the sync.Pool's buffer is useful here. - _, err = hasher.Write(encodingBuffer.Bytes()) - if err != nil { - return fmt.Errorf("cannot hash encoding of %s: %w", node.Kind(), err) - } - - _, err = digestWriter.Write(hasher.Sum(nil)) - if err != nil { - return fmt.Errorf("cannot write hash sum of %s to buffer: %w", node.Kind(), err) - } - return nil -} diff --git a/internal/trie/node/branch_encode_test.go b/internal/trie/node/branch_encode_test.go index e865a65e06..55386ff6f5 100644 --- a/internal/trie/node/branch_encode_test.go +++ b/internal/trie/node/branch_encode_test.go @@ -13,159 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -func Test_hashNode(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - node *Node - write writeCall - errWrapped error - errMessage string - }{ - "small leaf buffer write error": { - node: &Node{ - Encoding: []byte{1, 2, 3}, - }, - write: writeCall{ - written: []byte{1, 2, 3}, - err: errTest, - }, - errWrapped: errTest, - errMessage: "cannot write encoded leaf to buffer: " + - "test error", - }, - "small leaf success": { - node: &Node{ - Encoding: []byte{1, 2, 3}, - }, - write: writeCall{ - written: []byte{1, 2, 3}, - }, - }, - "leaf hash sum buffer write error": { - node: &Node{ - Encoding: []byte{ - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - }, - }, - write: writeCall{ - written: []byte{ - 107, 105, 154, 175, 253, 170, 232, - 135, 240, 21, 207, 148, 82, 117, - 249, 230, 80, 197, 254, 17, 149, - 108, 50, 7, 80, 56, 114, 176, - 84, 114, 125, 234}, - err: errTest, - }, - errWrapped: errTest, - errMessage: "cannot write hash sum of leaf to buffer: " + - "test error", - }, - "leaf hash sum success": { - node: &Node{ - Encoding: []byte{ - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - }, - }, - write: writeCall{ - written: []byte{ - 107, 105, 154, 175, 253, 170, 232, - 135, 240, 21, 207, 148, 82, 117, - 249, 230, 80, 197, 254, 17, 149, - 108, 50, 7, 80, 56, 114, 176, - 84, 114, 125, 234}, - }, - }, - "empty branch": { - node: &Node{ - Children: make([]*Node, ChildrenCapacity), - }, - write: writeCall{ - written: []byte{128, 0, 0}, - }, - }, - "less than 32 bytes encoding": { - node: &Node{ - Children: make([]*Node, ChildrenCapacity), - Key: []byte{1, 2}, - }, - write: writeCall{ - written: []byte{130, 18, 0, 0}, - }, - }, - "less than 32 bytes encoding write error": { - node: &Node{ - Children: make([]*Node, ChildrenCapacity), - Key: []byte{1, 2}, - }, - write: writeCall{ - written: []byte{130, 18, 0, 0}, - err: errTest, - }, - errWrapped: errTest, - errMessage: "cannot write encoded branch to buffer: test error", - }, - "more than 32 bytes encoding": { - node: &Node{ - Children: make([]*Node, ChildrenCapacity), - Key: repeatBytes(100, 1), - }, - write: writeCall{ - written: []byte{ - 70, 102, 188, 24, 31, 68, 86, 114, - 95, 156, 225, 138, 175, 254, 176, 251, - 81, 84, 193, 40, 11, 234, 142, 233, - 69, 250, 158, 86, 72, 228, 66, 46}, - }, - }, - "more than 32 bytes encoding write error": { - node: &Node{ - Children: make([]*Node, ChildrenCapacity), - Key: repeatBytes(100, 1), - }, - write: writeCall{ - written: []byte{ - 70, 102, 188, 24, 31, 68, 86, 114, - 95, 156, 225, 138, 175, 254, 176, 251, - 81, 84, 193, 40, 11, 234, 142, 233, - 69, 250, 158, 86, 72, 228, 66, 46}, - err: errTest, - }, - errWrapped: errTest, - errMessage: "cannot write hash sum of branch to buffer: test error", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - digestBuffer := NewMockWriter(ctrl) - digestBuffer.EXPECT().Write(testCase.write.written). - Return(testCase.write.n, testCase.write.err) - - err := hashNode(testCase.node, digestBuffer) - - if testCase.errWrapped != nil { - assert.ErrorIs(t, err, testCase.errWrapped) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - // Opportunistic parallel: 13781602 ns/op 14419488 B/op 323575 allocs/op // Sequentially: 24269268 ns/op 20126525 B/op 327668 allocs/op func Benchmark_encodeChildrenOpportunisticParallel(b *testing.B) { diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index fdfed9f456..8a97fa3f9f 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -5,58 +5,144 @@ package node import ( "bytes" + "fmt" + "hash" + "io" "github.com/ChainSafe/gossamer/internal/trie/pools" - "github.com/ChainSafe/gossamer/lib/common" ) -// EncodeAndHash returns the encoding of the node and -// the Merkle value of the node. +// MerkleValue writes the Merkle value from the encoding of a non-root +// node to the writer given. +// If the encoding is less or equal to 32 bytes, the Merkle value is the encoding. +// Otherwise, the Merkle value is the Blake2b hash digest of the encoding. +func MerkleValue(encoding []byte, writer io.Writer) (err error) { + if len(encoding) < 32 { + _, err = writer.Write(encoding) + if err != nil { + return fmt.Errorf("writing encoding: %w", err) + } + return nil + } + + return hashEncoding(encoding, writer) +} + +// MerkleValueRoot writes the Merkle value for the root of the trie +// to the writer given as argument. +// The Merkle value is the Blake2b hash of the encoding of the root node. +func MerkleValueRoot(rootEncoding []byte, writer io.Writer) (err error) { + return hashEncoding(rootEncoding, writer) +} + +func hashEncoding(encoding []byte, writer io.Writer) (err error) { + hasher := pools.Hashers.Get().(hash.Hash) + hasher.Reset() + defer pools.Hashers.Put(hasher) + + _, err = hasher.Write(encoding) + if err != nil { + return fmt.Errorf("hashing encoding: %w", err) + } + + digest := hasher.Sum(nil) + _, err = writer.Write(digest) + if err != nil { + return fmt.Errorf("writing digest: %w", err) + } + + return nil +} + +// CalculateMerkleValue returns the Merkle value of the non-root node. +func (n *Node) CalculateMerkleValue() (merkleValue []byte, err error) { + if !n.Dirty && n.MerkleValue != nil { + return n.MerkleValue, nil + } + + _, merkleValue, err = n.EncodeAndHash() + if err != nil { + return nil, fmt.Errorf("encoding and hashing node: %w", err) + } + + return merkleValue, nil +} + +// CalculateRootMerkleValue returns the Merkle value of the root node. +func (n *Node) CalculateRootMerkleValue() (merkleValue []byte, err error) { + const rootMerkleValueLength = 32 + if !n.Dirty && len(n.MerkleValue) == rootMerkleValueLength { + return n.MerkleValue, nil + } + + _, merkleValue, err = n.EncodeAndHashRoot() + if err != nil { + return nil, fmt.Errorf("encoding and hashing root node: %w", err) + } + + return merkleValue, nil +} + +// EncodeAndHash returns the encoding of the node and the +// Merkle value of the node. See the `MerkleValue` method for +// more details on the value of the Merkle value. +// TODO change this function to write to an encoding writer +// and a merkle value writer, such that buffer sync pools can be used +// by the caller. func (n *Node) EncodeAndHash() (encoding, merkleValue []byte, err error) { if !n.Dirty && n.Encoding != nil && n.MerkleValue != nil { return n.Encoding, n.MerkleValue, nil } - buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) - buffer.Reset() - defer pools.EncodingBuffers.Put(buffer) + encoding, err = n.encodeIfNeeded() + if err != nil { + return nil, nil, fmt.Errorf("encoding node: %w", err) + } - err = n.Encode(buffer) + const maxMerkleValueSize = 32 + merkleValueBuffer := bytes.NewBuffer(make([]byte, 0, maxMerkleValueSize)) + err = MerkleValue(encoding, merkleValueBuffer) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("merkle value: %w", err) } + merkleValue = merkleValueBuffer.Bytes() + n.MerkleValue = merkleValue // no need to copy - bufferBytes := buffer.Bytes() + return encoding, merkleValue, nil +} - // TODO remove this copying since it defeats the purpose of `buffer` - // and the sync.Pool. - n.Encoding = make([]byte, len(bufferBytes)) - copy(n.Encoding, bufferBytes) - encoding = n.Encoding // no need to copy +// EncodeAndHashRoot returns the encoding of the node and the +// Merkle value of the node. See the `MerkleValueRoot` method +// for more details on the value of the Merkle value. +// TODO change this function to write to an encoding writer +// and a merkle value writer, such that buffer sync pools can be used +// by the caller. +func (n *Node) EncodeAndHashRoot() (encoding, merkleValue []byte, err error) { + const rootMerkleValueLength = 32 + if !n.Dirty && n.Encoding != nil && len(n.MerkleValue) == rootMerkleValueLength { + return n.Encoding, n.MerkleValue, nil + } - if buffer.Len() < 32 { - n.MerkleValue = make([]byte, len(bufferBytes)) - copy(n.MerkleValue, bufferBytes) - merkleValue = n.MerkleValue // no need to copy - return encoding, merkleValue, nil + encoding, err = n.encodeIfNeeded() + if err != nil { + return nil, nil, fmt.Errorf("encoding node: %w", err) } - // Note: using the sync.Pool's buffer is useful here. - hashArray, err := common.Blake2bHash(buffer.Bytes()) + const merkleValueSize = 32 + merkleValueBuffer := bytes.NewBuffer(make([]byte, 0, merkleValueSize)) + err = MerkleValueRoot(encoding, merkleValueBuffer) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("merkle value: %w", err) } - n.MerkleValue = hashArray[:] - merkleValue = n.MerkleValue // no need to copy + merkleValue = merkleValueBuffer.Bytes() + n.MerkleValue = merkleValue // no need to copy return encoding, merkleValue, nil } -// EncodeAndHashRoot returns the encoding of the root node and -// the Merkle value of the root node (the hash of its encoding). -func (n *Node) EncodeAndHashRoot() (encoding, merkleValue []byte, err error) { - if !n.Dirty && n.Encoding != nil && n.MerkleValue != nil { - return n.Encoding, n.MerkleValue, nil +func (n *Node) encodeIfNeeded() (encoding []byte, err error) { + if !n.Dirty && n.Encoding != nil { + return n.Encoding, nil // no need to copy } buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) @@ -65,7 +151,7 @@ func (n *Node) EncodeAndHashRoot() (encoding, merkleValue []byte, err error) { err = n.Encode(buffer) if err != nil { - return nil, nil, err + return nil, fmt.Errorf("encoding: %w", err) } bufferBytes := buffer.Bytes() @@ -74,15 +160,6 @@ func (n *Node) EncodeAndHashRoot() (encoding, merkleValue []byte, err error) { // and the sync.Pool. n.Encoding = make([]byte, len(bufferBytes)) copy(n.Encoding, bufferBytes) - encoding = n.Encoding // no need to copy - - // Note: using the sync.Pool's buffer is useful here. - hashArray, err := common.Blake2bHash(buffer.Bytes()) - if err != nil { - return nil, nil, err - } - n.MerkleValue = hashArray[:] - merkleValue = n.MerkleValue // no need to copy - return encoding, merkleValue, nil + return n.Encoding, nil // no need to copy } diff --git a/internal/trie/node/hash_test.go b/internal/trie/node/hash_test.go index 0d5b673451..e2682a75b7 100644 --- a/internal/trie/node/hash_test.go +++ b/internal/trie/node/hash_test.go @@ -4,11 +4,269 @@ package node import ( + "io" "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) +func Test_MerkleValue(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + encoding []byte + writerBuilder func(ctrl *gomock.Controller) io.Writer + errWrapped error + errMessage string + }{ + "small encoding": { + encoding: []byte{1}, + writerBuilder: func(ctrl *gomock.Controller) io.Writer { + writer := NewMockWriter(ctrl) + writer.EXPECT().Write([]byte{1}).Return(1, nil) + return writer + }, + }, + "encoding write error": { + encoding: []byte{1}, + writerBuilder: func(ctrl *gomock.Controller) io.Writer { + writer := NewMockWriter(ctrl) + writer.EXPECT().Write([]byte{1}).Return(0, errTest) + return writer + }, + errWrapped: errTest, + errMessage: "writing encoding: test error", + }, + "long encoding": { + encoding: []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33}, + writerBuilder: func(ctrl *gomock.Controller) io.Writer { + writer := NewMockWriter(ctrl) + writer.EXPECT().Write([]byte{ + 0xfc, 0xd2, 0xd9, 0xac, 0xe8, 0x70, 0x52, 0x81, + 0x1d, 0x9f, 0x34, 0x27, 0xb5, 0x8f, 0xf3, 0x98, + 0xd2, 0xe9, 0xed, 0x83, 0xf3, 0x1, 0xbc, 0x7e, + 0xc1, 0xbe, 0x8b, 0x59, 0x39, 0x62, 0xf1, 0x7d, + }).Return(32, nil) + return writer + }, + }, + "digest write error": { + encoding: []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33}, + writerBuilder: func(ctrl *gomock.Controller) io.Writer { + writer := NewMockWriter(ctrl) + writer.EXPECT().Write([]byte{ + 0xfc, 0xd2, 0xd9, 0xac, 0xe8, 0x70, 0x52, 0x81, + 0x1d, 0x9f, 0x34, 0x27, 0xb5, 0x8f, 0xf3, 0x98, + 0xd2, 0xe9, 0xed, 0x83, 0xf3, 0x1, 0xbc, 0x7e, + 0xc1, 0xbe, 0x8b, 0x59, 0x39, 0x62, 0xf1, 0x7d, + }).Return(0, errTest) + return writer + }, + errWrapped: errTest, + errMessage: "writing digest: test error", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + writer := testCase.writerBuilder(ctrl) + + err := MerkleValue(testCase.encoding, writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_MerkleValueRoot(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + encoding []byte + writerBuilder func(ctrl *gomock.Controller) io.Writer + errWrapped error + errMessage string + }{ + "digest write error": { + encoding: []byte{1}, + writerBuilder: func(ctrl *gomock.Controller) io.Writer { + writer := NewMockWriter(ctrl) + writer.EXPECT().Write([]byte{ + 0xee, 0x15, 0x5a, 0xce, 0x9c, 0x40, 0x29, 0x20, + 0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2, + 0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef, + 0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25, + }).Return(0, errTest) + return writer + }, + errWrapped: errTest, + errMessage: "writing digest: test error", + }, + "small encoding": { + encoding: []byte{1}, + writerBuilder: func(ctrl *gomock.Controller) io.Writer { + writer := NewMockWriter(ctrl) + writer.EXPECT().Write([]byte{ + 0xee, 0x15, 0x5a, 0xce, 0x9c, 0x40, 0x29, 0x20, + 0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2, + 0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef, + 0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25, + }).Return(32, nil) + return writer + }, + }, + "long encoding": { + encoding: []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33}, + writerBuilder: func(ctrl *gomock.Controller) io.Writer { + writer := NewMockWriter(ctrl) + writer.EXPECT().Write([]byte{ + 0xfc, 0xd2, 0xd9, 0xac, 0xe8, 0x70, 0x52, 0x81, + 0x1d, 0x9f, 0x34, 0x27, 0xb5, 0x8f, 0xf3, 0x98, + 0xd2, 0xe9, 0xed, 0x83, 0xf3, 0x1, 0xbc, 0x7e, + 0xc1, 0xbe, 0x8b, 0x59, 0x39, 0x62, 0xf1, 0x7d, + }).Return(32, nil) + return writer + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + writer := testCase.writerBuilder(ctrl) + + err := MerkleValueRoot(testCase.encoding, writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_Node_CalculateMerkleValue(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + node Node + merkleValue []byte + errWrapped error + errMessage string + }{ + "cached merkle value": { + node: Node{ + MerkleValue: []byte{1}, + }, + merkleValue: []byte{1}, + }, + "small encoding": { + node: Node{ + Encoding: []byte{1}, + }, + merkleValue: []byte{1}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + merkleValue, err := testCase.node.CalculateMerkleValue() + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.merkleValue, merkleValue) + }) + } +} + +func Test_Node_CalculateRootMerkleValue(t *testing.T) { + t.Parallel() + + some32BHashDigest := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2, + 0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef, + 0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25} + + testCases := map[string]struct { + node Node + merkleValue []byte + errWrapped error + errMessage string + }{ + "cached merkle value 32 bytes": { + node: Node{ + MerkleValue: some32BHashDigest, + }, + merkleValue: some32BHashDigest, + }, + "cached merkle value not 32 bytes": { + node: Node{ + Encoding: []byte{1}, + MerkleValue: []byte{1}, + }, + merkleValue: []byte{ + 0xee, 0x15, 0x5a, 0xce, 0x9c, 0x40, 0x29, 0x20, + 0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2, + 0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef, + 0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25}, + }, + "root small encoding": { + node: Node{ + Encoding: []byte{1}, + }, + merkleValue: []byte{ + 0xee, 0x15, 0x5a, 0xce, 0x9c, 0x40, 0x29, 0x20, + 0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2, + 0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef, + 0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + merkleValue, err := testCase.node.CalculateRootMerkleValue() + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.merkleValue, merkleValue) + }) + } +} + func Test_Node_EncodeAndHash(t *testing.T) { t.Parallel() @@ -173,6 +431,12 @@ func Test_Node_EncodeAndHash(t *testing.T) { func Test_Node_EncodeAndHashRoot(t *testing.T) { t.Parallel() + some32BHashDigest := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x74, 0xcb, 0x6a, 0xff, 0x8c, 0x9c, 0xcd, 0xd2, + 0x73, 0xc8, 0x16, 0x48, 0xff, 0x11, 0x49, 0xef, + 0x36, 0xbc, 0xea, 0x6e, 0xbb, 0x8a, 0x3e, 0x25} + testCases := map[string]struct { node Node expectedNode Node @@ -187,16 +451,16 @@ func Test_Node_EncodeAndHashRoot(t *testing.T) { SubValue: []byte{2}, Dirty: false, Encoding: []byte{3}, - MerkleValue: []byte{4}, + MerkleValue: some32BHashDigest, }, expectedNode: Node{ Key: []byte{1}, SubValue: []byte{2}, Encoding: []byte{3}, - MerkleValue: []byte{4}, + MerkleValue: some32BHashDigest, }, encoding: []byte{3}, - hash: []byte{4}, + hash: some32BHashDigest, }, "small leaf encoding": { node: Node{ diff --git a/lib/trie/database.go b/lib/trie/database.go index b113ad5ee3..35753fc303 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -125,9 +125,9 @@ func (t *Trie) loadNode(db Database, n *Node) error { if len(merkleValue) == 0 { // node has already been loaded inline // just set encoding + hash digest - _, _, err := child.EncodeAndHash() + _, err := child.CalculateMerkleValue() if err != nil { - return err + return fmt.Errorf("merkle value: %w", err) } child.SetClean() continue @@ -397,11 +397,11 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, hashes map[common.Hash]struc return nil } - var hash []byte + var merkleValue []byte if n == t.root { - _, hash, err = n.EncodeAndHashRoot() + merkleValue, err = n.CalculateRootMerkleValue() } else { - _, hash, err = n.EncodeAndHash() + merkleValue, err = n.CalculateMerkleValue() } if err != nil { return fmt.Errorf( @@ -409,7 +409,7 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, hashes map[common.Hash]struc n.MerkleValue, err) } - hashes[common.BytesToHash(hash)] = struct{}{} + hashes[common.BytesToHash(merkleValue)] = struct{}{} if n.Kind() != node.Branch { return nil diff --git a/lib/trie/proof/generate.go b/lib/trie/proof/generate.go index f82f0e7f47..eb8075de9d 100644 --- a/lib/trie/proof/generate.go +++ b/lib/trie/proof/generate.go @@ -10,6 +10,7 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" ) @@ -36,7 +37,10 @@ func Generate(rootHash []byte, fullKeys [][]byte, database Database) ( } rootNode := trie.RootNode() - hashesSeen := make(map[string]struct{}) + buffer := pools.DigestBuffers.Get().(*bytes.Buffer) + defer pools.DigestBuffers.Put(buffer) + + merkleValuesSeen := make(map[string]struct{}) for _, fullKey := range fullKeys { fullKeyNibbles := codec.KeyLEToNibbles(fullKey) newEncodedProofNodes, err := walkRoot(rootNode, fullKeyNibbles) @@ -47,17 +51,18 @@ func Generate(rootHash []byte, fullKeys [][]byte, database Database) ( } for _, encodedProofNode := range newEncodedProofNodes { - digest, err := common.Blake2bHash(encodedProofNode) + buffer.Reset() + err := node.MerkleValue(encodedProofNode, buffer) if err != nil { return nil, fmt.Errorf("blake2b hash: %w", err) } - hashString := string(digest.ToBytes()) + merkleValueString := buffer.String() - _, seen := hashesSeen[hashString] + _, seen := merkleValuesSeen[merkleValueString] if seen { continue } - hashesSeen[hashString] = struct{}{} + merkleValuesSeen[merkleValueString] = struct{}{} encodedProofNodes = append(encodedProofNodes, encodedProofNode) } diff --git a/lib/trie/proof/verify.go b/lib/trie/proof/verify.go index 5da7a6d5b0..93aff2054e 100644 --- a/lib/trie/proof/verify.go +++ b/lib/trie/proof/verify.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" ) @@ -59,6 +60,12 @@ func buildTrie(encodedProofNodes [][]byte, rootHash []byte) (t *trie.Trie, err e digestToEncoding := make(map[string][]byte, len(encodedProofNodes)) + // note we can use a buffer from the pool since + // the calculated root hash digest is not used after + // the function completes. + buffer := pools.DigestBuffers.Get().(*bytes.Buffer) + defer pools.DigestBuffers.Put(buffer) + // This loop does two things: // 1. It finds the root node by comparing it with the root hash and decodes it. // 2. It stores other encoded nodes in a mapping from their encoding digest to @@ -70,12 +77,15 @@ func buildTrie(encodedProofNodes [][]byte, rootHash []byte) (t *trie.Trie, err e // - trie root node // - child trie root node // - child node with an encoding larger than 32 bytes - // In all cases, their Merkle value is the encoding hash digest. - digestHash, err := common.Blake2bHash(encodedProofNode) + // In all cases, their Merkle value is the encoding hash digest, + // so we use MerkleValueRoot to force hashing the node in case + // it is a root node smaller or equal to 32 bytes. + buffer.Reset() + err = node.MerkleValueRoot(encodedProofNode, buffer) if err != nil { - return nil, fmt.Errorf("blake2b hash: %w", err) + return nil, fmt.Errorf("calculating Merkle value: %w", err) } - digest := digestHash[:] + digest := buffer.Bytes() if root != nil || !bytes.Equal(digest, rootHash) { // root node already found or the hash doesn't match the root hash. diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index 306be9d7ac..697da7d5bc 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -496,9 +496,10 @@ func Test_Trie_Hash(t *testing.T) { Descendants: 1, Children: padRightChildren([]*Node{ { - Key: []byte{9}, - SubValue: []byte{1}, - Encoding: []byte{0x41, 0x09, 0x04, 0x01}, + Key: []byte{9}, + SubValue: []byte{1}, + Encoding: []byte{0x41, 0x09, 0x04, 0x01}, + MerkleValue: []byte{0x41, 0x09, 0x04, 0x01}, }, }), },