1 package auth
2
3 import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "net/http"
8 "net/url"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/docker/distribution/registry/client"
14 "github.com/docker/distribution/registry/client/auth/challenge"
15 "github.com/docker/distribution/registry/client/transport"
16 )
17
18 var (
19
20
21 ErrNoBasicAuthCredentials = errors.New("no basic auth credentials")
22
23
24
25 ErrNoToken = errors.New("authorization server did not include a token in the response")
26 )
27
28 const defaultClientID = "registry-client"
29
30
31
32 type AuthenticationHandler interface {
33
34 Scheme() string
35
36
37
38
39 AuthorizeRequest(req *http.Request, params map[string]string) error
40 }
41
42
43
44 type CredentialStore interface {
45
46 Basic(*url.URL) (string, string)
47
48
49
50 RefreshToken(*url.URL, string) string
51
52
53
54 SetRefreshToken(realm *url.URL, service, token string)
55 }
56
57
58
59
60
61 func NewAuthorizer(manager challenge.Manager, handlers ...AuthenticationHandler) transport.RequestModifier {
62 return &endpointAuthorizer{
63 challenges: manager,
64 handlers: handlers,
65 }
66 }
67
68 type endpointAuthorizer struct {
69 challenges challenge.Manager
70 handlers []AuthenticationHandler
71 }
72
73 func (ea *endpointAuthorizer) ModifyRequest(req *http.Request) error {
74 pingPath := req.URL.Path
75 if v2Root := strings.Index(req.URL.Path, "/v2/"); v2Root != -1 {
76 pingPath = pingPath[:v2Root+4]
77 } else if v1Root := strings.Index(req.URL.Path, "/v1/"); v1Root != -1 {
78 pingPath = pingPath[:v1Root] + "/v2/"
79 } else {
80 return nil
81 }
82
83 ping := url.URL{
84 Host: req.URL.Host,
85 Scheme: req.URL.Scheme,
86 Path: pingPath,
87 }
88
89 challenges, err := ea.challenges.GetChallenges(ping)
90 if err != nil {
91 return err
92 }
93
94 if len(challenges) > 0 {
95 for _, handler := range ea.handlers {
96 for _, c := range challenges {
97 if c.Scheme != handler.Scheme() {
98 continue
99 }
100 if err := handler.AuthorizeRequest(req, c.Parameters); err != nil {
101 return err
102 }
103 }
104 }
105 }
106
107 return nil
108 }
109
110
111
112
113
114
115 const minimumTokenLifetimeSeconds = 60
116
117
118 type clock interface {
119 Now() time.Time
120 }
121
122 type tokenHandler struct {
123 creds CredentialStore
124 transport http.RoundTripper
125 clock clock
126
127 offlineAccess bool
128 forceOAuth bool
129 clientID string
130 scopes []Scope
131
132 tokenLock sync.Mutex
133 tokenCache string
134 tokenExpiration time.Time
135
136 logger Logger
137 }
138
139
140
141 type Scope interface {
142 String() string
143 }
144
145
146
147 type RepositoryScope struct {
148 Repository string
149 Class string
150 Actions []string
151 }
152
153
154
155 func (rs RepositoryScope) String() string {
156 repoType := "repository"
157
158
159 if rs.Class != "" && rs.Class != "image" {
160 repoType = fmt.Sprintf("%s(%s)", repoType, rs.Class)
161 }
162 return fmt.Sprintf("%s:%s:%s", repoType, rs.Repository, strings.Join(rs.Actions, ","))
163 }
164
165
166
167 type RegistryScope struct {
168 Name string
169 Actions []string
170 }
171
172
173
174 func (rs RegistryScope) String() string {
175 return fmt.Sprintf("registry:%s:%s", rs.Name, strings.Join(rs.Actions, ","))
176 }
177
178
179 type Logger interface {
180 Debugf(format string, args ...interface{})
181 }
182
183 func logDebugf(logger Logger, format string, args ...interface{}) {
184 if logger == nil {
185 return
186 }
187 logger.Debugf(format, args...)
188 }
189
190
191 type TokenHandlerOptions struct {
192 Transport http.RoundTripper
193 Credentials CredentialStore
194
195 OfflineAccess bool
196 ForceOAuth bool
197 ClientID string
198 Scopes []Scope
199 Logger Logger
200 }
201
202
203 type realClock struct{}
204
205
206 func (realClock) Now() time.Time { return time.Now() }
207
208
209
210 func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler {
211
212 return NewTokenHandlerWithOptions(TokenHandlerOptions{
213 Transport: transport,
214 Credentials: creds,
215 Scopes: []Scope{
216 RepositoryScope{
217 Repository: scope,
218 Actions: actions,
219 },
220 },
221 })
222 }
223
224
225
226 func NewTokenHandlerWithOptions(options TokenHandlerOptions) AuthenticationHandler {
227 handler := &tokenHandler{
228 transport: options.Transport,
229 creds: options.Credentials,
230 offlineAccess: options.OfflineAccess,
231 forceOAuth: options.ForceOAuth,
232 clientID: options.ClientID,
233 scopes: options.Scopes,
234 clock: realClock{},
235 logger: options.Logger,
236 }
237
238 return handler
239 }
240
241 func (th *tokenHandler) client() *http.Client {
242 return &http.Client{
243 Transport: th.transport,
244 Timeout: 15 * time.Second,
245 }
246 }
247
248 func (th *tokenHandler) Scheme() string {
249 return "bearer"
250 }
251
252 func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
253 var additionalScopes []string
254 if fromParam := req.URL.Query().Get("from"); fromParam != "" {
255 additionalScopes = append(additionalScopes, RepositoryScope{
256 Repository: fromParam,
257 Actions: []string{"pull"},
258 }.String())
259 }
260
261 token, err := th.getToken(params, additionalScopes...)
262 if err != nil {
263 return err
264 }
265
266 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
267
268 return nil
269 }
270
271 func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) {
272 th.tokenLock.Lock()
273 defer th.tokenLock.Unlock()
274 scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
275 for _, scope := range th.scopes {
276 scopes = append(scopes, scope.String())
277 }
278 var addedScopes bool
279 for _, scope := range additionalScopes {
280 if hasScope(scopes, scope) {
281 continue
282 }
283 scopes = append(scopes, scope)
284 addedScopes = true
285 }
286
287 now := th.clock.Now()
288 if now.After(th.tokenExpiration) || addedScopes {
289 token, expiration, err := th.fetchToken(params, scopes)
290 if err != nil {
291 return "", err
292 }
293
294
295 if !addedScopes {
296 th.tokenCache = token
297 th.tokenExpiration = expiration
298 }
299
300 return token, nil
301 }
302
303 return th.tokenCache, nil
304 }
305
306 func hasScope(scopes []string, scope string) bool {
307 for _, s := range scopes {
308 if s == scope {
309 return true
310 }
311 }
312 return false
313 }
314
315 type postTokenResponse struct {
316 AccessToken string `json:"access_token"`
317 RefreshToken string `json:"refresh_token"`
318 ExpiresIn int `json:"expires_in"`
319 IssuedAt time.Time `json:"issued_at"`
320 Scope string `json:"scope"`
321 }
322
323 func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
324 form := url.Values{}
325 form.Set("scope", strings.Join(scopes, " "))
326 form.Set("service", service)
327
328 clientID := th.clientID
329 if clientID == "" {
330
331 clientID = defaultClientID
332 }
333 form.Set("client_id", clientID)
334
335 if refreshToken != "" {
336 form.Set("grant_type", "refresh_token")
337 form.Set("refresh_token", refreshToken)
338 } else if th.creds != nil {
339 form.Set("grant_type", "password")
340 username, password := th.creds.Basic(realm)
341 form.Set("username", username)
342 form.Set("password", password)
343
344
345 form.Set("access_type", "offline")
346 } else {
347
348 return "", time.Time{}, fmt.Errorf("no supported grant type")
349 }
350
351 resp, err := th.client().PostForm(realm.String(), form)
352 if err != nil {
353 return "", time.Time{}, err
354 }
355 defer resp.Body.Close()
356
357 if !client.SuccessStatus(resp.StatusCode) {
358 err := client.HandleErrorResponse(resp)
359 return "", time.Time{}, err
360 }
361
362 decoder := json.NewDecoder(resp.Body)
363
364 var tr postTokenResponse
365 if err = decoder.Decode(&tr); err != nil {
366 return "", time.Time{}, fmt.Errorf("unable to decode token response: %s", err)
367 }
368
369 if tr.RefreshToken != "" && tr.RefreshToken != refreshToken {
370 th.creds.SetRefreshToken(realm, service, tr.RefreshToken)
371 }
372
373 if tr.ExpiresIn < minimumTokenLifetimeSeconds {
374
375 tr.ExpiresIn = minimumTokenLifetimeSeconds
376 logDebugf(th.logger, "Increasing token expiration to: %d seconds", tr.ExpiresIn)
377 }
378
379 if tr.IssuedAt.IsZero() {
380
381 tr.IssuedAt = th.clock.Now().UTC()
382 }
383
384 return tr.AccessToken, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
385 }
386
387 type getTokenResponse struct {
388 Token string `json:"token"`
389 AccessToken string `json:"access_token"`
390 ExpiresIn int `json:"expires_in"`
391 IssuedAt time.Time `json:"issued_at"`
392 RefreshToken string `json:"refresh_token"`
393 }
394
395 func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
396
397 req, err := http.NewRequest("GET", realm.String(), nil)
398 if err != nil {
399 return "", time.Time{}, err
400 }
401
402 reqParams := req.URL.Query()
403
404 if service != "" {
405 reqParams.Add("service", service)
406 }
407
408 for _, scope := range scopes {
409 reqParams.Add("scope", scope)
410 }
411
412 if th.offlineAccess {
413 reqParams.Add("offline_token", "true")
414 clientID := th.clientID
415 if clientID == "" {
416 clientID = defaultClientID
417 }
418 reqParams.Add("client_id", clientID)
419 }
420
421 if th.creds != nil {
422 username, password := th.creds.Basic(realm)
423 if username != "" && password != "" {
424 reqParams.Add("account", username)
425 req.SetBasicAuth(username, password)
426 }
427 }
428
429 req.URL.RawQuery = reqParams.Encode()
430
431 resp, err := th.client().Do(req)
432 if err != nil {
433 return "", time.Time{}, err
434 }
435 defer resp.Body.Close()
436
437 if !client.SuccessStatus(resp.StatusCode) {
438 err := client.HandleErrorResponse(resp)
439 return "", time.Time{}, err
440 }
441
442 decoder := json.NewDecoder(resp.Body)
443
444 var tr getTokenResponse
445 if err = decoder.Decode(&tr); err != nil {
446 return "", time.Time{}, fmt.Errorf("unable to decode token response: %s", err)
447 }
448
449 if tr.RefreshToken != "" && th.creds != nil {
450 th.creds.SetRefreshToken(realm, service, tr.RefreshToken)
451 }
452
453
454
455
456 if tr.AccessToken != "" {
457 tr.Token = tr.AccessToken
458 }
459
460 if tr.Token == "" {
461 return "", time.Time{}, ErrNoToken
462 }
463
464 if tr.ExpiresIn < minimumTokenLifetimeSeconds {
465
466 tr.ExpiresIn = minimumTokenLifetimeSeconds
467 logDebugf(th.logger, "Increasing token expiration to: %d seconds", tr.ExpiresIn)
468 }
469
470 if tr.IssuedAt.IsZero() {
471
472 tr.IssuedAt = th.clock.Now().UTC()
473 }
474
475 return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
476 }
477
478 func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
479 realm, ok := params["realm"]
480 if !ok {
481 return "", time.Time{}, errors.New("no realm specified for token auth challenge")
482 }
483
484
485 realmURL, err := url.Parse(realm)
486 if err != nil {
487 return "", time.Time{}, fmt.Errorf("invalid token auth challenge realm: %s", err)
488 }
489
490 service := params["service"]
491
492 var refreshToken string
493
494 if th.creds != nil {
495 refreshToken = th.creds.RefreshToken(realmURL, service)
496 }
497
498 if refreshToken != "" || th.forceOAuth {
499 return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes)
500 }
501
502 return th.fetchTokenWithBasicAuth(realmURL, service, scopes)
503 }
504
505 type basicHandler struct {
506 creds CredentialStore
507 }
508
509
510
511 func NewBasicHandler(creds CredentialStore) AuthenticationHandler {
512 return &basicHandler{
513 creds: creds,
514 }
515 }
516
517 func (*basicHandler) Scheme() string {
518 return "basic"
519 }
520
521 func (bh *basicHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
522 if bh.creds != nil {
523 username, password := bh.creds.Basic(req.URL)
524 if username != "" && password != "" {
525 req.SetBasicAuth(username, password)
526 return nil
527 }
528 }
529 return ErrNoBasicAuthCredentials
530 }
531
View as plain text