1
16
17 package transport
18
19 import (
20 "context"
21 "crypto/tls"
22 "errors"
23 "fmt"
24 "net"
25 "net/http"
26 "testing"
27 )
28
29 const (
30 rootCACert = `-----BEGIN CERTIFICATE-----
31 MIIC4DCCAcqgAwIBAgIBATALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
32 MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIxNTczN1oXDTE2MDExNTIxNTcz
33 OFowIzEhMB8GA1UEAwwYMTAuMTMuMTI5LjEwNkAxNDIxMzU5MDU4MIIBIjANBgkq
34 hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAunDRXGwsiYWGFDlWH6kjGun+PshDGeZX
35 xtx9lUnL8pIRWH3wX6f13PO9sktaOWW0T0mlo6k2bMlSLlSZgG9H6og0W6gLS3vq
36 s4VavZ6DbXIwemZG2vbRwsvR+t4G6Nbwelm6F8RFnA1Fwt428pavmNQ/wgYzo+T1
37 1eS+HiN4ACnSoDSx3QRWcgBkB1g6VReofVjx63i0J+w8Q/41L9GUuLqquFxu6ZnH
38 60vTB55lHgFiDLjA1FkEz2dGvGh/wtnFlRvjaPC54JH2K1mPYAUXTreoeJtLJKX0
39 ycoiyB24+zGCniUmgIsmQWRPaOPircexCp1BOeze82BT1LCZNTVaxQIDAQABoyMw
40 ITAOBgNVHQ8BAf8EBAMCAKQwDwYDVR0TAQH/BAUwAwEB/zALBgkqhkiG9w0BAQsD
41 ggEBADMxsUuAFlsYDpF4fRCzXXwrhbtj4oQwcHpbu+rnOPHCZupiafzZpDu+rw4x
42 YGPnCb594bRTQn4pAu3Ac18NbLD5pV3uioAkv8oPkgr8aUhXqiv7KdDiaWm6sbAL
43 EHiXVBBAFvQws10HMqMoKtO8f1XDNAUkWduakR/U6yMgvOPwS7xl0eUTqyRB6zGb
44 K55q2dejiFWaFqB/y78txzvz6UlOZKE44g2JAVoJVM6kGaxh33q8/FmrL4kuN3ut
45 W+MmJCVDvd4eEqPwbp7146ZWTqpIJ8lvA6wuChtqV8lhAPka2hD/LMqY8iXNmfXD
46 uml0obOEy+ON91k+SWTJ3ggmF/U=
47 -----END CERTIFICATE-----`
48
49 certData = `-----BEGIN CERTIFICATE-----
50 MIIC6jCCAdSgAwIBAgIBCzALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
51 MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIyMDEzMVoXDTE2MDExNTIyMDEz
52 MlowGzEZMBcGA1UEAxMQb3BlbnNoaWZ0LWNsaWVudDCCASIwDQYJKoZIhvcNAQEB
53 BQADggEPADCCAQoCggEBAKtdhz0+uCLXw5cSYns9rU/XifFSpb/x24WDdrm72S/v
54 b9BPYsAStiP148buylr1SOuNi8sTAZmlVDDIpIVwMLff+o2rKYDicn9fjbrTxTOj
55 lI4pHJBH+JU3AJ0tbajupioh70jwFS0oYpwtneg2zcnE2Z4l6mhrj2okrc5Q1/X2
56 I2HChtIU4JYTisObtin10QKJX01CLfYXJLa8upWzKZ4/GOcHG+eAV3jXWoXidtjb
57 1Usw70amoTZ6mIVCkiu1QwCoa8+ycojGfZhvqMsAp1536ZcCul+Na+AbCv4zKS7F
58 kQQaImVrXdUiFansIoofGlw/JNuoKK6ssVpS5Ic3pgcCAwEAAaM1MDMwDgYDVR0P
59 AQH/BAQDAgCgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMAwGA1UdEwEB/wQCMAAwCwYJ
60 KoZIhvcNAQELA4IBAQCKLREH7bXtXtZ+8vI6cjD7W3QikiArGqbl36bAhhWsJLp/
61 p/ndKz39iFNaiZ3GlwIURWOOKx3y3GA0x9m8FR+Llthf0EQ8sUjnwaknWs0Y6DQ3
62 jjPFZOpV3KPCFrdMJ3++E3MgwFC/Ih/N2ebFX9EcV9Vcc6oVWMdwT0fsrhu683rq
63 6GSR/3iVX1G/pmOiuaR0fNUaCyCfYrnI4zHBDgSfnlm3vIvN2lrsR/DQBakNL8DJ
64 HBgKxMGeUPoneBv+c8DMXIL0EhaFXRlBv9QW45/GiAIOuyFJ0i6hCtGZpJjq4OpQ
65 BRjCI+izPzFTjsxD4aORE+WOkyWFCGPWKfNejfw0
66 -----END CERTIFICATE-----`
67
68 keyData = `-----BEGIN RSA PRIVATE KEY-----
69 MIIEowIBAAKCAQEAq12HPT64ItfDlxJiez2tT9eJ8VKlv/HbhYN2ubvZL+9v0E9i
70 wBK2I/Xjxu7KWvVI642LyxMBmaVUMMikhXAwt9/6jaspgOJyf1+NutPFM6OUjikc
71 kEf4lTcAnS1tqO6mKiHvSPAVLShinC2d6DbNycTZniXqaGuPaiStzlDX9fYjYcKG
72 0hTglhOKw5u2KfXRAolfTUIt9hcktry6lbMpnj8Y5wcb54BXeNdaheJ22NvVSzDv
73 RqahNnqYhUKSK7VDAKhrz7JyiMZ9mG+oywCnXnfplwK6X41r4BsK/jMpLsWRBBoi
74 ZWtd1SIVqewiih8aXD8k26gorqyxWlLkhzemBwIDAQABAoIBAD2XYRs3JrGHQUpU
75 FkdbVKZkvrSY0vAZOqBTLuH0zUv4UATb8487anGkWBjRDLQCgxH+jucPTrztekQK
76 aW94clo0S3aNtV4YhbSYIHWs1a0It0UdK6ID7CmdWkAj6s0T8W8lQT7C46mWYVLm
77 5mFnCTHi6aB42jZrqmEpC7sivWwuU0xqj3Ml8kkxQCGmyc9JjmCB4OrFFC8NNt6M
78 ObvQkUI6Z3nO4phTbpxkE1/9dT0MmPIF7GhHVzJMS+EyyRYUDllZ0wvVSOM3qZT0
79 JMUaBerkNwm9foKJ1+dv2nMKZZbJajv7suUDCfU44mVeaEO+4kmTKSGCGjjTBGkr
80 7L1ySDECgYEA5ElIMhpdBzIivCuBIH8LlUeuzd93pqssO1G2Xg0jHtfM4tz7fyeI
81 cr90dc8gpli24dkSxzLeg3Tn3wIj/Bu64m2TpZPZEIlukYvgdgArmRIPQVxerYey
82 OkrfTNkxU1HXsYjLCdGcGXs5lmb+K/kuTcFxaMOs7jZi7La+jEONwf8CgYEAwCs/
83 rUOOA0klDsWWisbivOiNPII79c9McZCNBqncCBfMUoiGe8uWDEO4TFHN60vFuVk9
84 8PkwpCfvaBUX+ajvbafIfHxsnfk1M04WLGCeqQ/ym5Q4sQoQOcC1b1y9qc/xEWfg
85 nIUuia0ukYRpl7qQa3tNg+BNFyjypW8zukUAC/kCgYB1/Kojuxx5q5/oQVPrx73k
86 2bevD+B3c+DYh9MJqSCNwFtUpYIWpggPxoQan4LwdsmO0PKzocb/ilyNFj4i/vII
87 NToqSc/WjDFpaDIKyuu9oWfhECye45NqLWhb/6VOuu4QA/Nsj7luMhIBehnEAHW+
88 GkzTKM8oD1PxpEG3nPKXYQKBgQC6AuMPRt3XBl1NkCrpSBy/uObFlFaP2Enpf39S
89 3OZ0Gv0XQrnSaL1kP8TMcz68rMrGX8DaWYsgytstR4W+jyy7WvZwsUu+GjTJ5aMG
90 77uEcEBpIi9CBzivfn7hPccE8ZgqPf+n4i6q66yxBJflW5xhvafJqDtW2LcPNbW/
91 bvzdmQKBgExALRUXpq+5dbmkdXBHtvXdRDZ6rVmrnjy4nI5bPw+1GqQqk6uAR6B/
92 F6NmLCQOO4PDG/cuatNHIr2FrwTmGdEL6ObLUGWn9Oer9gJhHVqqsY5I4sEPo4XX
93 stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
94 -----END RSA PRIVATE KEY-----`
95 )
96
97 func TestNew(t *testing.T) {
98 globalGetCert := &GetCertHolder{
99 GetCert: func() (*tls.Certificate, error) { return nil, nil },
100 }
101 globalDial := &DialHolder{
102 Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
103 }
104
105 testCases := map[string]struct {
106 Config *Config
107 Err bool
108 TLS bool
109 TLSCert bool
110 TLSErr bool
111 Default bool
112 Insecure bool
113 DefaultRoots bool
114 }{
115 "default transport": {
116 Default: true,
117 Config: &Config{},
118 },
119
120 "insecure": {
121 TLS: true,
122 Insecure: true,
123 DefaultRoots: true,
124 Config: &Config{TLS: TLSConfig{
125 Insecure: true,
126 }},
127 },
128
129 "server name": {
130 TLS: true,
131 DefaultRoots: true,
132 Config: &Config{TLS: TLSConfig{
133 ServerName: "foo",
134 }},
135 },
136
137 "ca transport": {
138 TLS: true,
139 Config: &Config{
140 TLS: TLSConfig{
141 CAData: []byte(rootCACert),
142 },
143 },
144 },
145 "bad ca file transport": {
146 Err: true,
147 Config: &Config{
148 TLS: TLSConfig{
149 CAFile: "invalid file",
150 },
151 },
152 },
153 "bad ca data transport": {
154 Err: true,
155 Config: &Config{
156 TLS: TLSConfig{
157 CAData: []byte(rootCACert + "this is not valid"),
158 },
159 },
160 },
161 "ca data overriding bad ca file transport": {
162 TLS: true,
163 Config: &Config{
164 TLS: TLSConfig{
165 CAData: []byte(rootCACert),
166 CAFile: "invalid file",
167 },
168 },
169 },
170
171 "cert transport": {
172 TLS: true,
173 TLSCert: true,
174 Config: &Config{
175 TLS: TLSConfig{
176 CAData: []byte(rootCACert),
177 CertData: []byte(certData),
178 KeyData: []byte(keyData),
179 },
180 },
181 },
182 "bad cert data transport": {
183 Err: true,
184 Config: &Config{
185 TLS: TLSConfig{
186 CAData: []byte(rootCACert),
187 CertData: []byte(certData),
188 KeyData: []byte("bad key data"),
189 },
190 },
191 },
192 "bad file cert transport": {
193 Err: true,
194 Config: &Config{
195 TLS: TLSConfig{
196 CAData: []byte(rootCACert),
197 CertData: []byte(certData),
198 KeyFile: "invalid file",
199 },
200 },
201 },
202 "key data overriding bad file cert transport": {
203 TLS: true,
204 TLSCert: true,
205 Config: &Config{
206 TLS: TLSConfig{
207 CAData: []byte(rootCACert),
208 CertData: []byte(certData),
209 KeyData: []byte(keyData),
210 KeyFile: "invalid file",
211 },
212 },
213 },
214 "callback cert and key": {
215 TLS: true,
216 TLSCert: true,
217 Config: &Config{
218 TLS: TLSConfig{
219 CAData: []byte(rootCACert),
220 GetCertHolder: &GetCertHolder{
221 GetCert: func() (*tls.Certificate, error) {
222 crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
223 return &crt, err
224 },
225 },
226 },
227 },
228 },
229 "cert callback error": {
230 TLS: true,
231 TLSCert: true,
232 TLSErr: true,
233 Config: &Config{
234 TLS: TLSConfig{
235 CAData: []byte(rootCACert),
236 GetCertHolder: &GetCertHolder{
237 GetCert: func() (*tls.Certificate, error) {
238 return nil, errors.New("GetCert failure")
239 },
240 },
241 },
242 },
243 },
244 "cert data overrides empty callback result": {
245 TLS: true,
246 TLSCert: true,
247 Config: &Config{
248 TLS: TLSConfig{
249 CAData: []byte(rootCACert),
250 GetCertHolder: &GetCertHolder{
251 GetCert: func() (*tls.Certificate, error) {
252 return nil, nil
253 },
254 },
255 CertData: []byte(certData),
256 KeyData: []byte(keyData),
257 },
258 },
259 },
260 "callback returns nothing": {
261 TLS: true,
262 TLSCert: true,
263 Config: &Config{
264 TLS: TLSConfig{
265 CAData: []byte(rootCACert),
266 GetCertHolder: &GetCertHolder{
267 GetCert: func() (*tls.Certificate, error) {
268 return nil, nil
269 },
270 },
271 },
272 },
273 },
274 "nil holders": {
275 Config: &Config{
276 TLS: TLSConfig{
277 GetCertHolder: nil,
278 },
279 DialHolder: nil,
280 },
281 Err: false,
282 TLS: false,
283 TLSCert: false,
284 TLSErr: false,
285 Default: true,
286 Insecure: false,
287 DefaultRoots: false,
288 },
289 "non-nil dial holder and nil internal": {
290 Config: &Config{
291 TLS: TLSConfig{
292 GetCertHolder: nil,
293 },
294 DialHolder: &DialHolder{},
295 },
296 Err: true,
297 },
298 "non-nil cert holder and nil internal": {
299 Config: &Config{
300 TLS: TLSConfig{
301 GetCertHolder: &GetCertHolder{},
302 },
303 DialHolder: nil,
304 },
305 Err: true,
306 },
307 "non-nil dial holder+internal": {
308 Config: &Config{
309 TLS: TLSConfig{
310 GetCertHolder: nil,
311 },
312 DialHolder: &DialHolder{
313 Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
314 },
315 },
316 Err: false,
317 TLS: true,
318 TLSCert: false,
319 TLSErr: false,
320 Default: false,
321 Insecure: false,
322 DefaultRoots: true,
323 },
324 "non-nil cert holder+internal": {
325 Config: &Config{
326 TLS: TLSConfig{
327 GetCertHolder: &GetCertHolder{
328 GetCert: func() (*tls.Certificate, error) { return nil, nil },
329 },
330 },
331 DialHolder: nil,
332 },
333 Err: false,
334 TLS: true,
335 TLSCert: true,
336 TLSErr: false,
337 Default: false,
338 Insecure: false,
339 DefaultRoots: true,
340 },
341 "non-nil holders+internal with global address": {
342 Config: &Config{
343 TLS: TLSConfig{
344 GetCertHolder: globalGetCert,
345 },
346 DialHolder: globalDial,
347 },
348 Err: false,
349 TLS: true,
350 TLSCert: true,
351 TLSErr: false,
352 Default: false,
353 Insecure: false,
354 DefaultRoots: true,
355 },
356 }
357 for k, testCase := range testCases {
358 t.Run(k, func(t *testing.T) {
359 rt, err := New(testCase.Config)
360 switch {
361 case testCase.Err && err == nil:
362 t.Fatal("unexpected non-error")
363 case !testCase.Err && err != nil:
364 t.Fatalf("unexpected error: %v", err)
365 }
366 if testCase.Err {
367 return
368 }
369
370 switch {
371 case testCase.Default && rt != http.DefaultTransport:
372 t.Fatalf("got %#v, expected the default transport", rt)
373 case !testCase.Default && rt == http.DefaultTransport:
374 t.Fatalf("got %#v, expected non-default transport", rt)
375 }
376
377
378 transport := rt.(*http.Transport)
379 switch {
380 case testCase.TLS && transport.TLSClientConfig == nil:
381 t.Fatalf("got %#v, expected TLSClientConfig", transport)
382 case !testCase.TLS && transport.TLSClientConfig != nil:
383 t.Fatalf("got %#v, expected no TLSClientConfig", transport)
384 }
385 if !testCase.TLS {
386 return
387 }
388
389 switch {
390 case testCase.DefaultRoots && transport.TLSClientConfig.RootCAs != nil:
391 t.Fatalf("got %#v, expected nil root CAs", transport.TLSClientConfig.RootCAs)
392 case !testCase.DefaultRoots && transport.TLSClientConfig.RootCAs == nil:
393 t.Fatalf("got %#v, expected non-nil root CAs", transport.TLSClientConfig.RootCAs)
394 }
395
396 switch {
397 case testCase.Insecure != transport.TLSClientConfig.InsecureSkipVerify:
398 t.Fatalf("got %#v, expected %#v", transport.TLSClientConfig.InsecureSkipVerify, testCase.Insecure)
399 }
400
401 switch {
402 case testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate == nil:
403 t.Fatalf("got %#v, expected TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
404 case !testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate != nil:
405 t.Fatalf("got %#v, expected no TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
406 }
407 if !testCase.TLSCert {
408 return
409 }
410
411 _, err = transport.TLSClientConfig.GetClientCertificate(nil)
412 switch {
413 case testCase.TLSErr && err == nil:
414 t.Error("got nil error from GetClientCertificate, expected non-nil")
415 case !testCase.TLSErr && err != nil:
416 t.Errorf("got error from GetClientCertificate: %q, expected nil", err)
417 }
418 })
419 }
420 }
421
422 type fakeRoundTripper struct {
423 Req *http.Request
424 Resp *http.Response
425 Err error
426 }
427
428 func (rt *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
429 rt.Req = req
430 return rt.Resp, rt.Err
431 }
432
433 type chainRoundTripper struct {
434 rt http.RoundTripper
435 value string
436 }
437
438 func testChain(value string) WrapperFunc {
439 return func(rt http.RoundTripper) http.RoundTripper {
440 return &chainRoundTripper{rt: rt, value: value}
441 }
442 }
443
444 func (rt *chainRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
445 resp, err := rt.rt.RoundTrip(req)
446 if resp != nil {
447 if resp.Header == nil {
448 resp.Header = make(http.Header)
449 }
450 resp.Header.Set("Value", resp.Header.Get("Value")+rt.value)
451 }
452 return resp, err
453 }
454
455 func TestWrappers(t *testing.T) {
456 resp1 := &http.Response{}
457 wrapperResp1 := func(rt http.RoundTripper) http.RoundTripper {
458 return &fakeRoundTripper{Resp: resp1}
459 }
460 resp2 := &http.Response{}
461 wrapperResp2 := func(rt http.RoundTripper) http.RoundTripper {
462 return &fakeRoundTripper{Resp: resp2}
463 }
464
465 tests := []struct {
466 name string
467 fns []WrapperFunc
468 wantNil bool
469 want func(*http.Response) bool
470 }{
471 {fns: []WrapperFunc{}, wantNil: true},
472 {fns: []WrapperFunc{nil, nil}, wantNil: true},
473 {fns: []WrapperFunc{nil}, wantNil: false},
474
475 {fns: []WrapperFunc{nil, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
476 {fns: []WrapperFunc{wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
477 {fns: []WrapperFunc{nil, wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
478 {fns: []WrapperFunc{nil, wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
479 {fns: []WrapperFunc{wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
480 {fns: []WrapperFunc{wrapperResp2, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
481
482 {fns: []WrapperFunc{testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "1" }},
483 {fns: []WrapperFunc{testChain("1"), testChain("2")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "12" }},
484 {fns: []WrapperFunc{testChain("2"), testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "21" }},
485 {fns: []WrapperFunc{testChain("1"), testChain("2"), testChain("3")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "123" }},
486 }
487 for _, tt := range tests {
488 t.Run(tt.name, func(t *testing.T) {
489 got := Wrappers(tt.fns...)
490 if got == nil != tt.wantNil {
491 t.Errorf("Wrappers() = %v", got)
492 return
493 }
494 if got == nil {
495 return
496 }
497
498 rt := &fakeRoundTripper{Resp: &http.Response{}}
499 nested := got(rt)
500 req := &http.Request{}
501 resp, _ := nested.RoundTrip(req)
502 if tt.want != nil && !tt.want(resp) {
503 t.Errorf("unexpected response: %#v", resp)
504 }
505 })
506 }
507 }
508
509 func Test_contextCanceller_RoundTrip(t *testing.T) {
510 tests := []struct {
511 name string
512 open bool
513 want bool
514 }{
515 {name: "open context should call nested round tripper", open: true, want: true},
516 {name: "closed context should return a known error", open: false, want: false},
517 }
518 for _, tt := range tests {
519 t.Run(tt.name, func(t *testing.T) {
520 req := &http.Request{}
521 rt := &fakeRoundTripper{Resp: &http.Response{}}
522 ctx := context.Background()
523 if !tt.open {
524 c, fn := context.WithCancel(ctx)
525 fn()
526 ctx = c
527 }
528 errTesting := fmt.Errorf("testing")
529 b := &contextCanceller{
530 rt: rt,
531 ctx: ctx,
532 err: errTesting,
533 }
534 got, err := b.RoundTrip(req)
535 if tt.want {
536 if err != nil {
537 t.Errorf("unexpected error: %v", err)
538 }
539 if got != rt.Resp {
540 t.Errorf("wanted response")
541 }
542 if req != rt.Req {
543 t.Errorf("expect nested call")
544 }
545 } else {
546 if err != errTesting {
547 t.Errorf("unexpected error: %v", err)
548 }
549 if got != nil {
550 t.Errorf("wanted no response")
551 }
552 if rt.Req != nil {
553 t.Errorf("want no nested call")
554 }
555 }
556 })
557 }
558 }
559
View as plain text