...
1
16
17 package ttrpc
18
19 import (
20 "context"
21 "strings"
22 "testing"
23 "time"
24
25 "github.com/containerd/ttrpc/internal"
26 "github.com/prometheus/procfs"
27 )
28
29 func TestUnixSocketHandshake(t *testing.T) {
30 var (
31 ctx = context.Background()
32 server = mustServer(t)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser())))
33 addr, listener = newTestListener(t)
34 errs = make(chan error, 1)
35 client, cleanup = newTestClient(t, addr)
36 )
37 defer cleanup()
38 defer listener.Close()
39 go func() {
40 errs <- server.Serve(ctx, listener)
41 }()
42
43 registerTestingService(server, &testingServer{})
44
45 var tp internal.TestPayload
46
47 if err := client.Call(ctx, serviceName, "Test", &tp, &tp); err != nil {
48 t.Fatalf("unexpected error making call: %v", err)
49 }
50 }
51
52 func BenchmarkRoundTripUnixSocketCreds(b *testing.B) {
53
54
55
56
57 var (
58 ctx = context.Background()
59 server = mustServer(b)(NewServer(WithServerHandshaker(UnixSocketRequireSameUser())))
60 testImpl = &testingServer{}
61 addr, listener = newTestListener(b)
62 client, cleanup = newTestClient(b, addr)
63 tclient = newTestingClient(client)
64 )
65
66 defer listener.Close()
67 defer cleanup()
68
69 registerTestingService(server, testImpl)
70
71 go server.Serve(ctx, listener)
72 defer server.Shutdown(ctx)
73
74 var tp internal.TestPayload
75 b.ResetTimer()
76
77 for i := 0; i < b.N; i++ {
78 if _, err := tclient.Test(ctx, &tp); err != nil {
79 b.Fatal(err)
80 }
81 }
82 }
83
84 func TestServerEOF(t *testing.T) {
85 var (
86 ctx = context.Background()
87 server = mustServer(t)(NewServer())
88 addr, listener = newTestListener(t)
89 client, cleanup = newTestClient(t, addr)
90 )
91 defer cleanup()
92 defer listener.Close()
93
94 socketCountBefore := socketCount(t)
95
96 go server.Serve(ctx, listener)
97
98 registerTestingService(server, &testingServer{})
99
100 tp := &internal.TestPayload{}
101
102 if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil {
103 t.Fatalf("unexpected error during test call: %v", err)
104 }
105
106
107 if err := client.Close(); err != nil {
108 t.Fatalf("unexpected error while closing client: %v", err)
109 }
110
111
112 maxAttempts := 20
113 for i := 1; i <= maxAttempts; i++ {
114 socketCountAfter := socketCount(t)
115 if socketCountAfter < socketCountBefore {
116 break
117 }
118 if i == maxAttempts {
119 t.Fatalf("expected number of open sockets to be less than %d after client close, got %d open sockets",
120 socketCountBefore, socketCountAfter)
121 }
122 time.Sleep(100 * time.Millisecond)
123 }
124 }
125
126 func socketCount(t *testing.T) int {
127 proc, err := procfs.Self()
128 if err != nil {
129 t.Fatalf("unexpected error while reading procfs: %v", err)
130 }
131 fds, err := proc.FileDescriptorTargets()
132 if err != nil {
133 t.Fatalf("unexpected error while listing open file descriptors: %v", err)
134 }
135
136 sockets := 0
137 for _, fd := range fds {
138 if strings.Contains(fd, "socket") {
139 sockets++
140 }
141 }
142 return sockets
143 }
144
View as plain text