1
2
3
4
5
6
7
8
9
10
11
12
13
14 package api
15
16 import (
17 "context"
18 "encoding/base64"
19 "fmt"
20 "net/url"
21 "regexp"
22 "strings"
23 "time"
24
25 "github.com/aws/aws-sdk-go-v2/aws"
26 "github.com/aws/aws-sdk-go-v2/service/ecr"
27 "github.com/aws/aws-sdk-go-v2/service/ecrpublic"
28 "github.com/sirupsen/logrus"
29
30 "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache"
31 )
32
33 const (
34 proxyEndpointScheme = "https://"
35 programName = "docker-credential-ecr-login"
36 ecrPublicName = "public.ecr.aws"
37 ecrPublicEndpoint = proxyEndpointScheme + ecrPublicName
38 )
39
40 var ecrPattern = regexp.MustCompile(`(^[a-zA-Z0-9][a-zA-Z0-9-_]*)\.dkr\.ecr(-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.amazonaws\.com(\.cn)?$`)
41
42 type Service string
43
44 const (
45 ServiceECR Service = "ecr"
46 ServiceECRPublic Service = "ecr-public"
47 )
48
49
50 type Registry struct {
51 Service Service
52 ID string
53 FIPS bool
54 Region string
55 }
56
57
58 func ExtractRegistry(input string) (*Registry, error) {
59 if strings.HasPrefix(input, proxyEndpointScheme) {
60 input = strings.TrimPrefix(input, proxyEndpointScheme)
61 }
62 serverURL, err := url.Parse(proxyEndpointScheme + input)
63 if err != nil {
64 return nil, err
65 }
66 if serverURL.Hostname() == ecrPublicName {
67 return &Registry{
68 Service: ServiceECRPublic,
69 }, nil
70 }
71 matches := ecrPattern.FindStringSubmatch(serverURL.Hostname())
72 if len(matches) == 0 {
73 return nil, fmt.Errorf(programName + " can only be used with Amazon Elastic Container Registry.")
74 } else if len(matches) < 3 {
75 return nil, fmt.Errorf("%q is not a valid repository URI for Amazon Elastic Container Registry.", input)
76 }
77 return &Registry{
78 Service: ServiceECR,
79 ID: matches[1],
80 FIPS: matches[2] == "-fips",
81 Region: matches[3],
82 }, nil
83 }
84
85
86 type Client interface {
87 GetCredentials(serverURL string) (*Auth, error)
88 GetCredentialsByRegistryID(registryID string) (*Auth, error)
89 ListCredentials() ([]*Auth, error)
90 }
91
92
93 type Auth struct {
94 ProxyEndpoint string
95 Username string
96 Password string
97 }
98
99 type defaultClient struct {
100 ecrClient ECRAPI
101 ecrPublicClient ECRPublicAPI
102 credentialCache cache.CredentialsCache
103 }
104
105 type ECRAPI interface {
106 GetAuthorizationToken(context.Context, *ecr.GetAuthorizationTokenInput, ...func(*ecr.Options)) (*ecr.GetAuthorizationTokenOutput, error)
107 }
108
109 type ECRPublicAPI interface {
110 GetAuthorizationToken(context.Context, *ecrpublic.GetAuthorizationTokenInput, ...func(*ecrpublic.Options)) (*ecrpublic.GetAuthorizationTokenOutput, error)
111 }
112
113
114 func (c *defaultClient) GetCredentials(serverURL string) (*Auth, error) {
115 registry, err := ExtractRegistry(serverURL)
116 if err != nil {
117 return nil, err
118 }
119 logrus.
120 WithField("service", registry.Service).
121 WithField("registry", registry.ID).
122 WithField("region", registry.Region).
123 WithField("serverURL", serverURL).
124 Debug("Retrieving credentials")
125 switch registry.Service {
126 case ServiceECR:
127 return c.GetCredentialsByRegistryID(registry.ID)
128 case ServiceECRPublic:
129 return c.GetPublicCredentials()
130 }
131 return nil, fmt.Errorf("unknown service %q", registry.Service)
132 }
133
134
135 func (c *defaultClient) GetCredentialsByRegistryID(registryID string) (*Auth, error) {
136 cachedEntry := c.credentialCache.Get(registryID)
137 if cachedEntry != nil {
138 if cachedEntry.IsValid(time.Now()) {
139 logrus.WithField("registry", registryID).Debug("Using cached token")
140 return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
141 }
142 logrus.
143 WithField("requestedAt", cachedEntry.RequestedAt).
144 WithField("expiresAt", cachedEntry.ExpiresAt).
145 Debug("Cached token is no longer valid")
146 }
147
148 auth, err := c.getAuthorizationToken(registryID)
149
150
151
152
153 if err != nil && cachedEntry != nil {
154 logrus.WithError(err).Info("Got error fetching authorization token. Falling back to cached token.")
155 return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
156 }
157 return auth, err
158 }
159
160 func (c *defaultClient) GetPublicCredentials() (*Auth, error) {
161 cachedEntry := c.credentialCache.GetPublic()
162 if cachedEntry != nil {
163 if cachedEntry.IsValid(time.Now()) {
164 logrus.WithField("registry", ecrPublicName).Debug("Using cached token")
165 return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
166 }
167 logrus.
168 WithField("requestedAt", cachedEntry.RequestedAt).
169 WithField("expiresAt", cachedEntry.ExpiresAt).
170 Debug("Cached token is no longer valid")
171 }
172
173 auth, err := c.getPublicAuthorizationToken()
174
175
176
177 if err != nil && cachedEntry != nil {
178 logrus.WithError(err).Info("Got error fetching authorization token. Falling back to cached token.")
179 return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
180 }
181 return auth, err
182 }
183
184 func (c *defaultClient) ListCredentials() ([]*Auth, error) {
185
186 _, err := c.GetCredentialsByRegistryID("")
187 if err != nil {
188 logrus.WithError(err).Debug("couldn't get authorization token for default registry")
189 }
190 _, err = c.GetPublicCredentials()
191 if err != nil {
192 logrus.WithError(err).Debug("couldn't get authorization token for public registry")
193 }
194
195 auths := make([]*Auth, 0)
196 for _, authEntry := range c.credentialCache.List() {
197 auth, err := extractToken(authEntry.AuthorizationToken, authEntry.ProxyEndpoint)
198 if err != nil {
199 logrus.WithError(err).Debug("Could not extract token")
200 } else {
201 auths = append(auths, auth)
202 }
203 }
204
205 return auths, nil
206 }
207
208 func (c *defaultClient) getAuthorizationToken(registryID string) (*Auth, error) {
209 var input *ecr.GetAuthorizationTokenInput
210 if registryID == "" {
211 logrus.Debug("Calling ECR.GetAuthorizationToken for default registry")
212 input = &ecr.GetAuthorizationTokenInput{}
213 } else {
214 logrus.WithField("registry", registryID).Debug("Calling ECR.GetAuthorizationToken")
215 input = &ecr.GetAuthorizationTokenInput{
216 RegistryIds: []string{registryID},
217 }
218 }
219
220 output, err := c.ecrClient.GetAuthorizationToken(context.TODO(), input)
221 if err != nil || output == nil {
222 if err == nil {
223 if registryID == "" {
224 err = fmt.Errorf("missing AuthorizationData in ECR response for default registry")
225 } else {
226 err = fmt.Errorf("missing AuthorizationData in ECR response for %s", registryID)
227 }
228 }
229 return nil, fmt.Errorf("ecr: Failed to get authorization token: %w", err)
230 }
231
232 for _, authData := range output.AuthorizationData {
233 if authData.ProxyEndpoint != nil && authData.AuthorizationToken != nil {
234 authEntry := cache.AuthEntry{
235 AuthorizationToken: aws.ToString(authData.AuthorizationToken),
236 RequestedAt: time.Now(),
237 ExpiresAt: aws.ToTime(authData.ExpiresAt),
238 ProxyEndpoint: aws.ToString(authData.ProxyEndpoint),
239 Service: cache.ServiceECR,
240 }
241 registry, err := ExtractRegistry(authEntry.ProxyEndpoint)
242 if err != nil {
243 return nil, fmt.Errorf("Invalid ProxyEndpoint returned by ECR: %s", authEntry.ProxyEndpoint)
244 }
245 auth, err := extractToken(authEntry.AuthorizationToken, authEntry.ProxyEndpoint)
246 if err != nil {
247 return nil, err
248 }
249 c.credentialCache.Set(registry.ID, &authEntry)
250 return auth, nil
251 }
252 }
253 if registryID == "" {
254 return nil, fmt.Errorf("No AuthorizationToken found for default registry")
255 }
256 return nil, fmt.Errorf("No AuthorizationToken found for %s", registryID)
257 }
258
259 func (c *defaultClient) getPublicAuthorizationToken() (*Auth, error) {
260 var input *ecrpublic.GetAuthorizationTokenInput
261
262 output, err := c.ecrPublicClient.GetAuthorizationToken(context.TODO(), input)
263 if err != nil {
264 return nil, fmt.Errorf("ecr: failed to get authorization token: %w", err)
265 }
266 if output == nil || output.AuthorizationData == nil {
267 return nil, fmt.Errorf("ecr: missing AuthorizationData in ECR Public response")
268 }
269 authData := output.AuthorizationData
270 token, err := extractToken(aws.ToString(authData.AuthorizationToken), ecrPublicEndpoint)
271 if err != nil {
272 return nil, err
273 }
274 authEntry := cache.AuthEntry{
275 AuthorizationToken: aws.ToString(authData.AuthorizationToken),
276 RequestedAt: time.Now(),
277 ExpiresAt: aws.ToTime(authData.ExpiresAt),
278 ProxyEndpoint: ecrPublicEndpoint,
279 Service: cache.ServiceECRPublic,
280 }
281 c.credentialCache.Set(ecrPublicName, &authEntry)
282 return token, nil
283 }
284
285 func extractToken(token string, proxyEndpoint string) (*Auth, error) {
286 decodedToken, err := base64.StdEncoding.DecodeString(token)
287 if err != nil {
288 return nil, fmt.Errorf("invalid token: %w", err)
289 }
290
291 parts := strings.SplitN(string(decodedToken), ":", 2)
292 if len(parts) < 2 {
293 return nil, fmt.Errorf("invalid token: expected two parts, got %d", len(parts))
294 }
295
296 return &Auth{
297 Username: parts[0],
298 Password: parts[1],
299 ProxyEndpoint: proxyEndpoint,
300 }, nil
301 }
302
View as plain text