...
1
16
17 package transport
18
19 import (
20 "fmt"
21 "net/http"
22 "os"
23 "strings"
24 "sync"
25 "time"
26
27 "golang.org/x/oauth2"
28
29 utilnet "k8s.io/apimachinery/pkg/util/net"
30 "k8s.io/klog/v2"
31 )
32
33
34
35 func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) http.RoundTripper {
36 return func(rt http.RoundTripper) http.RoundTripper {
37 return &tokenSourceTransport{
38 base: rt,
39 ort: &oauth2.Transport{
40 Source: ts,
41 Base: rt,
42 },
43 }
44 }
45 }
46
47 type ResettableTokenSource interface {
48 oauth2.TokenSource
49 ResetTokenOlderThan(time.Time)
50 }
51
52
53
54 func ResettableTokenSourceWrapTransport(ts ResettableTokenSource) func(http.RoundTripper) http.RoundTripper {
55 return func(rt http.RoundTripper) http.RoundTripper {
56 return &tokenSourceTransport{
57 base: rt,
58 ort: &oauth2.Transport{
59 Source: ts,
60 Base: rt,
61 },
62 src: ts,
63 }
64 }
65 }
66
67
68
69 func NewCachedFileTokenSource(path string) *cachingTokenSource {
70 return &cachingTokenSource{
71 now: time.Now,
72 leeway: 10 * time.Second,
73 base: &fileTokenSource{
74 path: path,
75
76
77
78
79 period: time.Minute,
80 },
81 }
82 }
83
84
85
86 func NewCachedTokenSource(ts oauth2.TokenSource) *cachingTokenSource {
87 return &cachingTokenSource{
88 now: time.Now,
89 base: ts,
90 }
91 }
92
93 type tokenSourceTransport struct {
94 base http.RoundTripper
95 ort http.RoundTripper
96 src ResettableTokenSource
97 }
98
99 var _ utilnet.RoundTripperWrapper = &tokenSourceTransport{}
100
101 func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
102
103 if req.Header.Get("Authorization") != "" {
104 return tst.base.RoundTrip(req)
105 }
106
107
108
109 start := time.Now()
110 resp, err := tst.ort.RoundTrip(req)
111 if err == nil && resp != nil && resp.StatusCode == 401 && tst.src != nil {
112 tst.src.ResetTokenOlderThan(start)
113 }
114 return resp, err
115 }
116
117 func (tst *tokenSourceTransport) CancelRequest(req *http.Request) {
118 if req.Header.Get("Authorization") != "" {
119 tryCancelRequest(tst.base, req)
120 return
121 }
122 tryCancelRequest(tst.ort, req)
123 }
124
125 func (tst *tokenSourceTransport) WrappedRoundTripper() http.RoundTripper { return tst.base }
126
127 type fileTokenSource struct {
128 path string
129 period time.Duration
130 }
131
132 var _ = oauth2.TokenSource(&fileTokenSource{})
133
134 func (ts *fileTokenSource) Token() (*oauth2.Token, error) {
135 tokb, err := os.ReadFile(ts.path)
136 if err != nil {
137 return nil, fmt.Errorf("failed to read token file %q: %v", ts.path, err)
138 }
139 tok := strings.TrimSpace(string(tokb))
140 if len(tok) == 0 {
141 return nil, fmt.Errorf("read empty token from file %q", ts.path)
142 }
143
144 return &oauth2.Token{
145 AccessToken: tok,
146 Expiry: time.Now().Add(ts.period),
147 }, nil
148 }
149
150 type cachingTokenSource struct {
151 base oauth2.TokenSource
152 leeway time.Duration
153
154 sync.RWMutex
155 tok *oauth2.Token
156 t time.Time
157
158
159 now func() time.Time
160 }
161
162 func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
163 now := ts.now()
164
165 ts.RLock()
166 tok := ts.tok
167 ts.RUnlock()
168
169 if tok != nil && tok.Expiry.Add(-1*ts.leeway).After(now) {
170 return tok, nil
171 }
172
173
174 ts.Lock()
175 defer ts.Unlock()
176 if tok := ts.tok; tok != nil && tok.Expiry.Add(-1*ts.leeway).After(now) {
177 return tok, nil
178 }
179
180 tok, err := ts.base.Token()
181 if err != nil {
182 if ts.tok == nil {
183 return nil, err
184 }
185 klog.Errorf("Unable to rotate token: %v", err)
186 return ts.tok, nil
187 }
188
189 ts.t = ts.now()
190 ts.tok = tok
191 return tok, nil
192 }
193
194 func (ts *cachingTokenSource) ResetTokenOlderThan(t time.Time) {
195 ts.Lock()
196 defer ts.Unlock()
197 if ts.t.Before(t) {
198 ts.tok = nil
199 ts.t = time.Time{}
200 }
201 }
202
View as plain text