1
2
3
4
5 package internal
6
7 import (
8 "crypto/tls"
9 "net/http"
10 "os"
11 "testing"
12 "time"
13 )
14
15 const (
16 testMTLSEndpoint = "https://test.mtls.googleapis.com/"
17 testRegularEndpoint = "https://test.googleapis.com/"
18 testEndpointTemplate = "https://test.UNIVERSE_DOMAIN/"
19 testOverrideEndpoint = "https://test.override.example.com/"
20 testUniverseDomain = "example.com"
21 testUniverseDomainEndpoint = "https://test.example.com/"
22 )
23
24 var dummyClientCertSource = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { return nil, nil }
25
26 func TestGetEndpoint(t *testing.T) {
27 testCases := []struct {
28 UserEndpoint string
29 DefaultEndpoint string
30 DefaultEndpointTemplate string
31 Want string
32 WantErr bool
33 }{
34 {
35 DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
36 Want: "https://foo.googleapis.com/bar/baz",
37 },
38 {
39 UserEndpoint: "myhost:3999",
40 DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
41 Want: "https://myhost:3999/bar/baz",
42 },
43 {
44 UserEndpoint: "https://host/path/to/bar",
45 DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
46 Want: "https://host/path/to/bar",
47 },
48 {
49 UserEndpoint: "host:123",
50 DefaultEndpoint: "",
51 Want: "host:123",
52 },
53 {
54 UserEndpoint: "host:123",
55 DefaultEndpoint: "default:443",
56 Want: "host:123",
57 },
58 {
59 UserEndpoint: "host:123",
60 DefaultEndpoint: "default:443/bar/baz",
61 Want: "host:123/bar/baz",
62 },
63 }
64
65 for _, tc := range testCases {
66 got, err := getEndpoint(&DialSettings{
67 Endpoint: tc.UserEndpoint,
68 DefaultEndpoint: tc.DefaultEndpoint,
69 DefaultEndpointTemplate: tc.DefaultEndpointTemplate,
70 }, nil)
71 if tc.WantErr && err == nil {
72 t.Errorf("want err, got nil err")
73 continue
74 }
75 if !tc.WantErr && err != nil {
76 t.Errorf("want nil err, got %v", err)
77 continue
78 }
79 if tc.Want != got {
80 t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpointTemplate, got, tc.Want)
81 }
82 }
83 }
84
85 func TestGetEndpointWithClientCertSource(t *testing.T) {
86
87 testCases := []struct {
88 UserEndpoint string
89 DefaultEndpoint string
90 DefaultMTLSEndpoint string
91 Want string
92 WantErr bool
93 }{
94 {
95 DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
96 DefaultMTLSEndpoint: "https://foo.mtls.googleapis.com/bar/baz",
97 Want: "https://foo.mtls.googleapis.com/bar/baz",
98 },
99 {
100 DefaultEndpoint: "https://staging-foo.sandbox.googleapis.com/bar/baz",
101 DefaultMTLSEndpoint: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz",
102 Want: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz",
103 },
104 {
105 UserEndpoint: "myhost:3999",
106 DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
107 Want: "https://myhost:3999/bar/baz",
108 },
109 {
110 UserEndpoint: "https://host/path/to/bar",
111 DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
112 Want: "https://host/path/to/bar",
113 },
114 {
115 UserEndpoint: "host:port",
116 DefaultEndpoint: "",
117 Want: "host:port",
118 },
119 }
120
121 for _, tc := range testCases {
122 got, err := getEndpoint(&DialSettings{
123 Endpoint: tc.UserEndpoint,
124 DefaultEndpoint: tc.DefaultEndpoint,
125 DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
126 }, dummyClientCertSource)
127 if tc.WantErr && err == nil {
128 t.Errorf("want err, got nil err")
129 continue
130 }
131 if !tc.WantErr && err != nil {
132 t.Errorf("want nil err, got %v", err)
133 continue
134 }
135 if tc.Want != got {
136 t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpoint, got, tc.Want)
137 }
138 }
139 }
140
141 func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
142 testCases := []struct {
143 Desc string
144 InputSettings *DialSettings
145 S2ARespFunc func() (string, error)
146 WantEndpoint string
147 }{
148 {
149 "has client cert",
150 &DialSettings{
151 DefaultMTLSEndpoint: testMTLSEndpoint,
152 DefaultEndpoint: testRegularEndpoint,
153 ClientCertSource: dummyClientCertSource,
154 },
155 validConfigResp,
156 testMTLSEndpoint,
157 },
158 {
159 "no client cert, S2A address not empty",
160 &DialSettings{
161 DefaultMTLSEndpoint: testMTLSEndpoint,
162 DefaultEndpoint: testRegularEndpoint,
163 },
164 validConfigResp,
165 testMTLSEndpoint,
166 },
167 {
168 "no client cert, S2A address not empty, EnableDirectPath == true",
169 &DialSettings{
170 DefaultMTLSEndpoint: testMTLSEndpoint,
171 DefaultEndpoint: testRegularEndpoint,
172 EnableDirectPath: true,
173 },
174 validConfigResp,
175 testRegularEndpoint,
176 },
177 {
178 "no client cert, S2A address not empty, EnableDirectPathXds == true",
179 &DialSettings{
180 DefaultMTLSEndpoint: testMTLSEndpoint,
181 DefaultEndpoint: testRegularEndpoint,
182 EnableDirectPathXds: true,
183 },
184 validConfigResp,
185 testRegularEndpoint,
186 },
187 {
188 "no client cert, S2A address empty",
189 &DialSettings{
190 DefaultMTLSEndpoint: testMTLSEndpoint,
191 DefaultEndpoint: testRegularEndpoint,
192 },
193 invalidConfigResp,
194 testRegularEndpoint,
195 },
196 {
197 "no client cert, S2A address not empty, override endpoint",
198 &DialSettings{
199 DefaultMTLSEndpoint: testMTLSEndpoint,
200 DefaultEndpointTemplate: testEndpointTemplate,
201 Endpoint: testOverrideEndpoint,
202 },
203 validConfigResp,
204 testOverrideEndpoint,
205 },
206 {
207 "no client cert, S2A address not empty, DefaultMTLSEndpoint not set",
208 &DialSettings{
209 DefaultMTLSEndpoint: "",
210 DefaultEndpointTemplate: testEndpointTemplate,
211 },
212 validConfigResp,
213 testRegularEndpoint,
214 },
215 }
216 defer setupTest()()
217
218 for _, tc := range testCases {
219 httpGetMetadataMTLSConfig = tc.S2ARespFunc
220 if tc.InputSettings.ClientCertSource != nil {
221 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
222 } else {
223 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
224 }
225 _, endpoint, _ := GetGRPCTransportConfigAndEndpoint(tc.InputSettings)
226 if tc.WantEndpoint != endpoint {
227 t.Errorf("%s: want endpoint: [%s], got [%s]", tc.Desc, tc.WantEndpoint, endpoint)
228 }
229
230 time.Sleep(2 * time.Millisecond)
231 }
232 }
233
234 func TestGetHTTPTransportConfigAndEndpoint_s2a(t *testing.T) {
235 testCases := []struct {
236 Desc string
237 InputSettings *DialSettings
238 S2ARespFunc func() (string, error)
239 WantEndpoint string
240 DialFuncNil bool
241 }{
242 {
243 "has client cert",
244 &DialSettings{
245 DefaultMTLSEndpoint: testMTLSEndpoint,
246 DefaultEndpoint: testRegularEndpoint,
247 ClientCertSource: dummyClientCertSource,
248 },
249 validConfigResp,
250 testMTLSEndpoint,
251 true,
252 },
253 {
254 "no client cert, S2A address not empty",
255 &DialSettings{
256 DefaultMTLSEndpoint: testMTLSEndpoint,
257 DefaultEndpoint: testRegularEndpoint,
258 },
259 validConfigResp,
260 testMTLSEndpoint,
261 false,
262 },
263 {
264 "no client cert, S2A address not empty, EnableDirectPath == true",
265 &DialSettings{
266 DefaultMTLSEndpoint: testMTLSEndpoint,
267 DefaultEndpoint: testRegularEndpoint,
268 EnableDirectPath: true,
269 },
270 validConfigResp,
271 testRegularEndpoint,
272 true,
273 },
274 {
275 "no client cert, S2A address not empty, EnableDirectPathXds == true",
276 &DialSettings{
277 DefaultMTLSEndpoint: testMTLSEndpoint,
278 DefaultEndpoint: testRegularEndpoint,
279 EnableDirectPathXds: true,
280 },
281 validConfigResp,
282 testRegularEndpoint,
283 true,
284 },
285 {
286 "no client cert, S2A address empty",
287 &DialSettings{
288 DefaultMTLSEndpoint: testMTLSEndpoint,
289 DefaultEndpoint: testRegularEndpoint,
290 },
291 invalidConfigResp,
292 testRegularEndpoint,
293 true,
294 },
295 {
296 "no client cert, S2A address not empty, override endpoint",
297 &DialSettings{
298 DefaultMTLSEndpoint: testMTLSEndpoint,
299 DefaultEndpoint: testRegularEndpoint,
300 Endpoint: testOverrideEndpoint,
301 },
302 validConfigResp,
303 testOverrideEndpoint,
304 true,
305 },
306 {
307 "no client cert, S2A address not empty, but DefaultMTLSEndpoint is not set",
308 &DialSettings{
309 DefaultMTLSEndpoint: "",
310 DefaultEndpoint: testRegularEndpoint,
311 },
312 validConfigResp,
313 testRegularEndpoint,
314 true,
315 },
316 {
317 "no client cert, S2A address not empty, custom HTTP client",
318 &DialSettings{
319 DefaultMTLSEndpoint: testMTLSEndpoint,
320 DefaultEndpoint: testRegularEndpoint,
321 HTTPClient: http.DefaultClient,
322 },
323 validConfigResp,
324 testRegularEndpoint,
325 true,
326 },
327 }
328
329 defer setupTest()()
330
331 for _, tc := range testCases {
332 httpGetMetadataMTLSConfig = tc.S2ARespFunc
333 if tc.InputSettings.ClientCertSource != nil {
334 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
335 } else {
336 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
337 }
338 _, dialFunc, endpoint, err := GetHTTPTransportConfigAndEndpoint(tc.InputSettings)
339 if err != nil {
340 t.Fatalf("%s: err: %v", tc.Desc, err)
341 }
342 if tc.WantEndpoint != endpoint {
343 t.Errorf("%s: want endpoint: [%s], got [%s]", tc.Desc, tc.WantEndpoint, endpoint)
344 }
345 if want, got := tc.DialFuncNil, dialFunc == nil; want != got {
346 t.Errorf("%s: expecting returned dialFunc is nil: [%v], got [%v]", tc.Desc, tc.DialFuncNil, got)
347 }
348
349 time.Sleep(2 * time.Millisecond)
350 }
351 }
352
353 func setupTest() func() {
354 oldHTTPGet := httpGetMetadataMTLSConfig
355 oldExpiry := configExpiry
356 oldUseS2A := os.Getenv(googleAPIUseS2AEnv)
357 oldUseClientCert := os.Getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE")
358
359 configExpiry = time.Millisecond
360 os.Setenv(googleAPIUseS2AEnv, "true")
361
362 return func() {
363 httpGetMetadataMTLSConfig = oldHTTPGet
364 configExpiry = oldExpiry
365 os.Setenv(googleAPIUseS2AEnv, oldUseS2A)
366 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", oldUseClientCert)
367 }
368 }
369
370 func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
371 testCases := []struct {
372 name string
373 ds *DialSettings
374 wantEndpoint string
375 wantErr error
376 }{
377 {
378 name: "google default universe (GDU), no client cert",
379 ds: &DialSettings{
380 DefaultEndpoint: testRegularEndpoint,
381 DefaultEndpointTemplate: testEndpointTemplate,
382 DefaultMTLSEndpoint: testMTLSEndpoint,
383 },
384 wantEndpoint: testRegularEndpoint,
385 },
386 {
387 name: "google default universe (GDU), client cert",
388 ds: &DialSettings{
389 DefaultEndpoint: testRegularEndpoint,
390 DefaultEndpointTemplate: testEndpointTemplate,
391 DefaultMTLSEndpoint: testMTLSEndpoint,
392 ClientCertSource: dummyClientCertSource,
393 },
394 wantEndpoint: testMTLSEndpoint,
395 },
396 {
397 name: "UniverseDomain, no client cert",
398 ds: &DialSettings{
399 DefaultEndpoint: testRegularEndpoint,
400 DefaultEndpointTemplate: testEndpointTemplate,
401 DefaultMTLSEndpoint: testMTLSEndpoint,
402 UniverseDomain: testUniverseDomain,
403 },
404 wantEndpoint: testUniverseDomainEndpoint,
405 },
406 {
407 name: "UniverseDomain, client cert",
408 ds: &DialSettings{
409 DefaultEndpoint: testRegularEndpoint,
410 DefaultEndpointTemplate: testEndpointTemplate,
411 DefaultMTLSEndpoint: testMTLSEndpoint,
412 UniverseDomain: testUniverseDomain,
413 ClientCertSource: dummyClientCertSource,
414 },
415 wantEndpoint: testUniverseDomainEndpoint,
416 wantErr: errUniverseNotSupportedMTLS,
417 },
418 }
419
420 for _, tc := range testCases {
421 if tc.ds.ClientCertSource != nil {
422 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
423 } else {
424 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
425 }
426 _, _, endpoint, err := GetHTTPTransportConfigAndEndpoint(tc.ds)
427 if err != nil {
428 if err != tc.wantErr {
429 t.Fatalf("%s: err: %v", tc.name, err)
430 }
431 } else {
432 if tc.wantEndpoint != endpoint {
433 t.Errorf("%s: want endpoint: [%s], got [%s]", tc.name, tc.wantEndpoint, endpoint)
434 }
435 }
436 }
437 }
438
439 func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
440 testCases := []struct {
441 name string
442 ds *DialSettings
443 wantEndpoint string
444 wantErr error
445 }{
446 {
447 name: "google default universe (GDU), no client cert",
448 ds: &DialSettings{
449 DefaultEndpoint: testRegularEndpoint,
450 DefaultEndpointTemplate: testEndpointTemplate,
451 DefaultMTLSEndpoint: testMTLSEndpoint,
452 },
453 wantEndpoint: testRegularEndpoint,
454 },
455 {
456 name: "google default universe (GDU), no client cert, endpoint",
457 ds: &DialSettings{
458 DefaultEndpoint: testRegularEndpoint,
459 DefaultEndpointTemplate: testEndpointTemplate,
460 DefaultMTLSEndpoint: testMTLSEndpoint,
461 Endpoint: testOverrideEndpoint,
462 },
463 wantEndpoint: testOverrideEndpoint,
464 },
465 {
466 name: "google default universe (GDU), client cert",
467 ds: &DialSettings{
468 DefaultEndpoint: testRegularEndpoint,
469 DefaultEndpointTemplate: testEndpointTemplate,
470 DefaultMTLSEndpoint: testMTLSEndpoint,
471 ClientCertSource: dummyClientCertSource,
472 },
473 wantEndpoint: testMTLSEndpoint,
474 },
475 {
476 name: "google default universe (GDU), client cert, endpoint",
477 ds: &DialSettings{
478 DefaultEndpoint: testRegularEndpoint,
479 DefaultEndpointTemplate: testEndpointTemplate,
480 DefaultMTLSEndpoint: testMTLSEndpoint,
481 ClientCertSource: dummyClientCertSource,
482 Endpoint: testOverrideEndpoint,
483 },
484 wantEndpoint: testOverrideEndpoint,
485 },
486 {
487 name: "UniverseDomain, no client cert",
488 ds: &DialSettings{
489 DefaultEndpoint: testRegularEndpoint,
490 DefaultEndpointTemplate: testEndpointTemplate,
491 DefaultMTLSEndpoint: testMTLSEndpoint,
492 UniverseDomain: testUniverseDomain,
493 },
494 wantEndpoint: testUniverseDomainEndpoint,
495 },
496 {
497 name: "UniverseDomain, no client cert, endpoint",
498 ds: &DialSettings{
499 DefaultEndpoint: testRegularEndpoint,
500 DefaultEndpointTemplate: testEndpointTemplate,
501 DefaultMTLSEndpoint: testMTLSEndpoint,
502 UniverseDomain: testUniverseDomain,
503 Endpoint: testOverrideEndpoint,
504 },
505 wantEndpoint: testOverrideEndpoint,
506 },
507 {
508 name: "UniverseDomain, client cert",
509 ds: &DialSettings{
510 DefaultEndpoint: testRegularEndpoint,
511 DefaultEndpointTemplate: testEndpointTemplate,
512 DefaultMTLSEndpoint: testMTLSEndpoint,
513 UniverseDomain: testUniverseDomain,
514 ClientCertSource: dummyClientCertSource,
515 },
516 wantErr: errUniverseNotSupportedMTLS,
517 },
518 {
519 name: "UniverseDomain, client cert, endpoint",
520 ds: &DialSettings{
521 DefaultEndpoint: testRegularEndpoint,
522 DefaultEndpointTemplate: testEndpointTemplate,
523 DefaultMTLSEndpoint: testMTLSEndpoint,
524 UniverseDomain: testUniverseDomain,
525 ClientCertSource: dummyClientCertSource,
526 Endpoint: testOverrideEndpoint,
527 },
528 wantEndpoint: testOverrideEndpoint,
529 },
530 }
531
532 for _, tc := range testCases {
533 if tc.ds.ClientCertSource != nil {
534 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
535 } else {
536 os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
537 }
538 _, endpoint, err := GetGRPCTransportConfigAndEndpoint(tc.ds)
539 if err != nil {
540 if err != tc.wantErr {
541 t.Fatalf("%s: err: %v", tc.name, err)
542 }
543 } else {
544 if tc.wantEndpoint != endpoint {
545 t.Errorf("%s: want endpoint: [%s], got [%s]", tc.name, tc.wantEndpoint, endpoint)
546 }
547 }
548 }
549 }
550
View as plain text