diff --git a/agent/ipc/messagebus/respondent.go b/agent/ipc/messagebus/respondent.go index 9e03d1ff1..8a7bdbdf7 100644 --- a/agent/ipc/messagebus/respondent.go +++ b/agent/ipc/messagebus/respondent.go @@ -28,6 +28,11 @@ import ( _ "go.nanomsg.org/mangos/v3/transport/ipc" ) +const ( + recvErrSleepTime = 30 * time.Second + maxRecvErrCount = 5 +) + // IMessageBus is the interface for process the core agent broadcast request type IMessageBus interface { ProcessHealthRequest() @@ -93,12 +98,25 @@ func (bus *MessageBus) ProcessHealthRequest() { } log.Infof("Start to listen to Core Agent health channel") + errRecvCount := 0 + for { var request *message.Message if msg, err = bus.healthChannel.Recv(); err != nil { - log.Errorf("Failed to receive from health channel: %s", err.Error()) + errRecvCount++ + log.Errorf("failed to receive from health channel: %s", err.Error()) + if errRecvCount >= maxRecvErrCount { + // Agent core will still consider worker healthy as the ssm-agent-worker exists in the system process tree + log.Errorf("failed to receive from agent core health channel %v times. Stopping health ipc listener", errRecvCount) + return + } + + log.Debugf("Retrying receive from core agent health channel in %v seconds", recvErrSleepTime.Seconds()) + bus.sleepFunc(recvErrSleepTime) continue } + + errRecvCount = 0 log.Debugf("Received health request from core agent %s", string(msg)) if err = json.Unmarshal(msg, &request); err != nil { @@ -158,15 +176,33 @@ func (bus *MessageBus) ProcessTerminationRequest() { log.Infof("Start to listen to Core Agent termination channel") bus.terminationChannelConnected <- true + errRecvCount := 0 for { var request *message.Message if msg, err = bus.terminationChannel.Recv(); err != nil { - log.Errorf("cannot recv: %s", err.Error()) + log.Errorf("cannot receive message from core agent termination channel: %s", err.Error()) + errRecvCount++ + if errRecvCount >= maxRecvErrCount { + // Consider communication channel to agent core to be broken + // When ssm-agent-worker exits agent-ssm-agent (agent core) creates new worker with updated communication channel + log.Errorf("failed to receive message from core agent termination channel %v times. Initiating worker shutdown", errRecvCount) + // Unblock main() function to allow agent worker to exit + bus.terminationRequestChannel <- true + close(bus.terminationRequestChannel) + // exit termination channel goroutine + break + } + + log.Debugf("Retrying receive from core agent termination channel in %v seconds", recvErrSleepTime.Seconds()) + bus.sleepFunc(recvErrSleepTime) continue } + log.Infof("Received termination message from core agent %s", string(msg)) + errRecvCount = 0 if err = json.Unmarshal(msg, &request); err != nil { + // Incoming message is dropped and not retried log.Errorf("failed to unmarshal message: %s", err.Error()) continue } @@ -180,10 +216,12 @@ func (bus *MessageBus) ProcessTerminationRequest() { message.LongRunning, os.Getpid(), true); err != nil { + // Response message is dropped and not retried log.Errorf("failed to create termination response: %s", err.Error()) } if err = bus.terminationChannel.Send(result); err != nil { + // Response message is dropped and not retried log.Errorf("failed to send termination response: %s", err.Error()) continue } diff --git a/agent/ipc/messagebus/respondent_test.go b/agent/ipc/messagebus/respondent_test.go index 3737f295c..7cb2ef5f7 100644 --- a/agent/ipc/messagebus/respondent_test.go +++ b/agent/ipc/messagebus/respondent_test.go @@ -24,7 +24,7 @@ import ( "github.com/aws/amazon-ssm-agent/agent/log" contextmocks "github.com/aws/amazon-ssm-agent/agent/mocks/context" logmocks "github.com/aws/amazon-ssm-agent/agent/mocks/log" - channel "github.com/aws/amazon-ssm-agent/common/channel" + "github.com/aws/amazon-ssm-agent/common/channel" channelmocks "github.com/aws/amazon-ssm-agent/common/channel/mocks" "github.com/aws/amazon-ssm-agent/common/message" "github.com/stretchr/testify/mock" @@ -66,6 +66,109 @@ func (suite *MessageBusTestSuite) SetupTest() { } } +func (suite *MessageBusTestSuite) TestProcessHealthRequest_Successful() { + // Arrange + suite.mockHealthChannel.On("IsChannelInitialized").Return(true).Once() + suite.mockHealthChannel.On("IsDialSuccessful").Return(true).Once() + suite.mockHealthChannel.On("Close").Return(nil).Once() + request := message.CreateHealthRequest() + requestString, _ := jsonutil.Marshal(request) + suite.mockHealthChannel.On("Recv").Return([]byte(requestString), nil).Once() + suite.mockHealthChannel.On("Send", mock.Anything).Return(nil) + // Kills the infinite loop + suite.mockHealthChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount) + + // Act + suite.messageBus.ProcessHealthRequest() + + // Assert + suite.mockHealthChannel.AssertExpectations(suite.T()) +} + +func (suite *MessageBusTestSuite) TestProcessHealthRequest_RecvError() { + // Arrange + suite.mockHealthChannel.On("IsChannelInitialized").Return(true).Once() + suite.mockHealthChannel.On("IsDialSuccessful").Return(true).Once() + suite.mockHealthChannel.On("Close").Return(nil).Once() + suite.mockHealthChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount) + + // Act + suite.messageBus.ProcessHealthRequest() + + // Assert + suite.mockHealthChannel.AssertExpectations(suite.T()) +} + +func (suite *MessageBusTestSuite) TestProcessHealthRequest_RecvErrorCount_Resets() { + // Arrange + suite.mockHealthChannel.On("IsChannelInitialized").Return(true).Once() + suite.mockHealthChannel.On("IsDialSuccessful").Return(true).Once() + suite.mockHealthChannel.On("Close").Return(nil).Once() + suite.mockHealthChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount - 1) + request := message.CreateHealthRequest() + requestString, _ := jsonutil.Marshal(request) + suite.mockHealthChannel.On("Recv").Return([]byte(requestString), nil).Once() + suite.mockHealthChannel.On("Send", mock.Anything).Return(nil) + // Kills the infinite loop + suite.mockHealthChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount) + + // Act + suite.messageBus.ProcessHealthRequest() + + // Assert + suite.mockHealthChannel.AssertExpectations(suite.T()) +} + +func (suite *MessageBusTestSuite) TestProcessTerminationRequest_Error() { + suite.mockTerminateChannel.On("IsDialSuccessful").Return(true).Once() + suite.mockTerminateChannel.On("IsChannelInitialized").Return(true).Once() + suite.mockTerminateChannel.On("Close").Return(nil).Once() + suite.mockTerminateChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount) + + suite.messageBus.ProcessTerminationRequest() + + suite.mockTerminateChannel.AssertExpectations(suite.T()) + + // Assert termination channel connected and that a termination message is sent + suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationChannelConnectedChan()) + suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationRequestChan()) +} + +func (suite *MessageBusTestSuite) TestProcessTerminationRequest_RecvRetried() { + suite.mockTerminateChannel.On("IsDialSuccessful").Return(true).Once() + suite.mockTerminateChannel.On("IsChannelInitialized").Return(true).Once() + suite.mockTerminateChannel.On("Close").Return(nil).Once() + suite.mockTerminateChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount - 1) + request := message.CreateTerminateWorkerRequest() + requestString, _ := jsonutil.Marshal(request) + suite.mockTerminateChannel.On("Recv").Return([]byte(requestString), nil) + suite.mockTerminateChannel.On("Send", mock.Anything).Return(nil) + suite.messageBus.ProcessTerminationRequest() + suite.mockTerminateChannel.AssertExpectations(suite.T()) + + // Assert termination channel connected and that a termination message is sent + suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationChannelConnectedChan()) + suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationRequestChan()) +} + +func (suite *MessageBusTestSuite) TestProcessTerminationRequest_RecvRetriesReset() { + suite.mockTerminateChannel.On("IsDialSuccessful").Return(true).Once() + suite.mockTerminateChannel.On("IsChannelInitialized").Return(true).Once() + suite.mockTerminateChannel.On("Close").Return(nil).Once() + suite.mockTerminateChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount - 1) + suite.mockTerminateChannel.On("Recv").Return([]byte("not valid json message"), nil).Once() + request := message.CreateTerminateWorkerRequest() + requestString, _ := jsonutil.Marshal(request) + suite.mockTerminateChannel.On("Recv").Return([]byte(requestString), nil).Once() + suite.mockTerminateChannel.On("Send", mock.Anything).Return(nil) + suite.messageBus.ProcessTerminationRequest() + suite.mockTerminateChannel.AssertExpectations(suite.T()) + + // Assert termination channel connected and that a termination message is sent + suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationChannelConnectedChan()) + suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationRequestChan()) +} + // Execute the test suite func TestMessageBusTestSuite(t *testing.T) { suite.Run(t, new(MessageBusTestSuite)) @@ -113,7 +216,7 @@ func (suite *MessageBusTestSuite) TestProcessTerminationRequest_SuccessfulConnec suite.mockTerminateChannel.On("Recv").Return([]byte(requestString), nil) suite.mockTerminateChannel.On("Send", mock.Anything).Return(nil) - // Fourth call to isConnect succeeds, fourth call is for defer where it will call close + // Fourth call to IsChannelInitialized succeeds, fourth call is for defer where it will call close suite.mockTerminateChannel.On("IsChannelInitialized").Return(true).Once() suite.mockTerminateChannel.On("Close").Return(nil).Once()