diff --git a/signer/cosigner.go b/signer/cosigner.go index f3db4254..b1a9f116 100644 --- a/signer/cosigner.go +++ b/signer/cosigner.go @@ -2,9 +2,13 @@ package signer import ( "context" + "errors" + "fmt" "time" cometcrypto "github.com/cometbft/cometbft/crypto" + "github.com/cometbft/cometbft/libs/protoio" + cometproto "github.com/cometbft/cometbft/proto/tendermint/types" "github.com/google/uuid" "github.com/strangelove-ventures/horcrux/v3/signer/proto" ) @@ -150,3 +154,51 @@ type CosignerSetNoncesAndSignRequest struct { SignBytes []byte VoteExtensionSignBytes []byte } + +func verifySignPayload(chainID string, signBytes, voteExtensionSignBytes []byte) (HRSTKey, bool, error) { + var vote cometproto.CanonicalVote + voteErr := protoio.UnmarshalDelimited(signBytes, &vote) + if voteErr == nil && (vote.Type == cometproto.PrevoteType || vote.Type == cometproto.PrecommitType) { + hrstKey := HRSTKey{ + Height: vote.Height, + Round: vote.Round, + Step: CanonicalVoteToStep(&vote), + Timestamp: vote.Timestamp.UnixNano(), + } + + if hrstKey.Step == stepPrecommit && len(voteExtensionSignBytes) > 0 && vote.BlockID != nil { + var voteExt cometproto.CanonicalVoteExtension + if err := protoio.UnmarshalDelimited(voteExtensionSignBytes, &voteExt); err != nil { + return hrstKey, false, fmt.Errorf("failed to unmarshal vote extension: %w", err) + } + if voteExt.ChainId != chainID { + return hrstKey, false, fmt.Errorf("vote extension chain ID %s does not match chain ID %s", voteExt.ChainId, chainID) + } + if voteExt.Height != hrstKey.Height { + return hrstKey, false, + fmt.Errorf("vote extension height %d does not match block height %d", voteExt.Height, hrstKey.Height) + } + if voteExt.Round != hrstKey.Round { + return hrstKey, false, + fmt.Errorf("vote extension round %d does not match block round %d", voteExt.Round, hrstKey.Round) + } + return hrstKey, true, nil + } + + return hrstKey, false, nil + } + + var proposal cometproto.CanonicalProposal + proposalErr := protoio.UnmarshalDelimited(signBytes, &proposal) + if proposalErr == nil { + return HRSTKey{ + Height: proposal.Height, + Round: proposal.Round, + Step: stepPropose, + Timestamp: proposal.Timestamp.UnixNano(), + }, false, nil + } + + return HRSTKey{}, false, + fmt.Errorf("failed to unmarshal sign bytes into vote or proposal: %w", errors.Join(voteErr, proposalErr)) +} diff --git a/signer/file.go b/signer/file.go index b0eae6f6..d4379dbc 100644 --- a/signer/file.go +++ b/signer/file.go @@ -200,7 +200,7 @@ func (pv *FilePV) GetPubKey() (crypto.PubKey, error) { return pv.Key.PubKey, nil } -func (pv *FilePV) Sign(block Block) ([]byte, []byte, time.Time, error) { +func (pv *FilePV) Sign(chainID string, block Block) ([]byte, []byte, time.Time, error) { height, round, step := block.Height, int32(block.Round), block.Step signBytes, voteExtensionSignBytes := block.SignBytes, block.VoteExtensionSignBytes @@ -211,13 +211,18 @@ func (pv *FilePV) Sign(block Block) ([]byte, []byte, time.Time, error) { return nil, nil, block.Timestamp, err } + _, hasVoteExtensions, err := verifySignPayload(chainID, signBytes, voteExtensionSignBytes) + if err != nil { + return nil, nil, block.Timestamp, err + } + // Vote extensions are non-deterministic, so it is possible that an // application may have created a different extension. We therefore always // re-sign the vote extensions of precommits. For prevotes and nil // precommits, the extension signature will always be empty. // Even if the signed over data is empty, we still add the signature var extSig []byte - if block.Step == stepPrecommit && len(voteExtensionSignBytes) > 0 { + if hasVoteExtensions { extSig, err = pv.Key.PrivKey.Sign(voteExtensionSignBytes) if err != nil { return nil, nil, block.Timestamp, err diff --git a/signer/io.go b/signer/io.go new file mode 100644 index 00000000..45237a8a --- /dev/null +++ b/signer/io.go @@ -0,0 +1,23 @@ +package signer + +import ( + "io" + + "github.com/cometbft/cometbft/libs/protoio" + cometprotoprivval "github.com/cometbft/cometbft/proto/tendermint/privval" +) + +// ReadMsg reads a message from an io.Reader +func ReadMsg(reader io.Reader) (msg cometprotoprivval.Message, err error) { + const maxRemoteSignerMsgSize = 1024 * 10 + protoReader := protoio.NewDelimitedReader(reader, maxRemoteSignerMsgSize) + _, err = protoReader.ReadMsg(&msg) + return msg, err +} + +// WriteMsg writes a message to an io.Writer +func WriteMsg(writer io.Writer, msg cometprotoprivval.Message) (err error) { + protoWriter := protoio.NewDelimitedWriter(writer) + _, err = protoWriter.WriteMsg(&msg) + return err +} diff --git a/signer/local_cosigner.go b/signer/local_cosigner.go index 10875b88..4559974f 100644 --- a/signer/local_cosigner.go +++ b/signer/local_cosigner.go @@ -214,14 +214,14 @@ func (cosigner *LocalCosigner) sign(req CosignerSignRequest) (CosignerSignRespon return res, err } - // This function has multiple exit points. Only start time can be guaranteed - metricsTimeKeeper.SetPreviousLocalSignStart(time.Now()) - - hrst, err := UnpackHRST(req.SignBytes) + hrst, hasVoteExtensions, err := verifySignPayload(chainID, req.SignBytes, req.VoteExtensionSignBytes) if err != nil { return res, err } + // This function has multiple exit points. Only start time can be guaranteed + metricsTimeKeeper.SetPreviousLocalSignStart(time.Now()) + existingSignature, err := ccs.lastSignState.existingSignatureOrErrorIfRegression(hrst, req.SignBytes) if err != nil { return res, err @@ -249,7 +249,7 @@ func (cosigner *LocalCosigner) sign(req CosignerSignRequest) (CosignerSignRespon } var voteExtNonces []Nonce - if len(req.VoteExtensionSignBytes) > 0 { + if hasVoteExtensions { voteExtNonces, err = cosigner.combinedNonces( cosigner.GetID(), uint8(cosigner.config.Config.ThresholdModeConfig.Threshold), @@ -268,7 +268,7 @@ func (cosigner *LocalCosigner) sign(req CosignerSignRequest) (CosignerSignRespon sig, err = ccs.signer.Sign(nonces, req.SignBytes) return err }) - if len(req.VoteExtensionSignBytes) > 0 { + if hasVoteExtensions { eg.Go(func() error { var err error voteExtSig, err = ccs.signer.Sign(voteExtNonces, req.VoteExtensionSignBytes) diff --git a/signer/serialization.go b/signer/serialization.go deleted file mode 100644 index ee6b5ab4..00000000 --- a/signer/serialization.go +++ /dev/null @@ -1,44 +0,0 @@ -package signer - -import ( - "errors" - "io" - - "github.com/cometbft/cometbft/libs/protoio" - cometprotoprivval "github.com/cometbft/cometbft/proto/tendermint/privval" - cometproto "github.com/cometbft/cometbft/proto/tendermint/types" -) - -// ReadMsg reads a message from an io.Reader -func ReadMsg(reader io.Reader) (msg cometprotoprivval.Message, err error) { - const maxRemoteSignerMsgSize = 1024 * 10 - protoReader := protoio.NewDelimitedReader(reader, maxRemoteSignerMsgSize) - _, err = protoReader.ReadMsg(&msg) - return msg, err -} - -// WriteMsg writes a message to an io.Writer -func WriteMsg(writer io.Writer, msg cometprotoprivval.Message) (err error) { - protoWriter := protoio.NewDelimitedWriter(writer) - _, err = protoWriter.WriteMsg(&msg) - return err -} - -// UnpackHRS deserializes sign bytes and gets the height, round, and step -func UnpackHRST(signBytes []byte) (HRSTKey, error) { - { - var proposal cometproto.CanonicalProposal - if err := protoio.UnmarshalDelimited(signBytes, &proposal); err == nil { - return HRSTKey{proposal.Height, proposal.Round, stepPropose, proposal.Timestamp.UnixNano()}, nil - } - } - - { - var vote cometproto.CanonicalVote - if err := protoio.UnmarshalDelimited(signBytes, &vote); err == nil { - return HRSTKey{vote.Height, vote.Round, CanonicalVoteToStep(&vote), vote.Timestamp.UnixNano()}, nil - } - } - - return HRSTKey{0, 0, 0, 0}, errors.New("could not UnpackHRS from sign bytes") -} diff --git a/signer/serialization_test.go b/signer/serialization_test.go deleted file mode 100644 index d51a678c..00000000 --- a/signer/serialization_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package signer - -import ( - "testing" - - cometproto "github.com/cometbft/cometbft/proto/tendermint/types" - comet "github.com/cometbft/cometbft/types" - "github.com/stretchr/testify/require" -) - -func TestUnpackHRSPrevote(t *testing.T) { - vote := cometproto.Vote{ - Height: 1, - Round: 2, - Type: cometproto.PrevoteType, - } - - signBytes := comet.VoteSignBytes("chain-id", &vote) - - hrs, err := UnpackHRST(signBytes) - require.NoError(t, err) - require.Equal(t, int64(1), hrs.Height) - require.Equal(t, int64(2), hrs.Round) - require.Equal(t, int8(2), hrs.Step) -} - -func TestUnpackHRSPrecommit(t *testing.T) { - vote := cometproto.Vote{ - Height: 3, - Round: 2, - Type: cometproto.PrecommitType, - } - - signBytes := comet.VoteSignBytes("chain-id", &vote) - - hrs, err := UnpackHRST(signBytes) - require.NoError(t, err) - require.Equal(t, int64(3), hrs.Height) - require.Equal(t, int64(2), hrs.Round) - require.Equal(t, int8(3), hrs.Step) -} - -func TestUnpackHRSProposal(t *testing.T) { - proposal := cometproto.Proposal{ - Height: 1, - Round: 2, - Type: cometproto.ProposalType, - } - - signBytes := comet.ProposalSignBytes("chain-id", &proposal) - - hrs, err := UnpackHRST(signBytes) - require.NoError(t, err) - require.Equal(t, int64(1), hrs.Height) - require.Equal(t, int64(2), hrs.Round) - require.Equal(t, int8(1), hrs.Step) -} diff --git a/signer/sign_state.go b/signer/sign_state.go index 819a36b0..f6c55eb0 100644 --- a/signer/sign_state.go +++ b/signer/sign_state.go @@ -100,12 +100,13 @@ func StepToType(step int8) cometproto.SignedMsgType { // SignState stores signing information for high level watermark management. type SignState struct { - Height int64 `json:"height"` - Round int64 `json:"round"` - Step int8 `json:"step"` - NoncePublic []byte `json:"nonce_public"` - Signature []byte `json:"signature,omitempty"` - SignBytes cometbytes.HexBytes `json:"signbytes,omitempty"` + Height int64 `json:"height"` + Round int64 `json:"round"` + Step int8 `json:"step"` + NoncePublic []byte `json:"nonce_public"` + Signature []byte `json:"signature,omitempty"` + SignBytes cometbytes.HexBytes `json:"signbytes,omitempty"` + VoteExtensionSignature []byte `json:"vote_ext_signature,omitempty"` filePath string @@ -232,6 +233,7 @@ func (signState *SignState) cacheAndMarshal(ssc SignStateConsensus) []byte { signState.Step = ssc.Step signState.Signature = ssc.Signature signState.SignBytes = ssc.SignBytes + signState.VoteExtensionSignature = ssc.VoteExtensionSignature jsonBytes, err := cometjson.MarshalIndent(signState, "", " ") if err != nil { @@ -416,13 +418,14 @@ func (signState *SignState) GetErrorIfLessOrEqual(height int64, round int64, ste // including the most recent sign state. func (signState *SignState) FreshCache() *SignState { newSignState := &SignState{ - Height: signState.Height, - Round: signState.Round, - Step: signState.Step, - NoncePublic: signState.NoncePublic, - Signature: signState.Signature, - SignBytes: signState.SignBytes, - cache: make(map[HRSKey]SignStateConsensus), + Height: signState.Height, + Round: signState.Round, + Step: signState.Step, + NoncePublic: signState.NoncePublic, + Signature: signState.Signature, + SignBytes: signState.SignBytes, + VoteExtensionSignature: signState.VoteExtensionSignature, + cache: make(map[HRSKey]SignStateConsensus), filePath: signState.filePath, } @@ -434,11 +437,12 @@ func (signState *SignState) FreshCache() *SignState { Round: signState.Round, Step: signState.Step, }] = SignStateConsensus{ - Height: signState.Height, - Round: signState.Round, - Step: signState.Step, - Signature: signState.Signature, - SignBytes: signState.SignBytes, + Height: signState.Height, + Round: signState.Round, + Step: signState.Step, + Signature: signState.Signature, + SignBytes: signState.SignBytes, + VoteExtensionSignature: signState.VoteExtensionSignature, } return newSignState diff --git a/signer/single_signer_validator.go b/signer/single_signer_validator.go index 33d78de0..846e283d 100644 --- a/signer/single_signer_validator.go +++ b/signer/single_signer_validator.go @@ -65,7 +65,7 @@ func (pv *SingleSignerValidator) Sign( chainState.pvMutex.Lock() defer chainState.pvMutex.Unlock() - return chainState.filePV.Sign(block) + return chainState.filePV.Sign(chainID, block) } func (pv *SingleSignerValidator) loadChainStateIfNecessary(chainID string) (*SingleSignerChainState, error) { diff --git a/signer/threshold_validator.go b/signer/threshold_validator.go index ba33dbb1..7ca5f87b 100644 --- a/signer/threshold_validator.go +++ b/signer/threshold_validator.go @@ -699,7 +699,14 @@ func (pv *ThresholdValidator) Sign( } var voteExtNonces *CosignerUUIDNonces - if step == stepPrecommit && len(voteExtensionSignBytes) > 0 { + + _, hasVoteExtensions, err := verifySignPayload(chainID, signBytes, voteExtensionSignBytes) + if err != nil { + pv.notifyBlockSignError(chainID, block.HRSKey(), signBytes) + return nil, nil, stamp, fmt.Errorf("failed to verify payload: %w", err) + } + + if hasVoteExtensions { voteExtNonces, err = pv.nonceCache.GetNonces(cosignersForThisBlock) if err != nil { // TODO how to handle fallback for vote extensions? @@ -853,7 +860,7 @@ func (pv *ThresholdValidator) Sign( var voteExtSig []byte - if step == stepPrecommit && len(voteExtensionSignBytes) > 0 { + if hasVoteExtensions { // collect all valid responses into array of partial signatures voteExtShareSigs := make([]PartialSignature, 0, pv.threshold) for idx, shareSig := range voteExtShareSignatures { diff --git a/signer/threshold_validator_test.go b/signer/threshold_validator_test.go index 05d499d5..ca016cdc 100644 --- a/signer/threshold_validator_test.go +++ b/signer/threshold_validator_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/rand" + "crypto/sha256" "fmt" mrand "math/rand" "path/filepath" @@ -254,9 +255,13 @@ func testThresholdValidator(t *testing.T, threshold, total uint8) { err = eg.Wait() require.NoError(t, err) + blockIDHash := sha256.New() + blockIDHash.Write([]byte("something")) + precommit := cometproto.Vote{ Height: int64(i), Round: 0, + BlockID: cometproto.BlockID{Hash: blockIDHash.Sum(nil)}, Type: cometproto.PrecommitType, Timestamp: time.Now(), Extension: []byte("test"),