1
18
19 package s2a
20
21 import (
22 "bytes"
23 "context"
24 "crypto/tls"
25 "crypto/x509"
26 "fmt"
27 "io"
28 "io/ioutil"
29 "net"
30 "net/http"
31 "os"
32 "path/filepath"
33 "testing"
34 "time"
35
36 _ "embed"
37
38 "github.com/google/s2a-go/fallback"
39 "github.com/google/s2a-go/internal/fakehandshaker/service"
40 "github.com/google/s2a-go/internal/v2/fakes2av2"
41 "github.com/google/s2a-go/retry"
42 "google.golang.org/grpc/credentials"
43 "google.golang.org/grpc/grpclog"
44 "google.golang.org/grpc/peer"
45
46 grpc "google.golang.org/grpc"
47
48 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
49 helloworldpb "github.com/google/s2a-go/internal/proto/examples/helloworld_go_proto"
50 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
51 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
52 )
53
54 const (
55 accessTokenEnvVariable = "S2A_ACCESS_TOKEN"
56 testAccessToken = "test_access_token"
57 testV2AccessToken = "valid_token"
58
59 applicationProtocol = "grpc"
60 authType = "s2a"
61 clientHostname = "test_client_hostname"
62 serverSpiffeID = "test_server_spiffe_id"
63 clientMessage = "echo"
64 defaultE2ETestTimeout = time.Second * 5
65 )
66
67 var (
68
69 clientCertpem []byte
70
71 clientKeypem []byte
72
73 serverCertpem []byte
74
75 serverKeypem []byte
76
77
78 mdsRootCertPem []byte
79
80 mdsServerCertPem []byte
81
82 mdsServerKeyPem []byte
83
84 mdsClientCertPem []byte
85
86 mdsClientKeyPem []byte
87
88 selfSignedCertPem []byte
89
90 selfSignedKeyPem []byte
91 )
92
93
94 type server struct {
95 helloworldpb.UnimplementedGreeterServer
96 }
97
98
99 func (s *server) SayHello(_ context.Context, in *helloworldpb.HelloRequest) (*helloworldpb.HelloReply, error) {
100 return &helloworldpb.HelloReply{Message: "Hello " + in.GetName()}, nil
101 }
102
103
104
105 func startFakeS2A(t *testing.T, enableLegacyMode bool, expToken string, serverTransportCreds credentials.TransportCredentials) string {
106 lis, err := net.Listen("tcp", ":")
107 if err != nil {
108 t.Errorf("net.Listen(tcp, :0) failed: %v", err)
109 }
110
111 var s *grpc.Server
112 if serverTransportCreds != nil {
113 s = grpc.NewServer(grpc.Creds(serverTransportCreds))
114 } else {
115 s = grpc.NewServer()
116 }
117
118 if enableLegacyMode {
119 s2apb.RegisterS2AServiceServer(s, &service.FakeHandshakerService{})
120 } else {
121 s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{ExpectedToken: expToken})
122 }
123 go func() {
124 if err := s.Serve(lis); err != nil {
125 t.Errorf("s.Serve(%v) failed: %v", lis, err)
126 }
127 }()
128 return lis.Addr().String()
129 }
130
131
132
133 func startFakeS2AOnUDS(t *testing.T, enableLegacyMode bool, expToken string) string {
134 dir, err := ioutil.TempDir("/tmp", "socket_dir")
135 if err != nil {
136 t.Errorf("Unable to create temporary directory: %v", err)
137 }
138 udsAddress := filepath.Join(dir, "socket")
139 lis, err := net.Listen("unix", filepath.Join(dir, "socket"))
140 if err != nil {
141 t.Errorf("net.Listen(unix, %s) failed: %v", udsAddress, err)
142 }
143 s := grpc.NewServer()
144 if enableLegacyMode {
145 s2apb.RegisterS2AServiceServer(s, &service.FakeHandshakerService{})
146 } else {
147 s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{ExpectedToken: expToken})
148 }
149 go func() {
150 if err := s.Serve(lis); err != nil {
151 t.Errorf("s.Serve(%v) failed: %v", lis, err)
152 }
153 }()
154 return fmt.Sprintf("unix://%s", lis.Addr().String())
155 }
156
157
158
159 func startServer(t *testing.T, s2aAddress string, transportCreds credentials.TransportCredentials, enableLegacyMode bool) string {
160 serverOpts := &ServerOptions{
161 LocalIdentities: []Identity{NewSpiffeID(serverSpiffeID)},
162 S2AAddress: s2aAddress,
163 TransportCreds: transportCreds,
164 EnableLegacyMode: enableLegacyMode,
165 }
166 creds, err := NewServerCreds(serverOpts)
167 if err != nil {
168 t.Errorf("NewServerCreds(%v) failed: %v", serverOpts, err)
169 }
170
171 lis, err := net.Listen("tcp", ":0")
172 if err != nil {
173 t.Errorf("net.Listen(tcp, :0) failed: %v", err)
174 }
175 s := grpc.NewServer(grpc.Creds(creds))
176 helloworldpb.RegisterGreeterServer(s, &server{})
177 go func() {
178 if err := s.Serve(lis); err != nil {
179 t.Errorf("s.Serve(%v) failed: %v", lis, err)
180 }
181 }()
182 return lis.Addr().String()
183 }
184
185
186 func runClient(ctx context.Context, t *testing.T, clientS2AAddress string, transportCreds credentials.TransportCredentials, serverAddr string, enableLegacyMode bool, fallbackHandshake fallback.ClientHandshake) {
187 clientOpts := &ClientOptions{
188 TargetIdentities: []Identity{NewSpiffeID(serverSpiffeID)},
189 LocalIdentity: NewHostname(clientHostname),
190 S2AAddress: clientS2AAddress,
191 TransportCreds: transportCreds,
192 EnableLegacyMode: enableLegacyMode,
193 FallbackOpts: &FallbackOptions{
194 FallbackClientHandshakeFunc: fallbackHandshake,
195 },
196 }
197 creds, err := NewClientCreds(clientOpts)
198 if err != nil {
199 t.Errorf("NewClientCreds(%v) failed: %v", clientOpts, err)
200 }
201 dialOptions := []grpc.DialOption{
202 grpc.WithTransportCredentials(creds),
203 grpc.WithBlock(),
204 }
205
206 grpclog.Info("Client dialing server at address: %v", serverAddr)
207
208 conn, err := grpc.Dial(serverAddr, dialOptions...)
209 if err != nil {
210 t.Errorf("grpc.Dial(%v, %v) failed: %v", serverAddr, dialOptions, err)
211 }
212 defer conn.Close()
213
214
215 peer := new(peer.Peer)
216 c := helloworldpb.NewGreeterClient(conn)
217 req := &helloworldpb.HelloRequest{Name: clientMessage}
218 grpclog.Infof("Client calling SayHello with request: %v", req)
219 resp, err := c.SayHello(ctx, req, grpc.Peer(peer), grpc.WaitForReady(true))
220 if err != nil {
221 t.Errorf("c.SayHello(%v, %v) failed: %v", ctx, req, err)
222 }
223 if got, want := resp.GetMessage(), "Hello "+clientMessage; got != want {
224 t.Errorf("r.GetMessage() = %v, want %v", got, want)
225 }
226 grpclog.Infof("Client received message from server: %s", resp.GetMessage())
227
228 if enableLegacyMode {
229
230 authInfo, err := AuthInfoFromPeer(peer)
231 if err != nil {
232 t.Errorf("AuthInfoFromContext(peer) failed: %v", err)
233 }
234 s2aAuthInfo, ok := authInfo.(AuthInfo)
235 if !ok {
236 t.Errorf("authInfo is not an s2a.AuthInfo")
237 }
238 if got, want := s2aAuthInfo.AuthType(), authType; got != want {
239 t.Errorf("s2aAuthInfo.AuthType() = %v, want %v", got, want)
240 }
241 if got, want := s2aAuthInfo.ApplicationProtocol(), applicationProtocol; got != want {
242 t.Errorf("s2aAuthInfo.ApplicationProtocol() = %v, want %v", got, want)
243 }
244 if got, want := s2aAuthInfo.TLSVersion(), commonpb.TLSVersion_TLS1_3; got != want {
245 t.Errorf("s2aAuthInfo.TLSVersion() = %v, want %v", got, want)
246 }
247 if got, want := s2aAuthInfo.IsHandshakeResumed(), false; got != want {
248 t.Errorf("s2aAuthInfo.IsHandshakeResumed() = %v, want %v", got, want)
249 }
250 if got, want := s2aAuthInfo.SecurityLevel(), credentials.PrivacyAndIntegrity; got != want {
251 t.Errorf("s2aAuthInfo.SecurityLevel() = %v, want %v", got, want)
252 }
253 }
254 }
255
256 func TestV1EndToEndUsingFakeS2AOverTCP(t *testing.T) {
257 os.Setenv(accessTokenEnvVariable, "")
258
259
260 serverHandshakerAddr := startFakeS2A(t, true, "", nil)
261 grpclog.Infof("Fake handshaker for server running at address: %v", serverHandshakerAddr)
262 clientHandshakerAddr := startFakeS2A(t, true, "", nil)
263 grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
264
265
266 serverAddr := startServer(t, serverHandshakerAddr, nil, true)
267 grpclog.Infof("Server running at address: %v", serverAddr)
268
269
270 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
271 defer cancel()
272 runClient(ctx, t, clientHandshakerAddr, nil, serverAddr, true, nil)
273 }
274
275 func TestV2EndToEndUsingFakeS2AOverTCP(t *testing.T) {
276 os.Setenv(accessTokenEnvVariable, testV2AccessToken)
277 oldRetry := retry.NewRetryer
278 defer func() { retry.NewRetryer = oldRetry }()
279 testRetryer := retry.NewRetryer()
280 retry.NewRetryer = func() *retry.S2ARetryer {
281 return testRetryer
282 }
283
284 serverHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
285 grpclog.Infof("Fake handshaker for server running at address: %v", serverHandshakerAddr)
286 clientHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
287 grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
288
289
290 serverAddr := startServer(t, serverHandshakerAddr, nil, false)
291 grpclog.Infof("Server running at address: %v", serverAddr)
292
293
294 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
295 defer cancel()
296 runClient(ctx, t, clientHandshakerAddr, nil, serverAddr, false, nil)
297 if got, want := testRetryer.Attempts(), 0; got != want {
298 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
299 }
300 }
301
302 func TestV2EndToEndUsingFakeMTLSS2AOverTCP(t *testing.T) {
303 os.Setenv(accessTokenEnvVariable, "")
304 oldRetry := retry.NewRetryer
305 defer func() { retry.NewRetryer = oldRetry }()
306 testRetryer := retry.NewRetryer()
307 retry.NewRetryer = func() *retry.S2ARetryer {
308 return testRetryer
309 }
310 serverTransportCreds := loadServerTransportCreds(t, mdsServerCertPem, mdsServerKeyPem)
311
312 serverHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
313 grpclog.Infof("Fake handshaker for server running at address: %v", serverHandshakerAddr)
314 clientHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
315 grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
316
317 clientTransportCreds := loadClientTransportCreds(t, mdsClientCertPem, mdsClientKeyPem)
318
319 serverAddr := startServer(t, serverHandshakerAddr, clientTransportCreds, false)
320 grpclog.Infof("Server running at address: %v", serverAddr)
321
322
323 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
324 defer cancel()
325 runClient(ctx, t, clientHandshakerAddr, clientTransportCreds, serverAddr, false, nil)
326 if got, want := testRetryer.Attempts(), 0; got != want {
327 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
328 }
329 }
330
331 func TestV2EndToEndUsingFakeMTLSS2AOverTCP_SelfSignedClientTransportCreds(t *testing.T) {
332 os.Setenv(accessTokenEnvVariable, "")
333 fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
334 oldRetry := retry.NewRetryer
335 defer func() { retry.NewRetryer = oldRetry }()
336 testRetryer := retry.NewRetryer()
337 retry.NewRetryer = func() *retry.S2ARetryer {
338 return testRetryer
339 }
340 serverTransportCreds := loadServerTransportCreds(t, mdsServerCertPem, mdsServerKeyPem)
341
342 serverHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
343 grpclog.Infof("Fake handshaker for server running at address: %v", serverHandshakerAddr)
344 clientHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
345 grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
346
347 clientTransportCreds := loadClientTransportCreds(t, mdsClientCertPem, mdsClientKeyPem)
348
349 selfSignedClientTransportCreds := loadClientTransportCreds(t, selfSignedCertPem, selfSignedKeyPem)
350
351 serverAddr := startServer(t, serverHandshakerAddr, clientTransportCreds, false)
352 fallbackServerAddr := startFallbackServer(t)
353 t.Logf("server running at address: %v", serverAddr)
354 t.Logf("fallback server running at address: %v", fallbackServerAddr)
355
356
357 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
358 defer cancel()
359 fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
360 if err != nil {
361 t.Errorf("error creating fallback handshake function: %v", err)
362 }
363 fallbackCalled := false
364 fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
365 fallbackCalled = true
366 return fallbackHandshake(ctx, targetServer, conn, err)
367 }
368
369
370
371 runClient(ctx, t, clientHandshakerAddr, selfSignedClientTransportCreds, serverAddr, false, fallbackHandshakeWrapper)
372 if !fallbackCalled {
373 t.Errorf("fallbackHandshake is not called")
374 }
375 if got, want := testRetryer.Attempts(), 5; got != want {
376 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
377 }
378 }
379
380 func loadServerTransportCreds(t *testing.T, cert, key []byte) credentials.TransportCredentials {
381 certificate, err := tls.X509KeyPair(cert, key)
382 if err != nil {
383 t.Errorf("failed to load S2A server cert/key: %v", err)
384 }
385 caPool := x509.NewCertPool()
386 if !caPool.AppendCertsFromPEM(mdsRootCertPem) {
387 t.Errorf("failed to add ca cert")
388 }
389 tlsConfig := &tls.Config{
390 ClientAuth: tls.RequireAndVerifyClientCert,
391 Certificates: []tls.Certificate{certificate},
392 ClientCAs: caPool,
393 }
394 return credentials.NewTLS(tlsConfig)
395 }
396
397 func loadClientTransportCreds(t *testing.T, cert, key []byte) credentials.TransportCredentials {
398 certificate, err := tls.X509KeyPair(cert, key)
399 if err != nil {
400 t.Errorf("failed to load S2A client cert/key: %v", err)
401 }
402 caPool := x509.NewCertPool()
403 if !caPool.AppendCertsFromPEM(mdsRootCertPem) {
404 t.Errorf("failed to add ca cert")
405 }
406 tlsConfig := &tls.Config{
407 Certificates: []tls.Certificate{certificate},
408 RootCAs: caPool,
409 }
410 return credentials.NewTLS(tlsConfig)
411 }
412
413
414
415 func startFallbackServer(t *testing.T) string {
416 lis, err := net.Listen("tcp", ":0")
417 if err != nil {
418 t.Errorf("net.Listen(tcp, :0) failed: %v", err)
419 }
420 cert, err := tls.X509KeyPair(serverCertpem, serverKeypem)
421 if err != nil {
422 t.Errorf("failure initializing tls.certificate: %v", err)
423 }
424
425 creds := credentials.NewTLS(&tls.Config{
426 MinVersion: tls.VersionTLS13,
427 MaxVersion: tls.VersionTLS13,
428 Certificates: []tls.Certificate{cert},
429 })
430 s := grpc.NewServer(grpc.Creds(creds))
431 helloworldpb.RegisterGreeterServer(s, &server{})
432 go func() {
433 if err := s.Serve(lis); err != nil {
434 t.Errorf("s.Serve(%v) failed: %v", lis, err)
435 }
436 }()
437 return lis.Addr().String()
438 }
439 func TestV2GRPCFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
440
441 fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
442 os.Setenv(accessTokenEnvVariable, testV2AccessToken)
443 oldRetry := retry.NewRetryer
444 defer func() { retry.NewRetryer = oldRetry }()
445 testRetryer := retry.NewRetryer()
446 retry.NewRetryer = func() *retry.S2ARetryer {
447 return testRetryer
448 }
449
450 serverHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
451 grpclog.Infof("fake handshaker for server running at address: %v", serverHandshakerAddr)
452
453
454 serverAddr := startServer(t, serverHandshakerAddr, nil, false)
455 fallbackServerAddr := startFallbackServer(t)
456 t.Logf("server running at address: %v", serverAddr)
457 t.Logf("fallback server running at address: %v", fallbackServerAddr)
458
459
460 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
461 defer cancel()
462 fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
463 if err != nil {
464 t.Errorf("error creating fallback handshake function: %v", err)
465 }
466 fallbackCalled := false
467 fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
468 fallbackCalled = true
469 return fallbackHandshake(ctx, targetServer, conn, err)
470 }
471 runClient(ctx, t, "not_exist", nil, serverAddr, false, fallbackHandshakeWrapper)
472 if !fallbackCalled {
473 t.Errorf("fallbackHandshake is not called")
474 }
475 if got, want := testRetryer.Attempts(), 5; got != want {
476 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
477 }
478 }
479
480 func TestV2GRPCRetryAndFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
481
482 fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
483
484 os.Setenv(accessTokenEnvVariable, "invalid_token")
485 oldRetry := retry.NewRetryer
486 defer func() { retry.NewRetryer = oldRetry }()
487 testRetryer := retry.NewRetryer()
488 retry.NewRetryer = func() *retry.S2ARetryer {
489 return testRetryer
490 }
491
492 serverHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
493 grpclog.Infof("fake handshaker for server running at address: %v", serverHandshakerAddr)
494 clientHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
495 grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
496
497
498 serverAddr := startServer(t, serverHandshakerAddr, nil, false)
499 fallbackServerAddr := startFallbackServer(t)
500 t.Logf("server running at address: %v", serverAddr)
501 t.Logf("fallback server running at address: %v", fallbackServerAddr)
502
503
504 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
505 defer cancel()
506 fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
507 if err != nil {
508 t.Errorf("error creating fallback handshake function: %v", err)
509 }
510 fallbackCalled := false
511 fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
512 fallbackCalled = true
513 return fallbackHandshake(ctx, targetServer, conn, err)
514 }
515 runClient(ctx, t, clientHandshakerAddr, nil, serverAddr, false, fallbackHandshakeWrapper)
516 if !fallbackCalled {
517 t.Errorf("fallbackHandshake is not called")
518 }
519 if got, want := testRetryer.Attempts(), 5; got != want {
520 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
521 }
522 }
523
524 func TestV1EndToEndUsingTokens(t *testing.T) {
525 os.Setenv(accessTokenEnvVariable, testAccessToken)
526
527
528 serverS2AAddress := startFakeS2A(t, true, "", nil)
529 grpclog.Infof("Fake S2A for server running at address: %v", serverS2AAddress)
530 clientS2AAddress := startFakeS2A(t, true, "", nil)
531 grpclog.Infof("Fake S2A for client running at address: %v", clientS2AAddress)
532
533
534 serverAddr := startServer(t, serverS2AAddress, nil, true)
535 grpclog.Infof("Server running at address: %v", serverAddr)
536
537
538 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
539 defer cancel()
540 runClient(ctx, t, clientS2AAddress, nil, serverAddr, true, nil)
541 }
542
543 func TestV2EndToEndUsingTokens(t *testing.T) {
544 os.Setenv(accessTokenEnvVariable, testV2AccessToken)
545
546
547 serverS2AAddress := startFakeS2A(t, false, testV2AccessToken, nil)
548 grpclog.Infof("Fake S2A for server running at address: %v", serverS2AAddress)
549 clientS2AAddress := startFakeS2A(t, false, testV2AccessToken, nil)
550 grpclog.Infof("Fake S2A for client running at address: %v", clientS2AAddress)
551
552
553 serverAddr := startServer(t, serverS2AAddress, nil, false)
554 grpclog.Infof("Server running at address: %v", serverAddr)
555
556
557 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
558 defer cancel()
559 runClient(ctx, t, clientS2AAddress, nil, serverAddr, false, nil)
560 }
561
562 func TestV2EndToEndEmptyToken(t *testing.T) {
563 os.Unsetenv(accessTokenEnvVariable)
564
565
566 serverS2AAddress := startFakeS2A(t, false, testV2AccessToken, nil)
567 grpclog.Infof("Fake S2A for server running at address: %v", serverS2AAddress)
568 clientS2AAddress := startFakeS2A(t, false, testV2AccessToken, nil)
569 grpclog.Infof("Fake S2A for client running at address: %v", clientS2AAddress)
570
571
572 serverAddr := startServer(t, serverS2AAddress, nil, false)
573 grpclog.Infof("Server running at address: %v", serverAddr)
574
575
576 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
577 defer cancel()
578 runClient(ctx, t, clientS2AAddress, nil, serverAddr, false, nil)
579 }
580
581 func TestV1EndToEndUsingFakeS2AOnUDS(t *testing.T) {
582 os.Setenv(accessTokenEnvVariable, "")
583
584
585 serverS2AAddress := startFakeS2AOnUDS(t, true, "")
586 grpclog.Infof("Fake S2A for server listening on UDS at address: %v", serverS2AAddress)
587 clientS2AAddress := startFakeS2AOnUDS(t, true, "")
588 grpclog.Infof("Fake S2A for client listening on UDS at address: %v", clientS2AAddress)
589
590
591 serverAddress := startServer(t, serverS2AAddress, nil, true)
592 grpclog.Infof("Server running at address: %v", serverS2AAddress)
593
594
595 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
596 defer cancel()
597 runClient(ctx, t, clientS2AAddress, nil, serverAddress, true, nil)
598 }
599
600 func TestV2EndToEndUsingFakeS2AOnUDS(t *testing.T) {
601 os.Setenv(accessTokenEnvVariable, testV2AccessToken)
602
603
604 serverS2AAddress := startFakeS2AOnUDS(t, false, testV2AccessToken)
605 grpclog.Infof("Fake S2A for server listening on UDS at address: %v", serverS2AAddress)
606 clientS2AAddress := startFakeS2AOnUDS(t, false, testV2AccessToken)
607 grpclog.Infof("Fake S2A for client listening on UDS at address: %v", clientS2AAddress)
608
609
610 serverAddress := startServer(t, serverS2AAddress, nil, false)
611 grpclog.Infof("Server running at address: %v", serverS2AAddress)
612
613
614 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
615 defer cancel()
616 runClient(ctx, t, clientS2AAddress, nil, serverAddress, false, nil)
617 }
618
619 func TestNewTLSClientConfigFactoryWithTokenManager(t *testing.T) {
620 os.Setenv(accessTokenEnvVariable, "TestNewTLSClientConfigFactory_token")
621 s2AAddr := startFakeS2A(t, false, "TestNewTLSClientConfigFactory_token", nil)
622 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
623 defer cancel()
624
625 factory, err := NewTLSClientConfigFactory(&ClientOptions{
626 S2AAddress: s2AAddr,
627 })
628 if err != nil {
629 t.Errorf("NewTLSClientConfigFactory() failed: %v", err)
630 }
631
632 config, err := factory.Build(ctx, nil)
633 if err != nil {
634 t.Errorf("Build tls config failed: %v", err)
635 }
636
637 cert, err := tls.X509KeyPair(clientCertpem, clientKeypem)
638 if err != nil {
639 t.Fatalf("tls.X509KeyPair failed: %v", err)
640 }
641
642 if got, want := config.Certificates[0].Certificate[0], cert.Certificate[0]; !bytes.Equal(got, want) {
643 t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
644 }
645 }
646
647 func TestNewTLSClientConfigFactoryWithoutTokenManager(t *testing.T) {
648 os.Unsetenv(accessTokenEnvVariable)
649 s2AAddr := startFakeS2A(t, false, "ignored-value", nil)
650 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
651 defer cancel()
652
653 factory, err := NewTLSClientConfigFactory(&ClientOptions{
654 S2AAddress: s2AAddr,
655 })
656 if err != nil {
657 t.Errorf("NewTLSClientConfigFactory() failed: %v", err)
658 }
659
660 config, err := factory.Build(ctx, nil)
661 if err != nil {
662 t.Errorf("Build tls config failed: %v", err)
663 }
664
665 cert, err := tls.X509KeyPair(clientCertpem, clientKeypem)
666 if err != nil {
667 t.Fatalf("tls.X509KeyPair failed: %v", err)
668 }
669 if got, want := config.Certificates[0].Certificate[0], cert.Certificate[0]; !bytes.Equal(got, want) {
670 t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
671 }
672 }
673
674
675
676
677 func startHTTPServer(t *testing.T, resp string) string {
678 cert, _ := tls.X509KeyPair(serverCertpem, serverKeypem)
679 tlsConfig := tls.Config{
680 MinVersion: tls.VersionTLS13,
681 MaxVersion: tls.VersionTLS13,
682 Certificates: []tls.Certificate{cert},
683 }
684 s := http.NewServeMux()
685 s.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) {
686 fmt.Fprintf(w, resp)
687 })
688 lis, err := tls.Listen("tcp", ":0", &tlsConfig)
689 if err != nil {
690 t.Errorf("net.Listen(tcp, :0) failed: %v", err)
691 }
692 go func() {
693 http.Serve(lis, s)
694 }()
695 return lis.Addr().String()
696 }
697
698
699
700 func runHTTPClient(t *testing.T, clientS2AAddress string, transportCreds credentials.TransportCredentials, serverAddr string, fallbackOpts *FallbackOptions) string {
701 dialTLSContext := NewS2ADialTLSContextFunc(&ClientOptions{
702 S2AAddress: clientS2AAddress,
703 TransportCreds: transportCreds,
704 FallbackOpts: fallbackOpts,
705 })
706
707 tr := http.Transport{
708 DialTLSContext: dialTLSContext,
709 }
710
711 client := &http.Client{Transport: &tr}
712 reqURL := fmt.Sprintf("https://%s/hello", serverAddr)
713 t.Logf("reqURL is set to: %v", reqURL)
714 req, err := http.NewRequest(http.MethodGet, reqURL, nil)
715 if err != nil {
716 t.Errorf("error creating new HTTP request: %v", err)
717 }
718 resp, err := client.Do(req)
719 if err != nil {
720 t.Errorf("error making client HTTP request: %v", err)
721 }
722 respBody, err := io.ReadAll(resp.Body)
723 if err != nil {
724 t.Errorf("error reading HTTP response: %v", err)
725 }
726 return string(respBody)
727 }
728 func TestHTTPEndToEndUsingFakeS2AOverTCP(t *testing.T) {
729 os.Setenv(accessTokenEnvVariable, testV2AccessToken)
730 oldRetry := retry.NewRetryer
731 defer func() { retry.NewRetryer = oldRetry }()
732 testRetryer := retry.NewRetryer()
733 retry.NewRetryer = func() *retry.S2ARetryer {
734 return testRetryer
735 }
736
737
738 clientHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
739 t.Logf("fake handshaker for client running at address: %v", clientHandshakerAddr)
740
741
742 serverAddr := startHTTPServer(t, "hello")
743 t.Logf("HTTP server running at address: %v", serverAddr)
744
745
746 resp := runHTTPClient(t, clientHandshakerAddr, nil, serverAddr, nil)
747
748 if got, want := resp, "hello"; got != want {
749 t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
750 }
751 if got, want := testRetryer.Attempts(), 0; got != want {
752 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
753 }
754 }
755
756 func TestHTTPEndToEndSUsingFakeMTLSS2AOverTCP(t *testing.T) {
757 os.Setenv(accessTokenEnvVariable, "")
758 oldRetry := retry.NewRetryer
759 defer func() { retry.NewRetryer = oldRetry }()
760 testRetryer := retry.NewRetryer()
761 retry.NewRetryer = func() *retry.S2ARetryer {
762 return testRetryer
763 }
764
765
766 serverTransportCreds := loadServerTransportCreds(t, mdsServerCertPem, mdsServerKeyPem)
767 clientHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
768 t.Logf("fake handshaker for client running at address: %v", clientHandshakerAddr)
769
770
771 serverAddr := startHTTPServer(t, "hello")
772 t.Logf("HTTP server running at address: %v", serverAddr)
773
774
775 clientTransportCreds := loadClientTransportCreds(t, mdsClientCertPem, mdsClientKeyPem)
776 resp := runHTTPClient(t, clientHandshakerAddr, clientTransportCreds, serverAddr, nil)
777
778 if got, want := resp, "hello"; got != want {
779 t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
780 }
781 if got, want := testRetryer.Attempts(), 0; got != want {
782 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
783 }
784 }
785
786 func TestHTTPEndToEndSUsingFakeMTLSS2AOverTCP_SelfSignedClientTransportCreds(t *testing.T) {
787 fallback.FallbackTLSConfigHTTP.InsecureSkipVerify = true
788 os.Setenv(accessTokenEnvVariable, "")
789 oldRetry := retry.NewRetryer
790 defer func() { retry.NewRetryer = oldRetry }()
791 testRetryer := retry.NewRetryer()
792 retry.NewRetryer = func() *retry.S2ARetryer {
793 return testRetryer
794 }
795
796
797 serverTransportCreds := loadServerTransportCreds(t, mdsServerCertPem, mdsServerKeyPem)
798 clientHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
799 t.Logf("fake handshaker for client running at address: %v", clientHandshakerAddr)
800
801 serverAddr := startHTTPServer(t, "hello")
802 t.Logf("HTTP server running at address: %v", serverAddr)
803
804 fallbackServerAddr := startHTTPServer(t, "hello fallback")
805 t.Logf("fallback HTTP server running at address: %v", fallbackServerAddr)
806
807
808 fbDialer, fbAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackServerAddr)
809 if err != nil {
810 t.Errorf("error creating fallback dialer: %v", err)
811 }
812 fallbackOpts := &FallbackOptions{
813 FallbackDialer: &FallbackDialer{
814 Dialer: fbDialer,
815 ServerAddr: fbAddr,
816 },
817 }
818
819 selfSignedClientTransportCreds := loadClientTransportCreds(t, selfSignedCertPem, selfSignedKeyPem)
820
821
822 resp := runHTTPClient(t, clientHandshakerAddr, selfSignedClientTransportCreds, serverAddr, fallbackOpts)
823 if got, want := resp, "hello fallback"; got != want {
824 t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
825 }
826
827 if got, want := testRetryer.Attempts(), 5; got != want {
828 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
829 }
830 }
831
832 func TestHTTPFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
833 fallback.FallbackTLSConfigHTTP.InsecureSkipVerify = true
834 os.Setenv(accessTokenEnvVariable, testV2AccessToken)
835 oldRetry := retry.NewRetryer
836 defer func() { retry.NewRetryer = oldRetry }()
837 testRetryer := retry.NewRetryer()
838 retry.NewRetryer = func() *retry.S2ARetryer {
839 return testRetryer
840 }
841
842
843 serverAddr := startHTTPServer(t, "hello")
844 t.Logf("HTTP server running at address: %v", serverAddr)
845
846 fallbackServerAddr := startHTTPServer(t, "hello fallback")
847 t.Logf("fallback HTTP server running at address: %v", fallbackServerAddr)
848
849
850 fbDialer, fbAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackServerAddr)
851 if err != nil {
852 t.Errorf("error creating fallback dialer: %v", err)
853 }
854
855 fallbackOpts := &FallbackOptions{
856 FallbackDialer: &FallbackDialer{
857 Dialer: fbDialer,
858 ServerAddr: fbAddr,
859 },
860 }
861
862 resp := runHTTPClient(t, "not_exist", nil, serverAddr, fallbackOpts)
863
864 if got, want := resp, "hello fallback"; got != want {
865 t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
866 }
867
868 if got, want := testRetryer.Attempts(), 5; got != want {
869 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
870 }
871 }
872
873 func TestHTTPRetryAndFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
874 fallback.FallbackTLSConfigHTTP.InsecureSkipVerify = true
875
876 os.Setenv(accessTokenEnvVariable, "invalid_token")
877 oldRetry := retry.NewRetryer
878 defer func() { retry.NewRetryer = oldRetry }()
879 testRetryer := retry.NewRetryer()
880 retry.NewRetryer = func() *retry.S2ARetryer {
881 return testRetryer
882 }
883
884
885 clientHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
886 t.Logf("fake handshaker for client running at address: %v", clientHandshakerAddr)
887
888 serverAddr := startHTTPServer(t, "hello")
889 t.Logf("HTTP server running at address: %v", serverAddr)
890
891 fallbackServerAddr := startHTTPServer(t, "hello fallback")
892 t.Logf("fallback HTTP server running at address: %v", fallbackServerAddr)
893
894
895 fbDialer, fbAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackServerAddr)
896 if err != nil {
897 t.Errorf("error creating fallback dialer: %v", err)
898 }
899
900 fallbackOpts := &FallbackOptions{
901 FallbackDialer: &FallbackDialer{
902 Dialer: fbDialer,
903 ServerAddr: fbAddr,
904 },
905 }
906
907 resp := runHTTPClient(t, clientHandshakerAddr, nil, serverAddr, fallbackOpts)
908
909 if got, want := resp, "hello fallback"; got != want {
910 t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
911 }
912
913 if got, want := testRetryer.Attempts(), 5; got != want {
914 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
915 }
916 }
917
View as plain text