Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

snssqs: fix consumer starvation #3478

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pubsub/aws/snssqs/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ type snsSqsMetadata struct {
AccountID string `mapstructure:"accountID"`
// processing concurrency mode
ConcurrencyMode pubsub.ConcurrencyMode `mapstructure:"concurrencyMode"`
// limits the number of concurrent goroutines
ConcurrencyLimit int `mapstructure:"concurrencyLimit"`
}

func maskLeft(s string) string {
Expand Down Expand Up @@ -130,6 +132,10 @@ func (s *snsSqs) getSnsSqsMetatdata(meta pubsub.Metadata) (*snsSqsMetadata, erro
return nil, err
}

if md.ConcurrencyLimit < 0 {
return nil, errors.New("concurrencyLimit must be greater than or equal to 0")
}

s.logger.Debug(md.hideDebugPrintedCredentials())

return md, nil
Expand Down
9 changes: 9 additions & 0 deletions pubsub/aws/snssqs/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ metadata:
default: '"parallel"'
example: '"single", "parallel"'
type: string
- name: concurrencyLimit
required: false
description: |
Defines the maximum number of concurrent workers handling messages.
This value is ignored when "concurrencyMode" is set to “single“.
To avoid limiting the number of concurrent workers set this to “0“.
type: number
default: '0'
example: '100'
- name: accountId
required: false
description: |
Expand Down
23 changes: 17 additions & 6 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,13 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
WaitTimeSeconds: aws.Int64(s.metadata.MessageWaitTimeSeconds),
}

// sem is a semaphore used to control the concurrencyLimit.
// It is set only when we are in parallel mode and limit is > 0.
var sem chan (struct{}) = nil
if (s.metadata.ConcurrencyMode == pubsub.Parallel) && s.metadata.ConcurrencyLimit > 0 {
sem = make(chan struct{}, s.metadata.ConcurrencyLimit)
}

for {
// If the context is canceled, stop requesting messages
if ctx.Err() != nil {
Expand Down Expand Up @@ -623,33 +630,37 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters
}
s.logger.Debugf("%v message(s) received on queue %s", len(messageResponse.Messages), queueInfo.arn)

var wg sync.WaitGroup
for _, message := range messageResponse.Messages {
if err := s.validateMessage(ctx, message, queueInfo, deadLettersQueueInfo); err != nil {
s.logger.Errorf("message is not valid for further processing by the handler. error is: %v", err)
continue
}

f := func(message *sqs.Message) {
defer wg.Done()
if err := s.callHandler(ctx, message, queueInfo); err != nil {
s.logger.Errorf("error while handling received message. error is: %v", err)
}
}

wg.Add(1)
switch s.metadata.ConcurrencyMode {
case pubsub.Single:
f(message)
case pubsub.Parallel:
wg.Add(1)
// This is the back pressure mechanism.
// It will block until another goroutine frees a slot.
if sem != nil {
sem <- struct{}{}
}

go func(message *sqs.Message) {
defer wg.Done()
if sem != nil {
defer func() { <-sem }()
}

f(message)
}(message)
}
}
wg.Wait()
}
}

Expand Down
17 changes: 17 additions & 0 deletions pubsub/aws/snssqs/snssqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) {
"consumerID": "consumer",
"Endpoint": "endpoint",
"concurrencyMode": string(pubsub.Single),
"concurrencyLimit": "42",
"accessKey": "a",
"secretKey": "s",
"sessionToken": "t",
Expand All @@ -68,6 +69,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) {
r.Equal("consumer", md.SqsQueueName)
r.Equal("endpoint", md.Endpoint)
r.Equal(pubsub.Single, md.ConcurrencyMode)
r.Equal(42, md.ConcurrencyLimit)
r.Equal("a", md.AccessKey)
r.Equal("s", md.SecretKey)
r.Equal("t", md.SessionToken)
Expand Down Expand Up @@ -105,6 +107,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) {
r.Equal("", md.SessionToken)
r.Equal("r", md.Region)
r.Equal(pubsub.Parallel, md.ConcurrencyMode)
r.Equal(0, md.ConcurrencyLimit)
r.Equal(int64(10), md.MessageVisibilityTimeout)
r.Equal(int64(10), md.MessageRetryLimit)
r.Equal(int64(2), md.MessageWaitTimeSeconds)
Expand Down Expand Up @@ -273,6 +276,20 @@ func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) {
}}},
name: "invalid message concurrencyMode",
},
// invalid concurrencyLimit
{
metadata: pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{
"consumerID": "consumer",
"Endpoint": "endpoint",
"AccessKey": "acctId",
"SecretKey": "secret",
"awsToken": "token",
"Region": "region",
"messageRetryLimit": "10",
"concurrencyLimit": "-1",
}}},
name: "invalid message concurrencyLimit",
},
}

l := logger.NewLogger("SnsSqs unit test")
Expand Down
Loading