Skip to content

Commit

Permalink
add additional checks
Browse files Browse the repository at this point in the history
  • Loading branch information
agouin committed Jan 9, 2024
1 parent f8a45c6 commit e68ce7f
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 130 deletions.
52 changes: 52 additions & 0 deletions signer/cosigner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package signer

import (
"context"
"errors"
"fmt"
"time"

cometcrypto "github.com/cometbft/cometbft/crypto"
"github.com/cometbft/cometbft/libs/protoio"
cometproto "github.com/cometbft/cometbft/proto/tendermint/types"
"github.com/google/uuid"
"github.com/strangelove-ventures/horcrux/v3/signer/proto"
)
Expand Down Expand Up @@ -150,3 +154,51 @@ type CosignerSetNoncesAndSignRequest struct {
SignBytes []byte
VoteExtensionSignBytes []byte
}

func verifySignPayload(chainID string, signBytes, voteExtensionSignBytes []byte) (HRSTKey, bool, error) {
var vote cometproto.CanonicalVote
voteErr := protoio.UnmarshalDelimited(signBytes, &vote)
if voteErr == nil && (vote.Type == cometproto.PrevoteType || vote.Type == cometproto.PrecommitType) {
hrstKey := HRSTKey{
Height: vote.Height,
Round: vote.Round,
Step: CanonicalVoteToStep(&vote),
Timestamp: vote.Timestamp.UnixNano(),
}

if hrstKey.Step == stepPrecommit && len(voteExtensionSignBytes) > 0 && vote.BlockID != nil {
var voteExt cometproto.CanonicalVoteExtension
if err := protoio.UnmarshalDelimited(voteExtensionSignBytes, &voteExt); err != nil {
return hrstKey, false, fmt.Errorf("failed to unmarshal vote extension: %w", err)
}
if voteExt.ChainId != chainID {
return hrstKey, false, fmt.Errorf("vote extension chain ID %s does not match chain ID %s", voteExt.ChainId, chainID)
}
if voteExt.Height != hrstKey.Height {
return hrstKey, false,
fmt.Errorf("vote extension height %d does not match block height %d", voteExt.Height, hrstKey.Height)
}
if voteExt.Round != hrstKey.Round {
return hrstKey, false,
fmt.Errorf("vote extension round %d does not match block round %d", voteExt.Round, hrstKey.Round)
}
return hrstKey, true, nil
}

return hrstKey, false, nil
}

var proposal cometproto.CanonicalProposal
proposalErr := protoio.UnmarshalDelimited(signBytes, &proposal)
if proposalErr == nil {
return HRSTKey{
Height: proposal.Height,
Round: proposal.Round,
Step: stepPropose,
Timestamp: proposal.Timestamp.UnixNano(),
}, false, nil
}

return HRSTKey{}, false,
fmt.Errorf("failed to unmarshal sign bytes into vote or proposal: %w", errors.Join(voteErr, proposalErr))
}
9 changes: 7 additions & 2 deletions signer/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func (pv *FilePV) GetPubKey() (crypto.PubKey, error) {
return pv.Key.PubKey, nil
}

func (pv *FilePV) Sign(block Block) ([]byte, []byte, time.Time, error) {
func (pv *FilePV) Sign(chainID string, block Block) ([]byte, []byte, time.Time, error) {
height, round, step := block.Height, int32(block.Round), block.Step
signBytes, voteExtensionSignBytes := block.SignBytes, block.VoteExtensionSignBytes

Expand All @@ -211,13 +211,18 @@ func (pv *FilePV) Sign(block Block) ([]byte, []byte, time.Time, error) {
return nil, nil, block.Timestamp, err
}

_, hasVoteExtensions, err := verifySignPayload(chainID, signBytes, voteExtensionSignBytes)
if err != nil {
return nil, nil, block.Timestamp, err
}

// Vote extensions are non-deterministic, so it is possible that an
// application may have created a different extension. We therefore always
// re-sign the vote extensions of precommits. For prevotes and nil
// precommits, the extension signature will always be empty.
// Even if the signed over data is empty, we still add the signature
var extSig []byte
if block.Step == stepPrecommit && len(voteExtensionSignBytes) > 0 {
if hasVoteExtensions {
extSig, err = pv.Key.PrivKey.Sign(voteExtensionSignBytes)
if err != nil {
return nil, nil, block.Timestamp, err
Expand Down
23 changes: 23 additions & 0 deletions signer/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package signer

import (
"io"

"github.com/cometbft/cometbft/libs/protoio"
cometprotoprivval "github.com/cometbft/cometbft/proto/tendermint/privval"
)

// ReadMsg reads a message from an io.Reader
func ReadMsg(reader io.Reader) (msg cometprotoprivval.Message, err error) {
const maxRemoteSignerMsgSize = 1024 * 10
protoReader := protoio.NewDelimitedReader(reader, maxRemoteSignerMsgSize)
_, err = protoReader.ReadMsg(&msg)
return msg, err
}

// WriteMsg writes a message to an io.Writer
func WriteMsg(writer io.Writer, msg cometprotoprivval.Message) (err error) {
protoWriter := protoio.NewDelimitedWriter(writer)
_, err = protoWriter.WriteMsg(&msg)
return err
}
12 changes: 6 additions & 6 deletions signer/local_cosigner.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,14 @@ func (cosigner *LocalCosigner) sign(req CosignerSignRequest) (CosignerSignRespon
return res, err
}

// This function has multiple exit points. Only start time can be guaranteed
metricsTimeKeeper.SetPreviousLocalSignStart(time.Now())

hrst, err := UnpackHRST(req.SignBytes)
hrst, hasVoteExtensions, err := verifySignPayload(chainID, req.SignBytes, req.VoteExtensionSignBytes)
if err != nil {
return res, err
}

// This function has multiple exit points. Only start time can be guaranteed
metricsTimeKeeper.SetPreviousLocalSignStart(time.Now())

existingSignature, err := ccs.lastSignState.existingSignatureOrErrorIfRegression(hrst, req.SignBytes)
if err != nil {
return res, err
Expand Down Expand Up @@ -249,7 +249,7 @@ func (cosigner *LocalCosigner) sign(req CosignerSignRequest) (CosignerSignRespon
}

var voteExtNonces []Nonce
if len(req.VoteExtensionSignBytes) > 0 {
if hasVoteExtensions {
voteExtNonces, err = cosigner.combinedNonces(
cosigner.GetID(),
uint8(cosigner.config.Config.ThresholdModeConfig.Threshold),
Expand All @@ -268,7 +268,7 @@ func (cosigner *LocalCosigner) sign(req CosignerSignRequest) (CosignerSignRespon
sig, err = ccs.signer.Sign(nonces, req.SignBytes)
return err
})
if len(req.VoteExtensionSignBytes) > 0 {
if hasVoteExtensions {
eg.Go(func() error {
var err error
voteExtSig, err = ccs.signer.Sign(voteExtNonces, req.VoteExtensionSignBytes)
Expand Down
44 changes: 0 additions & 44 deletions signer/serialization.go

This file was deleted.

57 changes: 0 additions & 57 deletions signer/serialization_test.go

This file was deleted.

40 changes: 22 additions & 18 deletions signer/sign_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ func StepToType(step int8) cometproto.SignedMsgType {

// SignState stores signing information for high level watermark management.
type SignState struct {
Height int64 `json:"height"`
Round int64 `json:"round"`
Step int8 `json:"step"`
NoncePublic []byte `json:"nonce_public"`
Signature []byte `json:"signature,omitempty"`
SignBytes cometbytes.HexBytes `json:"signbytes,omitempty"`
Height int64 `json:"height"`
Round int64 `json:"round"`
Step int8 `json:"step"`
NoncePublic []byte `json:"nonce_public"`
Signature []byte `json:"signature,omitempty"`
SignBytes cometbytes.HexBytes `json:"signbytes,omitempty"`
VoteExtensionSignature []byte `json:"vote_ext_signature,omitempty"`

filePath string

Expand Down Expand Up @@ -232,6 +233,7 @@ func (signState *SignState) cacheAndMarshal(ssc SignStateConsensus) []byte {
signState.Step = ssc.Step
signState.Signature = ssc.Signature
signState.SignBytes = ssc.SignBytes
signState.VoteExtensionSignature = ssc.VoteExtensionSignature

jsonBytes, err := cometjson.MarshalIndent(signState, "", " ")
if err != nil {
Expand Down Expand Up @@ -416,13 +418,14 @@ func (signState *SignState) GetErrorIfLessOrEqual(height int64, round int64, ste
// including the most recent sign state.
func (signState *SignState) FreshCache() *SignState {
newSignState := &SignState{
Height: signState.Height,
Round: signState.Round,
Step: signState.Step,
NoncePublic: signState.NoncePublic,
Signature: signState.Signature,
SignBytes: signState.SignBytes,
cache: make(map[HRSKey]SignStateConsensus),
Height: signState.Height,
Round: signState.Round,
Step: signState.Step,
NoncePublic: signState.NoncePublic,
Signature: signState.Signature,
SignBytes: signState.SignBytes,
VoteExtensionSignature: signState.VoteExtensionSignature,
cache: make(map[HRSKey]SignStateConsensus),

filePath: signState.filePath,
}
Expand All @@ -434,11 +437,12 @@ func (signState *SignState) FreshCache() *SignState {
Round: signState.Round,
Step: signState.Step,
}] = SignStateConsensus{
Height: signState.Height,
Round: signState.Round,
Step: signState.Step,
Signature: signState.Signature,
SignBytes: signState.SignBytes,
Height: signState.Height,
Round: signState.Round,
Step: signState.Step,
Signature: signState.Signature,
SignBytes: signState.SignBytes,
VoteExtensionSignature: signState.VoteExtensionSignature,
}

return newSignState
Expand Down
2 changes: 1 addition & 1 deletion signer/single_signer_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (pv *SingleSignerValidator) Sign(
chainState.pvMutex.Lock()
defer chainState.pvMutex.Unlock()

return chainState.filePV.Sign(block)
return chainState.filePV.Sign(chainID, block)
}

func (pv *SingleSignerValidator) loadChainStateIfNecessary(chainID string) (*SingleSignerChainState, error) {
Expand Down
11 changes: 9 additions & 2 deletions signer/threshold_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,14 @@ func (pv *ThresholdValidator) Sign(
}

var voteExtNonces *CosignerUUIDNonces
if step == stepPrecommit && len(voteExtensionSignBytes) > 0 {

_, hasVoteExtensions, err := verifySignPayload(chainID, signBytes, voteExtensionSignBytes)
if err != nil {
pv.notifyBlockSignError(chainID, block.HRSKey(), signBytes)
return nil, nil, stamp, fmt.Errorf("failed to verify payload: %w", err)
}

if hasVoteExtensions {
voteExtNonces, err = pv.nonceCache.GetNonces(cosignersForThisBlock)
if err != nil {
// TODO how to handle fallback for vote extensions?
Expand Down Expand Up @@ -853,7 +860,7 @@ func (pv *ThresholdValidator) Sign(

var voteExtSig []byte

if step == stepPrecommit && len(voteExtensionSignBytes) > 0 {
if hasVoteExtensions {
// collect all valid responses into array of partial signatures
voteExtShareSigs := make([]PartialSignature, 0, pv.threshold)
for idx, shareSig := range voteExtShareSignatures {
Expand Down
5 changes: 5 additions & 0 deletions signer/threshold_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"fmt"
mrand "math/rand"
"path/filepath"
Expand Down Expand Up @@ -254,9 +255,13 @@ func testThresholdValidator(t *testing.T, threshold, total uint8) {
err = eg.Wait()
require.NoError(t, err)

blockIDHash := sha256.New()
blockIDHash.Write([]byte("something"))

precommit := cometproto.Vote{
Height: int64(i),
Round: 0,
BlockID: cometproto.BlockID{Hash: blockIDHash.Sum(nil)},
Type: cometproto.PrecommitType,
Timestamp: time.Now(),
Extension: []byte("test"),
Expand Down

0 comments on commit e68ce7f

Please sign in to comment.