1 package ociauth
2
3 import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "strings"
12 "sync"
13 "time"
14
15 "cuelabs.dev/go/oci/ociregistry/internal/exp/slices"
16 )
17
18
19 const oauthClientID = "cuelabs-ociauth"
20
21 var ErrNoAuth = fmt.Errorf("no authorization token available to add to request")
22
23
24
25
26
27
28
29 type stdTransport struct {
30 config Config
31 transport http.RoundTripper
32 mu sync.Mutex
33 registries map[string]*registry
34 }
35
36 type StdTransportParams struct {
37
38
39
40 Config Config
41
42
43
44 Transport http.RoundTripper
45 }
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64 func NewStdTransport(p StdTransportParams) http.RoundTripper {
65 if p.Config == nil {
66 p.Config = emptyConfig{}
67 }
68 if p.Transport == nil {
69 p.Transport = http.DefaultTransport
70 }
71 return &stdTransport{
72 config: p.Config,
73 transport: p.Transport,
74 registries: make(map[string]*registry),
75 }
76 }
77
78
79 type registry struct {
80 host string
81 transport http.RoundTripper
82 config Config
83 initOnce sync.Once
84 initErr error
85
86
87 mu sync.Mutex
88
89
90
91
92
93 wwwAuthenticate *authHeader
94
95 accessTokens []*scopedToken
96 refreshToken string
97 basic *userPass
98 }
99
100 type scopedToken struct {
101
102 scope Scope
103
104 token string
105
106 expires time.Time
107 }
108
109 type userPass struct {
110 username string
111 password string
112 }
113
114 var forever = time.Date(99999, time.January, 1, 0, 0, 0, 0, time.UTC)
115
116
117 func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) {
118
119
120
121 req = req.Clone(req.Context())
122
123
124
125 needBodyClose := true
126 defer func() {
127 if needBodyClose && req.Body != nil {
128 req.Body.Close()
129 }
130 }()
131
132 a.mu.Lock()
133 r := a.registries[req.URL.Host]
134 if r == nil {
135 r = ®istry{
136 host: req.URL.Host,
137 config: a.config,
138 transport: a.transport,
139 }
140 a.registries[r.host] = r
141 }
142 a.mu.Unlock()
143 if err := r.init(); err != nil {
144 return nil, err
145 }
146
147 ctx := req.Context()
148 requiredScope := RequestInfoFromContext(ctx).RequiredScope
149 wantScope := ScopeFromContext(ctx)
150
151 if err := r.setAuthorization(ctx, req, requiredScope, wantScope); err != nil {
152 return nil, err
153 }
154 resp, err := r.transport.RoundTrip(req)
155
156
157
158 needBodyClose = false
159 if err != nil {
160 return nil, err
161 }
162 if resp.StatusCode != http.StatusUnauthorized {
163 return resp, nil
164 }
165 challenge := challengeFromResponse(resp)
166 if challenge == nil {
167 return resp, nil
168 }
169 authAdded, err := r.setAuthorizationFromChallenge(ctx, req, challenge, requiredScope, wantScope)
170 if err != nil {
171 resp.Body.Close()
172 return nil, err
173 }
174 if !authAdded {
175
176 return resp, nil
177 }
178 resp.Body.Close()
179
180 if req.GetBody != nil {
181 req.Body, err = req.GetBody()
182 if err != nil {
183 return nil, err
184 }
185 }
186 return r.transport.RoundTrip(req)
187 }
188
189
190
191 func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requiredScope, wantScope Scope) error {
192 r.mu.Lock()
193 defer r.mu.Unlock()
194
195
196
197 r.deleteExpiredTokens(time.Now().UTC().Add(time.Second))
198
199 if accessToken := r.accessTokenForScope(requiredScope); accessToken != nil {
200
201 req.Header.Set("Authorization", "Bearer "+accessToken.token)
202 return nil
203 }
204 if r.wwwAuthenticate == nil {
205
206
207
208
209 return nil
210 }
211 if r.refreshToken != "" && r.wwwAuthenticate.scheme == "bearer" {
212
213
214
215
216
217
218
219
220 accessToken, err := r.acquireAccessToken(ctx, requiredScope, wantScope)
221 if err != nil {
222 return err
223 }
224 req.Header.Set("Authorization", "Bearer "+accessToken)
225 return nil
226 }
227 if r.wwwAuthenticate.scheme != "bearer" && r.basic != nil {
228 req.SetBasicAuth(r.basic.username, r.basic.password)
229 return nil
230 }
231 return nil
232 }
233
234 func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http.Request, challenge *authHeader, requiredScope, wantScope Scope) (bool, error) {
235 r.mu.Lock()
236 defer r.mu.Unlock()
237 r.wwwAuthenticate = challenge
238
239 switch {
240 case r.wwwAuthenticate.scheme == "bearer":
241 scope := ParseScope(r.wwwAuthenticate.params["scope"])
242 accessToken, err := r.acquireAccessToken(ctx, scope, wantScope.Union(requiredScope))
243 if err != nil {
244 return false, err
245 }
246 req.Header.Set("Authorization", "Bearer "+accessToken)
247 return true, nil
248 case r.basic != nil:
249 req.SetBasicAuth(r.basic.username, r.basic.password)
250 return true, nil
251 }
252 return false, nil
253 }
254
255
256
257
258
259
260
261 func (r *registry) init() error {
262 inner := func() error {
263 info, err := r.config.EntryForRegistry(r.host)
264 if err != nil {
265 return fmt.Errorf("cannot acquire auth info for registry %q: %v", r.host, err)
266 }
267 r.refreshToken = info.RefreshToken
268 if info.AccessToken != "" {
269 r.accessTokens = append(r.accessTokens, &scopedToken{
270 scope: UnlimitedScope(),
271 token: info.AccessToken,
272 expires: forever,
273 })
274 }
275 if info.Username != "" && info.Password != "" {
276 r.basic = &userPass{
277 username: info.Username,
278 password: info.Password,
279 }
280 }
281 return nil
282 }
283 r.initOnce.Do(func() {
284 r.initErr = inner()
285 })
286 return r.initErr
287 }
288
289
290
291
292
293
294
295
296
297
298 func (r *registry) acquireAccessToken(ctx context.Context, requiredScope, wantScope Scope) (string, error) {
299 scope := requiredScope.Union(wantScope)
300 tok, err := r.acquireToken(ctx, scope)
301 if err != nil {
302 var rerr *responseError
303 if !errors.As(err, &rerr) || rerr.statusCode != http.StatusUnauthorized {
304 return "", err
305 }
306
307
308
309
310
311
312
313
314
315
316
317
318 scope = requiredScope
319 tok, err = r.acquireToken(ctx, scope)
320 if err != nil {
321 return "", err
322 }
323
324
325 }
326 if tok.RefreshToken != "" {
327 r.refreshToken = tok.RefreshToken
328 }
329 accessToken := tok.Token
330 if accessToken == "" {
331 accessToken = tok.AccessToken
332 }
333 if accessToken == "" {
334 return "", fmt.Errorf("no access token found in auth server response")
335 }
336 var expires time.Time
337 now := time.Now().UTC()
338 if tok.ExpiresIn == 0 {
339 expires = now.Add(60 * time.Second)
340 } else {
341 expires = now.Add(time.Duration(tok.ExpiresIn) * time.Second)
342 }
343 r.accessTokens = append(r.accessTokens, &scopedToken{
344 scope: scope,
345 token: accessToken,
346 expires: expires,
347 })
348
349
350 return accessToken, nil
351 }
352
353 func (r *registry) acquireToken(ctx context.Context, scope Scope) (*wireToken, error) {
354 realm := r.wwwAuthenticate.params["realm"]
355 if realm == "" {
356 return nil, fmt.Errorf("malformed Www-Authenticate header (missing realm)")
357 }
358 if r.refreshToken != "" {
359 v := url.Values{}
360 v.Set("scope", scope.String())
361 if service := r.wwwAuthenticate.params["service"]; service != "" {
362 v.Set("service", service)
363 }
364 v.Set("client_id", oauthClientID)
365 v.Set("grant_type", "refresh_token")
366 v.Set("refresh_token", r.refreshToken)
367 req, err := http.NewRequestWithContext(ctx, "POST", realm, strings.NewReader(v.Encode()))
368 if err != nil {
369 return nil, fmt.Errorf("cannot form HTTP request to %q: %v", realm, err)
370 }
371 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
372 tok, err := r.doTokenRequest(req)
373 if err == nil {
374 return tok, nil
375 }
376 var rerr *responseError
377 if !errors.As(err, &rerr) || rerr.statusCode != http.StatusNotFound {
378 return tok, err
379 }
380
381
382
383
384
385 }
386 u, err := url.Parse(realm)
387 if err != nil {
388 return nil, fmt.Errorf("malformed Www-Authenticate header (malformed realm %q): %v", realm, err)
389 }
390 v := u.Query()
391
392
393
394 v["scope"] = strings.Split(scope.String(), " ")
395 if service := r.wwwAuthenticate.params["service"]; service != "" {
396
397
398 v.Set("service", service)
399 }
400 u.RawQuery = v.Encode()
401 req, err := http.NewRequest("GET", u.String(), nil)
402 if err != nil {
403 return nil, err
404 }
405
406
407
408 if r.basic != nil {
409 req.SetBasicAuth(r.basic.username, r.basic.password)
410 }
411 return r.doTokenRequest(req)
412 }
413
414
415
416
417
418
419 type wireToken struct {
420
421
422
423
424
425 Token string `json:"token"`
426 AccessToken string `json:"access_token,omitempty"`
427
428
429
430
431
432
433
434 RefreshToken string `json:"refresh_token"`
435
436
437
438
439
440 ExpiresIn int `json:"expires_in"`
441 }
442
443 func (r *registry) doTokenRequest(req *http.Request) (*wireToken, error) {
444 client := &http.Client{
445 Transport: r.transport,
446 }
447 resp, err := client.Do(req)
448 if err != nil {
449 return nil, err
450 }
451 defer resp.Body.Close()
452 if resp.StatusCode != http.StatusOK {
453 return nil, errorFromResponse(resp)
454 }
455 data, err := io.ReadAll(resp.Body)
456 if err != nil {
457 return nil, fmt.Errorf("cannot read response body: %v", err)
458 }
459 var tok wireToken
460 if err := json.Unmarshal(data, &tok); err != nil {
461 return nil, fmt.Errorf("malformed JSON token in response: %v", err)
462 }
463 return &tok, nil
464 }
465
466 type responseError struct {
467 statusCode int
468 msg string
469 }
470
471 func errorFromResponse(resp *http.Response) error {
472
473 return &responseError{
474 statusCode: resp.StatusCode,
475 }
476 }
477
478 func (e *responseError) Error() string {
479 return fmt.Sprintf("unexpected HTTP response %d", e.statusCode)
480 }
481
482
483
484
485 func (r *registry) deleteExpiredTokens(now time.Time) {
486 r.accessTokens = slices.DeleteFunc(r.accessTokens, func(tok *scopedToken) bool {
487 return now.After(tok.expires)
488 })
489 }
490
491 func (r *registry) accessTokenForScope(scope Scope) *scopedToken {
492 for _, tok := range r.accessTokens {
493 if tok.scope.Contains(scope) {
494
495 return tok
496 }
497 }
498 return nil
499 }
500
501 type emptyConfig struct{}
502
503 func (emptyConfig) EntryForRegistry(host string) (ConfigEntry, error) {
504 return ConfigEntry{}, nil
505 }
506
View as plain text