diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go new file mode 100644 index 000000000..5281151eb --- /dev/null +++ b/transport/awssqs/consumer.go @@ -0,0 +1,229 @@ +package awssqs + +import ( + "context" + "encoding/json" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/transport" +) + +// Consumer wraps an endpoint and provides a handler for SQS messages. +type Consumer struct { + sqsClient sqsiface.SQSAPI + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + wantRep WantReplyFunc + queueURL string + before []ConsumerRequestFunc + after []ConsumerResponseFunc + errorEncoder ErrorEncoder + finalizer []ConsumerFinalizerFunc + errorHandler transport.ErrorHandler +} + +// NewConsumer constructs a new Consumer, which provides a Consume method +// and message handlers that wrap the provided endpoint. +func NewConsumer( + sqsClient sqsiface.SQSAPI, + e endpoint.Endpoint, + dec DecodeRequestFunc, + enc EncodeResponseFunc, + queueURL string, + options ...ConsumerOption, +) *Consumer { + s := &Consumer{ + sqsClient: sqsClient, + e: e, + dec: dec, + enc: enc, + wantRep: DoNotRespond, + queueURL: queueURL, + errorEncoder: DefaultErrorEncoder, + errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), + } + for _, option := range options { + option(s) + } + return s +} + +// ConsumerOption sets an optional parameter for consumers. +type ConsumerOption func(*Consumer) + +// ConsumerBefore functions are executed on the producer request object before the +// request is decoded. +func ConsumerBefore(before ...ConsumerRequestFunc) ConsumerOption { + return func(c *Consumer) { c.before = append(c.before, before...) } +} + +// ConsumerAfter functions are executed on the consumer reply after the +// endpoint is invoked, but before anything is published to the reply. +func ConsumerAfter(after ...ConsumerResponseFunc) ConsumerOption { + return func(c *Consumer) { c.after = append(c.after, after...) } +} + +// ConsumerErrorEncoder is used to encode errors to the consumer reply +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting. By default, +// errors will be published with the DefaultErrorEncoder. +func ConsumerErrorEncoder(ee ErrorEncoder) ConsumerOption { + return func(c *Consumer) { c.errorEncoder = ee } +} + +// ConsumerWantReplyFunc overrides the default value for the consumer's +// wantRep field. +func ConsumerWantReplyFunc(replyFunc WantReplyFunc) ConsumerOption { + return func(c *Consumer) { c.wantRep = replyFunc } +} + +// ConsumerErrorHandler is used to handle non-terminal errors. By default, non-terminal errors +// are ignored. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ConsumerErrorEncoder which has access to the context. +func ConsumerErrorHandler(errorHandler transport.ErrorHandler) ConsumerOption { + return func(c *Consumer) { c.errorHandler = errorHandler } +} + +// ConsumerFinalizer is executed once all the received SQS messages are done being processed. +// By default, no finalizer is registered. +func ConsumerFinalizer(f ...ConsumerFinalizerFunc) ConsumerOption { + return func(c *Consumer) { c.finalizer = f } +} + +// ConsumerDeleteMessageBefore returns a ConsumerOption that appends a function +// that delete the message from queue to the list of consumer's before functions. +func ConsumerDeleteMessageBefore() ConsumerOption { + return func(c *Consumer) { + deleteBefore := func(ctx context.Context, cancel context.CancelFunc, msg *sqs.Message) context.Context { + if err := deleteMessage(ctx, c.sqsClient, c.queueURL, msg); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + cancel() + } + return ctx + } + c.before = append(c.before, deleteBefore) + } +} + +// ConsumerDeleteMessageAfter returns a ConsumerOption that appends a function +// that delete a message from queue to the list of consumer's after functions. +func ConsumerDeleteMessageAfter() ConsumerOption { + return func(c *Consumer) { + deleteAfter := func(ctx context.Context, cancel context.CancelFunc, msg *sqs.Message, _ *sqs.SendMessageInput) context.Context { + if err := deleteMessage(ctx, c.sqsClient, c.queueURL, msg); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + cancel() + } + return ctx + } + c.after = append(c.after, deleteAfter) + } +} + +// ServeMessage serves an SQS message. +func (c Consumer) ServeMessage(ctx context.Context) func(msg *sqs.Message) error { + return func(msg *sqs.Message) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + if len(c.finalizer) > 0 { + defer func() { + for _, f := range c.finalizer { + f(ctx, msg) + } + }() + } + + for _, f := range c.before { + ctx = f(ctx, cancel, msg) + } + + req, err := c.dec(ctx, msg) + if err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + + response, err := c.e(ctx, req) + if err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + + responseMsg := sqs.SendMessageInput{} + for _, f := range c.after { + ctx = f(ctx, cancel, msg, &responseMsg) + } + + if !c.wantRep(ctx, msg) { + return nil + } + + if err := c.enc(ctx, &responseMsg, response); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + + if _, err := c.sqsClient.SendMessageWithContext(ctx, &responseMsg); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + return nil + } +} + +// ErrorEncoder is responsible for encoding an error to the consumer's reply. +// Users are encouraged to use custom ErrorEncoders to encode errors to +// their replies, and will likely want to pass and check for their own error +// types. +type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) + +// ConsumerFinalizerFunc can be used to perform work at the end of a request +// from a producer, after the response has been written to the producer. The +// principal intended use is for request logging. +// Can also be used to delete messages once fully proccessed. +type ConsumerFinalizerFunc func(ctx context.Context, msg *sqs.Message) + +// WantReplyFunc encapsulates logic to check whether message awaits response or not +// for example check for a given message attribute value. +type WantReplyFunc func(context.Context, *sqs.Message) bool + +// DefaultErrorEncoder simply ignores the message. It does not reply. +func DefaultErrorEncoder(context.Context, error, *sqs.Message, sqsiface.SQSAPI) { +} + +func deleteMessage(ctx context.Context, sqsClient sqsiface.SQSAPI, queueURL string, msg *sqs.Message) error { + _, err := sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + QueueUrl: &queueURL, + ReceiptHandle: msg.ReceiptHandle, + }) + return err +} + +// DoNotRespond is a WantReplyFunc and is the default value for consumer's wantRep field. +// It indicates that the message do not expect a response. +func DoNotRespond(context.Context, *sqs.Message) bool { + return false +} + +// EncodeJSONResponse marshals response as json and loads it into an sqs.SendMessageInput MessageBody. +func EncodeJSONResponse(_ context.Context, input *sqs.SendMessageInput, response interface{}) error { + payload, err := json.Marshal(response) + if err != nil { + return err + } + input.MessageBody = aws.String(string(payload)) + return nil +} diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go new file mode 100644 index 000000000..939eb7286 --- /dev/null +++ b/transport/awssqs/consumer_test.go @@ -0,0 +1,510 @@ +package awssqs_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/go-kit/kit/transport/awssqs" + "github.com/pborman/uuid" +) + +var ( + errTypeAssertion = errors.New("type assertion error") +) + +func (mock *mockClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) { + // Add logic to allow context errors. + for { + select { + case d := <-mock.receiveOuputChan: + return d, mock.err + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (mock *mockClient) DeleteMessageWithContext(ctx context.Context, input *sqs.DeleteMessageInput, opts ...request.Option) (*sqs.DeleteMessageOutput, error) { + return nil, mock.deleteError +} + +// TestConsumerDeleteBefore checks if deleteMessage is set properly using consumer options. +func TestConsumerDeleteBefore(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + deleteError: fmt.Errorf("delete err!"), + } + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, _ := json.Marshal(publishError) + + sqsClient.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + }) + consumer := awssqs.NewConsumer(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, nil }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + queueURL, + errEncoder, + awssqs.ConsumerDeleteMessageBefore(), + ) + + consumer.ServeMessage(context.Background())(&sqs.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeConsumerError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "delete err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestConsumerBadDecode checks if decoder errors are handled properly. +func TestConsumerBadDecode(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, _ := json.Marshal(publishError) + + sqsClient.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + }) + consumer := awssqs.NewConsumer(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, errors.New("err!") }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + queueURL, + errEncoder, + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), + ) + + consumer.ServeMessage(context.Background())(&sqs.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeConsumerError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestConsumerBadEndpoint checks if endpoint errors are handled properly. +func TestConsumerBadEndpoint(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, _ := json.Marshal(publishError) + + sqsClient.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + }) + consumer := awssqs.NewConsumer(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("err!") }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, nil }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + queueURL, + errEncoder, + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), + ) + + consumer.ServeMessage(context.Background())(&sqs.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeConsumerError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestConsumerBadEncoder checks if encoder errors are handled properly. +func TestConsumerBadEncoder(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, _ := json.Marshal(publishError) + + sqsClient.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + }) + consumer := awssqs.NewConsumer(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, nil }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return errors.New("err!") }, + queueURL, + errEncoder, + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), + ) + + consumer.ServeMessage(context.Background())(&sqs.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeConsumerError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestConsumerSuccess checks if consumer responds correctly to message. +func TestConsumerSuccess(t *testing.T) { + obj := testReq{ + Squadron: 436, + } + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + consumer := awssqs.NewConsumer(mock, + testEndpoint, + testReqDecoderfunc, + awssqs.EncodeJSONResponse, + queueURL, + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), + ) + + consumer.ServeMessage(context.Background())(&sqs.Message{ + Body: aws.String(string(b)), + MessageId: aws.String("fakeMsgID"), + }) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeResponse(receiveOutput) + if err != nil { + t.Fatal(err) + } + want := testRes{ + Squadron: 436, + Name: "tusker", + } + if have := res; want != have { + t.Errorf("want %v, have %v", want, have) + } +} + +// TestConsumerSuccessNoReply checks if consumer processes correctly message +// without sending response. +func TestConsumerSuccessNoReply(t *testing.T) { + obj := testReq{ + Squadron: 436, + } + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + consumer := awssqs.NewConsumer(mock, + testEndpoint, + testReqDecoderfunc, + awssqs.EncodeJSONResponse, + queueURL, + ) + + consumer.ServeMessage(context.Background())(&sqs.Message{ + Body: aws.String(string(b)), + MessageId: aws.String("fakeMsgID"), + }) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + t.Errorf("received output when none was expected, have %v", receiveOutput) + return + + case <-time.After(200 * time.Millisecond): + // As expected, we did not receive any response from consumer. + return + } +} + +// TestConsumerBeforeFilterMessages checks if consumer before is called as expected. +// Here before is used to add a value in context. +func TestConsumerBeforeAddValueToContext(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + msg := &sqs.Message{ + Body: aws.String("someBody"), + MessageId: aws.String("fakeMsgID1"), + MessageAttributes: map[string]*sqs.MessageAttributeValue{ + "recipient": { + DataType: aws.String("String"), + StringValue: aws.String("me"), + }, + }, + } + type ctxKey struct { + key string + } + consumer := awssqs.NewConsumer(mock, + // endpoint + func(ctx context.Context, request interface{}) (interface{}, error) { + return ctx.Value(ctxKey{"recipient"}).(string), nil + }, + // request decoder + func(_ context.Context, msg *sqs.Message) (interface{}, error) { + return *msg.Body, nil + }, + // response encoder + func(_ context.Context, input *sqs.SendMessageInput, response interface{}) error { + input.MessageBody = aws.String(fmt.Sprintf("%v", response)) + return nil + }, + queueURL, + awssqs.ConsumerBefore(func(ctx context.Context, cancel context.CancelFunc, msg *sqs.Message) context.Context { + // Filter a message that is not destined to the consumer. + if recipient, exists := msg.MessageAttributes["recipient"]; exists { + ctx = context.WithValue(ctx, ctxKey{"recipient"}, *recipient.StringValue) + } + return ctx + }), + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), + ) + ctx := context.Background() + err := consumer.ServeMessage(ctx)(msg) + if err != nil { + t.Errorf("got err %s", err) + } + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + if len(receiveOutput.Messages) != 1 { + t.Errorf("Error : received %d messages instead of 1", len(receiveOutput.Messages)) + } + res := *receiveOutput.Messages[0].Body + want := "me" + if have := res; want != have { + t.Errorf("want %v, have %v", want, have) + } + // Try fetching responses again. + select { + case receiveOutput = <-mock.receiveOuputChan: + t.Errorf("received second output when only one was expected, have %v", receiveOutput) + return + + case <-time.After(200 * time.Millisecond): + // As expected, we did not receive a second response from consumer. + return + } +} + +// TestConsumerAfter checks if consumer after is called as expected. +// Here after is used to transfer some info from received message in response. +func TestConsumerAfter(t *testing.T) { + obj1 := testReq{ + Squadron: 436, + } + b1, _ := json.Marshal(obj1) + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + correlationID := uuid.NewRandom().String() + msg := &sqs.Message{ + Body: aws.String(string(b1)), + MessageId: aws.String("fakeMsgID1"), + MessageAttributes: map[string]*sqs.MessageAttributeValue{ + "correlationID": { + DataType: aws.String("String"), + StringValue: &correlationID, + }, + }, + } + type ctxKey struct { + key string + } + consumer := awssqs.NewConsumer(mock, + testEndpoint, + testReqDecoderfunc, + awssqs.EncodeJSONResponse, + queueURL, + awssqs.ConsumerAfter(func(ctx context.Context, cancel context.CancelFunc, msg *sqs.Message, resp *sqs.SendMessageInput) context.Context { + if correlationIDAttribute, exists := msg.MessageAttributes["correlationID"]; exists { + if resp.MessageAttributes == nil { + resp.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) + } + resp.MessageAttributes["correlationID"] = &sqs.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: correlationIDAttribute.StringValue, + } + } + return ctx + }), + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), + ) + ctx := context.Background() + consumer.ServeMessage(ctx)(msg) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + if len(receiveOutput.Messages) != 1 { + t.Errorf("received %d messages instead of 1", len(receiveOutput.Messages)) + } + if correlationIDAttribute, exists := receiveOutput.Messages[0].MessageAttributes["correlationID"]; exists { + if have := correlationIDAttribute.StringValue; *have != correlationID { + t.Errorf("have %s, want %s", *have, correlationID) + } + } else { + t.Errorf("expected message attribute with key correlationID in response, but it was not found") + } +} + +type sqsError struct { + Err string `json:"err"` + MsgID string `json:"msgID"` +} + +func decodeConsumerError(receiveOutput *sqs.ReceiveMessageOutput) (sqsError, error) { + receivedError := sqsError{} + err := json.Unmarshal([]byte(*receiveOutput.Messages[0].Body), &receivedError) + return receivedError, err +} + +func testEndpoint(ctx context.Context, request interface{}) (interface{}, error) { + req, ok := request.(testReq) + if !ok { + return nil, errTypeAssertion + } + name, prs := names[req.Squadron] + if !prs { + return nil, errors.New("unknown squadron name") + } + res := testRes{ + Squadron: req.Squadron, + Name: name, + } + return res, nil +} + +func testReqDecoderfunc(_ context.Context, msg *sqs.Message) (interface{}, error) { + var obj testReq + err := json.Unmarshal([]byte(*msg.Body), &obj) + return obj, err +} + +func decodeResponse(receiveOutput *sqs.ReceiveMessageOutput) (interface{}, error) { + if len(receiveOutput.Messages) != 1 { + return nil, fmt.Errorf("Error : received %d messages instead of 1", len(receiveOutput.Messages)) + } + resp := testRes{} + err := json.Unmarshal([]byte(*receiveOutput.Messages[0].Body), &resp) + return resp, err +} diff --git a/transport/awssqs/doc.go b/transport/awssqs/doc.go new file mode 100644 index 000000000..779c77d69 --- /dev/null +++ b/transport/awssqs/doc.go @@ -0,0 +1,2 @@ +// Package awssqs implements an AWS Simple Queue Service transport. +package awssqs diff --git a/transport/awssqs/encode_decode.go b/transport/awssqs/encode_decode.go new file mode 100644 index 000000000..0f6b0f3de --- /dev/null +++ b/transport/awssqs/encode_decode.go @@ -0,0 +1,23 @@ +package awssqs + +import ( + "context" + + "github.com/aws/aws-sdk-go/service/sqs" +) + +// DecodeRequestFunc extracts a user-domain request object from +// an SQS message object. It is designed to be used in Consumers. +type DecodeRequestFunc func(context.Context, *sqs.Message) (request interface{}, err error) + +// EncodeRequestFunc encodes the passed payload object into +// an SQS message object. It is designed to be used in Producers. +type EncodeRequestFunc func(context.Context, *sqs.SendMessageInput, interface{}) error + +// EncodeResponseFunc encodes the passed response object to +// an SQS message object. It is designed to be used in Consumers. +type EncodeResponseFunc func(context.Context, *sqs.SendMessageInput, interface{}) error + +// DecodeResponseFunc extracts a user-domain response object from +// an SQS message object. It is designed to be used in Producers. +type DecodeResponseFunc func(context.Context, *sqs.Message) (response interface{}, err error) diff --git a/transport/awssqs/producer.go b/transport/awssqs/producer.go new file mode 100644 index 000000000..8192d8689 --- /dev/null +++ b/transport/awssqs/producer.go @@ -0,0 +1,140 @@ +package awssqs + +import ( + "context" + "encoding/json" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/go-kit/kit/endpoint" +) + +type contextKey int + +const ( + // ContextKeyResponseQueueURL is the context key that allows fetching + // and setting the response queue URL from and into context. + ContextKeyResponseQueueURL contextKey = iota +) + +// Producer wraps an SQS client and queue, and provides a method that +// implements endpoint.Endpoint. +type Producer struct { + sqsClient sqsiface.SQSAPI + queueURL string + enc EncodeRequestFunc + dec DecodeResponseFunc + before []ProducerRequestFunc + after []ProducerResponseFunc + timeout time.Duration +} + +// NewProducer constructs a usable Producer for a single remote method. +func NewProducer( + sqsClient sqsiface.SQSAPI, + queueURL string, + enc EncodeRequestFunc, + dec DecodeResponseFunc, + options ...ProducerOption, +) *Producer { + p := &Producer{ + sqsClient: sqsClient, + queueURL: queueURL, + enc: enc, + dec: dec, + timeout: 20 * time.Second, + } + for _, option := range options { + option(p) + } + return p +} + +// ProducerOption sets an optional parameter for clients. +type ProducerOption func(*Producer) + +// ProducerBefore sets the RequestFuncs that are applied to the outgoing SQS +// request before it's invoked. +func ProducerBefore(before ...ProducerRequestFunc) ProducerOption { + return func(p *Producer) { p.before = append(p.before, before...) } +} + +// ProducerAfter sets the ClientResponseFuncs applied to the incoming SQS +// request prior to it being decoded. This is useful for obtaining the response +// and adding any information onto the context prior to decoding. +func ProducerAfter(after ...ProducerResponseFunc) ProducerOption { + return func(p *Producer) { p.after = append(p.after, after...) } +} + +// ProducerTimeout sets the available timeout for an SQS request. +func ProducerTimeout(timeout time.Duration) ProducerOption { + return func(p *Producer) { p.timeout = timeout } +} + +// SetProducerResponseQueueURL can be used as a before function to add +// provided url as responseQueueURL in context. +func SetProducerResponseQueueURL(url string) ProducerRequestFunc { + return func(ctx context.Context, _ *sqs.SendMessageInput) context.Context { + return context.WithValue(ctx, ContextKeyResponseQueueURL, url) + } +} + +// Endpoint returns a usable endpoint that invokes the remote endpoint. +func (p Producer) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + ctx, cancel := context.WithTimeout(ctx, p.timeout) + defer cancel() + msgInput := sqs.SendMessageInput{ + QueueUrl: &p.queueURL, + } + if err := p.enc(ctx, &msgInput, request); err != nil { + return nil, err + } + + for _, f := range p.before { + ctx = f(ctx, &msgInput) + } + + output, err := p.sqsClient.SendMessageWithContext(ctx, &msgInput) + if err != nil { + return nil, err + } + + var responseMsg *sqs.Message + for _, f := range p.after { + ctx, responseMsg, err = f(ctx, p.sqsClient, output) + if err != nil { + return nil, err + } + } + + response, err := p.dec(ctx, responseMsg) + if err != nil { + return nil, err + } + + return response, nil + } +} + +// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a +// JSON object and loads it as the MessageBody of the sqs.SendMessageInput. +// This can be enough for most JSON over SQS communications. +func EncodeJSONRequest(_ context.Context, msg *sqs.SendMessageInput, request interface{}) error { + b, err := json.Marshal(request) + if err != nil { + return err + } + + msg.MessageBody = aws.String(string(b)) + + return nil +} + +// NoResponseDecode is a DecodeResponseFunc that can be used when no response is needed. +// It returns nil value and nil error. +func NoResponseDecode(_ context.Context, _ *sqs.Message) (interface{}, error) { + return nil, nil +} diff --git a/transport/awssqs/producer_test.go b/transport/awssqs/producer_test.go new file mode 100644 index 000000000..e05f5655d --- /dev/null +++ b/transport/awssqs/producer_test.go @@ -0,0 +1,372 @@ +package awssqs_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/go-kit/kit/transport/awssqs" +) + +type testReq struct { + Squadron int `json:"s"` +} + +type testRes struct { + Squadron int `json:"s"` + Name string `json:"n"` +} + +var names = map[int]string{ + 424: "tiger", + 426: "thunderbird", + 429: "bison", + 436: "tusker", + 437: "husky", +} + +// mockClient is a mock of *sqs.SQS. +type mockClient struct { + sqsiface.SQSAPI + err error + sendOutputChan chan *sqs.SendMessageOutput + receiveOuputChan chan *sqs.ReceiveMessageOutput + sendMsgID string + deleteError error +} + +func (mock *mockClient) SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) { + if input != nil && input.MessageBody != nil && *input.MessageBody != "" { + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + MessageAttributes: input.MessageAttributes, + Body: input.MessageBody, + MessageId: aws.String(mock.sendMsgID), + }, + }, + } + }() + return &sqs.SendMessageOutput{MessageId: aws.String(mock.sendMsgID)}, nil + } + // Add logic to allow context errors. + for { + select { + case d := <-mock.sendOutputChan: + return d, mock.err + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (mock *mockClient) ChangeMessageVisibilityWithContext(ctx aws.Context, input *sqs.ChangeMessageVisibilityInput, opts ...request.Option) (*sqs.ChangeMessageVisibilityOutput, error) { + return nil, nil +} + +// TestBadEncode tests if encode errors are handled properly. +func TestBadEncode(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + } + pub := awssqs.NewProducer( + mock, + queueURL, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return errors.New("err!") }, + func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, nil }, + ) + errChan := make(chan error, 1) + var err error + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + + }() + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + if err == nil { + t.Error("expected error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestBadDecode tests if decode errors are handled properly. +func TestBadDecode(t *testing.T) { + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + } + go func() { + mock.sendOutputChan <- &sqs.SendMessageOutput{ + MessageId: aws.String("someMsgID"), + } + }() + + queueURL := "someURL" + pub := awssqs.NewProducer( + mock, + queueURL, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + func(context.Context, *sqs.Message) (response interface{}, err error) { + return struct{}{}, errors.New("err!") + }, + awssqs.ProducerAfter(func(ctx context.Context, _ sqsiface.SQSAPI, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + // Set the actual response for the request. + return ctx, &sqs.Message{Body: aws.String("someMsgContent")}, nil + }), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestProducerTimeout ensures that the producer timeout mechanism works. +func TestProducerTimeout(t *testing.T) { + sendOutputChan := make(chan *sqs.SendMessageOutput) + mock := &mockClient{ + sendOutputChan: sendOutputChan, + } + queueURL := "someURL" + pub := awssqs.NewProducer( + mock, + queueURL, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + func(context.Context, *sqs.Message) (response interface{}, err error) { + return struct{}{}, nil + }, + awssqs.ProducerTimeout(50*time.Millisecond), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + + }() + + select { + case err = <-errChan: + break + + case <-time.After(1000 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + return + } + if want, have := context.DeadlineExceeded.Error(), err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSuccessfulProducer ensures that the producer mechanisms work. +func TestSuccessfulProducer(t *testing.T) { + mockReq := testReq{437} + mockRes := testRes{ + Squadron: mockReq.Squadron, + Name: names[mockReq.Squadron], + } + b, err := json.Marshal(mockRes) + if err != nil { + t.Fatal(err) + } + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + sendMsgID: "someMsgID", + } + go func() { + mock.sendOutputChan <- &sqs.SendMessageOutput{ + MessageId: aws.String("someMsgID"), + } + }() + + queueURL := "someURL" + pub := awssqs.NewProducer( + mock, + queueURL, + awssqs.EncodeJSONRequest, + func(_ context.Context, msg *sqs.Message) (interface{}, error) { + response := testRes{} + err := json.Unmarshal([]byte(*msg.Body), &response) + return response, err + }, + awssqs.ProducerAfter(func(ctx context.Context, _ sqsiface.SQSAPI, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + // Sets the actual response for the request. + if *msg.MessageId == "someMsgID" { + return ctx, &sqs.Message{Body: aws.String(string(b))}, nil + } + return nil, nil, fmt.Errorf("Did not receive expected SendMessageOutput") + }), + ) + var res testRes + var ok bool + resChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + go func() { + res, err := pub.Endpoint()(context.Background(), mockReq) + if err != nil { + errChan <- err + } else { + resChan <- res + } + }() + + select { + case response := <-resChan: + res, ok = response.(testRes) + if !ok { + t.Error("failed to assert endpoint response type") + } + break + + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err != nil { + t.Fatal(err) + } + if want, have := mockRes.Name, res.Name; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestSuccessfulProducerNoResponse ensures that the producer response mechanism works. +func TestSuccessfulProducerNoResponse(t *testing.T) { + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + sendMsgID: "someMsgID", + } + + queueURL := "someURL" + pub := awssqs.NewProducer( + mock, + queueURL, + awssqs.EncodeJSONRequest, + awssqs.NoResponseDecode, + ) + var err error + errChan := make(chan error, 1) + finishChan := make(chan bool, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + if err != nil { + errChan <- err + } else { + finishChan <- true + } + }() + + select { + case <-finishChan: + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } +} + +// TestProducerWithBefore adds a ProducerBefore function that adds responseQueueURL to context, +// and another on that adds it as a message attribute to outgoing message. +// This test ensures that setting multiple before functions work as expected +// and that SetProducerResponseQueueURL works as expected. +func TestProducerWithBefore(t *testing.T) { + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + sendMsgID: "someMsgID", + } + + queueURL := "someURL" + responseQueueURL := "someOtherURL" + pub := awssqs.NewProducer( + mock, + queueURL, + awssqs.EncodeJSONRequest, + awssqs.NoResponseDecode, + awssqs.ProducerBefore(awssqs.SetProducerResponseQueueURL(responseQueueURL)), + awssqs.ProducerBefore(func(c context.Context, s *sqs.SendMessageInput) context.Context { + responseQueueURL := c.Value(awssqs.ContextKeyResponseQueueURL).(string) + if s.MessageAttributes == nil { + s.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) + } + s.MessageAttributes["responseQueueURL"] = &sqs.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: &responseQueueURL, + } + return c + }), + ) + var err error + errChan := make(chan error, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + if err != nil { + errChan <- err + } + }() + + want := sqs.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: &responseQueueURL, + } + + select { + case receiveOutput := <-mock.receiveOuputChan: + if len(receiveOutput.Messages) != 1 { + t.Errorf("published %d messages instead of 1", len(receiveOutput.Messages)) + } + if have, exists := receiveOutput.Messages[0].MessageAttributes["responseQueueURL"]; !exists { + t.Errorf("expected MessageAttributes responseQueueURL not found") + } else if *have.StringValue != responseQueueURL || *have.DataType != "String" { + t.Errorf("want %s, have %s", want, *have) + } + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } +} diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go new file mode 100644 index 000000000..478f61f62 --- /dev/null +++ b/transport/awssqs/request_response_func.go @@ -0,0 +1,33 @@ +package awssqs + +import ( + "context" + + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" +) + +// ConsumerRequestFunc may take information from a consumer request result and +// put it into a request context. In Consumers, RequestFuncs are executed prior +// to invoking the endpoint. +// use cases eg. in Consumer : extract message information into context. +type ConsumerRequestFunc func(ctx context.Context, cancel context.CancelFunc, req *sqs.Message) context.Context + +// ProducerRequestFunc may take information from a producer request and put it into a +// request context, or add some informations to SendMessageInput. In Producers, +// RequestFuncs are executed prior to publishing the message but after encoding. +// use cases eg. in Producer : enforce some message attributes to SendMessageInput. +type ProducerRequestFunc func(ctx context.Context, input *sqs.SendMessageInput) context.Context + +// ConsumerResponseFunc may take information from a request context and use it to +// manipulate a Producer. ConsumerResponseFunc are only executed in +// consumers, after invoking the endpoint but prior to publishing a reply. +// use cases eg. : Pipe information from request message to response MessageInput, +// delete msg from queue or update leftMsgs slice. +type ConsumerResponseFunc func(ctx context.Context, cancel context.CancelFunc, req *sqs.Message, resp *sqs.SendMessageInput) context.Context + +// ProducerResponseFunc may take information from an sqs.SendMessageOutput and +// fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. +// ProducerResponseFunc are only executed in producers, after a request has been made, +// but prior to its response being decoded. So this is the perfect place to fetch actual response. +type ProducerResponseFunc func(context.Context, sqsiface.SQSAPI, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error)