diff --git a/examples/config-sync-ecr-credential-helper.json b/examples/config-sync-ecr-credential-helper.json new file mode 100644 index 00000000..8cb5b538 --- /dev/null +++ b/examples/config-sync-ecr-credential-helper.json @@ -0,0 +1,40 @@ +{ + "distSpecVersion": "1.1.0", + "storage": { + "rootDirectory": "/tmp/zot", + "dedupe": false, + "storageDriver": { + "name": "s3", + "region": "REGION_NAME", + "bucket": "BUGKET_NAME", + "rootdirectory": "/ROOTDIR", + "secure": true, + "skipverify": false + } + }, + "http": { + "address": "0.0.0.0", + "port": "8080" + }, + "log": { + "level": "debug" + }, + "extensions": { + "sync": { + "credentialsFile": "", + "DownloadDir": "/tmp/zot", + "registries": [ + { + "urls": [ + "https://ACCOUNTID.dkr.ecr.REGION.amazonaws.com" + ], + "onDemand": true, + "maxRetries": 5, + "retryDelay": "2m", + "credentialHelper": "ecr" + } + ] + } + } +} + diff --git a/go.mod b/go.mod index fcb54ca0..1243a9a3 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.28.7 github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.22 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.38.1 + github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.8 github.com/aws/aws-secretsmanager-caching-go v1.2.0 github.com/aws/smithy-go v1.22.1 @@ -158,7 +159,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.24.10 // indirect github.com/aws/aws-sdk-go-v2/service/ebs v1.25.3 // indirect github.com/aws/aws-sdk-go-v2/service/ec2 v1.193.0 // indirect - github.com/aws/aws-sdk-go-v2/service/ecr v1.36.6 // indirect github.com/aws/aws-sdk-go-v2/service/ecrpublic v1.25.3 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.7 // indirect diff --git a/pkg/extensions/config/sync/config.go b/pkg/extensions/config/sync/config.go index ec888a08..180420ee 100644 --- a/pkg/extensions/config/sync/config.go +++ b/pkg/extensions/config/sync/config.go @@ -23,15 +23,16 @@ type Config struct { } type RegistryConfig struct { - URLs []string - PollInterval time.Duration - Content []Content - TLSVerify *bool - OnDemand bool - CertDir string - MaxRetries *int - RetryDelay *time.Duration - OnlySigned *bool + URLs []string + PollInterval time.Duration + Content []Content + TLSVerify *bool + OnDemand bool + CertDir string + MaxRetries *int + RetryDelay *time.Duration + OnlySigned *bool + CredentialHelper string } type Content struct { diff --git a/pkg/extensions/sync/ecr_credential_helper.go b/pkg/extensions/sync/ecr_credential_helper.go new file mode 100644 index 00000000..dea23f80 --- /dev/null +++ b/pkg/extensions/sync/ecr_credential_helper.go @@ -0,0 +1,158 @@ +//go:build sync +// +build sync + +package sync + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ecr" + + syncconf "zotregistry.dev/zot/pkg/extensions/config/sync" + "zotregistry.dev/zot/pkg/log" +) + +// ECR tokens are valid for 12 hours. The ExpiryWindow variable is set to 1 hour, +// meaning if the remaining validity of the token is less than 1 hour, it will be considered expired. +const ( + ExpiryWindow int = 1 + ECRURLSplitPartsCount = 6 + UsernameTokenParts = 2 +) + +var ( + ErrInvalidURLFormat = errors.New("invalid ECR URL is received") + ErrInvalidTokenFormat = errors.New("invalid token format received from ECR") + ErrUnableToLoadAWSConfig = errors.New("unable to load AWS config for region") + ErrUnableToGetECRAuthToken = errors.New("unable to get ECR authorization token for account") + ErrUnableToDecodeECRToken = errors.New("unable to decode ECR token") + ErrFailedToGetECRCredentials = errors.New("failed to get ECR credentials") +) + +type ECRCredential struct { + username string + password string + expiry time.Time + account string + region string +} + +type ECRCredentialsHelper struct { + credentials map[string]ECRCredential + log log.Logger +} + +func NewECRCredentialHelper(log log.Logger) CredentialHelper { + return &ECRCredentialsHelper{ + credentials: make(map[string]ECRCredential), + log: log, + } +} + +// extractAccountAndRegion extracts the account ID and region from the given ECR URL. +// Example URL format: account.dkr.ecr.region.amazonaws.com. +func extractAccountAndRegion(url string) (string, string, error) { + parts := strings.Split(url, ".") + if len(parts) < ECRURLSplitPartsCount { + return "", "", fmt.Errorf("%w: %s", ErrInvalidURLFormat, url) + } + + accountID := parts[0] // First part is the account ID + region := parts[3] // Fourth part is the region + + return accountID, region, nil +} + +func getECRCredentials(remoteAddress string) (ECRCredential, error) { + // Extract account ID and region from the URL. + accountID, region, err := extractAccountAndRegion(remoteAddress) + if err != nil { + return ECRCredential{}, fmt.Errorf("%w %s: %w", ErrInvalidTokenFormat, remoteAddress, err) + } + + // Load the AWS config for the specific region. + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) + if err != nil { + return ECRCredential{}, fmt.Errorf("%w %s: %w", ErrUnableToLoadAWSConfig, region, err) + } + + // Create an ECR client + ecrClient := ecr.NewFromConfig(cfg) + + // Fetch the ECR authorization token. + ecrAuth, err := ecrClient.GetAuthorizationToken(context.TODO(), &ecr.GetAuthorizationTokenInput{ + RegistryIds: []string{accountID}, // Filter by the account ID. + }) + if err != nil { + return ECRCredential{}, fmt.Errorf("%w %s: %w", ErrUnableToGetECRAuthToken, accountID, err) + } + + // Decode the base64-encoded ECR token. + authToken := *ecrAuth.AuthorizationData[0].AuthorizationToken + decodedToken, err := base64.StdEncoding.DecodeString(authToken) + if err != nil { + return ECRCredential{}, fmt.Errorf("%w: %w", ErrUnableToDecodeECRToken, err) + } + + // Split the decoded token into username and password (username is "AWS"). + tokenParts := strings.Split(string(decodedToken), ":") + if len(tokenParts) != UsernameTokenParts { + return ECRCredential{}, fmt.Errorf("%w", ErrInvalidTokenFormat) + } + + expiry := *ecrAuth.AuthorizationData[0].ExpiresAt + username := tokenParts[0] + password := tokenParts[1] + + return ECRCredential{username: username, password: password, expiry: expiry, account: accountID, region: region}, nil +} + +// GetECRCredentials retrieves the ECR credentials (username and password) from AWS ECR. +func (credHelper *ECRCredentialsHelper) GetCredentials(urls []string) (syncconf.CredentialsFile, error) { + ecrCredentials := make(syncconf.CredentialsFile) + + for _, url := range urls { + remoteAddress := StripRegistryTransport(url) + ecrCred, err := getECRCredentials(remoteAddress) + if err != nil { + return syncconf.CredentialsFile{}, fmt.Errorf("%w %s: %w", ErrFailedToGetECRCredentials, url, err) + } + // Store the credentials in the map using the base URL as the key. + ecrCredentials[remoteAddress] = syncconf.Credentials{ + Username: ecrCred.username, + Password: ecrCred.password, + } + credHelper.credentials[remoteAddress] = ecrCred + } + return ecrCredentials, nil +} + +func (credHelper *ECRCredentialsHelper) IsCredentialsValid(remoteAddress string) bool { + expiry := credHelper.credentials[remoteAddress].expiry + expiryDuration := time.Duration(ExpiryWindow) * time.Hour + + if time.Until(expiry) <= expiryDuration { + credHelper.log.Info().Str("url", remoteAddress).Msg("The credentials are close to expiring") + + return false + } + credHelper.log.Info().Str("url", remoteAddress).Msg("The credentials are valid") + + return true +} + +func (credHelper *ECRCredentialsHelper) RefreshCredentials(remoteAddress string) (syncconf.Credentials, error) { + credHelper.log.Info().Str("url", remoteAddress).Msg("Refreshing the ECR credentials") + ecrCred, err := getECRCredentials(remoteAddress) + if err != nil { + return syncconf.Credentials{}, fmt.Errorf("%w %s: %w", ErrFailedToGetECRCredentials, remoteAddress, err) + } + + return syncconf.Credentials{Username: ecrCred.username, Password: ecrCred.password}, nil +} diff --git a/pkg/extensions/sync/remote.go b/pkg/extensions/sync/remote.go index bf85f62f..174ae94c 100644 --- a/pkg/extensions/sync/remote.go +++ b/pkg/extensions/sync/remote.go @@ -44,6 +44,13 @@ func NewRemoteRegistry(client *client.Client, logger log.Logger) Remote { return registry } +func (registry *RemoteRegistry) SetUpstreamAuthConfig(username, password string) { + registry.context.DockerAuthConfig = &types.DockerAuthConfig{ + Username: username, + Password: password, + } +} + func (registry *RemoteRegistry) GetContext() *types.SystemContext { return registry.context } diff --git a/pkg/extensions/sync/service.go b/pkg/extensions/sync/service.go index 4f1fca23..cf63c64e 100644 --- a/pkg/extensions/sync/service.go +++ b/pkg/extensions/sync/service.go @@ -27,19 +27,20 @@ import ( ) type BaseService struct { - config syncconf.RegistryConfig - credentials syncconf.CredentialsFile - clusterConfig *config.ClusterConfig - remote Remote - destination Destination - retryOptions *retry.RetryOptions - contentManager ContentManager - storeController storage.StoreController - metaDB mTypes.MetaDB - repositories []string - references references.References - client *client.Client - log log.Logger + config syncconf.RegistryConfig + credentials syncconf.CredentialsFile + credentialHelper CredentialHelper + clusterConfig *config.ClusterConfig + remote Remote + destination Destination + retryOptions *retry.RetryOptions + contentManager ContentManager + storeController storage.StoreController + metaDB mTypes.MetaDB + repositories []string + references references.References + client *client.Client + log log.Logger } func New( @@ -60,16 +61,35 @@ func New( var err error var credentialsFile syncconf.CredentialsFile - if credentialsFilepath != "" { - credentialsFile, err = getFileCredentials(credentialsFilepath) - if err != nil { - log.Error().Str("errortype", common.TypeOf(err)).Str("path", credentialsFilepath). - Err(err).Msg("couldn't get registry credentials from configured path") + if service.config.CredentialHelper == "" { + // Only load credentials from file if CredentialHelper is not set + if credentialsFilepath != "" { + log.Info().Msgf("Using file-based credentials because CredentialHelper is not set") + credentialsFile, err = getFileCredentials(credentialsFilepath) + if err != nil { + log.Error().Str("errortype", common.TypeOf(err)).Str("path", credentialsFilepath). + Err(err).Msg("couldn't get registry credentials from configured path") + } + service.credentialHelper = nil + service.credentials = credentialsFile + } + } else { + log.Info().Msgf("Using credentials helper, because CredentialHelper is set to %s", service.config.CredentialHelper) + switch service.config.CredentialHelper { + case "ecr": + // Logic to fetch credentials for ECR + log.Info().Msg("Fetch the credentials using AWS ECR Auth Token.") + service.credentialHelper = NewECRCredentialHelper(log) + creds, err := service.credentialHelper.GetCredentials(service.config.URLs) + if err != nil { + log.Error().Err(err).Msg("Failed to retrieve credentials using ECR credentials helper.") + } + service.credentials = creds + default: + log.Warn().Msgf("Unsupported CredentialHelper: %s", service.config.CredentialHelper) } } - service.credentials = credentialsFile - // load the cluster config into the object // can be nil if the user did not configure cluster config service.clusterConfig = clusterConfig @@ -102,7 +122,6 @@ func New( service.retryOptions = retryOptions service.storeController = storeController - // try to set next client. if err := service.SetNextAvailableClient(); err != nil { // if it's a ping issue, it will be retried @@ -126,9 +145,44 @@ func New( return service, nil } +// refreshRegistryTemporaryCredentials refreshes the temporary credentials for the registry if necessary. +// It checks whether a CredentialHelper is configured and if the current credentials have expired. +// If the credentials are expired, it attempts to refresh them and updates the service configuration. +func (service *BaseService) refreshRegistryTemporaryCredentials() error { + // If a CredentialHelper is configured, proceed to refresh the credentials if they are invalid or expired. + if service.config.CredentialHelper != "" { + // Strip the transport protocol (e.g., https:// or http://) from the remote address. + remoteAddress := StripRegistryTransport(service.client.GetHostname()) + + if !service.credentialHelper.IsCredentialsValid(remoteAddress) { + // Attempt to refresh the credentials using the CredentialHelper. + credentials, err := service.credentialHelper.RefreshCredentials(remoteAddress) + if err != nil { + service.log.Error(). + Err(err). + Str("url", remoteAddress). + Msg("Failed to refresh the credentials") + return err + } + service.log.Info(). + Str("url", remoteAddress). + Msg("Refreshing the upstream remote registry credentials") + + // Update the service's credentials map with the new set of credentials. + service.credentials[remoteAddress] = credentials + + // Set the upstream authentication context using the refreshed credentials. + service.remote.SetUpstreamAuthConfig(credentials.Username, credentials.Password) + } + } + + // Return nil to indicate the operation completed successfully. + return nil +} + func (service *BaseService) SetNextAvailableClient() error { if service.client != nil && service.client.Ping() { - return nil + return service.refreshRegistryTemporaryCredentials() } found := false diff --git a/pkg/extensions/sync/sync.go b/pkg/extensions/sync/sync.go index 1afd1117..9bc59b69 100644 --- a/pkg/extensions/sync/sync.go +++ b/pkg/extensions/sync/sync.go @@ -13,6 +13,7 @@ import ( "github.com/containers/image/v5/types" "github.com/opencontainers/go-digest" + syncconf "zotregistry.dev/zot/pkg/extensions/config/sync" "zotregistry.dev/zot/pkg/log" "zotregistry.dev/zot/pkg/scheduler" ) @@ -48,6 +49,22 @@ type Registry interface { GetContext() *types.SystemContext } +// The CredentialHelper interface should be implemented by registries that use temporary tokens. +// This interface defines methods to: +// - Check if the credentials for a registry are still valid. +// - Retrieve credentials for the specified registry URLs. +// - Refresh credentials for a given registry URL. +type CredentialHelper interface { + // Validates whether the credentials for the specified registry URL have expired. + IsCredentialsValid(url string) bool + + // Retrieves credentials for the provided list of registry URLs. + GetCredentials(urls []string) (syncconf.CredentialsFile, error) + + // Refreshes credentials for the specified registry URL. + RefreshCredentials(url string) (syncconf.Credentials, error) +} + /* Temporary oci layout, sync first pulls an image to this oci layout (using oci:// transport) then moves them into ImageStore. @@ -68,6 +85,9 @@ type Remote interface { // In the case of public dockerhub images 'library' namespace is added to the repo names of images // eg: alpine -> library/alpine GetDockerRemoteRepo(repo string) string + // SetUpstreamAuthConfig sets the upstream credentials used when the credential helper is set. + // This method refreshes the authentication configuration with the provided username and password. + SetUpstreamAuthConfig(username, password string) } // Local registry. diff --git a/pkg/extensions/sync/sync_internal_test.go b/pkg/extensions/sync/sync_internal_test.go index 3609cb4a..786db3d0 100644 --- a/pkg/extensions/sync/sync_internal_test.go +++ b/pkg/extensions/sync/sync_internal_test.go @@ -676,3 +676,33 @@ func TestConvertDockerLayersToOCI(t *testing.T) { So(dockerLayers[3].MediaType, ShouldEqual, ispec.MediaTypeImageLayerGzip) }) } + +func TestECRCredentialsHelper(t *testing.T) { + Convey("Test Mock ECR Credentials Helper", t, func() { + mockHelper := mocks.NewMockECRCredentialsHelper() + + Convey("Test Valid Credentials Retrieval", func() { + url := "mockAccount.dkr.ecr.mockRegion.amazonaws.com" + creds, err := mockHelper.GetCredentials([]string{url}) + So(err, ShouldBeNil) + So(creds, ShouldNotBeNil) + So(creds[url].Username, ShouldEqual, "mockUsername") + So(creds[url].Password, ShouldEqual, "mockPassword") + }) + + Convey("Test Credentials Retrieval", func() { + url := "invalid.dkr.ecr.mockRegion.amazonaws.com" + _, err := mockHelper.GetCredentials([]string{url}) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, "mock error for remote address") + }) + Convey("Test Credentials Refresh", func() { + url := "mockAccount.dkr.ecr.mockRegion.amazonaws.com" + _, err := mockHelper.RefreshCredentials(url) + So(err, ShouldBeNil) + + // Check that the new credentials are valid + So(mockHelper.IsCredentialsValid(url), ShouldBeTrue) + }) + }) +} diff --git a/pkg/test/mocks/ecr_credentials_helper_mock.go b/pkg/test/mocks/ecr_credentials_helper_mock.go new file mode 100644 index 00000000..e4f81d2a --- /dev/null +++ b/pkg/test/mocks/ecr_credentials_helper_mock.go @@ -0,0 +1,117 @@ +// mock_ecr_credentials_helper.go +package mocks + +import ( + "errors" + "fmt" + "strings" + "time" + + "zotregistry.dev/zot/pkg/extensions/config/sync" +) + +const ( + ExpiryWindow int = 1 + EcrURLSplitParts = 6 + ExpiryDuration = 12 +) + +var ( + ErrInvalidURLFormat = errors.New("invalid URL format") + ErrMockRemoteAddress = errors.New("mock error for remote address") +) + +type ECRCredential struct { + username string + password string + expiry time.Time + account string + region string +} + +// MockECRCredentialsHelper is a mock implementation of ECRCredentialsHelper. +type MockECRCredentialsHelper struct { + credentials map[string]ECRCredential +} + +// NewMockECRCredentialsHelper creates a new instance of MockECRCredentialsHelper. +func NewMockECRCredentialsHelper() *MockECRCredentialsHelper { + return &MockECRCredentialsHelper{ + credentials: make(map[string]ECRCredential), + } +} + +// extractAccountAndRegion extracts the account ID and region from the given ECR URL. +// Example URL format: account.dkr.ecr.region.amazonaws.com. +func extractAccountAndRegion(url string) (string, string, error) { + parts := strings.Split(url, ".") + if len(parts) < EcrURLSplitParts { + return "", "", fmt.Errorf("%w: %s", ErrInvalidURLFormat, url) + } + + accountID := parts[0] // First part is the account ID + region := parts[3] // Fourth part is the region + + return accountID, region, nil +} + +// Mock GetECRCredentials function. +func (m *MockECRCredentialsHelper) getECRCredentials(remoteAddress string) (ECRCredential, error) { + // Simulate extracting account ID and region. + accountID, region, err := extractAccountAndRegion(remoteAddress) + if err != nil { + return ECRCredential{}, err + } + + // Simulate returning mock credentials. + if accountID == "mockAccount" && region == "mockRegion" { + return ECRCredential{ + username: "mockUsername", + password: "mockPassword", + expiry: time.Now().Add(ExpiryDuration * time.Hour), // Set a valid expiry + account: accountID, + region: region, + }, nil + } + + return ECRCredential{}, fmt.Errorf("%w: %s", ErrMockRemoteAddress, remoteAddress) +} + +// Mock method for getting credentials. +func (m *MockECRCredentialsHelper) GetCredentials(urls []string) (sync.CredentialsFile, error) { + ecrCredentials := make(sync.CredentialsFile) + + for _, url := range urls { + ecrCred, err := m.getECRCredentials(url) + if err != nil { + return sync.CredentialsFile{}, err + } + + ecrCredentials[url] = sync.Credentials{ + Username: ecrCred.username, + Password: ecrCred.password, + } + m.credentials[url] = ecrCred + } + + return ecrCredentials, nil +} + +// Mock method for checking if credentials are valid. +func (m *MockECRCredentialsHelper) IsCredentialsValid(remoteAddress string) bool { + if cred, exists := m.credentials[remoteAddress]; exists { + return time.Until(cred.expiry) > time.Duration(ExpiryWindow)*time.Hour + } + + return false +} + +// Mock method for refreshing credentials. +func (m *MockECRCredentialsHelper) RefreshCredentials(remoteAddress string) (sync.Credentials, error) { + ecrCred, err := m.getECRCredentials(remoteAddress) + if err != nil { + return sync.Credentials{}, err + } + + return sync.Credentials{Username: ecrCred.username, Password: ecrCred.password}, nil +} diff --git a/pkg/test/mocks/sync_remote_mock.go b/pkg/test/mocks/sync_remote_mock.go index c22d74fd..3fb5d01c 100644 --- a/pkg/test/mocks/sync_remote_mock.go +++ b/pkg/test/mocks/sync_remote_mock.go @@ -75,3 +75,6 @@ func (remote SyncRemote) GetManifestContent(imageReference types.ImageReference) return nil, "", "", nil } + +func (remote SyncRemote) SetUpstreamAuthConfig(username, password string) { +}