...

Source file src/github.com/awslabs/amazon-ecr-credential-helper/ecr-login/api/client.go

Documentation: github.com/awslabs/amazon-ecr-credential-helper/ecr-login/api

     1  // Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"). You may
     4  // not use this file except in compliance with the License. A copy of the
     5  // License is located at
     6  //
     7  //	http://aws.amazon.com/apache2.0/
     8  //
     9  // or in the "license" file accompanying this file. This file is distributed
    10  // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
    11  // express or implied. See the License for the specific language governing
    12  // permissions and limitations under the License.
    13  
    14  package api
    15  
    16  import (
    17  	"context"
    18  	"encoding/base64"
    19  	"fmt"
    20  	"net/url"
    21  	"regexp"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/aws/aws-sdk-go-v2/aws"
    26  	"github.com/aws/aws-sdk-go-v2/service/ecr"
    27  	"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
    28  	"github.com/sirupsen/logrus"
    29  
    30  	"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache"
    31  )
    32  
    33  const (
    34  	proxyEndpointScheme = "https://"
    35  	programName         = "docker-credential-ecr-login"
    36  	ecrPublicName       = "public.ecr.aws"
    37  	ecrPublicEndpoint   = proxyEndpointScheme + ecrPublicName
    38  )
    39  
    40  var ecrPattern = regexp.MustCompile(`(^[a-zA-Z0-9][a-zA-Z0-9-_]*)\.dkr\.ecr(-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.amazonaws\.com(\.cn)?$`)
    41  
    42  type Service string
    43  
    44  const (
    45  	ServiceECR       Service = "ecr"
    46  	ServiceECRPublic Service = "ecr-public"
    47  )
    48  
    49  // Registry in ECR
    50  type Registry struct {
    51  	Service Service
    52  	ID      string
    53  	FIPS    bool
    54  	Region  string
    55  }
    56  
    57  // ExtractRegistry returns the ECR registry behind a given service endpoint
    58  func ExtractRegistry(input string) (*Registry, error) {
    59  	if strings.HasPrefix(input, proxyEndpointScheme) {
    60  		input = strings.TrimPrefix(input, proxyEndpointScheme)
    61  	}
    62  	serverURL, err := url.Parse(proxyEndpointScheme + input)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	if serverURL.Hostname() == ecrPublicName {
    67  		return &Registry{
    68  			Service: ServiceECRPublic,
    69  		}, nil
    70  	}
    71  	matches := ecrPattern.FindStringSubmatch(serverURL.Hostname())
    72  	if len(matches) == 0 {
    73  		return nil, fmt.Errorf(programName + " can only be used with Amazon Elastic Container Registry.")
    74  	} else if len(matches) < 3 {
    75  		return nil, fmt.Errorf("%q is not a valid repository URI for Amazon Elastic Container Registry.", input)
    76  	}
    77  	return &Registry{
    78  		Service: ServiceECR,
    79  		ID:      matches[1],
    80  		FIPS:    matches[2] == "-fips",
    81  		Region:  matches[3],
    82  	}, nil
    83  }
    84  
    85  // Client used for calling ECR service
    86  type Client interface {
    87  	GetCredentials(serverURL string) (*Auth, error)
    88  	GetCredentialsByRegistryID(registryID string) (*Auth, error)
    89  	ListCredentials() ([]*Auth, error)
    90  }
    91  
    92  // Auth credentials returned by ECR service to allow docker login
    93  type Auth struct {
    94  	ProxyEndpoint string
    95  	Username      string
    96  	Password      string
    97  }
    98  
    99  type defaultClient struct {
   100  	ecrClient       ECRAPI
   101  	ecrPublicClient ECRPublicAPI
   102  	credentialCache cache.CredentialsCache
   103  }
   104  
   105  type ECRAPI interface {
   106  	GetAuthorizationToken(context.Context, *ecr.GetAuthorizationTokenInput, ...func(*ecr.Options)) (*ecr.GetAuthorizationTokenOutput, error)
   107  }
   108  
   109  type ECRPublicAPI interface {
   110  	GetAuthorizationToken(context.Context, *ecrpublic.GetAuthorizationTokenInput, ...func(*ecrpublic.Options)) (*ecrpublic.GetAuthorizationTokenOutput, error)
   111  }
   112  
   113  // GetCredentials returns username, password, and proxyEndpoint
   114  func (c *defaultClient) GetCredentials(serverURL string) (*Auth, error) {
   115  	registry, err := ExtractRegistry(serverURL)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	logrus.
   120  		WithField("service", registry.Service).
   121  		WithField("registry", registry.ID).
   122  		WithField("region", registry.Region).
   123  		WithField("serverURL", serverURL).
   124  		Debug("Retrieving credentials")
   125  	switch registry.Service {
   126  	case ServiceECR:
   127  		return c.GetCredentialsByRegistryID(registry.ID)
   128  	case ServiceECRPublic:
   129  		return c.GetPublicCredentials()
   130  	}
   131  	return nil, fmt.Errorf("unknown service %q", registry.Service)
   132  }
   133  
   134  // GetCredentialsByRegistryID returns username, password, and proxyEndpoint
   135  func (c *defaultClient) GetCredentialsByRegistryID(registryID string) (*Auth, error) {
   136  	cachedEntry := c.credentialCache.Get(registryID)
   137  	if cachedEntry != nil {
   138  		if cachedEntry.IsValid(time.Now()) {
   139  			logrus.WithField("registry", registryID).Debug("Using cached token")
   140  			return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
   141  		}
   142  		logrus.
   143  			WithField("requestedAt", cachedEntry.RequestedAt).
   144  			WithField("expiresAt", cachedEntry.ExpiresAt).
   145  			Debug("Cached token is no longer valid")
   146  	}
   147  
   148  	auth, err := c.getAuthorizationToken(registryID)
   149  
   150  	// if we have a cached token, fall back to avoid failing the request. This may result an expired token
   151  	// being returned, but if there is a 500 or timeout from the service side, we'd like to attempt to re-use an
   152  	// old token. We invalidate tokens prior to their expiration date to help mitigate this scenario.
   153  	if err != nil && cachedEntry != nil {
   154  		logrus.WithError(err).Info("Got error fetching authorization token. Falling back to cached token.")
   155  		return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
   156  	}
   157  	return auth, err
   158  }
   159  
   160  func (c *defaultClient) GetPublicCredentials() (*Auth, error) {
   161  	cachedEntry := c.credentialCache.GetPublic()
   162  	if cachedEntry != nil {
   163  		if cachedEntry.IsValid(time.Now()) {
   164  			logrus.WithField("registry", ecrPublicName).Debug("Using cached token")
   165  			return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
   166  		}
   167  		logrus.
   168  			WithField("requestedAt", cachedEntry.RequestedAt).
   169  			WithField("expiresAt", cachedEntry.ExpiresAt).
   170  			Debug("Cached token is no longer valid")
   171  	}
   172  
   173  	auth, err := c.getPublicAuthorizationToken()
   174  	// if we have a cached token, fall back to avoid failing the request. This may result an expired token
   175  	// being returned, but if there is a 500 or timeout from the service side, we'd like to attempt to re-use an
   176  	// old token. We invalidate tokens prior to their expiration date to help mitigate this scenario.
   177  	if err != nil && cachedEntry != nil {
   178  		logrus.WithError(err).Info("Got error fetching authorization token. Falling back to cached token.")
   179  		return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
   180  	}
   181  	return auth, err
   182  }
   183  
   184  func (c *defaultClient) ListCredentials() ([]*Auth, error) {
   185  	// prime the cache with default authorization tokens
   186  	_, err := c.GetCredentialsByRegistryID("")
   187  	if err != nil {
   188  		logrus.WithError(err).Debug("couldn't get authorization token for default registry")
   189  	}
   190  	_, err = c.GetPublicCredentials()
   191  	if err != nil {
   192  		logrus.WithError(err).Debug("couldn't get authorization token for public registry")
   193  	}
   194  
   195  	auths := make([]*Auth, 0)
   196  	for _, authEntry := range c.credentialCache.List() {
   197  		auth, err := extractToken(authEntry.AuthorizationToken, authEntry.ProxyEndpoint)
   198  		if err != nil {
   199  			logrus.WithError(err).Debug("Could not extract token")
   200  		} else {
   201  			auths = append(auths, auth)
   202  		}
   203  	}
   204  
   205  	return auths, nil
   206  }
   207  
   208  func (c *defaultClient) getAuthorizationToken(registryID string) (*Auth, error) {
   209  	var input *ecr.GetAuthorizationTokenInput
   210  	if registryID == "" {
   211  		logrus.Debug("Calling ECR.GetAuthorizationToken for default registry")
   212  		input = &ecr.GetAuthorizationTokenInput{}
   213  	} else {
   214  		logrus.WithField("registry", registryID).Debug("Calling ECR.GetAuthorizationToken")
   215  		input = &ecr.GetAuthorizationTokenInput{
   216  			RegistryIds: []string{registryID},
   217  		}
   218  	}
   219  
   220  	output, err := c.ecrClient.GetAuthorizationToken(context.TODO(), input)
   221  	if err != nil || output == nil {
   222  		if err == nil {
   223  			if registryID == "" {
   224  				err = fmt.Errorf("missing AuthorizationData in ECR response for default registry")
   225  			} else {
   226  				err = fmt.Errorf("missing AuthorizationData in ECR response for %s", registryID)
   227  			}
   228  		}
   229  		return nil, fmt.Errorf("ecr: Failed to get authorization token: %w", err)
   230  	}
   231  
   232  	for _, authData := range output.AuthorizationData {
   233  		if authData.ProxyEndpoint != nil && authData.AuthorizationToken != nil {
   234  			authEntry := cache.AuthEntry{
   235  				AuthorizationToken: aws.ToString(authData.AuthorizationToken),
   236  				RequestedAt:        time.Now(),
   237  				ExpiresAt:          aws.ToTime(authData.ExpiresAt),
   238  				ProxyEndpoint:      aws.ToString(authData.ProxyEndpoint),
   239  				Service:            cache.ServiceECR,
   240  			}
   241  			registry, err := ExtractRegistry(authEntry.ProxyEndpoint)
   242  			if err != nil {
   243  				return nil, fmt.Errorf("Invalid ProxyEndpoint returned by ECR: %s", authEntry.ProxyEndpoint)
   244  			}
   245  			auth, err := extractToken(authEntry.AuthorizationToken, authEntry.ProxyEndpoint)
   246  			if err != nil {
   247  				return nil, err
   248  			}
   249  			c.credentialCache.Set(registry.ID, &authEntry)
   250  			return auth, nil
   251  		}
   252  	}
   253  	if registryID == "" {
   254  		return nil, fmt.Errorf("No AuthorizationToken found for default registry")
   255  	}
   256  	return nil, fmt.Errorf("No AuthorizationToken found for %s", registryID)
   257  }
   258  
   259  func (c *defaultClient) getPublicAuthorizationToken() (*Auth, error) {
   260  	var input *ecrpublic.GetAuthorizationTokenInput
   261  
   262  	output, err := c.ecrPublicClient.GetAuthorizationToken(context.TODO(), input)
   263  	if err != nil {
   264  		return nil, fmt.Errorf("ecr: failed to get authorization token: %w", err)
   265  	}
   266  	if output == nil || output.AuthorizationData == nil {
   267  		return nil, fmt.Errorf("ecr: missing AuthorizationData in ECR Public response")
   268  	}
   269  	authData := output.AuthorizationData
   270  	token, err := extractToken(aws.ToString(authData.AuthorizationToken), ecrPublicEndpoint)
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	authEntry := cache.AuthEntry{
   275  		AuthorizationToken: aws.ToString(authData.AuthorizationToken),
   276  		RequestedAt:        time.Now(),
   277  		ExpiresAt:          aws.ToTime(authData.ExpiresAt),
   278  		ProxyEndpoint:      ecrPublicEndpoint,
   279  		Service:            cache.ServiceECRPublic,
   280  	}
   281  	c.credentialCache.Set(ecrPublicName, &authEntry)
   282  	return token, nil
   283  }
   284  
   285  func extractToken(token string, proxyEndpoint string) (*Auth, error) {
   286  	decodedToken, err := base64.StdEncoding.DecodeString(token)
   287  	if err != nil {
   288  		return nil, fmt.Errorf("invalid token: %w", err)
   289  	}
   290  
   291  	parts := strings.SplitN(string(decodedToken), ":", 2)
   292  	if len(parts) < 2 {
   293  		return nil, fmt.Errorf("invalid token: expected two parts, got %d", len(parts))
   294  	}
   295  
   296  	return &Auth{
   297  		Username:      parts[0],
   298  		Password:      parts[1],
   299  		ProxyEndpoint: proxyEndpoint,
   300  	}, nil
   301  }
   302  

View as plain text