1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package transport
16
17 import (
18 "context"
19 "net/http"
20 "net/http/httptest"
21 "net/url"
22 "strings"
23 "sync/atomic"
24 "testing"
25 "time"
26
27 "github.com/google/go-containerregistry/pkg/name"
28 )
29
30 var (
31 testRegistry, _ = name.NewRegistry("localhost:8080", name.StrictValidation)
32 )
33
34 func TestPingNoChallenge(t *testing.T) {
35 server := httptest.NewServer(
36 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37 w.WriteHeader(http.StatusOK)
38 }))
39 defer server.Close()
40 tprt := &http.Transport{
41 Proxy: func(req *http.Request) (*url.URL, error) {
42 return url.Parse(server.URL)
43 },
44 }
45
46 pr, err := Ping(context.Background(), testRegistry, tprt)
47 if err != nil {
48 t.Errorf("ping() = %v", err)
49 }
50 if pr.Scheme != "" {
51 t.Errorf("ping(); got %v, want %v", pr.Scheme, "")
52 }
53 if !pr.Insecure {
54 t.Errorf("ping(); got %v, want %v", pr.Insecure, true)
55 }
56 }
57
58 func TestPingBasicChallengeNoParams(t *testing.T) {
59 server := httptest.NewServer(
60 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
61 w.Header().Set("WWW-Authenticate", `BASIC`)
62 http.Error(w, "Unauthorized", http.StatusUnauthorized)
63 }))
64 defer server.Close()
65 tprt := &http.Transport{
66 Proxy: func(req *http.Request) (*url.URL, error) {
67 return url.Parse(server.URL)
68 },
69 }
70
71 pr, err := Ping(context.Background(), testRegistry, tprt)
72 if err != nil {
73 t.Errorf("ping() = %v", err)
74 }
75 if pr.Scheme != "basic" {
76 t.Errorf("ping(); got %v, want %v", pr.Scheme, "basic")
77 }
78 if got, want := len(pr.Parameters), 0; got != want {
79 t.Errorf("ping(); got %v, want %v", got, want)
80 }
81 }
82
83 func TestPingBearerChallengeWithParams(t *testing.T) {
84 server := httptest.NewServer(
85 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
86 w.Header().Set("WWW-Authenticate", `Bearer realm="http://auth.example.com/token"`)
87 http.Error(w, "Unauthorized", http.StatusUnauthorized)
88 }))
89 defer server.Close()
90 tprt := &http.Transport{
91 Proxy: func(req *http.Request) (*url.URL, error) {
92 return url.Parse(server.URL)
93 },
94 }
95
96 pr, err := Ping(context.Background(), testRegistry, tprt)
97 if err != nil {
98 t.Errorf("ping() = %v", err)
99 }
100 if pr.Scheme != "bearer" {
101 t.Errorf("ping(); got %v, want %v", pr.Scheme, "bearer")
102 }
103 if got, want := len(pr.Parameters), 1; got != want {
104 t.Errorf("ping(); got %v, want %v", got, want)
105 }
106 }
107
108 func TestPingMultipleChallenges(t *testing.T) {
109 server := httptest.NewServer(
110 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
111 w.Header().Add("WWW-Authenticate", "Negotiate")
112 w.Header().Add("WWW-Authenticate", `Basic realm="http://auth.example.com/token"`)
113 http.Error(w, "Unauthorized", http.StatusUnauthorized)
114 }))
115 defer server.Close()
116 tprt := &http.Transport{
117 Proxy: func(req *http.Request) (*url.URL, error) {
118 return url.Parse(server.URL)
119 },
120 }
121
122 pr, err := Ping(context.Background(), testRegistry, tprt)
123 if err != nil {
124 t.Errorf("ping() = %v", err)
125 }
126 if pr.Scheme != "basic" {
127 t.Errorf("ping(); got %v, want %v", pr.Scheme, "basic")
128 }
129 if got, want := len(pr.Parameters), 1; got != want {
130 t.Errorf("ping(); got %v, want %v", got, want)
131 }
132 }
133
134 func TestPingMultipleNotSupportedChallenges(t *testing.T) {
135 server := httptest.NewServer(
136 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
137 w.Header().Add("WWW-Authenticate", "Negotiate")
138 w.Header().Add("WWW-Authenticate", "Digest")
139 http.Error(w, "Unauthorized", http.StatusUnauthorized)
140 }))
141 defer server.Close()
142 tprt := &http.Transport{
143 Proxy: func(req *http.Request) (*url.URL, error) {
144 return url.Parse(server.URL)
145 },
146 }
147
148 pr, err := Ping(context.Background(), testRegistry, tprt)
149 if err != nil {
150 t.Errorf("ping() = %v", err)
151 }
152 if pr.Scheme != "negotiate" {
153 t.Errorf("ping(); got %v, want %v", pr.Scheme, "negotiate")
154 }
155 }
156
157 func TestUnsupportedStatus(t *testing.T) {
158 server := httptest.NewServer(
159 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
160 w.Header().Set("WWW-Authenticate", `Bearer realm="http://auth.example.com/token`)
161 http.Error(w, "Forbidden", http.StatusForbidden)
162 }))
163 defer server.Close()
164 tprt := &http.Transport{
165 Proxy: func(req *http.Request) (*url.URL, error) {
166 return url.Parse(server.URL)
167 },
168 }
169
170 pr, err := Ping(context.Background(), testRegistry, tprt)
171 if err == nil {
172 t.Errorf("ping() = %v", pr)
173 }
174 }
175
176 func TestPingHttpFallback(t *testing.T) {
177 tests := []struct {
178 reg name.Registry
179 wantCount int64
180 err string
181 contains []string
182 }{{
183 reg: mustRegistry("gcr.io"),
184 wantCount: 1,
185 err: `Get "https://gcr.io/v2/": http: server gave HTTP response to HTTPS client`,
186 }, {
187 reg: mustRegistry("ko.local"),
188 wantCount: 2,
189 }, {
190 reg: mustInsecureRegistry("us.gcr.io"),
191 wantCount: 0,
192 contains: []string{"https://us.gcr.io/v2/", "http://us.gcr.io/v2/"},
193 }}
194
195 gotCount := int64(0)
196 server := httptest.NewServer(
197 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
198 atomic.AddInt64(&gotCount, 1)
199 if r.URL.Scheme != "http" {
200
201
202 time.Sleep(5 * time.Millisecond)
203 }
204 w.WriteHeader(http.StatusOK)
205 }))
206 defer server.Close()
207
208 tprt := &http.Transport{
209 Proxy: func(req *http.Request) (*url.URL, error) {
210 return url.Parse(server.URL)
211 },
212 }
213
214 fallbackDelay = 2 * time.Millisecond
215
216 for _, test := range tests {
217
218 if strings.Contains(test.reg.String(), "us.gcr.io") {
219 server.Close()
220 }
221
222 _, err := Ping(context.Background(), test.reg, tprt)
223 if got, want := gotCount, test.wantCount; got != want {
224 t.Errorf("%s: got %d requests, wanted %d", test.reg.String(), got, want)
225 }
226 gotCount = 0
227
228 if err == nil {
229 if test.err != "" {
230 t.Error("expected err, got nil")
231 }
232 continue
233 }
234 if len(test.contains) != 0 {
235 for _, c := range test.contains {
236 if !strings.Contains(err.Error(), c) {
237 t.Errorf("expected err to contain %q but did not: %q", c, err)
238 }
239 }
240 } else if got, want := err.Error(), test.err; got != want {
241 t.Errorf("got %q want %q", got, want)
242 }
243 }
244 }
245
246 func mustRegistry(r string) name.Registry {
247 reg, err := name.NewRegistry(r)
248 if err != nil {
249 panic(err)
250 }
251 return reg
252 }
253
254 func mustInsecureRegistry(r string) name.Registry {
255 reg, err := name.NewRegistry(r, name.Insecure)
256 if err != nil {
257 panic(err)
258 }
259 return reg
260 }
261
View as plain text