...

Source file src/k8s.io/client-go/transport/transport_test.go

Documentation: k8s.io/client-go/transport

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package transport
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"net/http"
    26  	"testing"
    27  )
    28  
    29  const (
    30  	rootCACert = `-----BEGIN CERTIFICATE-----
    31  MIIC4DCCAcqgAwIBAgIBATALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
    32  MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIxNTczN1oXDTE2MDExNTIxNTcz
    33  OFowIzEhMB8GA1UEAwwYMTAuMTMuMTI5LjEwNkAxNDIxMzU5MDU4MIIBIjANBgkq
    34  hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAunDRXGwsiYWGFDlWH6kjGun+PshDGeZX
    35  xtx9lUnL8pIRWH3wX6f13PO9sktaOWW0T0mlo6k2bMlSLlSZgG9H6og0W6gLS3vq
    36  s4VavZ6DbXIwemZG2vbRwsvR+t4G6Nbwelm6F8RFnA1Fwt428pavmNQ/wgYzo+T1
    37  1eS+HiN4ACnSoDSx3QRWcgBkB1g6VReofVjx63i0J+w8Q/41L9GUuLqquFxu6ZnH
    38  60vTB55lHgFiDLjA1FkEz2dGvGh/wtnFlRvjaPC54JH2K1mPYAUXTreoeJtLJKX0
    39  ycoiyB24+zGCniUmgIsmQWRPaOPircexCp1BOeze82BT1LCZNTVaxQIDAQABoyMw
    40  ITAOBgNVHQ8BAf8EBAMCAKQwDwYDVR0TAQH/BAUwAwEB/zALBgkqhkiG9w0BAQsD
    41  ggEBADMxsUuAFlsYDpF4fRCzXXwrhbtj4oQwcHpbu+rnOPHCZupiafzZpDu+rw4x
    42  YGPnCb594bRTQn4pAu3Ac18NbLD5pV3uioAkv8oPkgr8aUhXqiv7KdDiaWm6sbAL
    43  EHiXVBBAFvQws10HMqMoKtO8f1XDNAUkWduakR/U6yMgvOPwS7xl0eUTqyRB6zGb
    44  K55q2dejiFWaFqB/y78txzvz6UlOZKE44g2JAVoJVM6kGaxh33q8/FmrL4kuN3ut
    45  W+MmJCVDvd4eEqPwbp7146ZWTqpIJ8lvA6wuChtqV8lhAPka2hD/LMqY8iXNmfXD
    46  uml0obOEy+ON91k+SWTJ3ggmF/U=
    47  -----END CERTIFICATE-----`
    48  
    49  	certData = `-----BEGIN CERTIFICATE-----
    50  MIIC6jCCAdSgAwIBAgIBCzALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
    51  MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIyMDEzMVoXDTE2MDExNTIyMDEz
    52  MlowGzEZMBcGA1UEAxMQb3BlbnNoaWZ0LWNsaWVudDCCASIwDQYJKoZIhvcNAQEB
    53  BQADggEPADCCAQoCggEBAKtdhz0+uCLXw5cSYns9rU/XifFSpb/x24WDdrm72S/v
    54  b9BPYsAStiP148buylr1SOuNi8sTAZmlVDDIpIVwMLff+o2rKYDicn9fjbrTxTOj
    55  lI4pHJBH+JU3AJ0tbajupioh70jwFS0oYpwtneg2zcnE2Z4l6mhrj2okrc5Q1/X2
    56  I2HChtIU4JYTisObtin10QKJX01CLfYXJLa8upWzKZ4/GOcHG+eAV3jXWoXidtjb
    57  1Usw70amoTZ6mIVCkiu1QwCoa8+ycojGfZhvqMsAp1536ZcCul+Na+AbCv4zKS7F
    58  kQQaImVrXdUiFansIoofGlw/JNuoKK6ssVpS5Ic3pgcCAwEAAaM1MDMwDgYDVR0P
    59  AQH/BAQDAgCgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMAwGA1UdEwEB/wQCMAAwCwYJ
    60  KoZIhvcNAQELA4IBAQCKLREH7bXtXtZ+8vI6cjD7W3QikiArGqbl36bAhhWsJLp/
    61  p/ndKz39iFNaiZ3GlwIURWOOKx3y3GA0x9m8FR+Llthf0EQ8sUjnwaknWs0Y6DQ3
    62  jjPFZOpV3KPCFrdMJ3++E3MgwFC/Ih/N2ebFX9EcV9Vcc6oVWMdwT0fsrhu683rq
    63  6GSR/3iVX1G/pmOiuaR0fNUaCyCfYrnI4zHBDgSfnlm3vIvN2lrsR/DQBakNL8DJ
    64  HBgKxMGeUPoneBv+c8DMXIL0EhaFXRlBv9QW45/GiAIOuyFJ0i6hCtGZpJjq4OpQ
    65  BRjCI+izPzFTjsxD4aORE+WOkyWFCGPWKfNejfw0
    66  -----END CERTIFICATE-----`
    67  
    68  	keyData = `-----BEGIN RSA PRIVATE KEY-----
    69  MIIEowIBAAKCAQEAq12HPT64ItfDlxJiez2tT9eJ8VKlv/HbhYN2ubvZL+9v0E9i
    70  wBK2I/Xjxu7KWvVI642LyxMBmaVUMMikhXAwt9/6jaspgOJyf1+NutPFM6OUjikc
    71  kEf4lTcAnS1tqO6mKiHvSPAVLShinC2d6DbNycTZniXqaGuPaiStzlDX9fYjYcKG
    72  0hTglhOKw5u2KfXRAolfTUIt9hcktry6lbMpnj8Y5wcb54BXeNdaheJ22NvVSzDv
    73  RqahNnqYhUKSK7VDAKhrz7JyiMZ9mG+oywCnXnfplwK6X41r4BsK/jMpLsWRBBoi
    74  ZWtd1SIVqewiih8aXD8k26gorqyxWlLkhzemBwIDAQABAoIBAD2XYRs3JrGHQUpU
    75  FkdbVKZkvrSY0vAZOqBTLuH0zUv4UATb8487anGkWBjRDLQCgxH+jucPTrztekQK
    76  aW94clo0S3aNtV4YhbSYIHWs1a0It0UdK6ID7CmdWkAj6s0T8W8lQT7C46mWYVLm
    77  5mFnCTHi6aB42jZrqmEpC7sivWwuU0xqj3Ml8kkxQCGmyc9JjmCB4OrFFC8NNt6M
    78  ObvQkUI6Z3nO4phTbpxkE1/9dT0MmPIF7GhHVzJMS+EyyRYUDllZ0wvVSOM3qZT0
    79  JMUaBerkNwm9foKJ1+dv2nMKZZbJajv7suUDCfU44mVeaEO+4kmTKSGCGjjTBGkr
    80  7L1ySDECgYEA5ElIMhpdBzIivCuBIH8LlUeuzd93pqssO1G2Xg0jHtfM4tz7fyeI
    81  cr90dc8gpli24dkSxzLeg3Tn3wIj/Bu64m2TpZPZEIlukYvgdgArmRIPQVxerYey
    82  OkrfTNkxU1HXsYjLCdGcGXs5lmb+K/kuTcFxaMOs7jZi7La+jEONwf8CgYEAwCs/
    83  rUOOA0klDsWWisbivOiNPII79c9McZCNBqncCBfMUoiGe8uWDEO4TFHN60vFuVk9
    84  8PkwpCfvaBUX+ajvbafIfHxsnfk1M04WLGCeqQ/ym5Q4sQoQOcC1b1y9qc/xEWfg
    85  nIUuia0ukYRpl7qQa3tNg+BNFyjypW8zukUAC/kCgYB1/Kojuxx5q5/oQVPrx73k
    86  2bevD+B3c+DYh9MJqSCNwFtUpYIWpggPxoQan4LwdsmO0PKzocb/ilyNFj4i/vII
    87  NToqSc/WjDFpaDIKyuu9oWfhECye45NqLWhb/6VOuu4QA/Nsj7luMhIBehnEAHW+
    88  GkzTKM8oD1PxpEG3nPKXYQKBgQC6AuMPRt3XBl1NkCrpSBy/uObFlFaP2Enpf39S
    89  3OZ0Gv0XQrnSaL1kP8TMcz68rMrGX8DaWYsgytstR4W+jyy7WvZwsUu+GjTJ5aMG
    90  77uEcEBpIi9CBzivfn7hPccE8ZgqPf+n4i6q66yxBJflW5xhvafJqDtW2LcPNbW/
    91  bvzdmQKBgExALRUXpq+5dbmkdXBHtvXdRDZ6rVmrnjy4nI5bPw+1GqQqk6uAR6B/
    92  F6NmLCQOO4PDG/cuatNHIr2FrwTmGdEL6ObLUGWn9Oer9gJhHVqqsY5I4sEPo4XX
    93  stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
    94  -----END RSA PRIVATE KEY-----`
    95  )
    96  
    97  func TestNew(t *testing.T) {
    98  	globalGetCert := &GetCertHolder{
    99  		GetCert: func() (*tls.Certificate, error) { return nil, nil },
   100  	}
   101  	globalDial := &DialHolder{
   102  		Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
   103  	}
   104  
   105  	testCases := map[string]struct {
   106  		Config       *Config
   107  		Err          bool
   108  		TLS          bool
   109  		TLSCert      bool
   110  		TLSErr       bool
   111  		Default      bool
   112  		Insecure     bool
   113  		DefaultRoots bool
   114  	}{
   115  		"default transport": {
   116  			Default: true,
   117  			Config:  &Config{},
   118  		},
   119  
   120  		"insecure": {
   121  			TLS:          true,
   122  			Insecure:     true,
   123  			DefaultRoots: true,
   124  			Config: &Config{TLS: TLSConfig{
   125  				Insecure: true,
   126  			}},
   127  		},
   128  
   129  		"server name": {
   130  			TLS:          true,
   131  			DefaultRoots: true,
   132  			Config: &Config{TLS: TLSConfig{
   133  				ServerName: "foo",
   134  			}},
   135  		},
   136  
   137  		"ca transport": {
   138  			TLS: true,
   139  			Config: &Config{
   140  				TLS: TLSConfig{
   141  					CAData: []byte(rootCACert),
   142  				},
   143  			},
   144  		},
   145  		"bad ca file transport": {
   146  			Err: true,
   147  			Config: &Config{
   148  				TLS: TLSConfig{
   149  					CAFile: "invalid file",
   150  				},
   151  			},
   152  		},
   153  		"bad ca data transport": {
   154  			Err: true,
   155  			Config: &Config{
   156  				TLS: TLSConfig{
   157  					CAData: []byte(rootCACert + "this is not valid"),
   158  				},
   159  			},
   160  		},
   161  		"ca data overriding bad ca file transport": {
   162  			TLS: true,
   163  			Config: &Config{
   164  				TLS: TLSConfig{
   165  					CAData: []byte(rootCACert),
   166  					CAFile: "invalid file",
   167  				},
   168  			},
   169  		},
   170  
   171  		"cert transport": {
   172  			TLS:     true,
   173  			TLSCert: true,
   174  			Config: &Config{
   175  				TLS: TLSConfig{
   176  					CAData:   []byte(rootCACert),
   177  					CertData: []byte(certData),
   178  					KeyData:  []byte(keyData),
   179  				},
   180  			},
   181  		},
   182  		"bad cert data transport": {
   183  			Err: true,
   184  			Config: &Config{
   185  				TLS: TLSConfig{
   186  					CAData:   []byte(rootCACert),
   187  					CertData: []byte(certData),
   188  					KeyData:  []byte("bad key data"),
   189  				},
   190  			},
   191  		},
   192  		"bad file cert transport": {
   193  			Err: true,
   194  			Config: &Config{
   195  				TLS: TLSConfig{
   196  					CAData:   []byte(rootCACert),
   197  					CertData: []byte(certData),
   198  					KeyFile:  "invalid file",
   199  				},
   200  			},
   201  		},
   202  		"key data overriding bad file cert transport": {
   203  			TLS:     true,
   204  			TLSCert: true,
   205  			Config: &Config{
   206  				TLS: TLSConfig{
   207  					CAData:   []byte(rootCACert),
   208  					CertData: []byte(certData),
   209  					KeyData:  []byte(keyData),
   210  					KeyFile:  "invalid file",
   211  				},
   212  			},
   213  		},
   214  		"callback cert and key": {
   215  			TLS:     true,
   216  			TLSCert: true,
   217  			Config: &Config{
   218  				TLS: TLSConfig{
   219  					CAData: []byte(rootCACert),
   220  					GetCertHolder: &GetCertHolder{
   221  						GetCert: func() (*tls.Certificate, error) {
   222  							crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
   223  							return &crt, err
   224  						},
   225  					},
   226  				},
   227  			},
   228  		},
   229  		"cert callback error": {
   230  			TLS:     true,
   231  			TLSCert: true,
   232  			TLSErr:  true,
   233  			Config: &Config{
   234  				TLS: TLSConfig{
   235  					CAData: []byte(rootCACert),
   236  					GetCertHolder: &GetCertHolder{
   237  						GetCert: func() (*tls.Certificate, error) {
   238  							return nil, errors.New("GetCert failure")
   239  						},
   240  					},
   241  				},
   242  			},
   243  		},
   244  		"cert data overrides empty callback result": {
   245  			TLS:     true,
   246  			TLSCert: true,
   247  			Config: &Config{
   248  				TLS: TLSConfig{
   249  					CAData: []byte(rootCACert),
   250  					GetCertHolder: &GetCertHolder{
   251  						GetCert: func() (*tls.Certificate, error) {
   252  							return nil, nil
   253  						},
   254  					},
   255  					CertData: []byte(certData),
   256  					KeyData:  []byte(keyData),
   257  				},
   258  			},
   259  		},
   260  		"callback returns nothing": {
   261  			TLS:     true,
   262  			TLSCert: true,
   263  			Config: &Config{
   264  				TLS: TLSConfig{
   265  					CAData: []byte(rootCACert),
   266  					GetCertHolder: &GetCertHolder{
   267  						GetCert: func() (*tls.Certificate, error) {
   268  							return nil, nil
   269  						},
   270  					},
   271  				},
   272  			},
   273  		},
   274  		"nil holders": {
   275  			Config: &Config{
   276  				TLS: TLSConfig{
   277  					GetCertHolder: nil,
   278  				},
   279  				DialHolder: nil,
   280  			},
   281  			Err:          false,
   282  			TLS:          false,
   283  			TLSCert:      false,
   284  			TLSErr:       false,
   285  			Default:      true,
   286  			Insecure:     false,
   287  			DefaultRoots: false,
   288  		},
   289  		"non-nil dial holder and nil internal": {
   290  			Config: &Config{
   291  				TLS: TLSConfig{
   292  					GetCertHolder: nil,
   293  				},
   294  				DialHolder: &DialHolder{},
   295  			},
   296  			Err: true,
   297  		},
   298  		"non-nil cert holder and nil internal": {
   299  			Config: &Config{
   300  				TLS: TLSConfig{
   301  					GetCertHolder: &GetCertHolder{},
   302  				},
   303  				DialHolder: nil,
   304  			},
   305  			Err: true,
   306  		},
   307  		"non-nil dial holder+internal": {
   308  			Config: &Config{
   309  				TLS: TLSConfig{
   310  					GetCertHolder: nil,
   311  				},
   312  				DialHolder: &DialHolder{
   313  					Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
   314  				},
   315  			},
   316  			Err:          false,
   317  			TLS:          true,
   318  			TLSCert:      false,
   319  			TLSErr:       false,
   320  			Default:      false,
   321  			Insecure:     false,
   322  			DefaultRoots: true,
   323  		},
   324  		"non-nil cert holder+internal": {
   325  			Config: &Config{
   326  				TLS: TLSConfig{
   327  					GetCertHolder: &GetCertHolder{
   328  						GetCert: func() (*tls.Certificate, error) { return nil, nil },
   329  					},
   330  				},
   331  				DialHolder: nil,
   332  			},
   333  			Err:          false,
   334  			TLS:          true,
   335  			TLSCert:      true,
   336  			TLSErr:       false,
   337  			Default:      false,
   338  			Insecure:     false,
   339  			DefaultRoots: true,
   340  		},
   341  		"non-nil holders+internal with global address": {
   342  			Config: &Config{
   343  				TLS: TLSConfig{
   344  					GetCertHolder: globalGetCert,
   345  				},
   346  				DialHolder: globalDial,
   347  			},
   348  			Err:          false,
   349  			TLS:          true,
   350  			TLSCert:      true,
   351  			TLSErr:       false,
   352  			Default:      false,
   353  			Insecure:     false,
   354  			DefaultRoots: true,
   355  		},
   356  	}
   357  	for k, testCase := range testCases {
   358  		t.Run(k, func(t *testing.T) {
   359  			rt, err := New(testCase.Config)
   360  			switch {
   361  			case testCase.Err && err == nil:
   362  				t.Fatal("unexpected non-error")
   363  			case !testCase.Err && err != nil:
   364  				t.Fatalf("unexpected error: %v", err)
   365  			}
   366  			if testCase.Err {
   367  				return
   368  			}
   369  
   370  			switch {
   371  			case testCase.Default && rt != http.DefaultTransport:
   372  				t.Fatalf("got %#v, expected the default transport", rt)
   373  			case !testCase.Default && rt == http.DefaultTransport:
   374  				t.Fatalf("got %#v, expected non-default transport", rt)
   375  			}
   376  
   377  			// We only know how to check TLSConfig on http.Transports
   378  			transport := rt.(*http.Transport)
   379  			switch {
   380  			case testCase.TLS && transport.TLSClientConfig == nil:
   381  				t.Fatalf("got %#v, expected TLSClientConfig", transport)
   382  			case !testCase.TLS && transport.TLSClientConfig != nil:
   383  				t.Fatalf("got %#v, expected no TLSClientConfig", transport)
   384  			}
   385  			if !testCase.TLS {
   386  				return
   387  			}
   388  
   389  			switch {
   390  			case testCase.DefaultRoots && transport.TLSClientConfig.RootCAs != nil:
   391  				t.Fatalf("got %#v, expected nil root CAs", transport.TLSClientConfig.RootCAs)
   392  			case !testCase.DefaultRoots && transport.TLSClientConfig.RootCAs == nil:
   393  				t.Fatalf("got %#v, expected non-nil root CAs", transport.TLSClientConfig.RootCAs)
   394  			}
   395  
   396  			switch {
   397  			case testCase.Insecure != transport.TLSClientConfig.InsecureSkipVerify:
   398  				t.Fatalf("got %#v, expected %#v", transport.TLSClientConfig.InsecureSkipVerify, testCase.Insecure)
   399  			}
   400  
   401  			switch {
   402  			case testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate == nil:
   403  				t.Fatalf("got %#v, expected TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
   404  			case !testCase.TLSCert && transport.TLSClientConfig.GetClientCertificate != nil:
   405  				t.Fatalf("got %#v, expected no TLSClientConfig.GetClientCertificate", transport.TLSClientConfig)
   406  			}
   407  			if !testCase.TLSCert {
   408  				return
   409  			}
   410  
   411  			_, err = transport.TLSClientConfig.GetClientCertificate(nil)
   412  			switch {
   413  			case testCase.TLSErr && err == nil:
   414  				t.Error("got nil error from GetClientCertificate, expected non-nil")
   415  			case !testCase.TLSErr && err != nil:
   416  				t.Errorf("got error from GetClientCertificate: %q, expected nil", err)
   417  			}
   418  		})
   419  	}
   420  }
   421  
   422  type fakeRoundTripper struct {
   423  	Req  *http.Request
   424  	Resp *http.Response
   425  	Err  error
   426  }
   427  
   428  func (rt *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   429  	rt.Req = req
   430  	return rt.Resp, rt.Err
   431  }
   432  
   433  type chainRoundTripper struct {
   434  	rt    http.RoundTripper
   435  	value string
   436  }
   437  
   438  func testChain(value string) WrapperFunc {
   439  	return func(rt http.RoundTripper) http.RoundTripper {
   440  		return &chainRoundTripper{rt: rt, value: value}
   441  	}
   442  }
   443  
   444  func (rt *chainRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   445  	resp, err := rt.rt.RoundTrip(req)
   446  	if resp != nil {
   447  		if resp.Header == nil {
   448  			resp.Header = make(http.Header)
   449  		}
   450  		resp.Header.Set("Value", resp.Header.Get("Value")+rt.value)
   451  	}
   452  	return resp, err
   453  }
   454  
   455  func TestWrappers(t *testing.T) {
   456  	resp1 := &http.Response{}
   457  	wrapperResp1 := func(rt http.RoundTripper) http.RoundTripper {
   458  		return &fakeRoundTripper{Resp: resp1}
   459  	}
   460  	resp2 := &http.Response{}
   461  	wrapperResp2 := func(rt http.RoundTripper) http.RoundTripper {
   462  		return &fakeRoundTripper{Resp: resp2}
   463  	}
   464  
   465  	tests := []struct {
   466  		name    string
   467  		fns     []WrapperFunc
   468  		wantNil bool
   469  		want    func(*http.Response) bool
   470  	}{
   471  		{fns: []WrapperFunc{}, wantNil: true},
   472  		{fns: []WrapperFunc{nil, nil}, wantNil: true},
   473  		{fns: []WrapperFunc{nil}, wantNil: false},
   474  
   475  		{fns: []WrapperFunc{nil, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
   476  		{fns: []WrapperFunc{wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
   477  		{fns: []WrapperFunc{nil, wrapperResp1, nil}, want: func(resp *http.Response) bool { return resp == resp1 }},
   478  		{fns: []WrapperFunc{nil, wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
   479  		{fns: []WrapperFunc{wrapperResp1, wrapperResp2}, want: func(resp *http.Response) bool { return resp == resp2 }},
   480  		{fns: []WrapperFunc{wrapperResp2, wrapperResp1}, want: func(resp *http.Response) bool { return resp == resp1 }},
   481  
   482  		{fns: []WrapperFunc{testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "1" }},
   483  		{fns: []WrapperFunc{testChain("1"), testChain("2")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "12" }},
   484  		{fns: []WrapperFunc{testChain("2"), testChain("1")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "21" }},
   485  		{fns: []WrapperFunc{testChain("1"), testChain("2"), testChain("3")}, want: func(resp *http.Response) bool { return resp.Header.Get("Value") == "123" }},
   486  	}
   487  	for _, tt := range tests {
   488  		t.Run(tt.name, func(t *testing.T) {
   489  			got := Wrappers(tt.fns...)
   490  			if got == nil != tt.wantNil {
   491  				t.Errorf("Wrappers() = %v", got)
   492  				return
   493  			}
   494  			if got == nil {
   495  				return
   496  			}
   497  
   498  			rt := &fakeRoundTripper{Resp: &http.Response{}}
   499  			nested := got(rt)
   500  			req := &http.Request{}
   501  			resp, _ := nested.RoundTrip(req)
   502  			if tt.want != nil && !tt.want(resp) {
   503  				t.Errorf("unexpected response: %#v", resp)
   504  			}
   505  		})
   506  	}
   507  }
   508  
   509  func Test_contextCanceller_RoundTrip(t *testing.T) {
   510  	tests := []struct {
   511  		name string
   512  		open bool
   513  		want bool
   514  	}{
   515  		{name: "open context should call nested round tripper", open: true, want: true},
   516  		{name: "closed context should return a known error", open: false, want: false},
   517  	}
   518  	for _, tt := range tests {
   519  		t.Run(tt.name, func(t *testing.T) {
   520  			req := &http.Request{}
   521  			rt := &fakeRoundTripper{Resp: &http.Response{}}
   522  			ctx := context.Background()
   523  			if !tt.open {
   524  				c, fn := context.WithCancel(ctx)
   525  				fn()
   526  				ctx = c
   527  			}
   528  			errTesting := fmt.Errorf("testing")
   529  			b := &contextCanceller{
   530  				rt:  rt,
   531  				ctx: ctx,
   532  				err: errTesting,
   533  			}
   534  			got, err := b.RoundTrip(req)
   535  			if tt.want {
   536  				if err != nil {
   537  					t.Errorf("unexpected error: %v", err)
   538  				}
   539  				if got != rt.Resp {
   540  					t.Errorf("wanted response")
   541  				}
   542  				if req != rt.Req {
   543  					t.Errorf("expect nested call")
   544  				}
   545  			} else {
   546  				if err != errTesting {
   547  					t.Errorf("unexpected error: %v", err)
   548  				}
   549  				if got != nil {
   550  					t.Errorf("wanted no response")
   551  				}
   552  				if rt.Req != nil {
   553  					t.Errorf("want no nested call")
   554  				}
   555  			}
   556  		})
   557  	}
   558  }
   559  

View as plain text