diff --git a/cmd/aws-iam-authenticator/verify.go b/cmd/aws-iam-authenticator/verify.go index 9bb37f4f4..839b19ef8 100644 --- a/cmd/aws-iam-authenticator/verify.go +++ b/cmd/aws-iam-authenticator/verify.go @@ -25,9 +25,9 @@ import ( "sigs.k8s.io/aws-iam-authenticator/pkg/token" - "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -54,14 +54,18 @@ var verifyCmd = &cobra.Command{ os.Exit(1) } - sess := session.Must(session.NewSession()) - ec2metadata := ec2metadata.New(sess) - instanceRegion, err := ec2metadata.Region() + cfg, err := config.LoadDefaultConfig(cmd.Context()) + if err != nil { + fmt.Printf("Error constructing aws config: %v", err) + os.Exit(1) + } + client := imds.NewFromConfig(cfg) + resp, err := client.GetRegion(cmd.Context(), nil) if err != nil { fmt.Printf("[Warn] Region not found in instance metadata, err: %v", err) } - id, err := token.NewVerifier(clusterID, partition, instanceRegion).Verify(tok) + id, err := token.NewVerifier(clusterID, partition, resp.Region).Verify(tok) if err != nil { fmt.Fprintf(os.Stderr, "could not verify token: %v\n", err) os.Exit(1) diff --git a/go.mod b/go.mod index eb704fd2c..8fbd7f9a3 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,20 @@ go 1.22.5 require ( github.com/aws/aws-sdk-go v1.54.6 + github.com/aws/aws-sdk-go-v2 v1.30.4 + github.com/aws/aws-sdk-go-v2/config v1.27.30 + github.com/aws/aws-sdk-go-v2/credentials v1.17.29 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0 + github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 + github.com/aws/smithy-go v1.20.4 github.com/fsnotify/fsnotify v1.7.0 github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.6.0 github.com/manifoldco/promptui v0.9.0 github.com/prometheus/client_golang v1.19.1 github.com/sirupsen/logrus v1.9.3 + github.com/spf13/afero v1.11.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.18.2 golang.org/x/time v0.5.0 @@ -24,12 +32,20 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/emicklei/go-restful/v3 v3.11.1 // indirect + github.com/emicklei/go-restful/v3 v3.11.3 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-openapi/jsonpointer v0.20.2 // indirect @@ -61,7 +77,6 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect diff --git a/go.sum b/go.sum index 2a3bcb3a0..05290a6e2 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,33 @@ github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/aws-sdk-go-v2/config v1.27.30 h1:AQF3/+rOgeJBQP3iI4vojlPib5X6eeOYoa/af7OxAYg= +github.com/aws/aws-sdk-go-v2/config v1.27.30/go.mod h1:yxqvuubha9Vw8stEgNiStO+yZpP68Wm9hLmcm+R/Qk4= +github.com/aws/aws-sdk-go-v2/credentials v1.17.29 h1:CwGsupsXIlAFYuDVHv1nnK0wnxO0wZ/g1L8DSK/xiIw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.29/go.mod h1:BPJ/yXV92ZVq6G8uYvbU0gSl8q94UB63nMT5ctNO38g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 h1:yjwoSyDZF8Jth+mUk5lSPJCkMC0lMy6FaCD51jm6ayE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12/go.mod h1:fuR57fAgMk7ot3WcNQfb6rSEn+SUffl7ri+aa8uKysI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 h1:TNyt/+X43KJ9IJJMjKfa3bNTiZbUP7DeCxfbTROESwY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16/go.mod h1:2DwJF39FlNAUiX5pAc0UNeiz16lK2t7IaFcm0LFHEgc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY0L1/KftReOGxI/4NtVSTh9O/I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16/go.mod h1:7ZfEPZxkW42Afq4uQB8H2E2e6ebh6mXTueEpYzjCzcs= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0 h1:fWhkSvaQqa5eWiRwBw10FUnk1YatAQ9We4GdGxKiCtg= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0/go.mod h1:ISODge3zgdwOEa4Ou6WM9PKbxJWJ15DYKnr2bfmCAIA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbLPPHEmf9DgMGw51jMj77VfGPAN2Kv4cfhlfgI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHCiSH0jyd6gROjlJtNwov0eGYNz8s8nFcR0jQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 h1:zCsFCKvbj25i7p1u94imVoO447I/sFv8qq+lGJhRN0c= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.5/go.mod h1:ZeDX1SnKsVlejeuz41GiajjZpRSWR7/42q/EyA/QEiM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 h1:SKvPgvdvmiTWoi0GAJ7AsJfOz3ngVkD/ERbs5pUnHNI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5/go.mod h1:20sz31hv/WsPa3HhU3hfrIet2kxM4Pe0r20eBZ20Tac= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 h1:OMsEmCyz2i89XwRwPouAJvhj81wINh+4UK+k/0Yo/q8= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.5/go.mod h1:vmSqFK+BVIwVpDAGZB3CoCXHzurt4qBE8lf+I/kRTh0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= @@ -17,8 +45,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/emicklei/go-restful/v3 v3.11.1 h1:S+9bSbua1z3FgCnV0KKOSSZ3mDthb5NyEPL5gEpCvyk= -github.com/emicklei/go-restful/v3 v3.11.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.11.3 h1:yagOQz/38xJmcNeZJtrUcKjkHRltIaIFXKWeG1SkWGE= +github.com/emicklei/go-restful/v3 v3.11.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= diff --git a/pkg/arn/arn.go b/pkg/arn/arn.go index e9b73b587..22900c96d 100644 --- a/pkg/arn/arn.go +++ b/pkg/arn/arn.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - awsarn "github.com/aws/aws-sdk-go/aws/arn" + awsarn "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go/aws/endpoints" ) diff --git a/pkg/ec2provider/ec2provider.go b/pkg/ec2provider/ec2provider.go index d760f0bda..5d18ec720 100644 --- a/pkg/ec2provider/ec2provider.go +++ b/pkg/ec2provider/ec2provider.go @@ -1,25 +1,24 @@ package ec2provider import ( + "context" "errors" "fmt" "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/sirupsen/logrus" "sigs.k8s.io/aws-iam-authenticator/pkg" "sigs.k8s.io/aws-iam-authenticator/pkg/httputil" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" + "github.com/sirupsen/logrus" ) const ( @@ -36,11 +35,6 @@ const ( // Maximum time in Milliseconds to wait for a new batch call this also depends on if the instance size has // already become 100 then it will not respect this limit maxWaitIntervalForBatch = 200 - - // Headers for STS request for source ARN - headerSourceArn = "x-amz-source-arn" - // Headers for STS request for source account - headerSourceAccount = "x-amz-source-account" ) // Get a node name from instance ID @@ -60,13 +54,13 @@ type ec2Requests struct { } type ec2ProviderImpl struct { - ec2 ec2iface.EC2API + ec2 ec2.DescribeInstancesAPIClient privateDNSCache ec2PrivateDNSCache ec2Requests ec2Requests instanceIdsChannel chan string } -func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider { +func New(roleARN, sourceARN, region string, qps int, burst int) (EC2Provider, error) { dnsCache := ec2PrivateDNSCache{ cache: make(map[string]string), lock: sync.RWMutex{}, @@ -75,50 +69,56 @@ func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider { set: make(map[string]bool), lock: sync.RWMutex{}, } + cfg, err := newConfig(roleARN, sourceARN, region, qps, burst) + if err != nil { + return nil, err + } + return &ec2ProviderImpl{ - ec2: ec2.New(newSession(roleARN, sourceARN, region, qps, burst)), + ec2: ec2.NewFromConfig(cfg), privateDNSCache: dnsCache, ec2Requests: ec2Requests, instanceIdsChannel: make(chan string, maxChannelSize), - } + }, nil } -// Initial credentials loaded from SDK's default credential chain, such as -// the environment, shared credentials (~/.aws/credentials), or EC2 Instance -// Role. - -func newSession(roleARN, sourceARN, region string, qps int, burst int) *session.Session { - sess := session.Must(session.NewSession()) - sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ - Name: "authenticatorUserAgent", - Fn: request.MakeAddToUserAgentHandler( - "aws-iam-authenticator", pkg.Version), - }) - if aws.StringValue(sess.Config.Region) == "" { - sess.Config.Region = aws.String(region) +func newConfig(roleARN, sourceArn, region string, qps, burst int) (aws.Config, error) { + rateLimitedClient, err := httputil.NewRateLimitedClient(qps, burst) + if err != nil { + logrus.Errorf("error creating rate limited client %s", err) + return aws.Config{}, err + } + loadOpts := []func(*config.LoadOptions) error{ + config.WithRegion(region), + config.WithAPIOptions( + []func(*smithymiddleware.Stack) error{ + middleware.AddUserAgentKeyValue("aws-iam-authenticator", pkg.Version), + }), + config.WithHTTPClient(rateLimitedClient), } - if roleARN != "" { logrus.WithFields(logrus.Fields{ "roleARN": roleARN, }).Infof("Using assumed role for EC2 API") - rateLimitedClient, err := httputil.NewRateLimitedClient(qps, burst) - + cfg, err := config.LoadDefaultConfig(context.Background(), loadOpts...) if err != nil { - logrus.Errorf("Getting error = %s while creating rate limited client ", err) + logrus.Errorf("error loading AWS config %s", err) + return aws.Config{}, err } - - stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), sourceARN) - ap := &stscreds.AssumeRoleProvider{ - Client: stsClient, - RoleARN: roleARN, - Duration: time.Duration(60) * time.Minute, + stsOpts := []func(*sts.Options){} + if sourceArn != "" { + stsOpts = append(stsOpts, WithSourceHeaders(sourceArn)) } - sess.Config.Credentials = credentials.NewCredentials(ap) + stsCli := sts.NewFromConfig(cfg, stsOpts...) + creds := stscreds.NewAssumeRoleProvider(stsCli, roleARN, + func(o *stscreds.AssumeRoleOptions) { + o.Duration = time.Duration(60) * time.Minute + }) + loadOpts = append(loadOpts, config.WithCredentialsProvider(creds)) } - return sess + return config.LoadDefaultConfig(context.Background(), loadOpts...) } func (p *ec2ProviderImpl) setPrivateDNSNameCache(id string, privateDNSName string) { @@ -197,8 +197,8 @@ func (p *ec2ProviderImpl) GetPrivateDNSName(id string) (string, error) { logrus.Infof("Calling ec2:DescribeInstances for the InstanceId = %s ", id) metrics.Get().EC2DescribeInstanceCallCount.Inc() // Look up instance from EC2 API - output, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: aws.StringSlice([]string{id}), + output, err := p.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{ + InstanceIds: []string{id}, }) if err != nil { p.unsetRequestInFlightForInstanceId(id) @@ -206,8 +206,8 @@ func (p *ec2ProviderImpl) GetPrivateDNSName(id string) (string, error) { } for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { - if aws.StringValue(instance.InstanceId) == id { - privateDNSName = aws.StringValue(instance.PrivateDnsName) + if aws.ToString(instance.InstanceId) == id { + privateDNSName = aws.ToString(instance.PrivateDnsName) p.setPrivateDNSNameCache(id, privateDNSName) p.unsetRequestInFlightForInstanceId(id) } @@ -258,8 +258,8 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string // Look up instance from EC2 API logrus.Infof("Making Batch Query to DescribeInstances for %v instances ", len(instanceIdList)) metrics.Get().EC2DescribeInstanceCallCount.Inc() - output, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: aws.StringSlice(instanceIdList), + output, err := p.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{ + InstanceIds: instanceIdList, }) if err != nil { logrus.Errorf("Batch call failed querying private DNS from EC2 API for nodes [%s] : with error = []%s ", instanceIdList, err.Error()) @@ -272,8 +272,8 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string // Adding the result to privateDNSChache as well as removing from the requestQueueMap. for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { - id := aws.StringValue(instance.InstanceId) - privateDNSName := aws.StringValue(instance.PrivateDnsName) + id := aws.ToString(instance.InstanceId) + privateDNSName := aws.ToString(instance.PrivateDnsName) p.setPrivateDNSNameCache(id, privateDNSName) } } @@ -284,40 +284,3 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string p.unsetRequestInFlightForInstanceId(id) } } - -func applySTSRequestHeaders(stsClient *sts.STS, sourceARN string) *sts.STS { - // parse both source account and source arn from the sourceARN, and add them as headers to the STS client - if sourceARN != "" { - sourceAcct, err := getSourceAccount(sourceARN) - if err != nil { - panic(fmt.Sprintf("%s is not a valid arn, err: %v", sourceARN, err)) - } - reqHeaders := map[string]string{ - headerSourceAccount: sourceAcct, - headerSourceArn: sourceARN, - } - stsClient.Handlers.Sign.PushFront(func(s *request.Request) { - s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) - }) - logrus.Infof("configuring STS client with extra headers, %v", reqHeaders) - } - return stsClient -} - -// getSourceAccount constructs source acct and return them for use -func getSourceAccount(roleARN string) (string, error) { - // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) - // arn:partition:service:region:account-id:resource-type/resource-id - // IAM format, region is always blank - // arn:aws:iam::account:role/role-name-with-path - if !arn.IsARN(roleARN) { - return "", fmt.Errorf("incorrect ARN format for role %s", roleARN) - } - - parsedArn, err := arn.Parse(roleARN) - if err != nil { - return "", err - } - - return parsedArn.AccountID, nil -} diff --git a/pkg/ec2provider/ec2provider_test.go b/pkg/ec2provider/ec2provider_test.go index 912d73c8c..21ebb94aa 100644 --- a/pkg/ec2provider/ec2provider_test.go +++ b/pkg/ec2provider/ec2provider_test.go @@ -1,14 +1,15 @@ package ec2provider import ( + "context" "strconv" "sync" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/prometheus/client_golang/prometheus" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -18,25 +19,24 @@ const ( ) type mockEc2Client struct { - ec2iface.EC2API - Reservations []*ec2.Reservation + Reservations []ec2types.Reservation } -func (c *mockEc2Client) DescribeInstances(in *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { +func (c *mockEc2Client) DescribeInstances(ctx context.Context, in *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { // simulate the time it takes for aws to return time.Sleep(DescribeDelay * time.Millisecond) - var reservations []*ec2.Reservation + var reservations []ec2types.Reservation for _, res := range c.Reservations { - var reservation ec2.Reservation + var reservation ec2types.Reservation for _, inst := range res.Instances { for _, id := range in.InstanceIds { - if aws.StringValue(id) == aws.StringValue(inst.InstanceId) { + if id == aws.ToString(inst.InstanceId) { reservation.Instances = append(reservation.Instances, inst) } } } if len(reservation.Instances) > 0 { - reservations = append(reservations, &reservation) + reservations = append(reservations, reservation) } } return &ec2.DescribeInstancesOutput{ @@ -76,12 +76,12 @@ func TestGetPrivateDNSName(t *testing.T) { } } -func prepareSingleInstanceOutput() []*ec2.Reservation { - reservations := []*ec2.Reservation{ +func prepareSingleInstanceOutput() []ec2types.Reservation { + reservations := []ec2types.Reservation{ { Groups: nil, - Instances: []*ec2.Instance{ - &ec2.Instance{ + Instances: []ec2types.Instance{ + ec2types.Instance{ InstanceId: aws.String("ec2-1"), PrivateDnsName: aws.String("ec2-dns-1"), }, @@ -125,20 +125,20 @@ func getPrivateDNSName(ec2provider *ec2ProviderImpl, instanceString string, dnsS } } -func prepare100InstanceOutput() []*ec2.Reservation { +func prepare100InstanceOutput() []ec2types.Reservation { - var reservations []*ec2.Reservation + var reservations []ec2types.Reservation for i := 1; i < 101; i++ { instanceString := "ec2-" + strconv.Itoa(i) dnsString := "ec2-dns-" + strconv.Itoa(i) - instance := &ec2.Instance{ + instance := ec2types.Instance{ InstanceId: aws.String(instanceString), PrivateDnsName: aws.String(dnsString), } - var instances []*ec2.Instance + var instances []ec2types.Instance instances = append(instances, instance) - res1 := &ec2.Reservation{ + res1 := ec2types.Reservation{ Groups: nil, Instances: instances, OwnerId: nil, @@ -150,44 +150,3 @@ func prepare100InstanceOutput() []*ec2.Reservation { return reservations } - -func TestGetSourceAcctAndArn(t *testing.T) { - type args struct { - roleARN string - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "corect role arn", - args: args{ - roleARN: "arn:aws:iam::123456789876:role/test-cluster", - }, - want: "123456789876", - wantErr: false, - }, - { - name: "incorect role arn", - args: args{ - roleARN: "arn:aws:iam::123456789876", - }, - want: "", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := getSourceAccount(tt.args.roleARN) - if (err != nil) != tt.wantErr { - t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/ec2provider/source_headers.go b/pkg/ec2provider/source_headers.go new file mode 100644 index 000000000..f2473a689 --- /dev/null +++ b/pkg/ec2provider/source_headers.go @@ -0,0 +1,66 @@ +package ec2provider + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws/arn" + smithyhttp "github.com/aws/smithy-go/transport/http" + + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" +) + +const ( + // Headers for STS request for source ARN + headerSourceArn = "x-amz-source-arn" + // Headers for STS request for source account + headerSourceAccount = "x-amz-source-account" +) + +type withSourceHeaders struct { + sourceARN string +} + +// implements middleware.BuildMiddleware, which runs AFTER a request has been +// serialized and can operate on the transport request +var _ smithymiddleware.BuildMiddleware = (*withSourceHeaders)(nil) + +func (*withSourceHeaders) ID() string { + return "withSourceHeaders" +} + +func (m *withSourceHeaders) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) ( + out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unrecognized transport type %T", in.Request) + } + + if arn.IsARN(m.sourceARN) { + req.Header.Set(headerSourceArn, m.sourceARN) + } + + if parsedArn, err := arn.Parse(m.sourceARN); err == nil && parsedArn.AccountID != "" { + req.Header.Set(headerSourceAccount, parsedArn.AccountID) + } + + return next.HandleBuild(ctx, in) +} + +// WithSourceHeaders adds the x-amz-source-arn and x-amz-source-account headers to the request. +// These can be referenced in an IAM role trust policy document with the condition keys +// aws:SourceArn and aws:SourceAccount for sts:AssumeRole calls +// +// If the sourceARN is invalid, the source arn header is skipped. If the ARN is valid but doesn't +// contain an account ID, the source account header is skipped +func WithSourceHeaders(sourceARN string) func(*sts.Options) { + return func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, func(s *smithymiddleware.Stack) error { + return s.Build.Add(&withSourceHeaders{ + sourceARN: sourceARN, + }, smithymiddleware.After) + }) + } +} diff --git a/pkg/filecache/filecache.go b/pkg/filecache/filecache.go new file mode 100644 index 000000000..3d8624a2c --- /dev/null +++ b/pkg/filecache/filecache.go @@ -0,0 +1,272 @@ +package filecache + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "runtime" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/gofrs/flock" + "github.com/spf13/afero" + "gopkg.in/yaml.v2" +) + +// env variable name for custom credential cache file location +const cacheFileNameEnv = "AWS_IAM_AUTHENTICATOR_CACHE_FILE" + +// FileLocker is a subset of the methods exposed by *flock.Flock +type FileLocker interface { + Unlock() error + TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) + TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) +} + +// NewFileLocker returns a *flock.Flock that satisfies FileLocker +func NewFileLocker(filename string) FileLocker { + return flock.New(filename) +} + +// cacheFile is a map of clusterID/roleARNs to cached credentials +type cacheFile struct { + // a map of clusterIDs/profiles/roleARNs to cachedCredentials + ClusterMap map[string]map[string]map[string]aws.Credentials `yaml:"clusters"` +} + +// a utility type for dealing with compound cache keys +type cacheKey struct { + clusterID string + profile string + roleARN string +} + +func (c *cacheFile) Put(key cacheKey, credential aws.Credentials) { + if _, ok := c.ClusterMap[key.clusterID]; !ok { + // first use of this cluster id + c.ClusterMap[key.clusterID] = map[string]map[string]aws.Credentials{} + } + if _, ok := c.ClusterMap[key.clusterID][key.profile]; !ok { + // first use of this profile + c.ClusterMap[key.clusterID][key.profile] = map[string]aws.Credentials{} + } + c.ClusterMap[key.clusterID][key.profile][key.roleARN] = credential +} + +func (c *cacheFile) Get(key cacheKey) (credential aws.Credentials) { + if _, ok := c.ClusterMap[key.clusterID]; ok { + if _, ok := c.ClusterMap[key.clusterID][key.profile]; ok { + // we at least have this cluster and profile combo in the map, if no matching roleARN, map will + // return the zero-value for cachedCredential, which expired a long time ago. + credential = c.ClusterMap[key.clusterID][key.profile][key.roleARN] + } + } + return +} + +// readCacheWhileLocked reads the contents of the credential cache and returns the +// parsed yaml as a cacheFile object. This method must be called while a shared +// lock is held on the filename. +func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { + cache = cacheFile{ + map[string]map[string]map[string]aws.Credentials{}, + } + data, err := afero.ReadFile(fs, filename) + if err != nil { + err = fmt.Errorf("unable to open file %s: %v", filename, err) + return + } + + err = yaml.Unmarshal(data, &cache) + if err != nil { + err = fmt.Errorf("unable to parse file %s: %v", filename, err) + } + return +} + +// writeCacheWhileLocked writes the contents of the credential cache using the +// yaml marshaled form of the passed cacheFile object. This method must be +// called while an exclusive lock is held on the filename. +func writeCacheWhileLocked(fs afero.Fs, filename string, cache cacheFile) error { + data, err := yaml.Marshal(cache) + if err == nil { + // write privately owned by the user + err = afero.WriteFile(fs, filename, data, 0600) + } + return err +} + +type FileCacheOpt func(*FileCacheProvider) + +// WithFs returns a FileCacheOpt that sets the cache's filesystem +func WithFs(fs afero.Fs) FileCacheOpt { + return func(p *FileCacheProvider) { + p.fs = fs + } +} + +// WithFilename returns a FileCacheOpt that sets the cache's file +func WithFilename(filename string) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filename = filename + } +} + +// WithFileLockCreator returns a FileCacheOpt that sets the cache's FileLocker +// creation function +func WithFileLockerCreator(f func(string) FileLocker) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filelockCreator = f + } +} + +// FileCacheProvider is a credentials.Provider implementation that wraps an underlying Provider +// (contained in Credentials) and provides caching support for credentials for the +// specified clusterID, profile, and roleARN (contained in cacheKey) +type FileCacheProvider struct { + fs afero.Fs + filelockCreator func(string) FileLocker + filename string + provider aws.CredentialsProvider // the underlying implementation that has the *real* Provider + cacheKey cacheKey // cache key parameters used to create Provider + cachedCredential aws.Credentials // the cached credential, if it exists +} + +var _ aws.CredentialsProvider = &FileCacheProvider{} + +// NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials, +// and works with an on disk cache to speed up credential usage when the cached copy is not expired. +// If there are any problems accessing or initializing the cache, an error will be returned, and +// callers should just use the existing credentials provider. +func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.CredentialsProvider, opts ...FileCacheOpt) (*FileCacheProvider, error) { + if provider == nil { + return nil, errors.New("no underlying Credentials object provided") + } + + resp := &FileCacheProvider{ + fs: afero.NewOsFs(), + filelockCreator: NewFileLocker, + filename: defaultCacheFilename(), + provider: provider, + cacheKey: cacheKey{clusterID, profile, roleARN}, + cachedCredential: aws.Credentials{}, + } + + // override defaults + for _, opt := range opts { + opt(resp) + } + + // ensure path to cache file exists + _ = resp.fs.MkdirAll(filepath.Dir(resp.filename), 0700) + if info, err := resp.fs.Stat(resp.filename); err == nil { + if info.Mode()&0077 != 0 { + // cache file has secret credentials and should only be accessible to the user, refuse to use it. + return nil, fmt.Errorf("cache file %s is not private", resp.filename) + } + + // do file locking on cache to prevent inconsistent reads + lock := resp.filelockCreator(resp.filename) + defer lock.Unlock() + // wait up to a second for the file to lock + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + ok, err := lock.TryRLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second + if !ok { + // unable to lock the cache, something is wrong, refuse to use it. + return nil, fmt.Errorf("unable to read lock file %s: %v", resp.filename, err) + } + + cache, err := readCacheWhileLocked(resp.fs, resp.filename) + if err != nil { + // can't read or parse cache, refuse to use it. + return nil, err + } + + resp.cachedCredential = cache.Get(resp.cacheKey) + } else { + if errors.Is(err, fs.ErrNotExist) { + // cache file is missing. maybe this is the very first run? continue to use cache. + _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", resp.filename) + } else { + return nil, fmt.Errorf("couldn't stat cache file: %w", err) + } + } + + return resp, nil +} + +// Retrieve() implements the aws.CredentialsProvider interface, returning the cached credential if is not expired, +// otherwise fetching the credential from the underlying CredentialProvider and caching the results on disk +// with an expiration time. +func (f *FileCacheProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() { + // use the cached credential + return f.cachedCredential, nil + } else { + _, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n") + // fetch the credentials from the underlying Provider + credential, err := f.provider.Retrieve(ctx) + if err != nil { + return credential, err + } + + if credential.CanExpire { + // Credential supports expiration, so we can cache + + // do file locking on cache to prevent inconsistent writes + lock := f.filelockCreator(f.filename) + defer lock.Unlock() + // wait up to a second for the file to lock + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second + if !ok { + // can't get write lock to create/update cache, but still return the credential + _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) + return credential, nil + } + f.cachedCredential = credential + // don't really care about read error. Either read the cache, or we create a new cache. + cache, _ := readCacheWhileLocked(f.fs, f.filename) + cache.Put(f.cacheKey, f.cachedCredential) + err = writeCacheWhileLocked(f.fs, f.filename, cache) + if err != nil { + // can't write cache, but still return the credential + _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", f.filename, err) + err = nil + } else { + _, _ = fmt.Fprintf(os.Stderr, "Updated cached credential\n") + } + } else { + // credential doesn't support expiration time, so can't cache, but still return the credential + _, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err) + err = nil + } + return credential, err + } +} + +// defaultCacheFilename returns the name of the credential cache file, which can either be +// set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml +func defaultCacheFilename() string { + if filename := os.Getenv(cacheFileNameEnv); filename != "" { + return filename + } else { + return filepath.Join(userHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") + } +} + +// userHomeDir returns the home directory for the user the process is +// running under. +func userHomeDir() string { + if runtime.GOOS == "windows" { // Windows + return os.Getenv("USERPROFILE") + } + + // *nix + return os.Getenv("HOME") +} diff --git a/pkg/filecache/filecache_test.go b/pkg/filecache/filecache_test.go new file mode 100644 index 000000000..7c8fbaafd --- /dev/null +++ b/pkg/filecache/filecache_test.go @@ -0,0 +1,576 @@ +package filecache + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/google/go-cmp/cmp" + "github.com/spf13/afero" +) + +const ( + testFilename = "/test.yaml" +) + +// stubProvider implements credentials.Provider with configurable response values +type stubProvider struct { + creds aws.Credentials + err error +} + +var _ aws.CredentialsProvider = &stubProvider{} + +func (s *stubProvider) Retrieve(_ context.Context) (aws.Credentials, error) { + s.creds.Source = "stubProvider" + return s.creds, s.err +} + +// testFileInfo implements fs.FileInfo with configurable response values +type testFileInfo struct { + name string + size int64 + mode fs.FileMode + modTime time.Time +} + +var _ fs.FileInfo = &testFileInfo{} + +func (fs *testFileInfo) Name() string { return fs.name } +func (fs *testFileInfo) Size() int64 { return fs.size } +func (fs *testFileInfo) Mode() fs.FileMode { return fs.mode } +func (fs *testFileInfo) ModTime() time.Time { return fs.modTime } +func (fs *testFileInfo) IsDir() bool { return fs.Mode().IsDir() } +func (fs *testFileInfo) Sys() interface{} { return nil } + +// testFs wraps afero.Fs with an overridable Stat() method +type testFS struct { + afero.Fs + + fileinfo fs.FileInfo + err error +} + +func (t *testFS) Stat(filename string) (fs.FileInfo, error) { + if t.err != nil { + return nil, t.err + } + if t.fileinfo != nil { + return t.fileinfo, nil + } + return t.Fs.Stat(filename) +} + +// testFileLock implements FileLocker with configurable response options +type testFilelock struct { + ctx context.Context + retryDelay time.Duration + success bool + err error +} + +var _ FileLocker = &testFilelock{} + +func (l *testFilelock) Unlock() error { + return nil +} + +func (l *testFilelock) TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) { + l.ctx = ctx + l.retryDelay = retryDelay + return l.success, l.err +} + +func (l *testFilelock) TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) { + l.ctx = ctx + l.retryDelay = retryDelay + return l.success, l.err +} + +// getMocks returns a mocked filesystem and FileLocker +func getMocks() (*testFS, *testFilelock) { + return &testFS{Fs: afero.NewMemMapFs()}, &testFilelock{context.TODO(), 0, true, nil} +} + +// makeCredential returns a dummy AWS crdential +func makeCredential() aws.Credentials { + return aws.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", + SessionToken: "TOKEN", + Source: "stubProvider", + CanExpire: false, + } +} + +func makeExpiringCredential(e time.Time) aws.Credentials { + return aws.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", + SessionToken: "TOKEN", + Source: "stubProvider", + CanExpire: true, + Expires: e, + } +} + +// validateFileCacheProvider ensures that the cache provider is properly initialized +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c aws.CredentialsProvider) { + t.Helper() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if p.provider != c { + t.Errorf("Credentials not copied") + } + if p.cacheKey.clusterID != "CLUSTER" { + t.Errorf("clusterID not copied") + } + if p.cacheKey.profile != "PROFILE" { + t.Errorf("profile not copied") + } + if p.cacheKey.roleARN != "ARN" { + t.Errorf("roleARN not copied") + } +} + +// testSetEnv sets an env var, and returns a cleanup func +func testSetEnv(t *testing.T, key, value string) func() { + t.Helper() + old := os.Getenv(key) + os.Setenv(key, value) + return func() { + if old == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, old) + } + } +} + +func TestCacheFilename(t *testing.T) { + + c1 := testSetEnv(t, "HOME", "homedir") + defer c1() + c2 := testSetEnv(t, "USERPROFILE", "homedir") + defer c2() + + filename := defaultCacheFilename() + expected := "homedir/.kube/cache/aws-iam-authenticator/credentials.yaml" + if filename != expected { + t.Errorf("Incorrect default cacheFilename, expected %s, got %s", + expected, filename) + } + + c3 := testSetEnv(t, "AWS_IAM_AUTHENTICATOR_CACHE_FILE", "special.yaml") + defer c3() + filename = defaultCacheFilename() + expected = "special.yaml" + if filename != expected { + t.Errorf("Incorrect custom cacheFilename, expected %s, got %s", + expected, filename) + } +} + +func TestNewFileCacheProvider_Missing(t *testing.T) { + provider := &stubProvider{} + + tfs, tfl := getMocks() + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + })) + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing cache file should result in empty cached credential") + } +} + +func TestNewFileCacheProvider_BadPermissions(t *testing.T) { + provider := &stubProvider{} + + tfs, _ := getMocks() + // afero.MemMapFs always returns tempfile FileInfo, + // so we manually set the response to the Stat() call + tfs.fileinfo = &testFileInfo{mode: 0777} + + // bad permissions + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + ) + if err == nil { + t.Errorf("Expected error due to public permissions") + } + wantMsg := fmt.Sprintf("cache file %s is not private", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) + } +} + +func TestNewFileCacheProvider_Unlockable(t *testing.T) { + provider := &stubProvider{} + + tfs, tfl := getMocks() + tfs.Create(testFilename) + + // unable to lock + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + }), + ) + if err == nil { + t.Errorf("Expected error due to lock failure") + } +} + +func TestNewFileCacheProvider_Unreadable(t *testing.T) { + provider := &stubProvider{} + + tfs, tfl := getMocks() + tfs.Create(testFilename) + tfl.err = fmt.Errorf("open %s: permission denied", testFilename) + tfl.success = false + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + }), + ) + if err == nil { + t.Errorf("Expected error due to read failure") + return + } + wantMsg := fmt.Sprintf("unable to read lock file %s: open %s: permission denied", testFilename, testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) + } +} + +func TestNewFileCacheProvider_Unparseable(t *testing.T) { + provider := &stubProvider{} + + tfs, tfl := getMocks() + tfs.Create(testFilename) + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + afero.WriteFile( + tfs, + testFilename, + []byte("invalid: yaml: file"), + 0700) + return tfl + }), + ) + if err == nil { + t.Errorf("Expected error due to bad yaml") + } + wantMsg := fmt.Sprintf("unable to parse file %s: yaml: mapping values are not allowed in this context", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) + } +} + +func TestNewFileCacheProvider_Empty(t *testing.T) { + provider := &stubProvider{} + + tfs, tfl := getMocks() + + // successfully parse existing but empty cache file + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("empty cache file should result in empty cached credential") + } +} + +func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { + provider := &stubProvider{} + + tfs, tfl := getMocks() + tfs.Create(testFilename) + + // successfully parse existing cluster without matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: + CLUSTER: + PROFILE2: {} +`), + 0700) + return tfl + }), + ) + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing profile in cache file should result in empty cached credential") + } +} + +func TestNewFileCacheProvider_ExistingARN(t *testing.T) { + provider := &stubProvider{} + + expiry := time.Now().Add(time.Hour * 6) + content := []byte(`clusters: + CLUSTER: + PROFILE: + ARN: + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + expires: ` + expiry.Format(time.RFC3339Nano) + ` +`) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + // successfully parse cluster with matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + }), + ) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.AccessKeyID != "ABC" || p.cachedCredential.SecretAccessKey != "DEF" || + p.cachedCredential.SessionToken != "GHI" || p.cachedCredential.Source != "JKL" { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) + } + + if p.cachedCredential.Expired() { + t.Errorf("Cached credential should not be expired") + } + +} + +func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { + provider := &stubProvider{ + creds: makeCredential(), + } + + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) + validateFileCacheProvider(t, p, err, provider) + + credential, err := p.Retrieve(context.Background()) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken { + t.Errorf("Cache did not return provider credential, got %v, expected %v", + credential, provider.creds) + } +} + +func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } + + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) + validateFileCacheProvider(t, p, err, provider) + + // retrieve credential, which will fetch from underlying Provider + // fail to get write lock + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + + credential, err := p.Retrieve(context.Background()) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if credential.AccessKeyID != "AKID" || credential.SecretAccessKey != "SECRET" || + credential.SessionToken != "TOKEN" || credential.Source != "stubProvider" { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) + } +} + +func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } + + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) + validateFileCacheProvider(t, p, err, provider) + + credential, err := p.Retrieve(context.Background()) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.Source != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) + } + + expectedData := []byte(`clusters: + CLUSTER: + PROFILE: + ARN: + accesskeyid: AKID + secretaccesskey: SECRET + sessiontoken: TOKEN + source: stubProvider + canexpire: true + expires: ` + expires.Format(time.RFC3339Nano) + ` + accountid: "" +`) + got, err := afero.ReadFile(tfs, testFilename) + if err != nil { + t.Errorf("unexpected error reading generated file: %v", err) + } + if diff := cmp.Diff(got, expectedData); diff != "" { + t.Errorf("Wrong data written to cache, %s", diff) + } +} + +func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } + + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) + validateFileCacheProvider(t, p, err, provider) + + // retrieve credential, which will fetch from underlying Provider + // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, + // but write to disk (code coverage) + credential, err := p.Retrieve(context.Background()) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.Source != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) + } +} + +func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { + provider := &stubProvider{} + currentTime := time.Now() + + tfs, tfl := getMocks() + tfs.Create(testFilename) + + // successfully parse cluster with matching arn + content := []byte(`clusters: + CLUSTER: + PROFILE: + ARN: + credential: + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + canexpire: true + expires: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` +`) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + })) + validateFileCacheProvider(t, p, err, provider) + + credential, err := p.Retrieve(context.Background()) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if credential.AccessKeyID != "ABC" || credential.SecretAccessKey != "DEF" || + credential.SessionToken != "GHI" || credential.Source != "JKL" || + !credential.Expires.Equal(currentTime.Add(time.Hour*6)) { + t.Errorf("cached credential not returned") + } + +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 045f948c2..ce47f0f7d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -28,8 +28,6 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/session" "sigs.k8s.io/aws-iam-authenticator/pkg/config" "sigs.k8s.io/aws-iam-authenticator/pkg/ec2provider" "sigs.k8s.io/aws-iam-authenticator/pkg/errutil" @@ -42,7 +40,9 @@ import ( "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" "sigs.k8s.io/aws-iam-authenticator/pkg/token" - awsarn "github.com/aws/aws-sdk-go/aws/arn" + awsarn "github.com/aws/aws-sdk-go-v2/aws/arn" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" authenticationv1beta1 "k8s.io/api/authentication/v1beta1" @@ -88,7 +88,7 @@ func New(cfg config.Config, stopCh <-chan struct{}) *Server { backendMapper, err := BuildMapperChain(cfg, cfg.BackendMode) if err != nil { - logrus.Fatalf("failed to build mapper chain: %v", err) + logrus.WithError(err).Fatal("failed to build mapper chain") } for _, mapping := range c.RoleMappings { @@ -144,7 +144,11 @@ func New(cfg config.Config, stopCh <-chan struct{}) *Server { logrus.Infof("listening on %s", listener.Addr()) logrus.Infof("reconfigure your apiserver with `--authentication-token-webhook-config-file=%s` to enable (assuming default hostPath mounts)", c.GenerateKubeconfigPath) - internalHandler := c.getHandler(backendMapper, c.EC2DescribeInstancesQps, c.EC2DescribeInstancesBurst, stopCh) + internalHandler, err := c.getHandler(backendMapper, c.EC2DescribeInstancesQps, c.EC2DescribeInstancesBurst, stopCh) + if err != nil { + logrus.WithError(err).Fatal("Failed to create handlers") + } + c.httpServer = http.Server{ ErrorLog: log.New(errLog, "", 0), Handler: internalHandler, @@ -191,23 +195,35 @@ type healthzHandler struct{} func (m *healthzHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "ok") } -func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2DescribeBurst int, stopCh <-chan struct{}) *handler { + +func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2DescribeBurst int, stopCh <-chan struct{}) (*handler, error) { if c.ServerEC2DescribeInstancesRoleARN != "" { _, err := awsarn.Parse(c.ServerEC2DescribeInstancesRoleARN) if err != nil { - panic(fmt.Sprintf("describeinstancesrole %s is not a valid arn", c.ServerEC2DescribeInstancesRoleARN)) + logrus.WithError(err).Errorf("describeinstancesrole %s is not a valid arn", c.ServerEC2DescribeInstancesRoleARN) + return nil, err } } - sess := session.Must(session.NewSession()) - ec2metadata := ec2metadata.New(sess) - instanceRegion, err := ec2metadata.Region() + ctx := context.Background() + cfg, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + logrus.WithError(err).Error("EC2 instance metadata not configured") + } + cli := imds.NewFromConfig(cfg) + resp, err := cli.GetRegion(ctx, nil) + if err != nil { + logrus.WithError(err).Error("region not found in instance metadata.") + } + + ec2Prov, err := ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, resp.Region, ec2DescribeQps, ec2DescribeBurst) if err != nil { - logrus.WithError(err).Errorln("Region not found in instance metadata.") + logrus.WithError(err).Errorln("error initializing EC2 provider") + return nil, err } h := &handler{ - verifier: token.NewVerifier(c.ClusterID, c.PartitionID, instanceRegion), - ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst), + verifier: token.NewVerifier(c.ClusterID, c.PartitionID, resp.Region), + ec2Provider: ec2Prov, clusterID: c.ClusterID, backendMapper: backendMapper, scrubbedAccounts: c.Config.ScrubbedAWSAccounts, @@ -226,7 +242,7 @@ func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2 fileutil.StartLoadDynamicFile(c.DynamicBackendModePath, h, stopCh) } - return h + return h, nil } func BuildMapperChain(cfg config.Config, modes []string) (BackendMapper, error) { diff --git a/pkg/token/cluster_id_header.go b/pkg/token/cluster_id_header.go new file mode 100644 index 000000000..1f6a43d32 --- /dev/null +++ b/pkg/token/cluster_id_header.go @@ -0,0 +1,45 @@ +package token + +import ( + "context" + "fmt" + + smithyhttp "github.com/aws/smithy-go/transport/http" + + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" +) + +type withClusterIDHeader struct { + clusterID string +} + +// implements middleware.BuildMiddleware, which runs AFTER a request has been +// serialized and can operate on the transport request +var _ smithymiddleware.BuildMiddleware = (*withClusterIDHeader)(nil) + +func (*withClusterIDHeader) ID() string { + return "withClusterIDHeader" +} + +func (m *withClusterIDHeader) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) ( + out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unrecognized transport type %T", in.Request) + } + req.Header.Set(clusterIDHeader, m.clusterID) + return next.HandleBuild(ctx, in) +} + +// WithClusterIDHeader adds the clusterID header to the request befor signing +func WithClusterIDHeader(clusterID string) func(*sts.Options) { + return func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, func(s *smithymiddleware.Stack) error { + return s.Build.Add(&withClusterIDHeader{ + clusterID: clusterID, + }, smithymiddleware.After) + }) + } +} diff --git a/pkg/token/filecache.go b/pkg/token/filecache.go deleted file mode 100644 index e1a0c2a84..000000000 --- a/pkg/token/filecache.go +++ /dev/null @@ -1,314 +0,0 @@ -package token - -import ( - "context" - "errors" - "fmt" - "io/fs" - "os" - "path/filepath" - "runtime" - "time" - - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/gofrs/flock" - "gopkg.in/yaml.v2" -) - -// env variable name for custom credential cache file location -const cacheFileNameEnv = "AWS_IAM_AUTHENTICATOR_CACHE_FILE" - -// A mockable filesystem interface -var f filesystem = osFS{} - -type filesystem interface { - Stat(filename string) (os.FileInfo, error) - ReadFile(filename string) ([]byte, error) - WriteFile(filename string, data []byte, perm os.FileMode) error - MkdirAll(path string, perm os.FileMode) error -} - -// default os based implementation -type osFS struct{} - -func (osFS) Stat(filename string) (os.FileInfo, error) { - return os.Stat(filename) -} - -func (osFS) ReadFile(filename string) ([]byte, error) { - return os.ReadFile(filename) -} - -func (osFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - return os.WriteFile(filename, data, perm) -} - -func (osFS) MkdirAll(path string, perm os.FileMode) error { - return os.MkdirAll(path, perm) -} - -// A mockable environment interface -var e environment = osEnv{} - -type environment interface { - Getenv(key string) string - LookupEnv(key string) (string, bool) -} - -// default os based implementation -type osEnv struct{} - -func (osEnv) Getenv(key string) string { - return os.Getenv(key) -} - -func (osEnv) LookupEnv(key string) (string, bool) { - return os.LookupEnv(key) -} - -// A mockable flock interface -type filelock interface { - Unlock() error - TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) - TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) -} - -var newFlock = func(filename string) filelock { - return flock.New(filename) -} - -// cacheFile is a map of clusterID/roleARNs to cached credentials -type cacheFile struct { - // a map of clusterIDs/profiles/roleARNs to cachedCredentials - ClusterMap map[string]map[string]map[string]cachedCredential `yaml:"clusters"` -} - -// a utility type for dealing with compound cache keys -type cacheKey struct { - clusterID string - profile string - roleARN string -} - -func (c *cacheFile) Put(key cacheKey, credential cachedCredential) { - if _, ok := c.ClusterMap[key.clusterID]; !ok { - // first use of this cluster id - c.ClusterMap[key.clusterID] = map[string]map[string]cachedCredential{} - } - if _, ok := c.ClusterMap[key.clusterID][key.profile]; !ok { - // first use of this profile - c.ClusterMap[key.clusterID][key.profile] = map[string]cachedCredential{} - } - c.ClusterMap[key.clusterID][key.profile][key.roleARN] = credential -} - -func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { - if _, ok := c.ClusterMap[key.clusterID]; ok { - if _, ok := c.ClusterMap[key.clusterID][key.profile]; ok { - // we at least have this cluster and profile combo in the map, if no matching roleARN, map will - // return the zero-value for cachedCredential, which expired a long time ago. - credential = c.ClusterMap[key.clusterID][key.profile][key.roleARN] - } - } - return -} - -// cachedCredential is a single cached credential entry, along with expiration time -type cachedCredential struct { - Credential credentials.Value - Expiration time.Time - // If set will be used by IsExpired to determine the current time. - // Defaults to time.Now if CurrentTime is not set. Available for testing - // to be able to mock out the current time. - currentTime func() time.Time -} - -// IsExpired determines if the cached credential has expired -func (c *cachedCredential) IsExpired() bool { - curTime := c.currentTime - if curTime == nil { - curTime = time.Now - } - return c.Expiration.Before(curTime()) -} - -// readCacheWhileLocked reads the contents of the credential cache and returns the -// parsed yaml as a cacheFile object. This method must be called while a shared -// lock is held on the filename. -func readCacheWhileLocked(filename string) (cache cacheFile, err error) { - cache = cacheFile{ - map[string]map[string]map[string]cachedCredential{}, - } - data, err := f.ReadFile(filename) - if err != nil { - err = fmt.Errorf("unable to open file %s: %v", filename, err) - return - } - - err = yaml.Unmarshal(data, &cache) - if err != nil { - err = fmt.Errorf("unable to parse file %s: %v", filename, err) - } - return -} - -// writeCacheWhileLocked writes the contents of the credential cache using the -// yaml marshaled form of the passed cacheFile object. This method must be -// called while an exclusive lock is held on the filename. -func writeCacheWhileLocked(filename string, cache cacheFile) error { - data, err := yaml.Marshal(cache) - if err == nil { - // write privately owned by the user - err = f.WriteFile(filename, data, 0600) - } - return err -} - -// FileCacheProvider is a Provider implementation that wraps an underlying Provider -// (contained in Credentials) and provides caching support for credentials for the -// specified clusterID, profile, and roleARN (contained in cacheKey) -type FileCacheProvider struct { - credentials *credentials.Credentials // the underlying implementation that has the *real* Provider - cacheKey cacheKey // cache key parameters used to create Provider - cachedCredential cachedCredential // the cached credential, if it exists -} - -// NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials, -// and works with an on disk cache to speed up credential usage when the cached copy is not expired. -// If there are any problems accessing or initializing the cache, an error will be returned, and -// callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials) (FileCacheProvider, error) { - if creds == nil { - return FileCacheProvider{}, errors.New("no underlying Credentials object provided") - } - filename := CacheFilename() - cacheKey := cacheKey{clusterID, profile, roleARN} - cachedCredential := cachedCredential{} - // ensure path to cache file exists - _ = f.MkdirAll(filepath.Dir(filename), 0700) - if info, err := f.Stat(filename); err == nil { - if info.Mode()&0077 != 0 { - // cache file has secret credentials and should only be accessible to the user, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("cache file %s is not private", filename) - } - - // do file locking on cache to prevent inconsistent reads - lock := newFlock(filename) - defer lock.Unlock() - // wait up to a second for the file to lock - ctx, cancel := context.WithTimeout(context.TODO(), time.Second) - defer cancel() - ok, err := lock.TryRLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second - if !ok { - // unable to lock the cache, something is wrong, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("unable to read lock file %s: %v", filename, err) - } - - cache, err := readCacheWhileLocked(filename) - if err != nil { - // can't read or parse cache, refuse to use it. - return FileCacheProvider{}, err - } - - cachedCredential = cache.Get(cacheKey) - } else { - if errors.Is(err, fs.ErrNotExist) { - // cache file is missing. maybe this is the very first run? continue to use cache. - _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", filename) - } else { - return FileCacheProvider{}, fmt.Errorf("couldn't stat cache file: %w", err) - } - } - - return FileCacheProvider{ - creds, - cacheKey, - cachedCredential, - }, nil -} - -// Retrieve() implements the Provider interface, returning the cached credential if is not expired, -// otherwise fetching the credential from the underlying Provider and caching the results on disk -// with an expiration time. -func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { - if !f.cachedCredential.IsExpired() { - // use the cached credential - return f.cachedCredential.Credential, nil - } else { - _, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n") - // fetch the credentials from the underlying Provider - credential, err := f.credentials.Get() - if err != nil { - return credential, err - } - if expiration, err := f.credentials.ExpiresAt(); err == nil { - // underlying provider supports Expirer interface, so we can cache - filename := CacheFilename() - // do file locking on cache to prevent inconsistent writes - lock := newFlock(filename) - defer lock.Unlock() - // wait up to a second for the file to lock - ctx, cancel := context.WithTimeout(context.TODO(), time.Second) - defer cancel() - ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second - if !ok { - // can't get write lock to create/update cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", filename, err) - return credential, nil - } - f.cachedCredential = cachedCredential{ - credential, - expiration, - nil, - } - // don't really care about read error. Either read the cache, or we create a new cache. - cache, _ := readCacheWhileLocked(filename) - cache.Put(f.cacheKey, f.cachedCredential) - err = writeCacheWhileLocked(filename, cache) - if err != nil { - // can't write cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", filename, err) - err = nil - } else { - _, _ = fmt.Fprintf(os.Stderr, "Updated cached credential\n") - } - } else { - // credential doesn't support expiration time, so can't cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err) - err = nil - } - return credential, err - } -} - -// IsExpired() implements the Provider interface, deferring to the cached credential first, -// but fall back to the underlying Provider if it is expired. -func (f *FileCacheProvider) IsExpired() bool { - return f.cachedCredential.IsExpired() && f.credentials.IsExpired() -} - -// ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential -func (f *FileCacheProvider) ExpiresAt() time.Time { - return f.cachedCredential.Expiration -} - -// CacheFilename returns the name of the credential cache file, which can either be -// set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml -func CacheFilename() string { - if filename, ok := e.LookupEnv(cacheFileNameEnv); ok { - return filename - } else { - return filepath.Join(UserHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") - } -} - -// UserHomeDir returns the home directory for the user the process is -// running under. -func UserHomeDir() string { - if runtime.GOOS == "windows" { // Windows - return e.Getenv("USERPROFILE") - } - - // *nix - return e.Getenv("HOME") -} diff --git a/pkg/token/filecache_test.go b/pkg/token/filecache_test.go deleted file mode 100644 index d69c75937..000000000 --- a/pkg/token/filecache_test.go +++ /dev/null @@ -1,512 +0,0 @@ -package token - -import ( - "bytes" - "context" - "errors" - "github.com/aws/aws-sdk-go/aws/credentials" - "os" - "testing" - "time" -) - -type stubProvider struct { - creds credentials.Value - expired bool - err error -} - -func (s *stubProvider) Retrieve() (credentials.Value, error) { - s.expired = false - s.creds.ProviderName = "stubProvider" - return s.creds, s.err -} - -func (s *stubProvider) IsExpired() bool { - return s.expired -} - -type stubProviderExpirer struct { - stubProvider - expiration time.Time -} - -func (s *stubProviderExpirer) ExpiresAt() time.Time { - return s.expiration -} - -type testFileInfo struct { - name string - size int64 - mode os.FileMode - modTime time.Time -} - -func (fs *testFileInfo) Name() string { return fs.name } -func (fs *testFileInfo) Size() int64 { return fs.size } -func (fs *testFileInfo) Mode() os.FileMode { return fs.mode } -func (fs *testFileInfo) ModTime() time.Time { return fs.modTime } -func (fs *testFileInfo) IsDir() bool { return fs.Mode().IsDir() } -func (fs *testFileInfo) Sys() interface{} { return nil } - -type testFS struct { - filename string - fileinfo testFileInfo - data []byte - err error - perm os.FileMode -} - -func (t *testFS) Stat(filename string) (os.FileInfo, error) { - t.filename = filename - if t.err == nil { - return &t.fileinfo, nil - } else { - return nil, t.err - } -} - -func (t *testFS) ReadFile(filename string) ([]byte, error) { - t.filename = filename - return t.data, t.err -} - -func (t *testFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - t.filename = filename - t.data = data - t.perm = perm - return t.err -} - -func (t *testFS) MkdirAll(path string, perm os.FileMode) error { - t.filename = path - t.perm = perm - return t.err -} - -func (t *testFS) reset() { - t.filename = "" - t.fileinfo = testFileInfo{} - t.data = []byte{} - t.err = nil - t.perm = 0600 -} - -type testEnv struct { - values map[string]string -} - -func (e *testEnv) Getenv(key string) string { - return e.values[key] -} - -func (e *testEnv) LookupEnv(key string) (string, bool) { - value, ok := e.values[key] - return value, ok -} - -func (e *testEnv) reset() { - e.values = map[string]string{} -} - -type testFilelock struct { - ctx context.Context - retryDelay time.Duration - success bool - err error -} - -func (l *testFilelock) Unlock() error { - return nil -} - -func (l *testFilelock) TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) { - l.ctx = ctx - l.retryDelay = retryDelay - return l.success, l.err -} - -func (l *testFilelock) TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) { - l.ctx = ctx - l.retryDelay = retryDelay - return l.success, l.err -} - -func (l *testFilelock) reset() { - l.ctx = context.TODO() - l.retryDelay = 0 - l.success = true - l.err = nil -} - -func getMocks() (tf *testFS, te *testEnv, testFlock *testFilelock) { - tf = &testFS{} - tf.reset() - f = tf - te = &testEnv{} - te.reset() - e = te - testFlock = &testFilelock{} - testFlock.reset() - newFlock = func(filename string) filelock { - return testFlock - } - return -} - -func makeCredential() credentials.Value { - return credentials.Value{ - AccessKeyID: "AKID", - SecretAccessKey: "SECRET", - SessionToken: "TOKEN", - ProviderName: "stubProvider", - } -} - -func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c *credentials.Credentials) { - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if p.credentials != c { - t.Errorf("Credentials not copied") - } - if p.cacheKey.clusterID != "CLUSTER" { - t.Errorf("clusterID not copied") - } - if p.cacheKey.profile != "PROFILE" { - t.Errorf("profile not copied") - } - if p.cacheKey.roleARN != "ARN" { - t.Errorf("roleARN not copied") - } -} - -func TestCacheFilename(t *testing.T) { - _, te, _ := getMocks() - - te.values["HOME"] = "homedir" // unix - te.values["USERPROFILE"] = "homedir" // windows - - filename := CacheFilename() - expected := "homedir/.kube/cache/aws-iam-authenticator/credentials.yaml" - if filename != expected { - t.Errorf("Incorrect default cacheFilename, expected %s, got %s", - expected, filename) - } - - te.values["AWS_IAM_AUTHENTICATOR_CACHE_FILE"] = "special.yaml" - filename = CacheFilename() - expected = "special.yaml" - if filename != expected { - t.Errorf("Incorrect custom cacheFilename, expected %s, got %s", - expected, filename) - } -} - -func TestNewFileCacheProvider_Missing(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - tf, _, _ := getMocks() - - // missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing cache file should result in expired cached credential") - } - tf.err = nil -} - -func TestNewFileCacheProvider_BadPermissions(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - tf, _, _ := getMocks() - - // bad permissions - tf.fileinfo.mode = 0777 - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - if err == nil { - t.Errorf("Expected error due to public permissions") - } - if tf.filename != CacheFilename() { - t.Errorf("unexpected file checked, expected %s, got %s", - CacheFilename(), tf.filename) - } -} - -func TestNewFileCacheProvider_Unlockable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - _, _, testFlock := getMocks() - - // unable to lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - if err == nil { - t.Errorf("Expected error due to lock failure") - } - testFlock.success = true - testFlock.err = nil -} - -func TestNewFileCacheProvider_Unreadable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - tf, _, _ := getMocks() - - // unable to read existing cache - tf.err = errors.New("read failure") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - if err == nil { - t.Errorf("Expected error due to read failure") - } - tf.err = nil -} - -func TestNewFileCacheProvider_Unparseable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - tf, _, _ := getMocks() - - // unable to parse yaml - tf.data = []byte("invalid: yaml: file") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - if err == nil { - t.Errorf("Expected error due to bad yaml") - } -} - -func TestNewFileCacheProvider_Empty(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - _, _, _ = getMocks() - - // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("empty cache file should result in expired cached credential") - } -} - -func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - tf, _, _ := getMocks() - - // successfully parse existing cluster without matching arn - tf.data = []byte(`clusters: - CLUSTER: -`) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing arn in cache file should result in expired cached credential") - } -} - -func TestNewFileCacheProvider_ExistingARN(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - tf, _, _ := getMocks() - - // successfully parse cluster with matching arn - tf.data = []byte(`clusters: - CLUSTER: - PROFILE: - ARN: - credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: 2018-01-02T03:04:56.789Z -`) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || - p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { - t.Errorf("cached credential not extracted correctly") - } - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } - if p.cachedCredential.IsExpired() { - t.Errorf("Cached credential should not be expired") - } - if p.IsExpired() { - t.Errorf("Cache credential should not be expired") - } - expectedExpiration := time.Date(2018, 01, 02, 03, 04, 56, 789000000, time.UTC) - if p.ExpiresAt() != expectedExpiration { - t.Errorf("Credential expiration time is not correct, expected %v, got %v", - expectedExpiration, p.ExpiresAt()) - } -} - -func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { - providerCredential := makeCredential() - c := credentials.NewCredentials(&stubProvider{ - creds: providerCredential, - }) - - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - - credential, err := p.Retrieve() - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) - } -} - -func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { - providerCredential = makeCredential() - expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) - c = credentials.NewCredentials(&stubProviderExpirer{ - stubProvider{ - creds: providerCredential, - }, - expiration, - }) - return -} - -func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() - - tf, _, testFlock := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - - // retrieve credential, which will fetch from underlying Provider - // fail to get write lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") - credential, err := p.Retrieve() - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) - } -} - -func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { - providerCredential, expiration, c := makeExpirerCredentials() - - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - - // retrieve credential, which will fetch from underlying Provider - // fail to write cache - tf.err = errors.New("can't write cache") - credential, err := p.Retrieve() - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) - } - if tf.filename != CacheFilename() { - t.Errorf("Wrote to wrong file, expected %v, got %v", - CacheFilename(), tf.filename) - } - if tf.perm != 0600 { - t.Errorf("Wrote with wrong permissions, expected %o, got %o", - 0600, tf.perm) - } - expectedData := []byte(`clusters: - CLUSTER: - PROFILE: - ARN: - credential: - accesskeyid: AKID - secretaccesskey: SECRET - sessiontoken: TOKEN - providername: stubProvider - expiration: ` + expiration.Format(time.RFC3339Nano) + ` -`) - if bytes.Compare(tf.data, expectedData) != 0 { - t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, tf.data) - } -} - -func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() - - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - tf.err = nil - - // retrieve credential, which will fetch from underlying Provider - // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, - // but write to disk (code coverage) - credential, err := p.Retrieve() - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) - } -} - -func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - - tf, _, _ := getMocks() - - // successfully parse cluster with matching arn - tf.data = []byte(`clusters: - CLUSTER: - PROFILE: - ARN: - credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: 2018-01-02T03:04:56.789Z -`) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) - validateFileCacheProvider(t, p, err, c) - - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } - - credential, err := p.Retrieve() - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if credential.AccessKeyID != "ABC" || credential.SecretAccessKey != "DEF" || - credential.SessionToken != "GHI" || credential.ProviderName != "JKL" { - t.Errorf("cached credential not returned") - } -} diff --git a/pkg/token/token.go b/pkg/token/token.go index 16ab8d92b..9ad001f1c 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -17,6 +17,7 @@ limitations under the License. package token import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -28,23 +29,24 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "sigs.k8s.io/aws-iam-authenticator/pkg" + "sigs.k8s.io/aws-iam-authenticator/pkg/arn" + "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" + "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + smithymiddleware "github.com/aws/smithy-go/middleware" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/pkg/apis/clientauthentication" clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" - "sigs.k8s.io/aws-iam-authenticator/pkg" - "sigs.k8s.io/aws-iam-authenticator/pkg/arn" - "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) // Identity is returned on successful Verify() results. It contains a parsed @@ -179,12 +181,16 @@ type getCallerIdentityWrapper struct { } `json:"GetCallerIdentityResponse"` } +type GCIPresigner interface { + PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) +} + // Generator provides new tokens for the AWS IAM Authenticator. type Generator interface { // Get a token using the provided options - GetWithOptions(options *GetTokenOptions) (Token, error) - // GetWithSTS returns a token valid for clusterID using the given STS client. - GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) + GetWithOptions(*GetTokenOptions) (Token, error) + // Presign returns a Token using the given STS client + Presign(GCIPresigner) (Token, error) // FormatJSON returns the client auth formatted json for the ExecCredential auth FormatJSON(Token) string } @@ -219,23 +225,14 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { if options.ClusterID == "" { return Token{}, fmt.Errorf("ClusterID is required") } - - // create a session with the "base" credentials available - // (from environment variable, profile files, EC2 metadata, etc) - sess, err := session.NewSessionWithOptions(session.Options{ - AssumeRoleTokenProvider: StdinStderrTokenProvider, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return Token{}, fmt.Errorf("could not create session: %v", err) + loadOpts := []func(*config.LoadOptions) error{ + config.WithAPIOptions( + []func(*smithymiddleware.Stack) error{ + middleware.AddUserAgentKeyValue("aws-iam-authenticator", pkg.Version), + }), } - sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ - Name: "authenticatorUserAgent", - Fn: request.MakeAddToUserAgentHandler( - "aws-iam-authenticator", pkg.Version), - }) if options.Region != "" { - sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)) + loadOpts = append(loadOpts, config.WithRegion(options.Region)) } if g.cache { @@ -243,86 +240,91 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { var profile string if v := os.Getenv("AWS_PROFILE"); len(v) > 0 { profile = v - } else { - profile = session.DefaultSharedConfigProfile } - // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { - sess.Config.Credentials = credentials.NewCredentials(&cacheProvider) + // Create a new config to get the default cred chain + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + return Token{}, fmt.Errorf("could not create config: %v", err) + } + // create a caching Provider wrapper around the Credentials + cacheProvider, err := filecache.NewFileCacheProvider( + options.ClusterID, + profile, + options.AssumeRoleARN, + cfg.Credentials, + ) + if err == nil { + loadOpts = append(loadOpts, config.WithCredentialsProvider(cacheProvider)) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) } } - // use an STS client based on the direct credentials - stsAPI := sts.New(sess) - - // if a roleARN was specified, replace the STS client with one that uses - // temporary credentials from that role. - if options.AssumeRoleARN != "" { - var sessionSetters []func(*stscreds.AssumeRoleProvider) + cfg, err := config.LoadDefaultConfig(context.Background(), loadOpts...) + if err != nil { + return Token{}, fmt.Errorf("could not create config: %v", err) - if options.AssumeRoleExternalID != "" { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.ExternalID = &options.AssumeRoleExternalID - }) - } + } + if options.AssumeRoleARN != "" { + var sessionName = options.SessionName if g.forwardSessionName { - // If the current session is already a federated identity, carry through - // this session name onto the new session to provide better debugging - // capabilities - resp, err := stsAPI.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + stsSvc := sts.NewFromConfig(cfg) + gciResp, err := stsSvc.GetCallerIdentity(context.Background(), nil) if err != nil { return Token{}, err } - - userIDParts := strings.Split(*resp.UserId, ":") + userIDParts := strings.Split(aws.ToString(gciResp.UserId), ":") if len(userIDParts) == 2 { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.RoleSessionName = userIDParts[1] - }) + sessionName = userIDParts[1] } - } else if options.SessionName != "" { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.RoleSessionName = options.SessionName - }) } - // create STS-based credentials that will assume the given role - creds := stscreds.NewCredentials(sess, options.AssumeRoleARN, sessionSetters...) + creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), options.AssumeRoleARN, func(o *stscreds.AssumeRoleOptions) { + o.RoleSessionName = sessionName + o.ExternalID = &options.AssumeRoleExternalID + // TODO: Can we get the serial number from the client? + // o.SerialNumber = aws.String("myTokenSerialNumber") + o.TokenProvider = stscreds.StdinTokenProvider + }) - // create an STS API interface that uses the assumed role's temporary credentials - stsAPI = sts.New(sess, &aws.Config{Credentials: creds}) + cfg.Credentials = aws.NewCredentialsCache(creds) } - return g.GetWithSTS(options.ClusterID, stsAPI) + stsSvc := sts.NewFromConfig(cfg, WithClusterIDHeader(options.ClusterID)) + presigner := timedPresigner{v4.NewSigner(), g.nowFunc} + presignClient := sts.NewPresignClient(stsSvc, func(o *sts.PresignOptions) { + o.Presigner = &presigner + }) + return g.Presign(presignClient) } -func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler { - return request.NamedHandler{ - Name: "v4.SignRequestHandler", Fn: func(req *request.Request) { - v4.SignSDKRequestWithCurrentTime(req, nowFunc) - }, - } +// timedPresigner exists to wrap PresignHTTP() with a specific time +// and set the x-amz-expires header in the query url +type timedPresigner struct { + signer *v4.Signer + timeFunc func() time.Time } -// GetWithSTS returns a token valid for clusterID using the given STS client. -func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) { - // generate an sts:GetCallerIdentity request and add our custom cluster ID header - request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) - request.HTTPRequest.Header.Add(clusterIDHeader, clusterID) - - // override the Sign handler so we can control the now time for testing. - request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc)) +func (p *timedPresigner) PresignHTTP( + ctx context.Context, credentials aws.Credentials, r *http.Request, + payloadHash string, service string, region string, _ time.Time, + optFns ...func(*v4.SignerOptions), +) (url string, signedHeader http.Header, err error) { + query := r.URL.Query() + query.Set("X-Amz-Expires", strconv.Itoa(requestPresignParam)) + r.URL.RawQuery = query.Encode() + return p.signer.PresignHTTP(ctx, credentials, r, payloadHash, service, region, p.timeFunc(), optFns...) +} +// Presign returns a token valid for clusterID using the given STS client. +func (g generator) Presign(presigner GCIPresigner) (Token, error) { // Sign the request. The expires parameter (sets the x-amz-expires header) is // currently ignored by STS, and the token expires 15 minutes after the x-amz-date - // timestamp regardless. We set it to 60 seconds for backwards compatibility (the - // parameter is a required argument to Presign(), and authenticators 0.3.0 and older are expecting a value between - // 0 and 60 on the server side). + // timestamp regardless. // https://github.com/aws/aws-sdk-go/issues/2167 - presignedURLString, err := request.Presign(requestPresignParam * time.Second) + + req, err := presigner.PresignGetCallerIdentity(context.Background(), nil) if err != nil { return Token{}, err } @@ -330,7 +332,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, // Set token expiration to 1 minute before the presigned URL expires for some cushion tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute) // TODO: this may need to be a constant-time base64 encoding - return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil + return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(req.URL)), tokenExpiration}, nil } // FormatJSON formats the json to support ExecCredential authentication diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index a8e997c86..ce43cca45 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -2,6 +2,7 @@ package token import ( "bytes" + "context" "encoding/base64" "encoding/json" "errors" @@ -14,11 +15,12 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -587,12 +589,12 @@ func Test_getDefaultHostNameForRegion(t *testing.T) { } } -func TestGetWithSTS(t *testing.T) { +func TestPresign(t *testing.T) { clusterID := "test-cluster" cases := []struct { name string - creds *credentials.Credentials + creds aws.CredentialsProvider nowTime time.Time want Token wantErr error @@ -600,10 +602,10 @@ func TestGetWithSTS(t *testing.T) { { "Non-zero time", // Example non-real credentials - func() *credentials.Credentials { + func() credentials.StaticCredentialsProvider { decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") - return credentials.NewStaticCredentials( + return credentials.NewStaticCredentialsProvider( string(decodedAkid), string(decodedSk), "", @@ -620,13 +622,15 @@ func TestGetWithSTS(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - svc := sts.New(session.Must(session.NewSession( - &aws.Config{ - Credentials: tc.creds, - Region: aws.String("us-west-2"), - STSRegionalEndpoint: endpoints.RegionalSTSEndpoint, - }, - ))) + cfg, err := config.LoadDefaultConfig( + context.Background(), + config.WithRegion("us-west-2"), + config.WithCredentialsProvider(tc.creds), + ) + if err != nil { + t.Errorf("unexpected error initialzing config: %v", err) + return + } gen := &generator{ forwardSessionName: false, @@ -634,7 +638,13 @@ func TestGetWithSTS(t *testing.T) { nowFunc: func() time.Time { return tc.nowTime }, } - got, err := gen.GetWithSTS(clusterID, svc) + stsSvc := sts.NewFromConfig(cfg, WithClusterIDHeader(clusterID)) + presigner := timedPresigner{v4.NewSigner(), gen.nowFunc} + presignClient := sts.NewPresignClient(stsSvc, func(o *sts.PresignOptions) { + o.Presigner = &presigner + }) + + got, err := gen.Presign(presignClient) if diff := cmp.Diff(err, tc.wantErr); diff != "" { t.Errorf("Unexpected error: %s", diff) } diff --git a/tests/integration/go.mod b/tests/integration/go.mod index 451b89c6d..6781a0caf 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -18,6 +18,8 @@ require ( github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect + github.com/aws/aws-sdk-go-v2 v1.30.4 // indirect + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -72,6 +74,7 @@ require ( github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect + github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cobra v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect diff --git a/tests/integration/go.sum b/tests/integration/go.sum index f4a756a0e..4794685e6 100644 --- a/tests/integration/go.sum +++ b/tests/integration/go.sum @@ -12,6 +12,10 @@ github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4 github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= @@ -172,6 +176,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=