1
2
3
4
21
22 package alts
23
24 import (
25 "context"
26 "reflect"
27 "sync"
28 "testing"
29 "time"
30
31 "google.golang.org/grpc"
32 "google.golang.org/grpc/codes"
33 "google.golang.org/grpc/credentials/alts/internal/handshaker"
34 "google.golang.org/grpc/credentials/alts/internal/handshaker/service"
35 altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
36 altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
37 "google.golang.org/grpc/credentials/alts/internal/testutil"
38 "google.golang.org/grpc/internal/grpctest"
39 "google.golang.org/grpc/internal/testutils"
40 testgrpc "google.golang.org/grpc/interop/grpc_testing"
41 testpb "google.golang.org/grpc/interop/grpc_testing"
42 "google.golang.org/grpc/peer"
43 "google.golang.org/grpc/status"
44 "google.golang.org/protobuf/proto"
45 )
46
47 const (
48 defaultTestLongTimeout = 60 * time.Second
49 defaultTestShortTimeout = 10 * time.Millisecond
50 )
51
52 type s struct {
53 grpctest.Tester
54 }
55
56 func init() {
57
58
59
60 once.Do(func() {})
61 vmOnGCP = true
62 }
63
64 func Test(t *testing.T) {
65 grpctest.RunSubTests(t, s{})
66 }
67
68 func (s) TestInfoServerName(t *testing.T) {
69
70
71 alts := NewServerCreds(DefaultServerOptions())
72 if got, want := alts.Info().ServerName, ""; got != want {
73 t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
74 }
75 }
76
77 func (s) TestOverrideServerName(t *testing.T) {
78 wantServerName := "server.name"
79
80
81 c := NewServerCreds(DefaultServerOptions())
82 c.OverrideServerName(wantServerName)
83 if got, want := c.Info().ServerName, wantServerName; got != want {
84 t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
85 }
86 }
87
88 func (s) TestCloneClient(t *testing.T) {
89 wantServerName := "server.name"
90 opt := DefaultClientOptions()
91 opt.TargetServiceAccounts = []string{"not", "empty"}
92 c := NewClientCreds(opt)
93 c.OverrideServerName(wantServerName)
94 cc := c.Clone()
95 if got, want := cc.Info().ServerName, wantServerName; got != want {
96 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
97 }
98 cc.OverrideServerName("")
99 if got, want := c.Info().ServerName, wantServerName; got != want {
100 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
101 }
102 if got, want := cc.Info().ServerName, ""; got != want {
103 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
104 }
105
106 ct := c.(*altsTC)
107 cct := cc.(*altsTC)
108
109 if ct.side != cct.side {
110 t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
111 }
112 if ct.hsAddress != cct.hsAddress {
113 t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
114 }
115 if !reflect.DeepEqual(ct.accounts, cct.accounts) {
116 t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
117 }
118 }
119
120 func (s) TestCloneServer(t *testing.T) {
121 wantServerName := "server.name"
122 c := NewServerCreds(DefaultServerOptions())
123 c.OverrideServerName(wantServerName)
124 cc := c.Clone()
125 if got, want := cc.Info().ServerName, wantServerName; got != want {
126 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
127 }
128 cc.OverrideServerName("")
129 if got, want := c.Info().ServerName, wantServerName; got != want {
130 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
131 }
132 if got, want := cc.Info().ServerName, ""; got != want {
133 t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
134 }
135
136 ct := c.(*altsTC)
137 cct := cc.(*altsTC)
138
139 if ct.side != cct.side {
140 t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
141 }
142 if ct.hsAddress != cct.hsAddress {
143 t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
144 }
145 if !reflect.DeepEqual(ct.accounts, cct.accounts) {
146 t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
147 }
148 }
149
150 func (s) TestInfo(t *testing.T) {
151
152
153 c := NewServerCreds(DefaultServerOptions())
154 info := c.Info()
155 if got, want := info.ProtocolVersion, ""; got != want {
156 t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
157 }
158 if got, want := info.SecurityProtocol, "alts"; got != want {
159 t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
160 }
161 if got, want := info.SecurityVersion, "1.0"; got != want {
162 t.Errorf("info.SecurityVersion=%v, want %v", got, want)
163 }
164 if got, want := info.ServerName, ""; got != want {
165 t.Errorf("info.ServerName=%v, want %v", got, want)
166 }
167 }
168
169 func (s) TestCompareRPCVersions(t *testing.T) {
170 for _, tc := range []struct {
171 v1 *altspb.RpcProtocolVersions_Version
172 v2 *altspb.RpcProtocolVersions_Version
173 output int
174 }{
175 {
176 version(3, 2),
177 version(2, 1),
178 1,
179 },
180 {
181 version(3, 2),
182 version(3, 1),
183 1,
184 },
185 {
186 version(2, 1),
187 version(3, 2),
188 -1,
189 },
190 {
191 version(3, 1),
192 version(3, 2),
193 -1,
194 },
195 {
196 version(3, 2),
197 version(3, 2),
198 0,
199 },
200 } {
201 if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
202 t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
203 }
204 }
205 }
206
207 func (s) TestCheckRPCVersions(t *testing.T) {
208 for _, tc := range []struct {
209 desc string
210 local *altspb.RpcProtocolVersions
211 peer *altspb.RpcProtocolVersions
212 output bool
213 maxCommonVersion *altspb.RpcProtocolVersions_Version
214 }{
215 {
216 "local.max > peer.max and local.min > peer.min",
217 versions(2, 1, 3, 2),
218 versions(1, 2, 2, 1),
219 true,
220 version(2, 1),
221 },
222 {
223 "local.max > peer.max and local.min < peer.min",
224 versions(1, 2, 3, 2),
225 versions(2, 1, 2, 1),
226 true,
227 version(2, 1),
228 },
229 {
230 "local.max > peer.max and local.min = peer.min",
231 versions(2, 1, 3, 2),
232 versions(2, 1, 2, 1),
233 true,
234 version(2, 1),
235 },
236 {
237 "local.max < peer.max and local.min > peer.min",
238 versions(2, 1, 2, 1),
239 versions(1, 2, 3, 2),
240 true,
241 version(2, 1),
242 },
243 {
244 "local.max = peer.max and local.min > peer.min",
245 versions(2, 1, 2, 1),
246 versions(1, 2, 2, 1),
247 true,
248 version(2, 1),
249 },
250 {
251 "local.max < peer.max and local.min < peer.min",
252 versions(1, 2, 2, 1),
253 versions(2, 1, 3, 2),
254 true,
255 version(2, 1),
256 },
257 {
258 "local.max < peer.max and local.min = peer.min",
259 versions(1, 2, 2, 1),
260 versions(1, 2, 3, 2),
261 true,
262 version(2, 1),
263 },
264 {
265 "local.max = peer.max and local.min < peer.min",
266 versions(1, 2, 2, 1),
267 versions(2, 1, 2, 1),
268 true,
269 version(2, 1),
270 },
271 {
272 "all equal",
273 versions(2, 1, 2, 1),
274 versions(2, 1, 2, 1),
275 true,
276 version(2, 1),
277 },
278 {
279 "max is smaller than min",
280 versions(2, 1, 1, 2),
281 versions(2, 1, 1, 2),
282 false,
283 nil,
284 },
285 {
286 "no overlap, local > peer",
287 versions(4, 3, 6, 5),
288 versions(1, 0, 2, 1),
289 false,
290 nil,
291 },
292 {
293 "no overlap, local < peer",
294 versions(1, 0, 2, 1),
295 versions(4, 3, 6, 5),
296 false,
297 nil,
298 },
299 {
300 "no overlap, max < min",
301 versions(6, 5, 4, 3),
302 versions(2, 1, 1, 0),
303 false,
304 nil,
305 },
306 } {
307 output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
308 if got, want := output, tc.output; got != want {
309 t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
310 }
311 if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
312 t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
313 }
314 }
315 }
316
317
318
319
320 func (s) TestFullHandshake(t *testing.T) {
321
322 var wait sync.WaitGroup
323 defer wait.Wait()
324 stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait)
325 defer stopHandshaker()
326 stopServer, serverAddress := startServer(t, handshakerAddress, &wait)
327 defer stopServer()
328
329
330 establishAltsConnection(t, handshakerAddress, serverAddress)
331
332
333 if err := service.CloseForTesting(); err != nil {
334 t.Errorf("service.CloseForTesting() failed: %v", err)
335 }
336 }
337
338
339
340
341 func (s) TestConcurrentHandshakes(t *testing.T) {
342
343
344
345 handshaker.ResetConcurrentHandshakeSemaphoreForTesting(3)
346
347
348 var wait sync.WaitGroup
349 defer wait.Wait()
350 stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait)
351 defer stopHandshaker()
352 stopServer, serverAddress := startServer(t, handshakerAddress, &wait)
353 defer stopServer()
354
355
356 var waitForConnections sync.WaitGroup
357 for i := 0; i < 10; i++ {
358 waitForConnections.Add(1)
359 go func() {
360 establishAltsConnection(t, handshakerAddress, serverAddress)
361 waitForConnections.Done()
362 }()
363 }
364 waitForConnections.Wait()
365
366
367 if err := service.CloseForTesting(); err != nil {
368 t.Errorf("service.CloseForTesting() failed: %v", err)
369 }
370 }
371
372 func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
373 return &altspb.RpcProtocolVersions_Version{
374 Major: major,
375 Minor: minor,
376 }
377 }
378
379 func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
380 return &altspb.RpcProtocolVersions{
381 MinRpcVersion: version(minMajor, minMinor),
382 MaxRpcVersion: version(maxMajor, maxMinor),
383 }
384 }
385
386 func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) {
387 clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
388 conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds))
389 if err != nil {
390 t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err)
391 }
392 defer conn.Close()
393 ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
394 defer cancel()
395 c := testgrpc.NewTestServiceClient(conn)
396 var peer peer.Peer
397 success := false
398 for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
399 _, err = c.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.Peer(&peer))
400 if err == nil {
401 success = true
402 break
403 }
404 if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded {
405
406
407 continue
408 }
409 t.Fatalf("c.UnaryCall() failed: %v", err)
410 }
411 if !success {
412 t.Fatalf("c.UnaryCall() timed out after %v", defaultTestShortTimeout)
413 }
414
415
416
417
418 if got, want := peer.AuthInfo.AuthType(), "alts"; got != want {
419 t.Errorf("authInfo.AuthType() = %s, want = %s", got, want)
420 }
421 authInfo, err := AuthInfoFromPeer(&peer)
422 if err != nil {
423 t.Errorf("AuthInfoFromPeer failed: %v", err)
424 }
425 if got, want := authInfo.ApplicationProtocol(), "grpc"; got != want {
426 t.Errorf("authInfo.ApplicationProtocol() = %s, want = %s", got, want)
427 }
428 }
429
430 func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
431 listener, err := testutils.LocalTCPListener()
432 if err != nil {
433 t.Fatalf("LocalTCPListener() failed: %v", err)
434 }
435 s := grpc.NewServer()
436 altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{})
437 wait.Add(1)
438 go func() {
439 defer wait.Done()
440 if err := s.Serve(listener); err != nil {
441 t.Errorf("failed to serve: %v", err)
442 }
443 }()
444 return func() { s.Stop() }, listener.Addr().String()
445 }
446
447 func startServer(t *testing.T, handshakerServiceAddress string, wait *sync.WaitGroup) (stop func(), address string) {
448 listener, err := testutils.LocalTCPListener()
449 if err != nil {
450 t.Fatalf("LocalTCPListener() failed: %v", err)
451 }
452 serverOpts := &ServerOptions{HandshakerServiceAddress: handshakerServiceAddress}
453 creds := NewServerCreds(serverOpts)
454 s := grpc.NewServer(grpc.Creds(creds))
455 testgrpc.RegisterTestServiceServer(s, &testServer{})
456 wait.Add(1)
457 go func() {
458 defer wait.Done()
459 if err := s.Serve(listener); err != nil {
460 t.Errorf("s.Serve(%v) failed: %v", listener, err)
461 }
462 }()
463 return func() { s.Stop() }, listener.Addr().String()
464 }
465
466 type testServer struct {
467 testgrpc.UnimplementedTestServiceServer
468 }
469
470 func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
471 return &testpb.SimpleResponse{
472 Payload: &testpb.Payload{},
473 }, nil
474 }
475
View as plain text