1
16
17 package ttrpc
18
19 import (
20 "bytes"
21 "context"
22 "errors"
23 "fmt"
24 "net"
25 "runtime"
26 "strings"
27 "sync"
28 "syscall"
29 "testing"
30 "time"
31
32 "github.com/containerd/ttrpc/internal"
33 "google.golang.org/grpc/codes"
34 "google.golang.org/grpc/status"
35 "google.golang.org/protobuf/proto"
36 )
37
38 const serviceName = "testService"
39
40
41
42
43
44 type testingService interface {
45 Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error)
46 }
47
48 type testingClient struct {
49 client *Client
50 }
51
52 func newTestingClient(client *Client) *testingClient {
53 return &testingClient{
54 client: client,
55 }
56 }
57
58 func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
59 var tp internal.TestPayload
60 return &tp, tc.client.Call(ctx, serviceName, "Test", req, &tp)
61 }
62
63
64 type testingServer struct{}
65
66 func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
67 tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)}
68 if dl, ok := ctx.Deadline(); ok {
69 tp.Deadline = dl.UnixNano()
70 }
71
72 if v, ok := GetMetadataValue(ctx, "foo"); ok {
73 tp.Metadata = v
74 }
75
76 return tp, nil
77 }
78
79
80
81
82 func registerTestingService(srv *Server, svc testingService) {
83 srv.Register(serviceName, map[string]Method{
84 "Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
85 var req internal.TestPayload
86 if err := unmarshal(&req); err != nil {
87 return nil, err
88 }
89 return svc.Test(ctx, &req)
90 },
91 })
92 }
93
94 func protoEqual(a, b proto.Message) (bool, error) {
95 ma, err := proto.Marshal(a)
96 if err != nil {
97 return false, err
98 }
99 mb, err := proto.Marshal(b)
100 if err != nil {
101 return false, err
102 }
103 return bytes.Equal(ma, mb), nil
104 }
105
106 func TestServer(t *testing.T) {
107 var (
108 ctx = context.Background()
109 server = mustServer(t)(NewServer())
110 testImpl = &testingServer{}
111 addr, listener = newTestListener(t)
112 client, cleanup = newTestClient(t, addr)
113 tclient = newTestingClient(client)
114 )
115
116 defer listener.Close()
117 defer cleanup()
118
119 registerTestingService(server, testImpl)
120
121 go server.Serve(ctx, listener)
122 defer server.Shutdown(ctx)
123
124 testCases := []string{"bar", "baz"}
125 results := make(chan callResult, len(testCases))
126 for _, tc := range testCases {
127 go func(expected string) {
128 results <- roundTrip(ctx, tclient, expected)
129 }(tc)
130 }
131
132 for i := 0; i < len(testCases); {
133 result := <-results
134 if result.err != nil {
135 t.Fatalf("(%s): %v", result.name, result.err)
136 }
137 equal, err := protoEqual(result.received, result.expected)
138 if err != nil {
139 t.Fatalf("failed to compare %s and %s: %s", result.received, result.expected, err)
140 }
141 if !equal {
142 t.Fatalf("unexpected response: %+#v != %+#v", result.received, result.expected)
143 }
144 i++
145 }
146 }
147
148 func TestServerUnimplemented(t *testing.T) {
149 var (
150 ctx = context.Background()
151 server = mustServer(t)(NewServer())
152 addr, listener = newTestListener(t)
153 errs = make(chan error, 1)
154 client, cleanup = newTestClient(t, addr)
155 )
156 defer cleanup()
157 defer listener.Close()
158 go func() {
159 errs <- server.Serve(ctx, listener)
160 }()
161
162 var tp internal.TestPayload
163 if err := client.Call(ctx, "Not", "Found", &tp, &tp); err == nil {
164 t.Fatalf("expected error from non-existent service call")
165 } else if status, ok := status.FromError(err); !ok {
166 t.Fatalf("expected status present in error: %v", err)
167 } else if status.Code() != codes.Unimplemented {
168 t.Fatalf("expected not found for method")
169 }
170
171 if err := server.Shutdown(ctx); err != nil {
172 t.Fatal(err)
173 }
174 if err := <-errs; err != ErrServerClosed {
175 t.Fatal(err)
176 }
177 }
178
179 func TestServerListenerClosed(t *testing.T) {
180 var (
181 ctx = context.Background()
182 server = mustServer(t)(NewServer())
183 _, listener = newTestListener(t)
184 errs = make(chan error, 1)
185 )
186
187 go func() {
188 errs <- server.Serve(ctx, listener)
189 }()
190
191 if err := listener.Close(); err != nil {
192 t.Fatal(err)
193 }
194
195 err := <-errs
196 if err == nil {
197 t.Fatal(err)
198 }
199 }
200
201 func TestServerShutdown(t *testing.T) {
202 const ncalls = 5
203 var (
204 ctx = context.Background()
205 server = mustServer(t)(NewServer())
206 addr, listener = newTestListener(t)
207 shutdownStarted = make(chan struct{})
208 shutdownFinished = make(chan struct{})
209 handlersStarted sync.WaitGroup
210 proceed = make(chan struct{})
211 serveErrs = make(chan error, 1)
212 callErrs = make(chan error, ncalls)
213 shutdownErrs = make(chan error, 1)
214 client, cleanup = newTestClient(t, addr)
215 _, cleanup2 = newTestClient(t, addr)
216 )
217 defer cleanup()
218 defer cleanup2()
219
220
221 server.Register(serviceName, map[string]Method{
222 "Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
223 var req internal.TestPayload
224 if err := unmarshal(&req); err != nil {
225 return nil, err
226 }
227
228 handlersStarted.Done()
229 <-proceed
230 return &internal.TestPayload{Foo: "waited"}, nil
231 },
232 })
233
234 go func() {
235 serveErrs <- server.Serve(ctx, listener)
236 }()
237
238
239 for i := 0; i < ncalls; i++ {
240 handlersStarted.Add(1)
241 go func(i int) {
242 tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
243 callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
244 }(i)
245 }
246
247 handlersStarted.Wait()
248 go func() {
249 close(shutdownStarted)
250 shutdownErrs <- server.Shutdown(ctx)
251 close(shutdownFinished)
252 }()
253
254 <-shutdownStarted
255 close(proceed)
256 <-shutdownFinished
257
258 for i := 0; i < ncalls; i++ {
259 if err := <-callErrs; err != nil && err != ErrClosed {
260 t.Fatal(err)
261 }
262 }
263
264 if err := <-shutdownErrs; err != nil {
265 t.Fatal(err)
266 }
267
268 if err := <-serveErrs; err != ErrServerClosed {
269 t.Fatal(err)
270 }
271 checkServerShutdown(t, server)
272 }
273
274 func TestServerClose(t *testing.T) {
275 var (
276 ctx = context.Background()
277 server = mustServer(t)(NewServer())
278 _, listener = newTestListener(t)
279 startClose = make(chan struct{})
280 errs = make(chan error, 1)
281 )
282
283 go func() {
284 close(startClose)
285 errs <- server.Serve(ctx, listener)
286 }()
287
288 <-startClose
289 if err := server.Close(); err != nil {
290 t.Fatal(err)
291 }
292
293 err := <-errs
294 if err != ErrServerClosed {
295 t.Fatal("expected an error from a closed server", err)
296 }
297
298 checkServerShutdown(t, server)
299 }
300
301 func TestOversizeCall(t *testing.T) {
302 var (
303 ctx = context.Background()
304 server = mustServer(t)(NewServer())
305 addr, listener = newTestListener(t)
306 errs = make(chan error, 1)
307 client, cleanup = newTestClient(t, addr)
308 )
309 defer cleanup()
310 defer listener.Close()
311 go func() {
312 errs <- server.Serve(ctx, listener)
313 }()
314
315 registerTestingService(server, &testingServer{})
316
317 tp := &internal.TestPayload{
318 Foo: strings.Repeat("a", 1+messageLengthMax),
319 }
320 if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
321 t.Fatalf("expected error from non-existent service call")
322 } else if status, ok := status.FromError(err); !ok {
323 t.Fatalf("expected status present in error: %v", err)
324 } else if status.Code() != codes.ResourceExhausted {
325 t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
326 }
327
328 if err := server.Shutdown(ctx); err != nil {
329 t.Fatal(err)
330 }
331 if err := <-errs; err != ErrServerClosed {
332 t.Fatal(err)
333 }
334 }
335
336 func TestClientEOF(t *testing.T) {
337 var (
338 ctx = context.Background()
339 server = mustServer(t)(NewServer())
340 addr, listener = newTestListener(t)
341 errs = make(chan error, 1)
342 client, cleanup = newTestClient(t, addr)
343 )
344 defer cleanup()
345 defer listener.Close()
346 go func() {
347 errs <- server.Serve(ctx, listener)
348 }()
349
350 registerTestingService(server, &testingServer{})
351
352 tp := &internal.TestPayload{}
353
354 if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil {
355 t.Fatalf("unexpected error: %v", err)
356 }
357
358
359 if err := server.Close(); err != nil {
360 t.Fatal(err)
361 }
362 if err := <-errs; err != ErrServerClosed {
363 t.Fatal(err)
364 }
365
366
367 if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
368 t.Fatalf("expected error when calling against shutdown server")
369 } else if !errors.Is(err, ErrClosed) {
370 var errno syscall.Errno
371 if errors.As(err, &errno) {
372 t.Logf("errno=%d", errno)
373 }
374
375 t.Fatalf("expected to have a cause of ErrClosed, got %v", err)
376 }
377 }
378
379 func TestServerRequestTimeout(t *testing.T) {
380 var (
381 ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(10*time.Minute))
382 server = mustServer(t)(NewServer())
383 addr, listener = newTestListener(t)
384 testImpl = &testingServer{}
385 client, cleanup = newTestClient(t, addr)
386 result internal.TestPayload
387 )
388 defer cancel()
389 defer cleanup()
390 defer listener.Close()
391
392 registerTestingService(server, testImpl)
393
394 go server.Serve(ctx, listener)
395 defer server.Shutdown(ctx)
396
397 if err := client.Call(ctx, serviceName, "Test", &internal.TestPayload{}, &result); err != nil {
398 t.Fatalf("unexpected error making call: %v", err)
399 }
400
401 dl, _ := ctx.Deadline()
402 if result.Deadline != dl.UnixNano() {
403 t.Fatalf("expected deadline %v, actual: %v", dl, time.Unix(0, result.Deadline))
404 }
405 }
406
407 func TestServerConnectionsLeak(t *testing.T) {
408 var (
409 ctx = context.Background()
410 server = mustServer(t)(NewServer())
411 addr, listener = newTestListener(t)
412 client, cleanup = newTestClient(t, addr)
413 )
414 defer cleanup()
415 defer listener.Close()
416
417 connectionCountBefore := server.countConnection()
418
419 go server.Serve(ctx, listener)
420
421 registerTestingService(server, &testingServer{})
422
423 tp := &internal.TestPayload{}
424
425 if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil {
426 t.Fatalf("unexpected error during test call: %v", err)
427 }
428
429 connectionCount := server.countConnection()
430 if connectionCount != 1 {
431 t.Fatalf("unexpected connection count: %d, expected: %d", connectionCount, 1)
432 }
433
434
435 if err := client.Close(); err != nil {
436 t.Fatalf("unexpected error while closing client: %v", err)
437 }
438
439
440 maxAttempts := 20
441 for i := 1; i <= maxAttempts; i++ {
442 connectionCountAfter := server.countConnection()
443 if connectionCountAfter == connectionCountBefore {
444 break
445 }
446 if i == maxAttempts {
447 t.Fatalf("expected number of connections to be equal %d after client close, got %d connections",
448 connectionCountBefore, connectionCountAfter)
449 }
450 time.Sleep(100 * time.Millisecond)
451 }
452 }
453
454 func BenchmarkRoundTrip(b *testing.B) {
455 var (
456 ctx = context.Background()
457 server = mustServer(b)(NewServer())
458 testImpl = &testingServer{}
459 addr, listener = newTestListener(b)
460 client, cleanup = newTestClient(b, addr)
461 tclient = newTestingClient(client)
462 )
463
464 defer listener.Close()
465 defer cleanup()
466
467 registerTestingService(server, testImpl)
468
469 go server.Serve(ctx, listener)
470 defer server.Shutdown(ctx)
471
472 var tp internal.TestPayload
473 b.ResetTimer()
474
475 for i := 0; i < b.N; i++ {
476 if _, err := tclient.Test(ctx, &tp); err != nil {
477 b.Fatal(err)
478 }
479 }
480 }
481
482 func checkServerShutdown(t *testing.T, server *Server) {
483 t.Helper()
484 server.mu.Lock()
485 defer server.mu.Unlock()
486
487 if len(server.listeners) > 0 {
488 t.Errorf("expected listeners to be empty: %v", server.listeners)
489 }
490 for listener := range server.listeners {
491 t.Logf("listener addr=%s", listener.Addr())
492 }
493
494 if len(server.connections) > 0 {
495 t.Errorf("expected connections to be empty: %v", server.connections)
496 }
497 for conn := range server.connections {
498 state, ok := conn.getState()
499 if !ok {
500 t.Errorf("failed to get state from %v", conn)
501 }
502 t.Logf("conn state=%s", state)
503 }
504 }
505
506 type callResult struct {
507 name string
508 err error
509 input *internal.TestPayload
510 expected *internal.TestPayload
511 received *internal.TestPayload
512 }
513
514 func roundTrip(ctx context.Context, client *testingClient, name string) callResult {
515 var (
516 tp = &internal.TestPayload{
517 Foo: name,
518 }
519 )
520
521 ctx = WithMetadata(ctx, MD{"foo": []string{name}})
522
523 resp, err := client.Test(ctx, tp)
524 if err != nil {
525 return callResult{
526 name: name,
527 err: err,
528 }
529 }
530
531 return callResult{
532 name: name,
533 input: tp,
534 expected: &internal.TestPayload{Foo: strings.Repeat(tp.Foo, 2), Metadata: name},
535 received: resp,
536 }
537 }
538
539 func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func()) {
540 conn, err := net.Dial("unix", addr)
541 if err != nil {
542 t.Fatal(err)
543 }
544 client := NewClient(conn, opts...)
545 return client, func() {
546 conn.Close()
547 client.Close()
548 }
549 }
550
551 func newTestListener(t testing.TB) (string, net.Listener) {
552 var prefix string
553
554
555 if runtime.GOOS == "linux" {
556 prefix = "\x00"
557 }
558 addr := prefix + t.Name()
559 listener, err := net.Listen("unix", addr)
560 if err != nil {
561 t.Fatal(err)
562 }
563
564 return addr, listener
565 }
566
567 func mustServer(t testing.TB) func(server *Server, err error) *Server {
568 return func(server *Server, err error) *Server {
569 t.Helper()
570 if err != nil {
571 t.Fatal(err)
572 }
573
574 return server
575 }
576 }
577
View as plain text