diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go index e23d1244ca..20ed547006 100644 --- a/common/authentication/aws/client_test.go +++ b/common/authentication/aws/client_test.go @@ -23,7 +23,7 @@ type mockedSQS struct { GetQueueURLFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) } -func (m *mockedSQS) GetQueueURLWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { +func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { //nolint:stylecheck return m.GetQueueURLFn(ctx, input) } diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index a9fcc17f05..495edb522d 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -458,6 +458,9 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er config = a.cfg.WithRegion(*a.region).WithHTTPClient(client).WithLogLevel(aws.LogOff) } + // this is needed for testing purposes to mock the client, + // so code never sets the client, but tests do. + var rolesClient *rolesanywhere.RolesAnywhere if a.rolesAnywhereClient == nil { mySession = session.Must(session.NewSession(config)) rolesAnywhereClient := rolesanywhere.New(mySession, config) @@ -465,7 +468,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er if err := a.setSigningFunction(rolesAnywhereClient); err != nil { return nil, err } - a.rolesAnywhereClient = rolesAnywhereClient + rolesClient = rolesAnywhereClient } var ( @@ -498,7 +501,7 @@ func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, er } } - output, err := a.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + output, err := rolesClient.CreateSessionWithContext(ctx, &createSessionRequest) if err != nil { return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) }