1
16
17 package rest
18
19 import (
20 "fmt"
21 "net/http"
22 "reflect"
23 "strconv"
24 "testing"
25
26 clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
27 )
28
29 func TestAuthPluginWrapTransport(t *testing.T) {
30 if err := RegisterAuthProviderPlugin("pluginA", pluginAProvider); err != nil {
31 t.Errorf("Unexpected error: failed to register pluginA: %v", err)
32 }
33 if err := RegisterAuthProviderPlugin("pluginB", pluginBProvider); err != nil {
34 t.Errorf("Unexpected error: failed to register pluginB: %v", err)
35 }
36 if err := RegisterAuthProviderPlugin("pluginFail", pluginFailProvider); err != nil {
37 t.Errorf("Unexpected error: failed to register pluginFail: %v", err)
38 }
39 testCases := []struct {
40 useWrapTransport bool
41 plugin string
42 expectErr bool
43 expectPluginA bool
44 expectPluginB bool
45 }{
46 {false, "", false, false, false},
47 {false, "pluginA", false, true, false},
48 {false, "pluginB", false, false, true},
49 {false, "pluginFail", true, false, false},
50 {false, "pluginUnknown", true, false, false},
51 }
52 for i, tc := range testCases {
53 c := Config{}
54 if tc.useWrapTransport {
55
56
57 c.WrapTransport = func(rt http.RoundTripper) http.RoundTripper {
58 return &wrapTransport{rt}
59 }
60 }
61 if len(tc.plugin) != 0 {
62 c.AuthProvider = &clientcmdapi.AuthProviderConfig{Name: tc.plugin}
63 }
64 tConfig, err := c.TransportConfig()
65 if err != nil {
66
67 if !tc.expectErr {
68 t.Errorf("%d. Did not expect errors loading Auth Plugin: %q. Got: %v", i, tc.plugin, err)
69 }
70 continue
71 }
72 var fullyWrappedTransport http.RoundTripper
73 fullyWrappedTransport = &emptyTransport{}
74 if tConfig.WrapTransport != nil {
75 fullyWrappedTransport = tConfig.WrapTransport(&emptyTransport{})
76 }
77 res, err := fullyWrappedTransport.RoundTrip(&http.Request{})
78 if err != nil {
79 t.Errorf("%d. Unexpected error in RoundTrip: %v", i, err)
80 continue
81 }
82 hasWrapTransport := res.Header.Get("wrapTransport") == "Y"
83 hasPluginA := res.Header.Get("pluginA") == "Y"
84 hasPluginB := res.Header.Get("pluginB") == "Y"
85 if hasWrapTransport != tc.useWrapTransport {
86 t.Errorf("%d. Expected Existing config.WrapTransport: %t; Got: %t", i, tc.useWrapTransport, hasWrapTransport)
87 }
88 if hasPluginA != tc.expectPluginA {
89 t.Errorf("%d. Expected Plugin A: %t; Got: %t", i, tc.expectPluginA, hasPluginA)
90 }
91 if hasPluginB != tc.expectPluginB {
92 t.Errorf("%d. Expected Plugin B: %t; Got: %t", i, tc.expectPluginB, hasPluginB)
93 }
94 }
95 }
96
97 func TestAuthPluginPersist(t *testing.T) {
98
99 if err := RegisterAuthProviderPlugin("pluginA2", pluginAProvider); err != nil {
100 t.Errorf("Unexpected error: failed to register pluginA: %v", err)
101 }
102 if err := RegisterAuthProviderPlugin("pluginPersist", pluginPersistProvider); err != nil {
103 t.Errorf("Unexpected error: failed to register pluginPersist: %v", err)
104 }
105 fooBarConfig := map[string]string{"foo": "bar"}
106 testCases := []struct {
107 plugin string
108 startingConfig map[string]string
109 expectedConfigAfterLogin map[string]string
110 expectedConfigAfterRoundTrip map[string]string
111 }{
112
113 {"pluginA2", map[string]string{}, map[string]string{}, map[string]string{}},
114 {"pluginA2", fooBarConfig, fooBarConfig, fooBarConfig},
115
116 {
117 "pluginPersist",
118 map[string]string{},
119 map[string]string{
120 "login": "Y",
121 },
122 map[string]string{
123 "login": "Y",
124 "roundTrips": "1",
125 },
126 },
127 {
128 "pluginPersist",
129 map[string]string{
130 "login": "Y",
131 "roundTrips": "123",
132 },
133 map[string]string{
134 "login": "Y",
135 "roundTrips": "123",
136 },
137 map[string]string{
138 "login": "Y",
139 "roundTrips": "124",
140 },
141 },
142 }
143 for i, tc := range testCases {
144 cfg := &clientcmdapi.AuthProviderConfig{
145 Name: tc.plugin,
146 Config: tc.startingConfig,
147 }
148 persister := &inMemoryPersister{make(map[string]string)}
149 persister.Persist(tc.startingConfig)
150 plugin, err := GetAuthProvider("127.0.0.1", cfg, persister)
151 if err != nil {
152 t.Errorf("%d. Unexpected error: failed to get plugin %q: %v", i, tc.plugin, err)
153 }
154 if err := plugin.Login(); err != nil {
155 t.Errorf("%d. Unexpected error calling Login() w/ plugin %q: %v", i, tc.plugin, err)
156 }
157
158 if !reflect.DeepEqual(persister.savedConfig, tc.expectedConfigAfterLogin) {
159 t.Errorf("%d. Unexpected persisted config after calling %s.Login(): \nGot:\n%v\nExpected:\n%v",
160 i, tc.plugin, persister.savedConfig, tc.expectedConfigAfterLogin)
161 }
162 if _, err := plugin.WrapTransport(&emptyTransport{}).RoundTrip(&http.Request{}); err != nil {
163 t.Errorf("%d. Unexpected error round-tripping w/ plugin %q: %v", i, tc.plugin, err)
164 }
165
166 if !reflect.DeepEqual(persister.savedConfig, tc.expectedConfigAfterRoundTrip) {
167 t.Errorf("%d. Unexpected persisted config after calling %s.WrapTransport.RoundTrip(): \nGot:\n%v\nExpected:\n%v",
168 i, tc.plugin, persister.savedConfig, tc.expectedConfigAfterLogin)
169 }
170 }
171
172 }
173
174 func Test_WhenNilPersister_NoOpPersisterIsAssigned(t *testing.T) {
175
176 if err := RegisterAuthProviderPlugin("anyPlugin", pluginPersistProvider); err != nil {
177 t.Errorf("unexpected error: failed to register 'anyPlugin': %v", err)
178 }
179 cfg := &clientcmdapi.AuthProviderConfig{
180 Name: "anyPlugin",
181 Config: nil,
182 }
183 plugin, err := GetAuthProvider("127.0.0.1", cfg, nil)
184 if err != nil {
185 t.Errorf("unexpected error: failed to get 'anyPlugin': %v", err)
186 }
187
188 anyPlugin := plugin.(*pluginPersist)
189
190 if _, ok := anyPlugin.persister.(*noopPersister); !ok {
191 t.Errorf("expected to be No Operation persister")
192 }
193
194 }
195
196
197
198 type emptyTransport struct{}
199
200 func (*emptyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
201 res := &http.Response{
202 Header: make(map[string][]string),
203 }
204 return res, nil
205 }
206
207
208 type wrapTransport struct {
209 rt http.RoundTripper
210 }
211
212 func (w *wrapTransport) RoundTrip(req *http.Request) (*http.Response, error) {
213 res, err := w.rt.RoundTrip(req)
214 if err != nil {
215 return nil, err
216 }
217 res.Header.Add("wrapTransport", "Y")
218 return res, nil
219 }
220
221
222 type wrapTransportA struct {
223 rt http.RoundTripper
224 }
225
226 func (w *wrapTransportA) RoundTrip(req *http.Request) (*http.Response, error) {
227 res, err := w.rt.RoundTrip(req)
228 if err != nil {
229 return nil, err
230 }
231 res.Header.Add("pluginA", "Y")
232 return res, nil
233 }
234
235 type pluginA struct{}
236
237 func (*pluginA) WrapTransport(rt http.RoundTripper) http.RoundTripper {
238 return &wrapTransportA{rt}
239 }
240
241 func (*pluginA) Login() error { return nil }
242
243 func pluginAProvider(string, map[string]string, AuthProviderConfigPersister) (AuthProvider, error) {
244 return &pluginA{}, nil
245 }
246
247
248 type wrapTransportB struct {
249 rt http.RoundTripper
250 }
251
252 func (w *wrapTransportB) RoundTrip(req *http.Request) (*http.Response, error) {
253 res, err := w.rt.RoundTrip(req)
254 if err != nil {
255 return nil, err
256 }
257 res.Header.Add("pluginB", "Y")
258 return res, nil
259 }
260
261 type pluginB struct{}
262
263 func (*pluginB) WrapTransport(rt http.RoundTripper) http.RoundTripper {
264 return &wrapTransportB{rt}
265 }
266
267 func (*pluginB) Login() error { return nil }
268
269 func pluginBProvider(string, map[string]string, AuthProviderConfigPersister) (AuthProvider, error) {
270 return &pluginB{}, nil
271 }
272
273
274 func pluginFailProvider(string, map[string]string, AuthProviderConfigPersister) (AuthProvider, error) {
275 return nil, fmt.Errorf("Failed to load AuthProvider")
276 }
277
278 type inMemoryPersister struct {
279 savedConfig map[string]string
280 }
281
282 func (i *inMemoryPersister) Persist(config map[string]string) error {
283 i.savedConfig = make(map[string]string)
284 for k, v := range config {
285 i.savedConfig[k] = v
286 }
287 return nil
288 }
289
290
291
292 type wrapTransportPersist struct {
293 rt http.RoundTripper
294 config map[string]string
295 persister AuthProviderConfigPersister
296 }
297
298 func (w *wrapTransportPersist) RoundTrip(req *http.Request) (*http.Response, error) {
299 roundTrips := 0
300 if rtVal, ok := w.config["roundTrips"]; ok {
301 var err error
302 roundTrips, err = strconv.Atoi(rtVal)
303 if err != nil {
304 return nil, err
305 }
306 }
307 roundTrips++
308 w.config["roundTrips"] = fmt.Sprintf("%d", roundTrips)
309 if err := w.persister.Persist(w.config); err != nil {
310 return nil, err
311 }
312 return w.rt.RoundTrip(req)
313 }
314
315 type pluginPersist struct {
316 config map[string]string
317 persister AuthProviderConfigPersister
318 }
319
320 func (p *pluginPersist) WrapTransport(rt http.RoundTripper) http.RoundTripper {
321 return &wrapTransportPersist{rt, p.config, p.persister}
322 }
323
324
325 func (p *pluginPersist) Login() error {
326 p.config["login"] = "Y"
327 p.persister.Persist(p.config)
328 return nil
329 }
330
331 func pluginPersistProvider(_ string, config map[string]string, persister AuthProviderConfigPersister) (AuthProvider, error) {
332 return &pluginPersist{config, persister}, nil
333 }
334
View as plain text