Skip to content

Commit

Permalink
Updated agent worker to exit if termination channel not established
Browse files Browse the repository at this point in the history
  • Loading branch information
gianniLesl authored and VishnuKarthikRavindran committed Sep 18, 2023
1 parent 461af16 commit 655a0af
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 4 deletions.
42 changes: 40 additions & 2 deletions agent/ipc/messagebus/respondent.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import (
_ "go.nanomsg.org/mangos/v3/transport/ipc"
)

const (
recvErrSleepTime = 30 * time.Second
maxRecvErrCount = 5
)

// IMessageBus is the interface for process the core agent broadcast request
type IMessageBus interface {
ProcessHealthRequest()
Expand Down Expand Up @@ -93,12 +98,25 @@ func (bus *MessageBus) ProcessHealthRequest() {
}

log.Infof("Start to listen to Core Agent health channel")
errRecvCount := 0

for {
var request *message.Message
if msg, err = bus.healthChannel.Recv(); err != nil {
log.Errorf("Failed to receive from health channel: %s", err.Error())
errRecvCount++
log.Errorf("failed to receive from health channel: %s", err.Error())
if errRecvCount >= maxRecvErrCount {
// Agent core will still consider worker healthy as the ssm-agent-worker exists in the system process tree
log.Errorf("failed to receive from agent core health channel %v times. Stopping health ipc listener", errRecvCount)
return
}

log.Debugf("Retrying receive from core agent health channel in %v seconds", recvErrSleepTime.Seconds())
bus.sleepFunc(recvErrSleepTime)
continue
}

errRecvCount = 0
log.Debugf("Received health request from core agent %s", string(msg))

if err = json.Unmarshal(msg, &request); err != nil {
Expand Down Expand Up @@ -158,15 +176,33 @@ func (bus *MessageBus) ProcessTerminationRequest() {

log.Infof("Start to listen to Core Agent termination channel")
bus.terminationChannelConnected <- true
errRecvCount := 0

for {
var request *message.Message
if msg, err = bus.terminationChannel.Recv(); err != nil {
log.Errorf("cannot recv: %s", err.Error())
log.Errorf("cannot receive message from core agent termination channel: %s", err.Error())
errRecvCount++
if errRecvCount >= maxRecvErrCount {
// Consider communication channel to agent core to be broken
// When ssm-agent-worker exits agent-ssm-agent (agent core) creates new worker with updated communication channel
log.Errorf("failed to receive message from core agent termination channel %v times. Initiating worker shutdown", errRecvCount)
// Unblock main() function to allow agent worker to exit
bus.terminationRequestChannel <- true
close(bus.terminationRequestChannel)
// exit termination channel goroutine
break
}

log.Debugf("Retrying receive from core agent termination channel in %v seconds", recvErrSleepTime.Seconds())
bus.sleepFunc(recvErrSleepTime)
continue
}

log.Infof("Received termination message from core agent %s", string(msg))
errRecvCount = 0
if err = json.Unmarshal(msg, &request); err != nil {
// Incoming message is dropped and not retried
log.Errorf("failed to unmarshal message: %s", err.Error())
continue
}
Expand All @@ -180,10 +216,12 @@ func (bus *MessageBus) ProcessTerminationRequest() {
message.LongRunning,
os.Getpid(),
true); err != nil {
// Response message is dropped and not retried
log.Errorf("failed to create termination response: %s", err.Error())
}

if err = bus.terminationChannel.Send(result); err != nil {
// Response message is dropped and not retried
log.Errorf("failed to send termination response: %s", err.Error())
continue
}
Expand Down
107 changes: 105 additions & 2 deletions agent/ipc/messagebus/respondent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"github.com/aws/amazon-ssm-agent/agent/log"
contextmocks "github.com/aws/amazon-ssm-agent/agent/mocks/context"
logmocks "github.com/aws/amazon-ssm-agent/agent/mocks/log"
channel "github.com/aws/amazon-ssm-agent/common/channel"
"github.com/aws/amazon-ssm-agent/common/channel"
channelmocks "github.com/aws/amazon-ssm-agent/common/channel/mocks"
"github.com/aws/amazon-ssm-agent/common/message"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -66,6 +66,109 @@ func (suite *MessageBusTestSuite) SetupTest() {
}
}

func (suite *MessageBusTestSuite) TestProcessHealthRequest_Successful() {
// Arrange
suite.mockHealthChannel.On("IsChannelInitialized").Return(true).Once()
suite.mockHealthChannel.On("IsDialSuccessful").Return(true).Once()
suite.mockHealthChannel.On("Close").Return(nil).Once()
request := message.CreateHealthRequest()
requestString, _ := jsonutil.Marshal(request)
suite.mockHealthChannel.On("Recv").Return([]byte(requestString), nil).Once()
suite.mockHealthChannel.On("Send", mock.Anything).Return(nil)
// Kills the infinite loop
suite.mockHealthChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount)

// Act
suite.messageBus.ProcessHealthRequest()

// Assert
suite.mockHealthChannel.AssertExpectations(suite.T())
}

func (suite *MessageBusTestSuite) TestProcessHealthRequest_RecvError() {
// Arrange
suite.mockHealthChannel.On("IsChannelInitialized").Return(true).Once()
suite.mockHealthChannel.On("IsDialSuccessful").Return(true).Once()
suite.mockHealthChannel.On("Close").Return(nil).Once()
suite.mockHealthChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount)

// Act
suite.messageBus.ProcessHealthRequest()

// Assert
suite.mockHealthChannel.AssertExpectations(suite.T())
}

func (suite *MessageBusTestSuite) TestProcessHealthRequest_RecvErrorCount_Resets() {
// Arrange
suite.mockHealthChannel.On("IsChannelInitialized").Return(true).Once()
suite.mockHealthChannel.On("IsDialSuccessful").Return(true).Once()
suite.mockHealthChannel.On("Close").Return(nil).Once()
suite.mockHealthChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount - 1)
request := message.CreateHealthRequest()
requestString, _ := jsonutil.Marshal(request)
suite.mockHealthChannel.On("Recv").Return([]byte(requestString), nil).Once()
suite.mockHealthChannel.On("Send", mock.Anything).Return(nil)
// Kills the infinite loop
suite.mockHealthChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount)

// Act
suite.messageBus.ProcessHealthRequest()

// Assert
suite.mockHealthChannel.AssertExpectations(suite.T())
}

func (suite *MessageBusTestSuite) TestProcessTerminationRequest_Error() {
suite.mockTerminateChannel.On("IsDialSuccessful").Return(true).Once()
suite.mockTerminateChannel.On("IsChannelInitialized").Return(true).Once()
suite.mockTerminateChannel.On("Close").Return(nil).Once()
suite.mockTerminateChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount)

suite.messageBus.ProcessTerminationRequest()

suite.mockTerminateChannel.AssertExpectations(suite.T())

// Assert termination channel connected and that a termination message is sent
suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationChannelConnectedChan())
suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationRequestChan())
}

func (suite *MessageBusTestSuite) TestProcessTerminationRequest_RecvRetried() {
suite.mockTerminateChannel.On("IsDialSuccessful").Return(true).Once()
suite.mockTerminateChannel.On("IsChannelInitialized").Return(true).Once()
suite.mockTerminateChannel.On("Close").Return(nil).Once()
suite.mockTerminateChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount - 1)
request := message.CreateTerminateWorkerRequest()
requestString, _ := jsonutil.Marshal(request)
suite.mockTerminateChannel.On("Recv").Return([]byte(requestString), nil)
suite.mockTerminateChannel.On("Send", mock.Anything).Return(nil)
suite.messageBus.ProcessTerminationRequest()
suite.mockTerminateChannel.AssertExpectations(suite.T())

// Assert termination channel connected and that a termination message is sent
suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationChannelConnectedChan())
suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationRequestChan())
}

func (suite *MessageBusTestSuite) TestProcessTerminationRequest_RecvRetriesReset() {
suite.mockTerminateChannel.On("IsDialSuccessful").Return(true).Once()
suite.mockTerminateChannel.On("IsChannelInitialized").Return(true).Once()
suite.mockTerminateChannel.On("Close").Return(nil).Once()
suite.mockTerminateChannel.On("Recv").Return(nil, fmt.Errorf("failed to receive message on channel")).Times(maxRecvErrCount - 1)
suite.mockTerminateChannel.On("Recv").Return([]byte("not valid json message"), nil).Once()
request := message.CreateTerminateWorkerRequest()
requestString, _ := jsonutil.Marshal(request)
suite.mockTerminateChannel.On("Recv").Return([]byte(requestString), nil).Once()
suite.mockTerminateChannel.On("Send", mock.Anything).Return(nil)
suite.messageBus.ProcessTerminationRequest()
suite.mockTerminateChannel.AssertExpectations(suite.T())

// Assert termination channel connected and that a termination message is sent
suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationChannelConnectedChan())
suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationRequestChan())
}

// Execute the test suite
func TestMessageBusTestSuite(t *testing.T) {
suite.Run(t, new(MessageBusTestSuite))
Expand Down Expand Up @@ -113,7 +216,7 @@ func (suite *MessageBusTestSuite) TestProcessTerminationRequest_SuccessfulConnec
suite.mockTerminateChannel.On("Recv").Return([]byte(requestString), nil)
suite.mockTerminateChannel.On("Send", mock.Anything).Return(nil)

// Fourth call to isConnect succeeds, fourth call is for defer where it will call close
// Fourth call to IsChannelInitialized succeeds, fourth call is for defer where it will call close
suite.mockTerminateChannel.On("IsChannelInitialized").Return(true).Once()
suite.mockTerminateChannel.On("Close").Return(nil).Once()

Expand Down

0 comments on commit 655a0af

Please sign in to comment.