1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package transport
16
17 import (
18 "context"
19 "errors"
20 "fmt"
21 "io"
22 "net/http"
23 "strings"
24 "time"
25
26 authchallenge "github.com/docker/distribution/registry/client/auth/challenge"
27 "github.com/google/go-containerregistry/pkg/logs"
28 "github.com/google/go-containerregistry/pkg/name"
29 )
30
31
32 var fallbackDelay = 300 * time.Millisecond
33
34 type Challenge struct {
35 Scheme string
36
37
38
39 Parameters map[string]string
40
41
42 Insecure bool
43 }
44
45
46 func Ping(ctx context.Context, reg name.Registry, t http.RoundTripper) (*Challenge, error) {
47
48
49
50 schemes := []string{"https"}
51 if reg.Scheme() == "http" {
52 schemes = append(schemes, "http")
53 }
54 if len(schemes) == 1 {
55 return pingSingle(ctx, reg, t, schemes[0])
56 }
57 return pingParallel(ctx, reg, t, schemes)
58 }
59
60 func pingSingle(ctx context.Context, reg name.Registry, t http.RoundTripper, scheme string) (*Challenge, error) {
61 client := http.Client{Transport: t}
62 url := fmt.Sprintf("%s://%s/v2/", scheme, reg.RegistryStr())
63 req, err := http.NewRequest(http.MethodGet, url, nil)
64 if err != nil {
65 return nil, err
66 }
67 resp, err := client.Do(req.WithContext(ctx))
68 if err != nil {
69 return nil, err
70 }
71 defer func() {
72
73
74 io.Copy(io.Discard, resp.Body)
75 resp.Body.Close()
76 }()
77
78 insecure := scheme == "http"
79
80 switch resp.StatusCode {
81 case http.StatusOK:
82
83 return &Challenge{
84 Insecure: insecure,
85 }, nil
86 case http.StatusUnauthorized:
87 if challenges := authchallenge.ResponseChallenges(resp); len(challenges) != 0 {
88
89 wac := pickFromMultipleChallenges(challenges)
90 return &Challenge{
91 Scheme: wac.Scheme,
92 Parameters: wac.Parameters,
93 Insecure: insecure,
94 }, nil
95 }
96
97 return &Challenge{
98 Scheme: resp.Header.Get("WWW-Authenticate"),
99 Insecure: insecure,
100 }, nil
101 default:
102 return nil, CheckError(resp, http.StatusOK, http.StatusUnauthorized)
103 }
104 }
105
106
107 func pingParallel(ctx context.Context, reg name.Registry, t http.RoundTripper, schemes []string) (*Challenge, error) {
108 returned := make(chan struct{})
109 defer close(returned)
110
111 type pingResult struct {
112 *Challenge
113 error
114 primary bool
115 done bool
116 }
117
118 results := make(chan pingResult)
119
120 startRacer := func(ctx context.Context, scheme string) {
121 pr, err := pingSingle(ctx, reg, t, scheme)
122 select {
123 case results <- pingResult{Challenge: pr, error: err, primary: scheme == "https", done: true}:
124 case <-returned:
125 if pr != nil {
126 logs.Debug.Printf("%s lost race", scheme)
127 }
128 }
129 }
130
131 var primary, fallback pingResult
132
133 primaryCtx, primaryCancel := context.WithCancel(ctx)
134 defer primaryCancel()
135 go startRacer(primaryCtx, schemes[0])
136
137 fallbackTimer := time.NewTimer(fallbackDelay)
138 defer fallbackTimer.Stop()
139
140 for {
141 select {
142 case <-fallbackTimer.C:
143 fallbackCtx, fallbackCancel := context.WithCancel(ctx)
144 defer fallbackCancel()
145 go startRacer(fallbackCtx, schemes[1])
146
147 case res := <-results:
148 if res.error == nil {
149 return res.Challenge, nil
150 }
151 if res.primary {
152 primary = res
153 } else {
154 fallback = res
155 }
156 if primary.done && fallback.done {
157 return nil, multierrs{primary.error, fallback.error}
158 }
159 if res.primary && fallbackTimer.Stop() {
160
161
162 fallbackTimer.Reset(0)
163 }
164 }
165 }
166 }
167
168 func pickFromMultipleChallenges(challenges []authchallenge.Challenge) authchallenge.Challenge {
169
170
171
172 allowedSchemes := []string{"basic", "bearer"}
173
174 for _, wac := range challenges {
175 currentScheme := strings.ToLower(wac.Scheme)
176 for _, allowed := range allowedSchemes {
177 if allowed == currentScheme {
178 return wac
179 }
180 }
181 }
182
183 return challenges[0]
184 }
185
186 type multierrs []error
187
188 func (m multierrs) Error() string {
189 var b strings.Builder
190 hasWritten := false
191 for _, err := range m {
192 if hasWritten {
193 b.WriteString("; ")
194 }
195 hasWritten = true
196 b.WriteString(err.Error())
197 }
198 return b.String()
199 }
200
201 func (m multierrs) As(target any) bool {
202 for _, err := range m {
203 if errors.As(err, target) {
204 return true
205 }
206 }
207 return false
208 }
209
210 func (m multierrs) Is(target error) bool {
211 for _, err := range m {
212 if errors.Is(err, target) {
213 return true
214 }
215 }
216 return false
217 }
218
View as plain text