...

Source file src/k8s.io/client-go/transport/token_source_test.go

Documentation: k8s.io/client-go/transport

     1  /*
     2  Copyright 2018 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package transport
    18  
    19  import (
    20  	"fmt"
    21  	"net/http"
    22  	"reflect"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	"golang.org/x/oauth2"
    28  )
    29  
    30  type testTokenSource struct {
    31  	calls int
    32  	tok   *oauth2.Token
    33  	err   error
    34  }
    35  
    36  func (ts *testTokenSource) Token() (*oauth2.Token, error) {
    37  	ts.calls++
    38  	return ts.tok, ts.err
    39  }
    40  
    41  func TestCachingTokenSource(t *testing.T) {
    42  	start := time.Now()
    43  	tokA := &oauth2.Token{
    44  		AccessToken: "a",
    45  		Expiry:      start.Add(10 * time.Minute),
    46  	}
    47  	tokB := &oauth2.Token{
    48  		AccessToken: "b",
    49  		Expiry:      start.Add(20 * time.Minute),
    50  	}
    51  	tests := []struct {
    52  		name string
    53  
    54  		tok   *oauth2.Token
    55  		tsTok *oauth2.Token
    56  		tsErr error
    57  		wait  time.Duration
    58  
    59  		wantTok     *oauth2.Token
    60  		wantErr     bool
    61  		wantTSCalls int
    62  	}{
    63  		{
    64  			name:    "valid token returned from cache",
    65  			tok:     tokA,
    66  			wantTok: tokA,
    67  		},
    68  		{
    69  			name:    "valid token returned from cache 1 minute before scheduled refresh",
    70  			tok:     tokA,
    71  			wait:    8 * time.Minute,
    72  			wantTok: tokA,
    73  		},
    74  		{
    75  			name:        "new token created when cache is empty",
    76  			tsTok:       tokA,
    77  			wantTok:     tokA,
    78  			wantTSCalls: 1,
    79  		},
    80  		{
    81  			name:        "new token created 1 minute after scheduled refresh",
    82  			tok:         tokA,
    83  			tsTok:       tokB,
    84  			wait:        10 * time.Minute,
    85  			wantTok:     tokB,
    86  			wantTSCalls: 1,
    87  		},
    88  		{
    89  			name:        "error on create token returns error",
    90  			tsErr:       fmt.Errorf("error"),
    91  			wantErr:     true,
    92  			wantTSCalls: 1,
    93  		},
    94  	}
    95  	for _, c := range tests {
    96  		t.Run(c.name, func(t *testing.T) {
    97  			tts := &testTokenSource{
    98  				tok: c.tsTok,
    99  				err: c.tsErr,
   100  			}
   101  
   102  			ts := &cachingTokenSource{
   103  				base:   tts,
   104  				tok:    c.tok,
   105  				leeway: 1 * time.Minute,
   106  				now:    func() time.Time { return start.Add(c.wait) },
   107  			}
   108  
   109  			gotTok, gotErr := ts.Token()
   110  			if got, want := gotTok, c.wantTok; !reflect.DeepEqual(got, want) {
   111  				t.Errorf("unexpected token:\n\tgot:\t%#v\n\twant:\t%#v", got, want)
   112  			}
   113  			if got, want := tts.calls, c.wantTSCalls; got != want {
   114  				t.Errorf("unexpected number of Token() calls: got %d, want %d", got, want)
   115  			}
   116  			if gotErr == nil && c.wantErr {
   117  				t.Errorf("wanted error but got none")
   118  			}
   119  			if gotErr != nil && !c.wantErr {
   120  				t.Errorf("unexpected error: %v", gotErr)
   121  			}
   122  		})
   123  	}
   124  }
   125  
   126  func TestCachingTokenSourceRace(t *testing.T) {
   127  	for i := 0; i < 100; i++ {
   128  		tts := &testTokenSource{
   129  			tok: &oauth2.Token{
   130  				AccessToken: "a",
   131  				Expiry:      time.Now().Add(1000 * time.Hour),
   132  			},
   133  		}
   134  
   135  		ts := &cachingTokenSource{
   136  			now:    time.Now,
   137  			base:   tts,
   138  			leeway: 1 * time.Minute,
   139  		}
   140  
   141  		var wg sync.WaitGroup
   142  		wg.Add(100)
   143  		errc := make(chan error, 100)
   144  
   145  		for i := 0; i < 100; i++ {
   146  			go func() {
   147  				defer wg.Done()
   148  				if _, err := ts.Token(); err != nil {
   149  					errc <- err
   150  				}
   151  			}()
   152  		}
   153  		go func() {
   154  			wg.Wait()
   155  			close(errc)
   156  		}()
   157  		if err, ok := <-errc; ok {
   158  			t.Fatalf("err: %v", err)
   159  		}
   160  		if tts.calls != 1 {
   161  			t.Errorf("expected one call to Token() but saw: %d", tts.calls)
   162  		}
   163  	}
   164  }
   165  
   166  func TestTokenSourceTransportRoundTrip(t *testing.T) {
   167  	goodToken := &oauth2.Token{
   168  		AccessToken: "good",
   169  		Expiry:      time.Now().Add(1000 * time.Hour),
   170  	}
   171  	badToken := &oauth2.Token{
   172  		AccessToken: "bad",
   173  		Expiry:      time.Now().Add(1000 * time.Hour),
   174  	}
   175  	tests := []struct {
   176  		name        string
   177  		header      http.Header
   178  		token       *oauth2.Token
   179  		cachedToken *oauth2.Token
   180  		wantCalls   int
   181  		wantCaching bool
   182  	}{
   183  		{
   184  			name:   "skip oauth rt if has authorization header",
   185  			header: map[string][]string{"Authorization": {"Bearer TOKEN"}},
   186  			token:  goodToken,
   187  		},
   188  		{
   189  			name:        "authorized on newly acquired good token",
   190  			token:       goodToken,
   191  			wantCalls:   1,
   192  			wantCaching: true,
   193  		},
   194  		{
   195  			name:        "authorized on cached good token",
   196  			token:       goodToken,
   197  			cachedToken: goodToken,
   198  			wantCalls:   0,
   199  			wantCaching: true,
   200  		},
   201  		{
   202  			name:        "unauthorized on newly acquired bad token",
   203  			token:       badToken,
   204  			wantCalls:   1,
   205  			wantCaching: true,
   206  		},
   207  		{
   208  			name:        "unauthorized on cached bad token",
   209  			token:       badToken,
   210  			cachedToken: badToken,
   211  			wantCalls:   0,
   212  		},
   213  	}
   214  	for _, test := range tests {
   215  		t.Run(test.name, func(t *testing.T) {
   216  			tts := &testTokenSource{
   217  				tok: test.token,
   218  			}
   219  			cachedTokenSource := NewCachedTokenSource(tts)
   220  			cachedTokenSource.tok = test.cachedToken
   221  
   222  			rt := ResettableTokenSourceWrapTransport(cachedTokenSource)(&testTransport{})
   223  
   224  			rt.RoundTrip(&http.Request{Header: test.header})
   225  			if tts.calls != test.wantCalls {
   226  				t.Errorf("RoundTrip() called Token() = %d times, want %d", tts.calls, test.wantCalls)
   227  			}
   228  
   229  			if (cachedTokenSource.tok != nil) != test.wantCaching {
   230  				t.Errorf("Got caching %v, want caching %v", cachedTokenSource != nil, test.wantCaching)
   231  			}
   232  		})
   233  	}
   234  }
   235  
   236  type uncancellableRT struct {
   237  	rt http.RoundTripper
   238  }
   239  
   240  func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error) {
   241  	return urt.rt.RoundTrip(req)
   242  }
   243  
   244  func TestTokenSourceTransportCancelRequest(t *testing.T) {
   245  	tests := []struct {
   246  		name          string
   247  		header        http.Header
   248  		wrapTransport func(http.RoundTripper) http.RoundTripper
   249  		expectCancel  bool
   250  	}{
   251  		{
   252  			name:         "cancel req with bearer token skips oauth rt",
   253  			header:       map[string][]string{"Authorization": {"Bearer TOKEN"}},
   254  			expectCancel: true,
   255  		},
   256  		{
   257  			name: "can't cancel request with rts that doesn't implent unwrap or cancel",
   258  			wrapTransport: func(rt http.RoundTripper) http.RoundTripper {
   259  				return &uncancellableRT{rt: rt}
   260  			},
   261  			expectCancel: false,
   262  		},
   263  	}
   264  	for _, test := range tests {
   265  		t.Run(test.name, func(t *testing.T) {
   266  			baseRecorder := &testTransport{}
   267  
   268  			var base http.RoundTripper = baseRecorder
   269  			if test.wrapTransport != nil {
   270  				base = test.wrapTransport(base)
   271  			}
   272  
   273  			rt := &tokenSourceTransport{
   274  				base: base,
   275  				ort: &oauth2.Transport{
   276  					Base: base,
   277  				},
   278  			}
   279  
   280  			rt.CancelRequest(&http.Request{
   281  				Header: test.header,
   282  			})
   283  
   284  			if baseRecorder.canceled != test.expectCancel {
   285  				t.Errorf("unexpected cancel: got=%v, want=%v", baseRecorder.canceled, test.expectCancel)
   286  			}
   287  		})
   288  	}
   289  }
   290  
   291  type testTransport struct {
   292  	canceled bool
   293  	base     http.RoundTripper
   294  }
   295  
   296  func (rt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   297  	if req.Header["Authorization"][0] == "Bearer bad" {
   298  		return &http.Response{StatusCode: 401}, nil
   299  	}
   300  	return nil, nil
   301  }
   302  
   303  func (rt *testTransport) CancelRequest(req *http.Request) {
   304  	rt.canceled = true
   305  	if rt.base != nil {
   306  		tryCancelRequest(rt.base, req)
   307  	}
   308  }
   309  

View as plain text