1
18
19 package v2
20
21 import (
22 "bytes"
23 "context"
24 "crypto/tls"
25 "fmt"
26 "io/ioutil"
27 "net"
28 "os"
29 "path/filepath"
30 "testing"
31 "time"
32
33 _ "embed"
34
35 "github.com/google/s2a-go/fallback"
36 "github.com/google/s2a-go/internal/tokenmanager"
37 "github.com/google/s2a-go/internal/v2/fakes2av2"
38 "github.com/google/s2a-go/retry"
39 "google.golang.org/grpc/credentials"
40 "google.golang.org/grpc/grpclog"
41
42 grpc "google.golang.org/grpc"
43
44 commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
45 helloworldpb "github.com/google/s2a-go/internal/proto/examples/helloworld_go_proto"
46 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
47 )
48
49 const (
50 accessTokenEnvVariable = "S2A_ACCESS_TOKEN"
51 defaultE2ETimeout = time.Second * 5
52 clientMessage = "echo"
53 )
54
55 var (
56
57 clientCertpem []byte
58
59 clientKeypem []byte
60
61 serverCertpem []byte
62
63 serverKeypem []byte
64 )
65
66
67 type server struct {
68 helloworldpb.UnimplementedGreeterServer
69 }
70
71
72 func (s *server) SayHello(_ context.Context, in *helloworldpb.HelloRequest) (*helloworldpb.HelloReply, error) {
73 return &helloworldpb.HelloReply{Message: "Hello " + in.GetName()}, nil
74 }
75
76
77
78 func startFakeS2A(t *testing.T, expToken string) string {
79 lis, err := net.Listen("tcp", ":")
80 if err != nil {
81 t.Errorf("net.Listen(tcp, :0) failed: %v", err)
82 }
83 s := grpc.NewServer()
84 s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{ExpectedToken: expToken})
85 go func() {
86 if err := s.Serve(lis); err != nil {
87 t.Errorf("s.Serve(%v) failed: %v", lis, err)
88 }
89 }()
90 return lis.Addr().String()
91 }
92
93
94
95 func startFakeS2AOnUDS(t *testing.T, expToken string) string {
96 dir, err := ioutil.TempDir("/tmp", "socket_dir")
97 if err != nil {
98 t.Errorf("Unable to create temporary directory: %v", err)
99 }
100 udsAddress := filepath.Join(dir, "socket")
101 lis, err := net.Listen("unix", filepath.Join(dir, "socket"))
102 if err != nil {
103 t.Errorf("net.Listen(unix, %s) failed: %v", udsAddress, err)
104 }
105 s := grpc.NewServer()
106 s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{ExpectedToken: expToken})
107 go func() {
108 if err := s.Serve(lis); err != nil {
109 t.Errorf("s.Serve(%v) failed: %v", lis, err)
110 }
111 }()
112 return fmt.Sprintf("unix://%s", lis.Addr().String())
113 }
114
115
116
117 func startServer(t *testing.T, s2aAddress string, localIdentities []*commonpbv1.Identity) string {
118
119 creds, err := NewServerCreds(s2aAddress, nil, localIdentities, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil)
120 if err != nil {
121 t.Errorf("NewServerCreds(%s) failed: %v", s2aAddress, err)
122 }
123
124 lis, err := net.Listen("tcp", ":0")
125 if err != nil {
126 t.Errorf("net.Listen(tcp, :0) failed: %v", err)
127 }
128 s := grpc.NewServer(grpc.Creds(creds))
129 helloworldpb.RegisterGreeterServer(s, &server{})
130 go func() {
131 if err := s.Serve(lis); err != nil {
132 t.Errorf("s.Serve(%v) failed: %v", lis, err)
133 }
134 }()
135 return lis.Addr().String()
136 }
137
138
139
140 func startFallbackServer(t *testing.T) string {
141 lis, err := net.Listen("tcp", ":0")
142 if err != nil {
143 t.Errorf("net.Listen(tcp, :0) failed: %v", err)
144 }
145 cert, err := tls.X509KeyPair(serverCertpem, serverKeypem)
146 if err != nil {
147 t.Errorf("failure initializing tls.certificate: %v", err)
148 }
149
150 creds := credentials.NewTLS(&tls.Config{
151 MinVersion: tls.VersionTLS13,
152 MaxVersion: tls.VersionTLS13,
153 Certificates: []tls.Certificate{cert},
154 })
155 s := grpc.NewServer(grpc.Creds(creds))
156 helloworldpb.RegisterGreeterServer(s, &server{})
157 go func() {
158 if err := s.Serve(lis); err != nil {
159 t.Errorf("s.Serve(%v) failed: %v", lis, err)
160 }
161 }()
162 return lis.Addr().String()
163 }
164
165
166 func runClient(ctx context.Context, t *testing.T, clientS2AAddress, serverAddr string, localIdentity *commonpbv1.Identity, fallbackHandshake fallback.ClientHandshake) {
167 creds, err := NewClientCreds(clientS2AAddress, nil, localIdentity, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, fallbackHandshake, nil, nil)
168 if err != nil {
169 t.Errorf("NewClientCreds(%s) failed: %v", clientS2AAddress, err)
170 }
171 dialOptions := []grpc.DialOption{
172 grpc.WithTransportCredentials(creds),
173 grpc.WithBlock(),
174 }
175
176 grpclog.Info("Client dialing server at address: %v", serverAddr)
177
178 conn, err := grpc.Dial(serverAddr, dialOptions...)
179 if err != nil {
180 t.Errorf("grpc.Dial(%v, %v) failed: %v", serverAddr, dialOptions, err)
181 }
182 defer conn.Close()
183
184
185 c := helloworldpb.NewGreeterClient(conn)
186 req := &helloworldpb.HelloRequest{Name: clientMessage}
187 grpclog.Infof("Client calling SayHello with request: %v", req)
188 resp, err := c.SayHello(ctx, req, grpc.WaitForReady(true))
189 if err != nil {
190 t.Errorf("c.SayHello(%v, %v) failed: %v", ctx, req, err)
191 }
192 if got, want := resp.GetMessage(), "Hello "+clientMessage; got != want {
193 t.Errorf("r.GetMessage() = %v, want %v", got, want)
194 }
195 grpclog.Infof("Client received message from server: %s", resp.GetMessage())
196 }
197
198 func TestEndToEndUsingFakeS2AOverTCP(t *testing.T) {
199 os.Setenv(accessTokenEnvVariable, "TestE2ETCP_token")
200 oldRetry := retry.NewRetryer
201 defer func() { retry.NewRetryer = oldRetry }()
202 testRetryer := retry.NewRetryer()
203 retry.NewRetryer = func() *retry.S2ARetryer {
204 return testRetryer
205 }
206
207 serverS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
208 grpclog.Infof("Fake handshaker for server running at address: %v", serverS2AAddr)
209 clientS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
210 grpclog.Infof("Fake handshaker for client running at address: %v", clientS2AAddr)
211
212
213 localIdentities := []*commonpbv1.Identity{
214 {
215 IdentityOneof: &commonpbv1.Identity_Hostname{
216 Hostname: "test_rsa_server_identity",
217 },
218 },
219 }
220 serverAddr := startServer(t, serverS2AAddr, localIdentities)
221 grpclog.Infof("Server running at address: %v", serverAddr)
222
223
224 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
225 defer cancel()
226 runClient(ctx, t, clientS2AAddr, serverAddr, &commonpbv1.Identity{
227 IdentityOneof: &commonpbv1.Identity_Hostname{
228 Hostname: "test_rsa_client_identity",
229 },
230 }, nil)
231 if got, want := testRetryer.Attempts(), 0; got != want {
232 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
233 }
234 }
235
236 func TestEndToEndUsingFakeS2AOverTCPEmptyId(t *testing.T) {
237 os.Setenv(accessTokenEnvVariable, "TestE2ETCP_token")
238
239 serverS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
240 grpclog.Infof("Fake handshaker for server running at address: %v", serverS2AAddr)
241 clientS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
242 grpclog.Infof("Fake handshaker for client running at address: %v", clientS2AAddr)
243
244
245 var localIdentities []*commonpbv1.Identity
246 localIdentities = append(localIdentities, nil)
247 serverAddr := startServer(t, serverS2AAddr, localIdentities)
248 grpclog.Infof("Server running at address: %v", serverAddr)
249
250
251 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
252 defer cancel()
253 runClient(ctx, t, clientS2AAddr, serverAddr, nil, nil)
254 }
255
256 func TestEndToEndUsingFakeS2AOnUDS(t *testing.T) {
257 os.Setenv(accessTokenEnvVariable, "TestE2EUDS_token")
258
259 serverS2AAddr := startFakeS2AOnUDS(t, "TestE2EUDS_token")
260 grpclog.Infof("Fake S2A for server listening on UDS at address: %v", serverS2AAddr)
261 clientS2AAddr := startFakeS2AOnUDS(t, "TestE2EUDS_token")
262 grpclog.Infof("Fake S2A for client listening on UDS at address: %v", clientS2AAddr)
263
264
265 localIdentities := []*commonpbv1.Identity{
266 {
267 IdentityOneof: &commonpbv1.Identity_Hostname{
268 Hostname: "test_rsa_server_identity",
269 },
270 },
271 }
272 serverAddr := startServer(t, serverS2AAddr, localIdentities)
273 grpclog.Infof("Server running at address: %v", serverAddr)
274
275
276 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
277 defer cancel()
278 runClient(ctx, t, clientS2AAddr, serverAddr, &commonpbv1.Identity{
279 IdentityOneof: &commonpbv1.Identity_Hostname{
280 Hostname: "test_rsa_client_identity",
281 },
282 }, nil)
283 }
284
285 func TestEndToEndUsingFakeS2AOnUDSEmptyId(t *testing.T) {
286 os.Setenv(accessTokenEnvVariable, "TestE2EUDS_token")
287
288 serverS2AAddr := startFakeS2AOnUDS(t, "TestE2EUDS_token")
289 grpclog.Infof("Fake S2A for server listening on UDS at address: %v", serverS2AAddr)
290 clientS2AAddr := startFakeS2AOnUDS(t, "TestE2EUDS_token")
291 grpclog.Infof("Fake S2A for client listening on UDS at address: %v", clientS2AAddr)
292
293
294 var localIdentities []*commonpbv1.Identity
295 localIdentities = append(localIdentities, nil)
296 serverAddr := startServer(t, serverS2AAddr, localIdentities)
297 grpclog.Infof("Server running at address: %v", serverAddr)
298
299
300 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
301 defer cancel()
302 runClient(ctx, t, clientS2AAddr, serverAddr, nil, nil)
303 }
304
305 func TestGRPCFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
306
307 fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
308 os.Setenv(accessTokenEnvVariable, "TestE2ETCP_token")
309 oldRetry := retry.NewRetryer
310 defer func() { retry.NewRetryer = oldRetry }()
311 testRetryer := retry.NewRetryer()
312 retry.NewRetryer = func() *retry.S2ARetryer {
313 return testRetryer
314 }
315
316
317 serverS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
318 t.Logf("Fake handshaker for server running at address: %v", serverS2AAddr)
319
320
321 localIdentities := []*commonpbv1.Identity{
322 {
323 IdentityOneof: &commonpbv1.Identity_Hostname{
324 Hostname: "test_rsa_server_identity",
325 },
326 },
327 }
328 serverAddr := startServer(t, serverS2AAddr, localIdentities)
329 fallbackServerAddr := startFallbackServer(t)
330 t.Logf("server running at address: %v", serverAddr)
331 t.Logf("fallback server running at address: %v", fallbackServerAddr)
332
333
334 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
335 defer cancel()
336 fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
337 if err != nil {
338 t.Errorf("error creating fallback handshake function: %v", err)
339 }
340 fallbackCalled := false
341 fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
342 fallbackCalled = true
343 return fallbackHandshake(ctx, targetServer, conn, err)
344 }
345
346 runClient(ctx, t, "not_exist", serverAddr, &commonpbv1.Identity{
347 IdentityOneof: &commonpbv1.Identity_Hostname{
348 Hostname: "test_rsa_client_identity",
349 },
350 }, fallbackHandshakeWrapper)
351
352 if !fallbackCalled {
353 t.Errorf("fallbackHandshake is not called")
354 }
355 if got, want := testRetryer.Attempts(), 5; got != want {
356 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
357 }
358 }
359
360 func TestGRPCRetryAndFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
361
362 fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
363
364 os.Setenv(accessTokenEnvVariable, "invalid_token")
365 oldRetry := retry.NewRetryer
366 defer func() { retry.NewRetryer = oldRetry }()
367 testRetryer := retry.NewRetryer()
368 retry.NewRetryer = func() *retry.S2ARetryer {
369 return testRetryer
370 }
371
372 clientS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
373 grpclog.Infof("Fake handshaker for client running at address: %v", clientS2AAddr)
374 serverS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
375 grpclog.Infof("Fake handshaker for server running at address: %v", serverS2AAddr)
376
377
378 localIdentities := []*commonpbv1.Identity{
379 {
380 IdentityOneof: &commonpbv1.Identity_Hostname{
381 Hostname: "test_rsa_server_identity",
382 },
383 },
384 }
385 serverAddr := startServer(t, serverS2AAddr, localIdentities)
386 fallbackServerAddr := startFallbackServer(t)
387 t.Logf("server running at address: %v", serverAddr)
388 t.Logf("fallback server running at address: %v", fallbackServerAddr)
389
390
391 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
392 defer cancel()
393 fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
394 if err != nil {
395 t.Errorf("error creating fallback handshake function: %v", err)
396 }
397 fallbackCalled := false
398 fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
399 fallbackCalled = true
400 return fallbackHandshake(ctx, targetServer, conn, err)
401 }
402 runClient(ctx, t, clientS2AAddr, serverAddr, &commonpbv1.Identity{
403 IdentityOneof: &commonpbv1.Identity_Hostname{
404 Hostname: "test_rsa_client_identity",
405 },
406 }, fallbackHandshakeWrapper)
407
408 if !fallbackCalled {
409 t.Errorf("fallbackHandshake is not called")
410 }
411 if got, want := testRetryer.Attempts(), 5; got != want {
412 t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
413 }
414 }
415
416 func TestNewClientTlsConfigWithTokenManager(t *testing.T) {
417 os.Setenv(accessTokenEnvVariable, "TestNewClientTlsConfig_token")
418 s2AAddr := startFakeS2A(t, "TestNewClientTlsConfig_token")
419 accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
420 if err != nil {
421 t.Errorf("tokenmanager.NewSingleTokenAccessTokenManager() failed: %v", err)
422 }
423 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
424 defer cancel()
425 config, err := NewClientTLSConfig(ctx, s2AAddr, nil, accessTokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil)
426 if err != nil {
427 t.Errorf("NewClientTLSConfig() failed: %v", err)
428 }
429
430 cert, err := tls.X509KeyPair(clientCertpem, clientKeypem)
431 if err != nil {
432 t.Fatalf("tls.X509KeyPair failed: %v", err)
433 }
434 if got, want := config.Certificates[0].Certificate[0], cert.Certificate[0]; !bytes.Equal(got, want) {
435 t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
436 }
437 }
438
439 func TestNewClientTlsConfigWithoutTokenManager(t *testing.T) {
440 os.Unsetenv(accessTokenEnvVariable)
441 s2AAddr := startFakeS2A(t, "ignored-value")
442 var tokenManager tokenmanager.AccessTokenManager
443 ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
444 defer cancel()
445 config, err := NewClientTLSConfig(ctx, s2AAddr, nil, tokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil)
446 if err != nil {
447 t.Errorf("NewClientTLSConfig() failed: %v", err)
448 }
449
450 cert, err := tls.X509KeyPair(clientCertpem, clientKeypem)
451 if err != nil {
452 t.Fatalf("tls.X509KeyPair failed: %v", err)
453 }
454 if got, want := config.Certificates[0].Certificate[0], cert.Certificate[0]; !bytes.Equal(got, want) {
455 t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
456 }
457 }
458
View as plain text