diff --git a/execution.go b/execution.go index 88574eb..2eeba4c 100644 --- a/execution.go +++ b/execution.go @@ -1,6 +1,7 @@ package execution import ( + "context" "time" "github.com/rollkit/go-execution/types" @@ -9,33 +10,14 @@ import ( // Executor defines a common interface for interacting with the execution client. type Executor interface { // InitChain initializes the blockchain with genesis information. - InitChain( - genesisTime time.Time, - initialHeight uint64, - chainID string, - ) ( - stateRoot types.Hash, - maxBytes uint64, - err error, - ) + InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (stateRoot types.Hash, maxBytes uint64, err error) // GetTxs retrieves all available transactions from the execution client's mempool. - GetTxs() ([]types.Tx, error) + GetTxs(ctx context.Context) ([]types.Tx, error) // ExecuteTxs executes a set of transactions to produce a new block header. - ExecuteTxs( - txs []types.Tx, - blockHeight uint64, - timestamp time.Time, - prevStateRoot types.Hash, - ) ( - updatedStateRoot types.Hash, - maxBytes uint64, - err error, - ) + ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (updatedStateRoot types.Hash, maxBytes uint64, err error) // SetFinal marks a block at the given height as final. - SetFinal( - blockHeight uint64, - ) error + SetFinal(ctx context.Context, blockHeight uint64) error } diff --git a/mocks/mock_Executor.go b/mocks/mock_Executor.go index 7245228..2110bf6 100644 --- a/mocks/mock_Executor.go +++ b/mocks/mock_Executor.go @@ -3,7 +3,10 @@ package mocks import ( + context "context" + header "github.com/celestiaorg/go-header" + mock "github.com/stretchr/testify/mock" time "time" @@ -24,9 +27,9 @@ func (_m *MockExecutor) EXPECT() *MockExecutor_Expecter { return &MockExecutor_Expecter{mock: &_m.Mock} } -// ExecuteTxs provides a mock function with given fields: txs, blockHeight, timestamp, prevStateRoot -func (_m *MockExecutor) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot header.Hash) (header.Hash, uint64, error) { - ret := _m.Called(txs, blockHeight, timestamp, prevStateRoot) +// ExecuteTxs provides a mock function with given fields: ctx, txs, blockHeight, timestamp, prevStateRoot +func (_m *MockExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot header.Hash) (header.Hash, uint64, error) { + ret := _m.Called(ctx, txs, blockHeight, timestamp, prevStateRoot) if len(ret) == 0 { panic("no return value specified for ExecuteTxs") @@ -35,25 +38,25 @@ func (_m *MockExecutor) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp var r0 header.Hash var r1 uint64 var r2 error - if rf, ok := ret.Get(0).(func([]types.Tx, uint64, time.Time, header.Hash) (header.Hash, uint64, error)); ok { - return rf(txs, blockHeight, timestamp, prevStateRoot) + if rf, ok := ret.Get(0).(func(context.Context, []types.Tx, uint64, time.Time, header.Hash) (header.Hash, uint64, error)); ok { + return rf(ctx, txs, blockHeight, timestamp, prevStateRoot) } - if rf, ok := ret.Get(0).(func([]types.Tx, uint64, time.Time, header.Hash) header.Hash); ok { - r0 = rf(txs, blockHeight, timestamp, prevStateRoot) + if rf, ok := ret.Get(0).(func(context.Context, []types.Tx, uint64, time.Time, header.Hash) header.Hash); ok { + r0 = rf(ctx, txs, blockHeight, timestamp, prevStateRoot) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(header.Hash) } } - if rf, ok := ret.Get(1).(func([]types.Tx, uint64, time.Time, header.Hash) uint64); ok { - r1 = rf(txs, blockHeight, timestamp, prevStateRoot) + if rf, ok := ret.Get(1).(func(context.Context, []types.Tx, uint64, time.Time, header.Hash) uint64); ok { + r1 = rf(ctx, txs, blockHeight, timestamp, prevStateRoot) } else { r1 = ret.Get(1).(uint64) } - if rf, ok := ret.Get(2).(func([]types.Tx, uint64, time.Time, header.Hash) error); ok { - r2 = rf(txs, blockHeight, timestamp, prevStateRoot) + if rf, ok := ret.Get(2).(func(context.Context, []types.Tx, uint64, time.Time, header.Hash) error); ok { + r2 = rf(ctx, txs, blockHeight, timestamp, prevStateRoot) } else { r2 = ret.Error(2) } @@ -67,17 +70,18 @@ type MockExecutor_ExecuteTxs_Call struct { } // ExecuteTxs is a helper method to define mock.On call +// - ctx context.Context // - txs []types.Tx // - blockHeight uint64 // - timestamp time.Time // - prevStateRoot header.Hash -func (_e *MockExecutor_Expecter) ExecuteTxs(txs interface{}, blockHeight interface{}, timestamp interface{}, prevStateRoot interface{}) *MockExecutor_ExecuteTxs_Call { - return &MockExecutor_ExecuteTxs_Call{Call: _e.mock.On("ExecuteTxs", txs, blockHeight, timestamp, prevStateRoot)} +func (_e *MockExecutor_Expecter) ExecuteTxs(ctx interface{}, txs interface{}, blockHeight interface{}, timestamp interface{}, prevStateRoot interface{}) *MockExecutor_ExecuteTxs_Call { + return &MockExecutor_ExecuteTxs_Call{Call: _e.mock.On("ExecuteTxs", ctx, txs, blockHeight, timestamp, prevStateRoot)} } -func (_c *MockExecutor_ExecuteTxs_Call) Run(run func(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot header.Hash)) *MockExecutor_ExecuteTxs_Call { +func (_c *MockExecutor_ExecuteTxs_Call) Run(run func(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot header.Hash)) *MockExecutor_ExecuteTxs_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]types.Tx), args[1].(uint64), args[2].(time.Time), args[3].(header.Hash)) + run(args[0].(context.Context), args[1].([]types.Tx), args[2].(uint64), args[3].(time.Time), args[4].(header.Hash)) }) return _c } @@ -87,14 +91,14 @@ func (_c *MockExecutor_ExecuteTxs_Call) Return(updatedStateRoot header.Hash, max return _c } -func (_c *MockExecutor_ExecuteTxs_Call) RunAndReturn(run func([]types.Tx, uint64, time.Time, header.Hash) (header.Hash, uint64, error)) *MockExecutor_ExecuteTxs_Call { +func (_c *MockExecutor_ExecuteTxs_Call) RunAndReturn(run func(context.Context, []types.Tx, uint64, time.Time, header.Hash) (header.Hash, uint64, error)) *MockExecutor_ExecuteTxs_Call { _c.Call.Return(run) return _c } -// GetTxs provides a mock function with given fields: -func (_m *MockExecutor) GetTxs() ([]types.Tx, error) { - ret := _m.Called() +// GetTxs provides a mock function with given fields: ctx +func (_m *MockExecutor) GetTxs(ctx context.Context) ([]types.Tx, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for GetTxs") @@ -102,19 +106,19 @@ func (_m *MockExecutor) GetTxs() ([]types.Tx, error) { var r0 []types.Tx var r1 error - if rf, ok := ret.Get(0).(func() ([]types.Tx, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]types.Tx, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []types.Tx); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []types.Tx); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]types.Tx) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -128,13 +132,14 @@ type MockExecutor_GetTxs_Call struct { } // GetTxs is a helper method to define mock.On call -func (_e *MockExecutor_Expecter) GetTxs() *MockExecutor_GetTxs_Call { - return &MockExecutor_GetTxs_Call{Call: _e.mock.On("GetTxs")} +// - ctx context.Context +func (_e *MockExecutor_Expecter) GetTxs(ctx interface{}) *MockExecutor_GetTxs_Call { + return &MockExecutor_GetTxs_Call{Call: _e.mock.On("GetTxs", ctx)} } -func (_c *MockExecutor_GetTxs_Call) Run(run func()) *MockExecutor_GetTxs_Call { +func (_c *MockExecutor_GetTxs_Call) Run(run func(ctx context.Context)) *MockExecutor_GetTxs_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -144,14 +149,14 @@ func (_c *MockExecutor_GetTxs_Call) Return(_a0 []types.Tx, _a1 error) *MockExecu return _c } -func (_c *MockExecutor_GetTxs_Call) RunAndReturn(run func() ([]types.Tx, error)) *MockExecutor_GetTxs_Call { +func (_c *MockExecutor_GetTxs_Call) RunAndReturn(run func(context.Context) ([]types.Tx, error)) *MockExecutor_GetTxs_Call { _c.Call.Return(run) return _c } -// InitChain provides a mock function with given fields: genesisTime, initialHeight, chainID -func (_m *MockExecutor) InitChain(genesisTime time.Time, initialHeight uint64, chainID string) (header.Hash, uint64, error) { - ret := _m.Called(genesisTime, initialHeight, chainID) +// InitChain provides a mock function with given fields: ctx, genesisTime, initialHeight, chainID +func (_m *MockExecutor) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (header.Hash, uint64, error) { + ret := _m.Called(ctx, genesisTime, initialHeight, chainID) if len(ret) == 0 { panic("no return value specified for InitChain") @@ -160,25 +165,25 @@ func (_m *MockExecutor) InitChain(genesisTime time.Time, initialHeight uint64, c var r0 header.Hash var r1 uint64 var r2 error - if rf, ok := ret.Get(0).(func(time.Time, uint64, string) (header.Hash, uint64, error)); ok { - return rf(genesisTime, initialHeight, chainID) + if rf, ok := ret.Get(0).(func(context.Context, time.Time, uint64, string) (header.Hash, uint64, error)); ok { + return rf(ctx, genesisTime, initialHeight, chainID) } - if rf, ok := ret.Get(0).(func(time.Time, uint64, string) header.Hash); ok { - r0 = rf(genesisTime, initialHeight, chainID) + if rf, ok := ret.Get(0).(func(context.Context, time.Time, uint64, string) header.Hash); ok { + r0 = rf(ctx, genesisTime, initialHeight, chainID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(header.Hash) } } - if rf, ok := ret.Get(1).(func(time.Time, uint64, string) uint64); ok { - r1 = rf(genesisTime, initialHeight, chainID) + if rf, ok := ret.Get(1).(func(context.Context, time.Time, uint64, string) uint64); ok { + r1 = rf(ctx, genesisTime, initialHeight, chainID) } else { r1 = ret.Get(1).(uint64) } - if rf, ok := ret.Get(2).(func(time.Time, uint64, string) error); ok { - r2 = rf(genesisTime, initialHeight, chainID) + if rf, ok := ret.Get(2).(func(context.Context, time.Time, uint64, string) error); ok { + r2 = rf(ctx, genesisTime, initialHeight, chainID) } else { r2 = ret.Error(2) } @@ -192,16 +197,17 @@ type MockExecutor_InitChain_Call struct { } // InitChain is a helper method to define mock.On call +// - ctx context.Context // - genesisTime time.Time // - initialHeight uint64 // - chainID string -func (_e *MockExecutor_Expecter) InitChain(genesisTime interface{}, initialHeight interface{}, chainID interface{}) *MockExecutor_InitChain_Call { - return &MockExecutor_InitChain_Call{Call: _e.mock.On("InitChain", genesisTime, initialHeight, chainID)} +func (_e *MockExecutor_Expecter) InitChain(ctx interface{}, genesisTime interface{}, initialHeight interface{}, chainID interface{}) *MockExecutor_InitChain_Call { + return &MockExecutor_InitChain_Call{Call: _e.mock.On("InitChain", ctx, genesisTime, initialHeight, chainID)} } -func (_c *MockExecutor_InitChain_Call) Run(run func(genesisTime time.Time, initialHeight uint64, chainID string)) *MockExecutor_InitChain_Call { +func (_c *MockExecutor_InitChain_Call) Run(run func(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string)) *MockExecutor_InitChain_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(time.Time), args[1].(uint64), args[2].(string)) + run(args[0].(context.Context), args[1].(time.Time), args[2].(uint64), args[3].(string)) }) return _c } @@ -211,22 +217,22 @@ func (_c *MockExecutor_InitChain_Call) Return(stateRoot header.Hash, maxBytes ui return _c } -func (_c *MockExecutor_InitChain_Call) RunAndReturn(run func(time.Time, uint64, string) (header.Hash, uint64, error)) *MockExecutor_InitChain_Call { +func (_c *MockExecutor_InitChain_Call) RunAndReturn(run func(context.Context, time.Time, uint64, string) (header.Hash, uint64, error)) *MockExecutor_InitChain_Call { _c.Call.Return(run) return _c } -// SetFinal provides a mock function with given fields: blockHeight -func (_m *MockExecutor) SetFinal(blockHeight uint64) error { - ret := _m.Called(blockHeight) +// SetFinal provides a mock function with given fields: ctx, blockHeight +func (_m *MockExecutor) SetFinal(ctx context.Context, blockHeight uint64) error { + ret := _m.Called(ctx, blockHeight) if len(ret) == 0 { panic("no return value specified for SetFinal") } var r0 error - if rf, ok := ret.Get(0).(func(uint64) error); ok { - r0 = rf(blockHeight) + if rf, ok := ret.Get(0).(func(context.Context, uint64) error); ok { + r0 = rf(ctx, blockHeight) } else { r0 = ret.Error(0) } @@ -240,14 +246,15 @@ type MockExecutor_SetFinal_Call struct { } // SetFinal is a helper method to define mock.On call +// - ctx context.Context // - blockHeight uint64 -func (_e *MockExecutor_Expecter) SetFinal(blockHeight interface{}) *MockExecutor_SetFinal_Call { - return &MockExecutor_SetFinal_Call{Call: _e.mock.On("SetFinal", blockHeight)} +func (_e *MockExecutor_Expecter) SetFinal(ctx interface{}, blockHeight interface{}) *MockExecutor_SetFinal_Call { + return &MockExecutor_SetFinal_Call{Call: _e.mock.On("SetFinal", ctx, blockHeight)} } -func (_c *MockExecutor_SetFinal_Call) Run(run func(blockHeight uint64)) *MockExecutor_SetFinal_Call { +func (_c *MockExecutor_SetFinal_Call) Run(run func(ctx context.Context, blockHeight uint64)) *MockExecutor_SetFinal_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(uint64)) + run(args[0].(context.Context), args[1].(uint64)) }) return _c } @@ -257,7 +264,7 @@ func (_c *MockExecutor_SetFinal_Call) Return(_a0 error) *MockExecutor_SetFinal_C return _c } -func (_c *MockExecutor_SetFinal_Call) RunAndReturn(run func(uint64) error) *MockExecutor_SetFinal_Call { +func (_c *MockExecutor_SetFinal_Call) RunAndReturn(run func(context.Context, uint64) error) *MockExecutor_SetFinal_Call { _c.Call.Return(run) return _c } diff --git a/proxy/grpc/client.go b/proxy/grpc/client.go index 3d1e1d3..7b14f73 100644 --- a/proxy/grpc/client.go +++ b/proxy/grpc/client.go @@ -51,10 +51,7 @@ func (c *Client) Stop() error { } // InitChain initializes the blockchain with genesis information. -func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { - ctx, cancel := context.WithTimeout(context.Background(), c.config.DefaultTimeout) - defer cancel() - +func (c *Client) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { resp, err := c.client.InitChain(ctx, &pb.InitChainRequest{ GenesisTime: genesisTime.Unix(), InitialHeight: initialHeight, @@ -71,10 +68,7 @@ func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID } // GetTxs retrieves all available transactions from the execution client's mempool. -func (c *Client) GetTxs() ([]types.Tx, error) { - ctx, cancel := context.WithTimeout(context.Background(), c.config.DefaultTimeout) - defer cancel() - +func (c *Client) GetTxs(ctx context.Context) ([]types.Tx, error) { resp, err := c.client.GetTxs(ctx, &pb.GetTxsRequest{}) if err != nil { return nil, err @@ -89,10 +83,7 @@ func (c *Client) GetTxs() ([]types.Tx, error) { } // ExecuteTxs executes a set of transactions to produce a new block header. -func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { - ctx, cancel := context.WithTimeout(context.Background(), c.config.DefaultTimeout) - defer cancel() - +func (c *Client) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { req := &pb.ExecuteTxsRequest{ Txs: make([][]byte, len(txs)), BlockHeight: blockHeight, @@ -115,10 +106,7 @@ func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.T } // SetFinal marks a block at the given height as final. -func (c *Client) SetFinal(blockHeight uint64) error { - ctx, cancel := context.WithTimeout(context.Background(), c.config.DefaultTimeout) - defer cancel() - +func (c *Client) SetFinal(ctx context.Context, blockHeight uint64) error { _, err := c.client.SetFinal(ctx, &pb.SetFinalRequest{ BlockHeight: blockHeight, }) diff --git a/proxy/grpc/client_server_test.go b/proxy/grpc/client_server_test.go index 14b22b5..af9bc34 100644 --- a/proxy/grpc/client_server_test.go +++ b/proxy/grpc/client_server_test.go @@ -1,11 +1,14 @@ package grpc_test import ( + "context" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" @@ -45,7 +48,7 @@ func TestClientServer(t *testing.T) { require.NoError(t, err) defer func() { _ = client.Stop() }() - mockExec.On("GetTxs").Return([]types.Tx{}, nil).Maybe() + mockExec.On("GetTxs", mock.Anything).Return([]types.Tx{}, nil).Maybe() t.Run("InitChain", func(t *testing.T) { genesisTime := time.Now().UTC().Truncate(time.Second) @@ -64,10 +67,10 @@ func TestClientServer(t *testing.T) { unixTime := genesisTime.Unix() expectedTime := time.Unix(unixTime, 0).UTC() - mockExec.On("InitChain", expectedTime, initialHeight, chainID). + mockExec.On("InitChain", mock.Anything, expectedTime, initialHeight, chainID). Return(stateRootHash, expectedMaxBytes, nil).Once() - stateRoot, maxBytes, err := client.InitChain(genesisTime, initialHeight, chainID) + stateRoot, maxBytes, err := client.InitChain(context.TODO(), genesisTime, initialHeight, chainID) require.NoError(t, err) assert.Equal(t, stateRootHash, stateRoot) diff --git a/proxy/grpc/proxy_test.go b/proxy/grpc/proxy_test.go index a88b871..ab45ae0 100644 --- a/proxy/grpc/proxy_test.go +++ b/proxy/grpc/proxy_test.go @@ -65,7 +65,7 @@ func (s *ProxyTestSuite) SetupTest() { require.NoError(s.T(), err) for i := 0; i < 10; i++ { - if _, err := client.GetTxs(); err == nil { + if _, err := client.GetTxs(context.TODO()); err == nil { break } time.Sleep(100 * time.Millisecond) diff --git a/proxy/grpc/server.go b/proxy/grpc/server.go index e285179..629360e 100644 --- a/proxy/grpc/server.go +++ b/proxy/grpc/server.go @@ -48,11 +48,7 @@ func (s *Server) InitChain(ctx context.Context, req *pb.InitChainRequest) (*pb.I // Convert Unix timestamp to UTC time genesisTime := time.Unix(req.GenesisTime, 0).UTC() - stateRoot, maxBytes, err := s.exec.InitChain( - genesisTime, - req.InitialHeight, - req.ChainId, - ) + stateRoot, maxBytes, err := s.exec.InitChain(ctx, genesisTime, req.InitialHeight, req.ChainId) if err != nil { return nil, err } @@ -65,7 +61,7 @@ func (s *Server) InitChain(ctx context.Context, req *pb.InitChainRequest) (*pb.I // GetTxs handles GetTxs method call from execution API. func (s *Server) GetTxs(ctx context.Context, req *pb.GetTxsRequest) (*pb.GetTxsResponse, error) { - txs, err := s.exec.GetTxs() + txs, err := s.exec.GetTxs(ctx) if err != nil { return nil, err } @@ -91,6 +87,7 @@ func (s *Server) ExecuteTxs(ctx context.Context, req *pb.ExecuteTxsRequest) (*pb copy(prevStateRoot[:], req.PrevStateRoot) updatedStateRoot, maxBytes, err := s.exec.ExecuteTxs( + ctx, txs, req.BlockHeight, time.Unix(req.Timestamp, 0), @@ -108,7 +105,7 @@ func (s *Server) ExecuteTxs(ctx context.Context, req *pb.ExecuteTxsRequest) (*pb // SetFinal handles SetFinal method call from execution API. func (s *Server) SetFinal(ctx context.Context, req *pb.SetFinalRequest) (*pb.SetFinalResponse, error) { - err := s.exec.SetFinal(req.BlockHeight) + err := s.exec.SetFinal(ctx, req.BlockHeight) if err != nil { return nil, err } diff --git a/proxy/jsonrpc/client.go b/proxy/jsonrpc/client.go index cdebf9f..294a88e 100644 --- a/proxy/jsonrpc/client.go +++ b/proxy/jsonrpc/client.go @@ -47,7 +47,7 @@ func (c *Client) Stop() error { } // InitChain initializes the blockchain with genesis information. -func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { +func (c *Client) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { params := map[string]interface{}{ "genesis_time": genesisTime.Unix(), "initial_height": initialHeight, @@ -59,7 +59,7 @@ func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID MaxBytes uint64 `json:"max_bytes"` } - if err := c.call("init_chain", params, &result); err != nil { + if err := c.call(ctx, "init_chain", params, &result); err != nil { return types.Hash{}, 0, err } @@ -75,12 +75,12 @@ func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID } // GetTxs retrieves all available transactions from the execution client's mempool. -func (c *Client) GetTxs() ([]types.Tx, error) { +func (c *Client) GetTxs(ctx context.Context) ([]types.Tx, error) { var result struct { Txs []string `json:"txs"` } - if err := c.call("get_txs", nil, &result); err != nil { + if err := c.call(ctx, "get_txs", nil, &result); err != nil { return nil, err } @@ -97,7 +97,7 @@ func (c *Client) GetTxs() ([]types.Tx, error) { } // ExecuteTxs executes a set of transactions to produce a new block header. -func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { +func (c *Client) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { // Encode txs to base64 encodedTxs := make([]string, len(txs)) for i, tx := range txs { @@ -116,7 +116,7 @@ func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.T MaxBytes uint64 `json:"max_bytes"` } - if err := c.call("execute_txs", params, &result); err != nil { + if err := c.call(ctx, "execute_txs", params, &result); err != nil { return types.Hash{}, 0, err } @@ -132,15 +132,15 @@ func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.T } // SetFinal marks a block at the given height as final. -func (c *Client) SetFinal(blockHeight uint64) error { +func (c *Client) SetFinal(ctx context.Context, blockHeight uint64) error { params := map[string]interface{}{ "block_height": blockHeight, } - return c.call("set_final", params, nil) + return c.call(ctx, "set_final", params, nil) } -func (c *Client) call(method string, params interface{}, result interface{}) error { +func (c *Client) call(ctx context.Context, method string, params interface{}, result interface{}) error { request := struct { JSONRPC string `json:"jsonrpc"` Method string `json:"method"` @@ -158,7 +158,7 @@ func (c *Client) call(method string, params interface{}, result interface{}) err return fmt.Errorf("failed to marshal request: %w", err) } - req, err := http.NewRequestWithContext(context.Background(), "POST", c.endpoint, bytes.NewReader(reqBody)) + req, err := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(reqBody)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } diff --git a/proxy/jsonrpc/client_server_test.go b/proxy/jsonrpc/client_server_test.go index 2295663..ebb6a5d 100644 --- a/proxy/jsonrpc/client_server_test.go +++ b/proxy/jsonrpc/client_server_test.go @@ -1,11 +1,13 @@ package jsonrpc_test import ( + "context" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/rollkit/go-execution/mocks" @@ -47,10 +49,10 @@ func TestClientServer(t *testing.T) { unixTime := genesisTime.Unix() expectedTime := time.Unix(unixTime, 0).UTC() - mockExec.On("InitChain", expectedTime, initialHeight, chainID). + mockExec.On("InitChain", mock.Anything, expectedTime, initialHeight, chainID). Return(stateRootHash, expectedMaxBytes, nil).Once() - stateRoot, maxBytes, err := client.InitChain(genesisTime, initialHeight, chainID) + stateRoot, maxBytes, err := client.InitChain(context.TODO(), genesisTime, initialHeight, chainID) require.NoError(t, err) assert.Equal(t, stateRootHash, stateRoot) @@ -60,9 +62,9 @@ func TestClientServer(t *testing.T) { t.Run("GetTxs", func(t *testing.T) { expectedTxs := []types.Tx{[]byte("tx1"), []byte("tx2")} - mockExec.On("GetTxs").Return(expectedTxs, nil).Once() + mockExec.On("GetTxs", mock.Anything).Return(expectedTxs, nil).Once() - txs, err := client.GetTxs() + txs, err := client.GetTxs(context.TODO()) require.NoError(t, err) assert.Equal(t, expectedTxs, txs) mockExec.AssertExpectations(t) @@ -85,10 +87,10 @@ func TestClientServer(t *testing.T) { unixTime := timestamp.Unix() expectedTime := time.Unix(unixTime, 0).UTC() - mockExec.On("ExecuteTxs", txs, blockHeight, expectedTime, prevStateRoot). + mockExec.On("ExecuteTxs", mock.Anything, txs, blockHeight, expectedTime, prevStateRoot). Return(expectedStateRoot, expectedMaxBytes, nil).Once() - updatedStateRoot, maxBytes, err := client.ExecuteTxs(txs, blockHeight, timestamp, prevStateRoot) + updatedStateRoot, maxBytes, err := client.ExecuteTxs(context.TODO(), txs, blockHeight, timestamp, prevStateRoot) require.NoError(t, err) assert.Equal(t, expectedStateRoot, updatedStateRoot) @@ -98,9 +100,9 @@ func TestClientServer(t *testing.T) { t.Run("SetFinal", func(t *testing.T) { blockHeight := uint64(1) - mockExec.On("SetFinal", blockHeight).Return(nil).Once() + mockExec.On("SetFinal", mock.Anything, blockHeight).Return(nil).Once() - err := client.SetFinal(blockHeight) + err := client.SetFinal(context.TODO(), blockHeight) require.NoError(t, err) mockExec.AssertExpectations(t) }) diff --git a/proxy/jsonrpc/server.go b/proxy/jsonrpc/server.go index a5d94ad..b91b24b 100644 --- a/proxy/jsonrpc/server.go +++ b/proxy/jsonrpc/server.go @@ -1,6 +1,7 @@ package jsonrpc import ( + "context" "encoding/base64" "encoding/json" "net/http" @@ -90,11 +91,7 @@ func (s *Server) handleInitChain(params json.RawMessage) (interface{}, *jsonRPCE return nil, ErrInvalidParams } - stateRoot, maxBytes, err := s.exec.InitChain( - time.Unix(p.GenesisTime, 0).UTC(), - p.InitialHeight, - p.ChainID, - ) + stateRoot, maxBytes, err := s.exec.InitChain(context.TODO(), time.Unix(p.GenesisTime, 0).UTC(), p.InitialHeight, p.ChainID) if err != nil { return nil, &jsonRPCError{Code: ErrCodeInternal, Message: err.Error()} } @@ -106,7 +103,7 @@ func (s *Server) handleInitChain(params json.RawMessage) (interface{}, *jsonRPCE } func (s *Server) handleGetTxs() (interface{}, *jsonRPCError) { - txs, err := s.exec.GetTxs() + txs, err := s.exec.GetTxs(context.TODO()) if err != nil { return nil, &jsonRPCError{Code: ErrCodeInternal, Message: err.Error()} } @@ -152,12 +149,7 @@ func (s *Server) handleExecuteTxs(params json.RawMessage) (interface{}, *jsonRPC var prevStateRoot types.Hash copy(prevStateRoot[:], prevStateRootBytes) - updatedStateRoot, maxBytes, err := s.exec.ExecuteTxs( - txs, - p.BlockHeight, - time.Unix(p.Timestamp, 0).UTC(), - prevStateRoot, - ) + updatedStateRoot, maxBytes, err := s.exec.ExecuteTxs(context.TODO(), txs, p.BlockHeight, time.Unix(p.Timestamp, 0).UTC(), prevStateRoot) if err != nil { return nil, &jsonRPCError{Code: ErrCodeInternal, Message: err.Error()} } @@ -177,7 +169,7 @@ func (s *Server) handleSetFinal(params json.RawMessage) (interface{}, *jsonRPCEr return nil, ErrInvalidParams } - if err := s.exec.SetFinal(p.BlockHeight); err != nil { + if err := s.exec.SetFinal(context.TODO(), p.BlockHeight); err != nil { return nil, &jsonRPCError{Code: ErrCodeInternal, Message: err.Error()} } diff --git a/test/dummy.go b/test/dummy.go index 2214aea..bac37a4 100644 --- a/test/dummy.go +++ b/test/dummy.go @@ -1,6 +1,7 @@ package test import ( + "context" "time" "github.com/rollkit/go-execution/types" @@ -24,22 +25,22 @@ func NewDummyExecutor() *DummyExecutor { // InitChain initializes the chain state with the given genesis time, initial height, and chain ID. // It returns the state root hash, the maximum byte size, and an error if the initialization fails. -func (e *DummyExecutor) InitChain(genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { +func (e *DummyExecutor) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { return e.stateRoot, e.maxBytes, nil } // GetTxs returns the list of transactions (types.Tx) within the DummyExecutor instance and an error if any. -func (e *DummyExecutor) GetTxs() ([]types.Tx, error) { +func (e *DummyExecutor) GetTxs(context.Context) ([]types.Tx, error) { return e.txs, nil } // ExecuteTxs simulate execution of transactions. -func (e *DummyExecutor) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { +func (e *DummyExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { e.txs = append(e.txs, txs...) return e.stateRoot, e.maxBytes, nil } // SetFinal marks block at given height as finalized. Currently not implemented. -func (e *DummyExecutor) SetFinal(blockHeight uint64) error { +func (e *DummyExecutor) SetFinal(ctx context.Context, blockHeight uint64) error { return nil } diff --git a/test/suite.go b/test/suite.go index 244dc90..c0536fb 100644 --- a/test/suite.go +++ b/test/suite.go @@ -1,6 +1,7 @@ package test import ( + "context" "time" "github.com/stretchr/testify/suite" @@ -21,7 +22,7 @@ func (s *ExecutorSuite) TestInitChain() { initialHeight := uint64(1) chainID := "test-chain" - stateRoot, maxBytes, err := s.Exec.InitChain(genesisTime, initialHeight, chainID) + stateRoot, maxBytes, err := s.Exec.InitChain(context.TODO(), genesisTime, initialHeight, chainID) s.Require().NoError(err) s.NotEqual(types.Hash{}, stateRoot) s.Greater(maxBytes, uint64(0)) @@ -29,7 +30,7 @@ func (s *ExecutorSuite) TestInitChain() { // TestGetTxs tests GetTxs method. func (s *ExecutorSuite) TestGetTxs() { - txs, err := s.Exec.GetTxs() + txs, err := s.Exec.GetTxs(context.TODO()) s.Require().NoError(err) s.NotNil(txs) } @@ -41,7 +42,7 @@ func (s *ExecutorSuite) TestExecuteTxs() { timestamp := time.Now().UTC() prevStateRoot := types.Hash{1, 2, 3} - stateRoot, maxBytes, err := s.Exec.ExecuteTxs(txs, blockHeight, timestamp, prevStateRoot) + stateRoot, maxBytes, err := s.Exec.ExecuteTxs(context.TODO(), txs, blockHeight, timestamp, prevStateRoot) s.Require().NoError(err) s.NotEqual(types.Hash{}, stateRoot) s.Greater(maxBytes, uint64(0)) @@ -49,6 +50,6 @@ func (s *ExecutorSuite) TestExecuteTxs() { // TestSetFinal tests SetFinal method. func (s *ExecutorSuite) TestSetFinal() { - err := s.Exec.SetFinal(1) + err := s.Exec.SetFinal(context.TODO(), 1) s.Require().NoError(err) }