1
2
3
4
21
22 package test
23
24 import (
25 "context"
26 "fmt"
27 "net"
28 "os"
29 "strings"
30 "sync"
31 "testing"
32
33 "google.golang.org/grpc"
34 "google.golang.org/grpc/codes"
35 "google.golang.org/grpc/credentials/insecure"
36 "google.golang.org/grpc/internal/stubserver"
37 "google.golang.org/grpc/metadata"
38 "google.golang.org/grpc/resolver"
39 "google.golang.org/grpc/resolver/manual"
40 "google.golang.org/grpc/status"
41
42 testgrpc "google.golang.org/grpc/interop/grpc_testing"
43 testpb "google.golang.org/grpc/interop/grpc_testing"
44 )
45
46 func authorityChecker(ctx context.Context, expectedAuthority string) (*testpb.Empty, error) {
47 md, ok := metadata.FromIncomingContext(ctx)
48 if !ok {
49 return nil, status.Error(codes.InvalidArgument, "failed to parse metadata")
50 }
51 auths, ok := md[":authority"]
52 if !ok {
53 return nil, status.Error(codes.InvalidArgument, "no authority header")
54 }
55 if len(auths) != 1 {
56 return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths))
57 }
58 if auths[0] != expectedAuthority {
59 return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority))
60 }
61 return &testpb.Empty{}, nil
62 }
63
64 func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer func(context.Context, string) (net.Conn, error)) {
65 if !strings.HasPrefix(target, "unix-abstract:") {
66 if err := os.RemoveAll(address); err != nil {
67 t.Fatalf("Error removing socket file %v: %v\n", address, err)
68 }
69 }
70 ss := &stubserver.StubServer{
71 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
72 return authorityChecker(ctx, expectedAuthority)
73 },
74 Network: "unix",
75 Address: address,
76 Target: target,
77 }
78 opts := []grpc.DialOption{}
79 if dialer != nil {
80 opts = append(opts, grpc.WithContextDialer(dialer))
81 }
82 if err := ss.Start(nil, opts...); err != nil {
83 t.Fatalf("Error starting endpoint server: %v", err)
84 }
85 defer ss.Stop()
86 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
87 defer cancel()
88 _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
89 if err != nil {
90 t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
91 }
92 }
93
94 type authorityTest struct {
95 name string
96 address string
97 target string
98 authority string
99 dialTargetWant string
100 }
101
102 var authorityTests = []authorityTest{
103 {
104 name: "UnixRelative",
105 address: "sock.sock",
106 target: "unix:sock.sock",
107 authority: "localhost",
108 dialTargetWant: "unix:sock.sock",
109 },
110 {
111 name: "UnixAbsolute",
112 address: "/tmp/sock.sock",
113 target: "unix:/tmp/sock.sock",
114 authority: "localhost",
115 dialTargetWant: "unix:///tmp/sock.sock",
116 },
117 {
118 name: "UnixAbsoluteAlternate",
119 address: "/tmp/sock.sock",
120 target: "unix:///tmp/sock.sock",
121 authority: "localhost",
122 dialTargetWant: "unix:///tmp/sock.sock",
123 },
124 {
125 name: "UnixPassthrough",
126 address: "/tmp/sock.sock",
127 target: "passthrough:///unix:///tmp/sock.sock",
128 authority: "unix:%2F%2F%2Ftmp%2Fsock.sock",
129 dialTargetWant: "unix:///tmp/sock.sock",
130 },
131 {
132 name: "UnixAbstract",
133 address: "@abc efg",
134 target: "unix-abstract:abc efg",
135 authority: "localhost",
136 dialTargetWant: "unix:@abc efg",
137 },
138 }
139
140
141
142 func (s) TestUnix(t *testing.T) {
143 for _, test := range authorityTests {
144 t.Run(test.name, func(t *testing.T) {
145 runUnixTest(t, test.address, test.target, test.authority, nil)
146 })
147 }
148 }
149
150
151
152
153 func (s) TestUnixCustomDialer(t *testing.T) {
154 for _, test := range authorityTests {
155 t.Run(test.name+"WithDialer", func(t *testing.T) {
156 dialer := func(ctx context.Context, address string) (net.Conn, error) {
157 if address != test.dialTargetWant {
158 return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.dialTargetWant, address)
159 }
160 address = address[len("unix:"):]
161 return (&net.Dialer{}).DialContext(ctx, "unix", address)
162 }
163 runUnixTest(t, test.address, test.target, test.authority, dialer)
164 })
165 }
166 }
167
168
169
170 func (s) TestColonPortAuthority(t *testing.T) {
171 expectedAuthority := ""
172 var authorityMu sync.Mutex
173 ss := &stubserver.StubServer{
174 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
175 authorityMu.Lock()
176 defer authorityMu.Unlock()
177 return authorityChecker(ctx, expectedAuthority)
178 },
179 Network: "tcp",
180 }
181 if err := ss.Start(nil); err != nil {
182 t.Fatalf("Error starting endpoint server: %v", err)
183 }
184 defer ss.Stop()
185 _, port, err := net.SplitHostPort(ss.Address)
186 if err != nil {
187 t.Fatalf("Failed splitting host from post: %v", err)
188 }
189 authorityMu.Lock()
190 expectedAuthority = "localhost:" + port
191 authorityMu.Unlock()
192
193
194
195
196
197 cc, err := grpc.Dial(":"+port, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
198 return (&net.Dialer{}).DialContext(ctx, "tcp", "localhost"+addr)
199 }))
200 if err != nil {
201 t.Fatalf("grpc.Dial(%q) = %v", ss.Target, err)
202 }
203 defer cc.Close()
204 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
205 defer cancel()
206 _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{})
207 if err != nil {
208 t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
209 }
210 }
211
212
213
214
215
216 func (s) TestAuthorityReplacedWithResolverAddress(t *testing.T) {
217 const expectedAuthority = "test.server.name"
218
219 ss := &stubserver.StubServer{
220 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
221 return authorityChecker(ctx, expectedAuthority)
222 },
223 }
224 if err := ss.Start(nil); err != nil {
225 t.Fatalf("Error starting endpoint server: %v", err)
226 }
227 defer ss.Stop()
228
229 r := manual.NewBuilderWithScheme("whatever")
230 r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address, ServerName: expectedAuthority}}})
231 cc, err := grpc.NewClient(r.Scheme()+":///whatever", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
232 if err != nil {
233 t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
234 }
235 defer cc.Close()
236
237 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
238 defer cancel()
239 if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}); err != nil {
240 t.Fatalf("EmptyCall() rpc failed: %v", err)
241 }
242 }
243
View as plain text