diff --git a/bridge/setu/listener/rootchain_selfheal_graph.go b/bridge/setu/listener/rootchain_selfheal_graph.go index 736f4bc99..43542c7f0 100644 --- a/bridge/setu/listener/rootchain_selfheal_graph.go +++ b/bridge/setu/listener/rootchain_selfheal_graph.go @@ -14,6 +14,8 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" jsoniter "github.com/json-iterator/go" + + "github.com/maticnetwork/heimdall/helper" ) // StakeUpdate represents the StakeUpdate event @@ -42,6 +44,7 @@ type stateSyncResponse struct { } `json:"data"` } +// querySubGraph queries the subgraph and limits the read size func (rl *RootChainListener) querySubGraph(query []byte, ctx context.Context) (data []byte, err error) { request, err := http.NewRequestWithContext(ctx, http.MethodPost, rl.subGraphClient.graphUrl, bytes.NewBuffer(query)) if err != nil { @@ -56,7 +59,10 @@ func (rl *RootChainListener) querySubGraph(query []byte, ctx context.Context) (d } defer response.Body.Close() - return io.ReadAll(response.Body) + // Limit the number of bytes read from the response body + limitedBody := http.MaxBytesReader(nil, response.Body, helper.APIBodyLimit) + + return io.ReadAll(limitedBody) } // getLatestStateID returns state ID from the latest StateSynced event diff --git a/bridge/setu/processor/base.go b/bridge/setu/processor/base.go index f33b4e357..5505463dc 100644 --- a/bridge/setu/processor/base.go +++ b/bridge/setu/processor/base.go @@ -162,12 +162,14 @@ func (bp *BaseProcessor) checkTxAgainstMempool(msg types.Msg, event interface{}) bp.Logger.Error("Error fetching mempool tx", "url", endpoint, "error", err) return false, err } - - body, err := io.ReadAll(resp.Body) defer resp.Body.Close() + // Limit the number of bytes read from the response body + limitedBody := http.MaxBytesReader(nil, resp.Body, helper.APIBodyLimit) + + body, err := io.ReadAll(limitedBody) if err != nil { - bp.Logger.Error("Error fetching mempool tx", "error", err) + bp.Logger.Error("Error reading response body for mempool tx", "error", err) return false, err } diff --git a/bridge/setu/util/common.go b/bridge/setu/util/common.go index d8173b004..12a4ae108 100644 --- a/bridge/setu/util/common.go +++ b/bridge/setu/util/common.go @@ -652,10 +652,12 @@ func GetUnconfirmedTxnCount(event interface{}) int { logger.Error("Error fetching mempool txs count", "url", endpoint, "error", err) return 0 } - - body, err := io.ReadAll(resp.Body) defer resp.Body.Close() + // Limit the number of bytes read from the response body + limitedBody := http.MaxBytesReader(nil, resp.Body, helper.APIBodyLimit) + + body, err := io.ReadAll(limitedBody) if err != nil { logger.Error("Error fetching mempool txs count", "error", err) return 0 diff --git a/helper/call.go b/helper/call.go index dca52bece..75424a496 100644 --- a/helper/call.go +++ b/helper/call.go @@ -310,10 +310,10 @@ func (c *ContractCaller) GetRootHash(start uint64, end uint64, checkpointLength return common.FromHex(rootHash), nil } -// GetRootHash get root hash from bor chain +// GetVoteOnHash gets vote on hash from bor chain func (c *ContractCaller) GetVoteOnHash(start uint64, end uint64, milestoneLength uint64, hash string, milestoneID string) (bool, error) { if start > end { - return false, errors.New("Start block number is greater than the end block number") + return false, errors.New("start block number is greater than the end block number") } ctx, cancel := context.WithTimeout(context.Background(), c.MaticChainTimeout) @@ -368,7 +368,7 @@ func (c *ContractCaller) GetValidatorInfo(valID types.ValidatorID, stakingInfoIn // amount, startEpoch, endEpoch, signer, status, err := c.StakingInfoInstance.GetStakerDetails(nil, big.NewInt(int64(valID))) stakerDetails, err := stakingInfoInstance.GetStakerDetails(nil, big.NewInt(int64(valID))) if err != nil { - Logger.Error("Error fetching validator information from stake manager", "validatorId", valID, "status", stakerDetails.Status, "error", err) + Logger.Error("Error fetching validator information from stake manager", "validatorId", valID, "error", err) return } @@ -814,6 +814,9 @@ func (c *ContractCaller) GetSpanDetails(id *big.Int, validatorSetInstance *valid error, ) { d, err := validatorSetInstance.GetSpan(nil, id) + if err != nil { + return nil, nil, nil, err + } return d.Number, d.StartBlock, d.EndBlock, err } diff --git a/helper/util.go b/helper/util.go index e3f41956e..5e17674d7 100644 --- a/helper/util.go +++ b/helper/util.go @@ -42,6 +42,8 @@ import ( "github.com/maticnetwork/heimdall/types/rest" ) +const APIBodyLimit = 128 * 1024 * 1024 // 128 MB + //go:generate mockgen -destination=./mocks/http_client_mock.go -package=mocks . HTTPClient type HTTPClient interface { Get(string) (resp *http.Response, err error) @@ -567,10 +569,19 @@ func SignStdTx(cliCtx context.CLIContext, stdTx authTypes.StdTx, appendSig bool, // ReadStdTxFromFile and decode a StdTx from the given filename. Can pass "-" to read from stdin. func ReadStdTxFromFile(cdc *amino.Codec, filename string) (stdTx authTypes.StdTx, err error) { var bytes []byte + if filename == "-" { - bytes, err = io.ReadAll(os.Stdin) + limitedReader := &io.LimitedReader{R: os.Stdin, N: APIBodyLimit} + bytes, err = io.ReadAll(limitedReader) } else { - bytes, err = os.ReadFile(filename) + file, er := os.Open(filename) + if er != nil { + err = er + return + } + defer file.Close() + limitedReader := &io.LimitedReader{R: file, N: APIBodyLimit} + bytes, err = io.ReadAll(limitedReader) } if err != nil { @@ -802,7 +813,7 @@ func GetHeimdallServerEndpoint(endpoint string) string { return u.String() } -// FetchFromAPI fetches data from any URL +// FetchFromAPI fetches data from any URL with limited read size func FetchFromAPI(cliCtx cliContext.CLIContext, URL string) (result rest.ResponseWithHeight, err error) { resp, err := Client.Get(URL) if err != nil { @@ -811,9 +822,12 @@ func FetchFromAPI(cliCtx cliContext.CLIContext, URL string) (result rest.Respons defer resp.Body.Close() - // response + // Limit the number of bytes read from the response body + limitedBody := http.MaxBytesReader(nil, resp.Body, APIBodyLimit) + + // Handle the response if resp.StatusCode == 200 { - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(limitedBody) if err != nil { return result, err } diff --git a/types/rest/rest.go b/types/rest/rest.go index 622e743cb..92f1cfe9e 100644 --- a/types/rest/rest.go +++ b/types/rest/rest.go @@ -23,7 +23,8 @@ import ( const ( DefaultPage = 1 - DefaultLimit = 30 // should be consistent with tendermint/tendermint/rpc/core/pipe.go:19 + DefaultLimit = 30 // should be consistent with tendermint/tendermint/rpc/core/pipe.go:19 + APIBodyLimit = 128 * 1024 * 1024 // 128 MB ) var ( @@ -128,7 +129,10 @@ func (br BaseReq) ValidateBasic(w http.ResponseWriter) bool { // ReadRESTReq reads and unmarshals a Request's body to the BaseReq struct. // Writes an error response to ResponseWriter and returns true if errors occurred. func ReadRESTReq(w http.ResponseWriter, r *http.Request, cdc *codec.Codec, req interface{}) bool { - body, err := io.ReadAll(r.Body) + // Limit the number of bytes read from the request body + limitedBody := http.MaxBytesReader(w, r.Body, APIBodyLimit) + + body, err := io.ReadAll(limitedBody) if err != nil { WriteErrorResponse(w, http.StatusBadRequest, err.Error()) return false