1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package transport
16
17 import (
18 "crypto/tls"
19 "encoding/json"
20 "fmt"
21 "net/http"
22 "testing"
23 "time"
24 )
25
26 const (
27 testMTLSEndpoint = "https://test.mtls.googleapis.com/"
28 testEndpointTemplate = "https://test.UNIVERSE_DOMAIN/"
29 testRegularEndpoint = "https://test.googleapis.com/"
30 testOverrideEndpoint = "https://test.override.example.com/"
31 testUniverseDomain = "example.com"
32 testUniverseDomainEndpoint = "https://test.example.com/"
33 )
34
35 var (
36 validConfigResp = func() (string, error) {
37 validConfig := mtlsConfig{
38 S2A: &s2aAddresses{
39 PlaintextAddress: testS2AAddr,
40 MTLSAddress: "",
41 },
42 }
43 configStr, err := json.Marshal(validConfig)
44 if err != nil {
45 return "", err
46 }
47 return string(configStr), nil
48 }
49
50 errorConfigResp = func() (string, error) {
51 return "", fmt.Errorf("error getting config")
52 }
53
54 invalidConfigResp = func() (string, error) {
55 return "{}", nil
56 }
57
58 invalidJSONResp = func() (string, error) {
59 return "test", nil
60 }
61 fakeClientCertSource = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { return nil, nil }
62 )
63
64 func TestOptions_UniverseDomain(t *testing.T) {
65 testCases := []struct {
66 name string
67 opts *Options
68 wantUniverseDomain string
69 wantDefaultEndpoint string
70 wantIsGDU bool
71 wantMergedEndpoint string
72 }{
73 {
74 name: "empty",
75 opts: &Options{},
76 wantUniverseDomain: "googleapis.com",
77 wantDefaultEndpoint: "",
78 wantIsGDU: true,
79 wantMergedEndpoint: "",
80 },
81 {
82 name: "defaults",
83 opts: &Options{
84 DefaultEndpointTemplate: "https://test.UNIVERSE_DOMAIN/",
85 },
86 wantUniverseDomain: "googleapis.com",
87 wantDefaultEndpoint: "https://test.googleapis.com/",
88 wantIsGDU: true,
89 wantMergedEndpoint: "",
90 },
91 {
92 name: "non-GDU",
93 opts: &Options{
94 DefaultEndpointTemplate: "https://test.UNIVERSE_DOMAIN/",
95 UniverseDomain: "example.com",
96 },
97 wantUniverseDomain: "example.com",
98 wantDefaultEndpoint: "https://test.example.com/",
99 wantIsGDU: false,
100 wantMergedEndpoint: "",
101 },
102 {
103 name: "merged endpoint",
104 opts: &Options{
105 DefaultEndpointTemplate: "https://test.UNIVERSE_DOMAIN/bar/baz",
106 Endpoint: "myhost:8000",
107 },
108 wantUniverseDomain: "googleapis.com",
109 wantDefaultEndpoint: "https://test.googleapis.com/bar/baz",
110 wantIsGDU: true,
111 wantMergedEndpoint: "https://myhost:8000/bar/baz",
112 },
113 }
114 for _, tc := range testCases {
115 t.Run(tc.name, func(t *testing.T) {
116 if got := tc.opts.getUniverseDomain(); got != tc.wantUniverseDomain {
117 t.Errorf("got %v, want %v", got, tc.wantUniverseDomain)
118 }
119 if got := tc.opts.isUniverseDomainGDU(); got != tc.wantIsGDU {
120 t.Errorf("got %v, want %v", got, tc.wantIsGDU)
121 }
122 if got := tc.opts.defaultEndpoint(); got != tc.wantDefaultEndpoint {
123 t.Errorf("got %v, want %v", got, tc.wantDefaultEndpoint)
124 }
125 if tc.opts.Endpoint != "" {
126 got, err := tc.opts.mergedEndpoint()
127 if err != nil {
128 t.Fatalf("%v", err)
129 }
130 if got != tc.wantMergedEndpoint {
131 t.Errorf("got %v, want %v", got, tc.wantMergedEndpoint)
132 }
133 }
134 })
135 }
136 }
137
138 func TestGetEndpoint(t *testing.T) {
139 testCases := []struct {
140 endpoint string
141 defaultEndpointTemplate string
142 want string
143 wantErr bool
144 }{
145 {
146 defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
147 want: "https://foo.googleapis.com/bar/baz",
148 },
149 {
150 endpoint: "myhost:3999",
151 defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
152 want: "https://myhost:3999/bar/baz",
153 },
154 {
155 endpoint: "https://host/path/to/bar",
156 defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
157 want: "https://host/path/to/bar",
158 },
159 {
160 endpoint: "host:123",
161 defaultEndpointTemplate: "",
162 want: "host:123",
163 },
164 {
165 endpoint: "host:123",
166 defaultEndpointTemplate: "default:443",
167 want: "host:123",
168 },
169 {
170 endpoint: "host:123",
171 defaultEndpointTemplate: "default:443/bar/baz",
172 want: "host:123/bar/baz",
173 },
174 }
175
176 for _, tc := range testCases {
177 t.Run(tc.want, func(t *testing.T) {
178 got, err := getEndpoint(&Options{
179 Endpoint: tc.endpoint,
180 DefaultEndpointTemplate: tc.defaultEndpointTemplate,
181 }, nil)
182 if tc.wantErr && err == nil {
183 t.Fatalf("want err, got nil err")
184 }
185 if !tc.wantErr && err != nil {
186 t.Fatalf("want nil err, got %v", err)
187 }
188 if tc.want != got {
189 t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.endpoint, tc.defaultEndpointTemplate, got, tc.want)
190 }
191 })
192 }
193 }
194
195 func TestGetEndpointWithClientCertSource(t *testing.T) {
196 testCases := []struct {
197 endpoint string
198 defaultEndpointTemplate string
199 defaultMTLSEndpoint string
200 want string
201 wantErr bool
202 }{
203 {
204 defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
205 defaultMTLSEndpoint: "https://foo.mtls.googleapis.com/bar/baz",
206 want: "https://foo.mtls.googleapis.com/bar/baz",
207 },
208 {
209 defaultEndpointTemplate: "https://staging-foo.sandbox.UNIVERSE_DOMAIN/bar/baz",
210 defaultMTLSEndpoint: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz",
211 want: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz",
212 },
213 {
214 endpoint: "myhost:3999",
215 defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
216 want: "https://myhost:3999/bar/baz",
217 },
218 {
219 endpoint: "https://host/path/to/bar",
220 defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
221 want: "https://host/path/to/bar",
222 },
223 {
224 endpoint: "host:port",
225 defaultEndpointTemplate: "",
226 want: "host:port",
227 },
228 }
229
230 for _, tc := range testCases {
231 t.Run(tc.want, func(t *testing.T) {
232 got, err := getEndpoint(&Options{
233 Endpoint: tc.endpoint,
234 DefaultEndpointTemplate: tc.defaultEndpointTemplate,
235 DefaultMTLSEndpoint: tc.defaultMTLSEndpoint,
236 }, fakeClientCertSource)
237 if tc.wantErr && err == nil {
238 t.Fatalf("want err, got nil err")
239 }
240 if !tc.wantErr && err != nil {
241 t.Fatalf("want nil err, got %v", err)
242 }
243 if tc.want != got {
244 t.Fatalf("getEndpoint(%q, %q): got %v; want %v", tc.endpoint, tc.defaultEndpointTemplate, got, tc.want)
245 }
246 })
247 }
248 }
249
250 func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
251 testCases := []struct {
252 name string
253 opts *Options
254 s2ARespFn func() (string, error)
255 want string
256 }{
257 {
258 name: "has client cert",
259 opts: &Options{
260 DefaultMTLSEndpoint: testMTLSEndpoint,
261 DefaultEndpointTemplate: testEndpointTemplate,
262 ClientCertProvider: fakeClientCertSource,
263 },
264 s2ARespFn: validConfigResp,
265 want: testMTLSEndpoint,
266 },
267 {
268 name: "no client cert, S2A address not empty",
269 opts: &Options{
270 DefaultMTLSEndpoint: testMTLSEndpoint,
271 DefaultEndpointTemplate: testEndpointTemplate,
272 },
273 s2ARespFn: validConfigResp,
274 want: testMTLSEndpoint,
275 },
276 {
277 name: "no client cert, S2A address not empty, EnableDirectPath == true",
278 opts: &Options{
279 DefaultMTLSEndpoint: testMTLSEndpoint,
280 DefaultEndpointTemplate: testEndpointTemplate,
281 EnableDirectPath: true,
282 },
283 s2ARespFn: validConfigResp,
284 want: testRegularEndpoint,
285 },
286 {
287 name: "no client cert, S2A address not empty, EnableDirectPathXds == true",
288 opts: &Options{
289 DefaultMTLSEndpoint: testMTLSEndpoint,
290 DefaultEndpointTemplate: testEndpointTemplate,
291 EnableDirectPathXds: true,
292 },
293 s2ARespFn: validConfigResp,
294 want: testRegularEndpoint,
295 },
296 {
297 name: "no client cert, S2A address empty",
298 opts: &Options{
299 DefaultMTLSEndpoint: testMTLSEndpoint,
300 DefaultEndpointTemplate: testEndpointTemplate,
301 },
302 s2ARespFn: invalidConfigResp,
303 want: testRegularEndpoint,
304 },
305 {
306 name: "no client cert, S2A address not empty, override endpoint",
307 opts: &Options{
308 DefaultMTLSEndpoint: testMTLSEndpoint,
309 DefaultEndpointTemplate: testEndpointTemplate,
310 Endpoint: testOverrideEndpoint,
311 },
312 s2ARespFn: validConfigResp,
313 want: testOverrideEndpoint,
314 },
315 {
316 "no client cert, S2A address not empty, DefaultMTLSEndpoint not set",
317 &Options{
318 DefaultMTLSEndpoint: "",
319 DefaultEndpointTemplate: testEndpointTemplate,
320 },
321 validConfigResp,
322 testRegularEndpoint,
323 },
324 }
325 defer setupTest(t)()
326 for _, tc := range testCases {
327 t.Run(tc.name, func(t *testing.T) {
328 httpGetMetadataMTLSConfig = tc.s2ARespFn
329 if tc.opts.ClientCertProvider != nil {
330 t.Setenv(googleAPIUseCertSource, "true")
331 } else {
332 t.Setenv(googleAPIUseCertSource, "false")
333 }
334 _, endpoint, _ := GetGRPCTransportCredsAndEndpoint(tc.opts)
335 if tc.want != endpoint {
336 t.Fatalf("want endpoint: %s, got %s", tc.want, endpoint)
337 }
338
339 time.Sleep(2 * time.Millisecond)
340 })
341 }
342 }
343
344 func TestGetHTTPTransportConfig_S2a(t *testing.T) {
345 testCases := []struct {
346 name string
347 opts *Options
348 s2aFn func() (string, error)
349 want string
350 isDialFnNil bool
351 }{
352 {
353 name: "has client cert",
354 opts: &Options{
355 DefaultMTLSEndpoint: testMTLSEndpoint,
356 DefaultEndpointTemplate: testEndpointTemplate,
357 ClientCertProvider: fakeClientCertSource,
358 },
359 s2aFn: validConfigResp,
360 want: testMTLSEndpoint,
361 isDialFnNil: true,
362 },
363 {
364 name: "no client cert, S2A address not empty",
365 opts: &Options{
366 DefaultMTLSEndpoint: testMTLSEndpoint,
367 DefaultEndpointTemplate: testEndpointTemplate,
368 },
369 s2aFn: validConfigResp,
370 want: testMTLSEndpoint,
371 },
372 {
373 name: "no client cert, S2A address empty",
374 opts: &Options{
375 DefaultMTLSEndpoint: testMTLSEndpoint,
376 DefaultEndpointTemplate: testEndpointTemplate,
377 },
378 s2aFn: invalidConfigResp,
379 want: testRegularEndpoint,
380 isDialFnNil: true,
381 },
382 {
383 name: "no client cert, S2A address not empty, override endpoint",
384 opts: &Options{
385 DefaultMTLSEndpoint: testMTLSEndpoint,
386 DefaultEndpointTemplate: testEndpointTemplate,
387 Endpoint: testOverrideEndpoint,
388 },
389 s2aFn: validConfigResp,
390 want: testOverrideEndpoint,
391 isDialFnNil: true,
392 },
393 {
394 name: "no client cert, S2A address not empty, but DefaultMTLSEndpoint is not set",
395 opts: &Options{
396 DefaultMTLSEndpoint: "",
397 DefaultEndpointTemplate: testEndpointTemplate,
398 },
399 s2aFn: validConfigResp,
400 want: testRegularEndpoint,
401 isDialFnNil: true,
402 },
403 {
404 name: "no client cert, S2A address not empty, custom HTTP client",
405 opts: &Options{
406 DefaultMTLSEndpoint: testMTLSEndpoint,
407 DefaultEndpointTemplate: testEndpointTemplate,
408 Client: http.DefaultClient,
409 },
410 s2aFn: validConfigResp,
411 want: testRegularEndpoint,
412 isDialFnNil: true,
413 },
414 }
415 defer setupTest(t)()
416 for _, tc := range testCases {
417 t.Run(tc.name, func(t *testing.T) {
418 httpGetMetadataMTLSConfig = tc.s2aFn
419 if tc.opts.ClientCertProvider != nil {
420 t.Setenv(googleAPIUseCertSource, "true")
421 } else {
422 t.Setenv(googleAPIUseCertSource, "false")
423 }
424 _, dialFunc, err := GetHTTPTransportConfig(tc.opts)
425 if err != nil {
426 t.Fatalf("err: %v", err)
427 }
428 if want, got := tc.isDialFnNil, dialFunc == nil; want != got {
429 t.Errorf("expecting returned dialFunc is nil: [%v], got [%v]", tc.isDialFnNil, got)
430 }
431
432 time.Sleep(2 * time.Millisecond)
433 })
434 }
435 }
436
437 func setupTest(t *testing.T) func() {
438 oldHTTPGet := httpGetMetadataMTLSConfig
439 oldExpiry := configExpiry
440
441 configExpiry = time.Millisecond
442 t.Setenv(googleAPIUseS2AEnv, "true")
443
444 return func() {
445 httpGetMetadataMTLSConfig = oldHTTPGet
446 configExpiry = oldExpiry
447 }
448 }
449
450 func TestGetTransportConfig_UniverseDomain(t *testing.T) {
451 testCases := []struct {
452 name string
453 opts *Options
454 wantEndpoint string
455 wantErr error
456 }{
457 {
458 name: "google default universe (GDU), no client cert, template is regular endpoint",
459 opts: &Options{
460 DefaultEndpointTemplate: testRegularEndpoint,
461 DefaultMTLSEndpoint: testMTLSEndpoint,
462 },
463 wantEndpoint: testRegularEndpoint,
464 },
465 {
466 name: "google default universe (GDU), no client cert",
467 opts: &Options{
468 DefaultEndpointTemplate: testEndpointTemplate,
469 DefaultMTLSEndpoint: testMTLSEndpoint,
470 },
471 wantEndpoint: testRegularEndpoint,
472 },
473 {
474 name: "google default universe (GDU), client cert",
475 opts: &Options{
476 DefaultEndpointTemplate: testEndpointTemplate,
477 DefaultMTLSEndpoint: testMTLSEndpoint,
478 ClientCertProvider: fakeClientCertSource,
479 },
480 wantEndpoint: testMTLSEndpoint,
481 },
482 {
483 name: "UniverseDomain, no client cert",
484 opts: &Options{
485 DefaultEndpointTemplate: testEndpointTemplate,
486 DefaultMTLSEndpoint: testMTLSEndpoint,
487 UniverseDomain: testUniverseDomain,
488 },
489 wantEndpoint: testUniverseDomainEndpoint,
490 },
491 {
492 name: "UniverseDomain, client cert",
493 opts: &Options{
494 DefaultEndpointTemplate: testEndpointTemplate,
495 DefaultMTLSEndpoint: testMTLSEndpoint,
496 UniverseDomain: testUniverseDomain,
497 ClientCertProvider: fakeClientCertSource,
498 },
499 wantEndpoint: testUniverseDomainEndpoint,
500 wantErr: errUniverseNotSupportedMTLS,
501 },
502 }
503
504 for _, tc := range testCases {
505 t.Run(tc.name, func(t *testing.T) {
506 if tc.opts.ClientCertProvider != nil {
507 t.Setenv(googleAPIUseCertSource, "true")
508 } else {
509 t.Setenv(googleAPIUseCertSource, "false")
510 }
511 config, err := getTransportConfig(tc.opts)
512 if err != nil {
513 if err != tc.wantErr {
514 t.Fatalf("err: %v", err)
515 }
516 } else {
517 if tc.wantEndpoint != config.endpoint {
518 t.Errorf("want endpoint: %s, got %s", tc.wantEndpoint, config.endpoint)
519 }
520 }
521 })
522 }
523 }
524
525 func TestGetGRPCTransportCredsAndEndpoint_UniverseDomain(t *testing.T) {
526 testCases := []struct {
527 name string
528 opts *Options
529 wantEndpoint string
530 wantErr error
531 }{
532 {
533 name: "google default universe (GDU), no client cert",
534 opts: &Options{
535 DefaultEndpointTemplate: testEndpointTemplate,
536 DefaultMTLSEndpoint: testMTLSEndpoint,
537 },
538 wantEndpoint: testRegularEndpoint,
539 },
540 {
541 name: "google default universe (GDU), no client cert, endpoint",
542 opts: &Options{
543 DefaultEndpointTemplate: testEndpointTemplate,
544 DefaultMTLSEndpoint: testMTLSEndpoint,
545 Endpoint: testOverrideEndpoint,
546 },
547 wantEndpoint: testOverrideEndpoint,
548 },
549 {
550 name: "google default universe (GDU), client cert",
551 opts: &Options{
552 DefaultEndpointTemplate: testEndpointTemplate,
553 DefaultMTLSEndpoint: testMTLSEndpoint,
554 ClientCertProvider: fakeClientCertSource,
555 },
556 wantEndpoint: testMTLSEndpoint,
557 },
558 {
559 name: "google default universe (GDU), client cert, endpoint",
560 opts: &Options{
561 DefaultEndpointTemplate: testEndpointTemplate,
562 DefaultMTLSEndpoint: testMTLSEndpoint,
563 ClientCertProvider: fakeClientCertSource,
564 Endpoint: testOverrideEndpoint,
565 },
566 wantEndpoint: testOverrideEndpoint,
567 },
568 {
569 name: "UniverseDomain, no client cert",
570 opts: &Options{
571 DefaultEndpointTemplate: testEndpointTemplate,
572 DefaultMTLSEndpoint: testMTLSEndpoint,
573 UniverseDomain: testUniverseDomain,
574 },
575 wantEndpoint: testUniverseDomainEndpoint,
576 },
577 {
578 name: "UniverseDomain, no client cert, endpoint",
579 opts: &Options{
580 DefaultEndpointTemplate: testEndpointTemplate,
581 DefaultMTLSEndpoint: testMTLSEndpoint,
582 UniverseDomain: testUniverseDomain,
583 Endpoint: testOverrideEndpoint,
584 },
585 wantEndpoint: testOverrideEndpoint,
586 },
587 {
588 name: "UniverseDomain, client cert",
589 opts: &Options{
590 DefaultEndpointTemplate: testEndpointTemplate,
591 DefaultMTLSEndpoint: testMTLSEndpoint,
592 UniverseDomain: testUniverseDomain,
593 ClientCertProvider: fakeClientCertSource,
594 },
595 wantErr: errUniverseNotSupportedMTLS,
596 },
597 {
598 name: "UniverseDomain, client cert, endpoint",
599 opts: &Options{
600 DefaultEndpointTemplate: testEndpointTemplate,
601 DefaultMTLSEndpoint: testMTLSEndpoint,
602 UniverseDomain: testUniverseDomain,
603 ClientCertProvider: fakeClientCertSource,
604 Endpoint: testOverrideEndpoint,
605 },
606 wantEndpoint: testOverrideEndpoint,
607 },
608 }
609
610 for _, tc := range testCases {
611 t.Run(tc.name, func(t *testing.T) {
612 if tc.opts.ClientCertProvider != nil {
613 t.Setenv(googleAPIUseCertSource, "true")
614 } else {
615 t.Setenv(googleAPIUseCertSource, "false")
616 }
617 _, endpoint, err := GetGRPCTransportCredsAndEndpoint(tc.opts)
618 if err != nil {
619 if err != tc.wantErr {
620 t.Fatalf("err: %v", err)
621 }
622 } else {
623 if tc.wantEndpoint != endpoint {
624 t.Errorf("want endpoint: %s, got %s", tc.wantEndpoint, endpoint)
625 }
626 }
627 })
628 }
629 }
630
View as plain text