1
15 package auth
16
17 import (
18 "context"
19 "encoding/base64"
20 "encoding/json"
21 "errors"
22 "fmt"
23 "io"
24 "net/http"
25 "net/url"
26 "strings"
27
28 "oras.land/oras-go/pkg/registry/remote/internal/errutil"
29 )
30
31
32 var DefaultClient = &Client{
33 Header: http.Header{
34 "User-Agent": {"oras-go"},
35 },
36 Cache: DefaultCache,
37 }
38
39
40
41
42
43
44
45
46
47 var maxResponseBytes int64 = 128 * 1024
48
49
50
51 var defaultClientID = "oras-go"
52
53
54
55 type Client struct {
56
57
58
59 Client *http.Client
60
61
62 Header http.Header
63
64
65
66
67
68
69 Credential func(context.Context, string) (Credential, error)
70
71
72
73 Cache Cache
74
75
76
77
78 ClientID string
79
80
81
82
83
84
85
86 ForceAttemptOAuth2 bool
87 }
88
89
90
91 func (c *Client) client() *http.Client {
92 if c.Client == nil {
93 return http.DefaultClient
94 }
95 return c.Client
96 }
97
98
99 func (c *Client) send(req *http.Request) (*http.Response, error) {
100 for key, values := range c.Header {
101 req.Header[key] = append(req.Header[key], values...)
102 }
103 return c.client().Do(req)
104 }
105
106
107 func (c *Client) credential(ctx context.Context, reg string) (Credential, error) {
108 if c.Credential == nil {
109 return EmptyCredential, nil
110 }
111 return c.Credential(ctx, reg)
112 }
113
114
115
116 func (c *Client) cache() Cache {
117 if c.Cache == nil {
118 return noCache{}
119 }
120 return c.Cache
121 }
122
123
124 func (c *Client) SetUserAgent(userAgent string) {
125 if c.Header == nil {
126 c.Header = http.Header{}
127 }
128 c.Header.Set("User-Agent", userAgent)
129 }
130
131
132
133
134
135
136 func (c *Client) Do(originalReq *http.Request) (*http.Response, error) {
137 ctx := originalReq.Context()
138 req := originalReq.Clone(ctx)
139
140
141 var attemptedKey string
142 cache := c.cache()
143 registry := originalReq.Host
144 scheme, err := cache.GetScheme(ctx, registry)
145 if err == nil {
146 switch scheme {
147 case SchemeBasic:
148 token, err := cache.GetToken(ctx, registry, SchemeBasic, "")
149 if err == nil {
150 req.Header.Set("Authorization", "Basic "+token)
151 }
152 case SchemeBearer:
153 scopes := GetScopes(ctx)
154 attemptedKey = strings.Join(scopes, " ")
155 token, err := cache.GetToken(ctx, registry, SchemeBearer, attemptedKey)
156 if err == nil {
157 req.Header.Set("Authorization", "Bearer "+token)
158 }
159 }
160 }
161
162 resp, err := c.send(req)
163 if err != nil {
164 return nil, err
165 }
166 if resp.StatusCode != http.StatusUnauthorized {
167 return resp, nil
168 }
169
170
171 challenge := resp.Header.Get("Www-Authenticate")
172 scheme, params := parseChallenge(challenge)
173 switch scheme {
174 case SchemeBasic:
175 resp.Body.Close()
176
177 token, err := cache.Set(ctx, registry, SchemeBasic, "", func(ctx context.Context) (string, error) {
178 return c.fetchBasicAuth(ctx, registry)
179 })
180 if err != nil {
181 return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err)
182 }
183
184 req = originalReq.Clone(ctx)
185 req.Header.Set("Authorization", "Basic "+token)
186 case SchemeBearer:
187 resp.Body.Close()
188
189
190 scopes := GetScopes(ctx)
191 if scope := params["scope"]; scope != "" {
192 scopes = append(scopes, strings.Split(scope, " ")...)
193 scopes = CleanScopes(scopes)
194 }
195 key := strings.Join(scopes, " ")
196
197
198 if key != attemptedKey {
199 if token, err := cache.GetToken(ctx, registry, SchemeBearer, key); err == nil {
200 req = originalReq.Clone(ctx)
201 req.Header.Set("Authorization", "Bearer "+token)
202
203 resp, err := c.send(req)
204 if err != nil {
205 return nil, err
206 }
207 if resp.StatusCode != http.StatusUnauthorized {
208 return resp, nil
209 }
210 resp.Body.Close()
211 }
212 }
213
214
215 realm := params["realm"]
216 service := params["service"]
217 token, err := cache.Set(ctx, registry, SchemeBearer, key, func(ctx context.Context) (string, error) {
218 return c.fetchBearerToken(ctx, registry, realm, service, scopes)
219 })
220 if err != nil {
221 return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err)
222 }
223
224 req = originalReq.Clone(ctx)
225 req.Header.Set("Authorization", "Bearer "+token)
226 default:
227 return resp, nil
228 }
229
230 return c.send(req)
231 }
232
233
234 func (c *Client) fetchBasicAuth(ctx context.Context, registry string) (string, error) {
235 cred, err := c.credential(ctx, registry)
236 if err != nil {
237 return "", fmt.Errorf("failed to resolve credential: %w", err)
238 }
239 if cred == EmptyCredential {
240 return "", errors.New("credential required for basic auth")
241 }
242 if cred.Username == "" || cred.Password == "" {
243 return "", errors.New("missing username or password for basic auth")
244 }
245 auth := cred.Username + ":" + cred.Password
246 return base64.StdEncoding.EncodeToString([]byte(auth)), nil
247 }
248
249
250 func (c *Client) fetchBearerToken(ctx context.Context, registry, realm, service string, scopes []string) (string, error) {
251 cred, err := c.credential(ctx, registry)
252 if err != nil {
253 return "", err
254 }
255 if cred.AccessToken != "" {
256 return cred.AccessToken, nil
257 }
258 if cred == EmptyCredential || (cred.RefreshToken == "" && !c.ForceAttemptOAuth2) {
259 return c.fetchDistributionToken(ctx, realm, service, scopes, cred.Username, cred.Password)
260 }
261 return c.fetchOAuth2Token(ctx, realm, service, scopes, cred)
262 }
263
264
265
266
267
268
269
270 func (c *Client) fetchDistributionToken(ctx context.Context, realm, service string, scopes []string, username, password string) (string, error) {
271 req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm, nil)
272 if err != nil {
273 return "", err
274 }
275 if username != "" || password != "" {
276 req.SetBasicAuth(username, password)
277 }
278 q := req.URL.Query()
279 if service != "" {
280 q.Set("service", service)
281 }
282 for _, scope := range scopes {
283 q.Add("scope", scope)
284 }
285 req.URL.RawQuery = q.Encode()
286
287 resp, err := c.send(req)
288 if err != nil {
289 return "", err
290 }
291 defer resp.Body.Close()
292 if resp.StatusCode != http.StatusOK {
293 return "", errutil.ParseErrorResponse(resp)
294 }
295
296
297
298
299 var result struct {
300 Token string `json:"token"`
301 AccessToken string `json:"access_token"`
302 }
303 lr := io.LimitReader(resp.Body, maxResponseBytes)
304 if err := json.NewDecoder(lr).Decode(&result); err != nil {
305 return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err)
306 }
307 if result.AccessToken != "" {
308 return result.AccessToken, nil
309 }
310 if result.Token != "" {
311 return result.Token, nil
312 }
313 return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL)
314 }
315
316
317
318 func (c *Client) fetchOAuth2Token(ctx context.Context, realm, service string, scopes []string, cred Credential) (string, error) {
319 form := url.Values{}
320 if cred.RefreshToken != "" {
321 form.Set("grant_type", "refresh_token")
322 form.Set("refresh_token", cred.RefreshToken)
323 } else if cred.Username != "" && cred.Password != "" {
324 form.Set("grant_type", "password")
325 form.Set("username", cred.Username)
326 form.Set("password", cred.Password)
327 } else {
328 return "", errors.New("missing username or password for bearer auth")
329 }
330 form.Set("service", service)
331 clientID := c.ClientID
332 if clientID == "" {
333 clientID = defaultClientID
334 }
335 form.Set("client_id", clientID)
336 if len(scopes) != 0 {
337 form.Set("scope", strings.Join(scopes, " "))
338 }
339 body := strings.NewReader(form.Encode())
340
341 req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm, body)
342 if err != nil {
343 return "", err
344 }
345 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
346
347 resp, err := c.send(req)
348 if err != nil {
349 return "", err
350 }
351 defer resp.Body.Close()
352 if resp.StatusCode != http.StatusOK {
353 return "", errutil.ParseErrorResponse(resp)
354 }
355
356 var result struct {
357 AccessToken string `json:"access_token"`
358 }
359 lr := io.LimitReader(resp.Body, maxResponseBytes)
360 if err := json.NewDecoder(lr).Decode(&result); err != nil {
361 return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err)
362 }
363 if result.AccessToken != "" {
364 return result.AccessToken, nil
365 }
366 return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL)
367 }
368
View as plain text