1
18
19 package rls
20
21 import (
22 "context"
23 "crypto/tls"
24 "crypto/x509"
25 "errors"
26 "fmt"
27 "os"
28 "regexp"
29 "testing"
30 "time"
31
32 "github.com/google/go-cmp/cmp"
33 "google.golang.org/grpc"
34 "google.golang.org/grpc/balancer"
35 "google.golang.org/grpc/codes"
36 "google.golang.org/grpc/credentials"
37 "google.golang.org/grpc/internal"
38 rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
39 rlstest "google.golang.org/grpc/internal/testutils/rls"
40 "google.golang.org/grpc/metadata"
41 "google.golang.org/grpc/status"
42 "google.golang.org/grpc/testdata"
43 "google.golang.org/protobuf/proto"
44 )
45
46
47
48 func (s) TestControlChannelThrottled(t *testing.T) {
49
50 rlsServer, rlsReqCh := rlstest.SetupFakeRLSServer(t, nil)
51 overrideAdaptiveThrottler(t, alwaysThrottlingThrottler())
52
53
54 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, nil)
55 if err != nil {
56 t.Fatalf("Failed to create control channel to RLS server: %v", err)
57 }
58 defer ctrlCh.close()
59
60
61 ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, nil)
62
63 select {
64 case <-rlsReqCh:
65 t.Fatal("RouteLookup RPC invoked when control channel is throtlled")
66 case <-time.After(defaultTestShortTimeout):
67 }
68 }
69
70
71 func (s) TestLookupFailure(t *testing.T) {
72
73 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
74 overrideAdaptiveThrottler(t, neverThrottlingThrottler())
75
76
77 rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
78 return &rlstest.RouteLookupResponse{Err: errors.New("rls failure")}
79 })
80
81
82 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, nil)
83 if err != nil {
84 t.Fatalf("Failed to create control channel to RLS server: %v", err)
85 }
86 defer ctrlCh.close()
87
88
89 errCh := make(chan error, 1)
90 ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
91 if err == nil {
92 errCh <- errors.New("rlsClient.lookup() succeeded, should have failed")
93 return
94 }
95 errCh <- nil
96 })
97
98 select {
99 case <-time.After(defaultTestTimeout):
100 t.Fatal("timeout when waiting for lookup callback to be invoked")
101 case err := <-errCh:
102 if err != nil {
103 t.Fatal(err)
104 }
105 }
106 }
107
108
109
110 func (s) TestLookupDeadlineExceeded(t *testing.T) {
111
112 interceptor := func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
113 return nil, status.Error(codes.DeadlineExceeded, "deadline exceeded")
114 }
115
116
117 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, grpc.UnaryInterceptor(interceptor))
118 overrideAdaptiveThrottler(t, neverThrottlingThrottler())
119
120
121 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestShortTimeout, balancer.BuildOptions{}, nil)
122 if err != nil {
123 t.Fatalf("Failed to create control channel to RLS server: %v", err)
124 }
125 defer ctrlCh.close()
126
127
128 errCh := make(chan error)
129 ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
130 if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
131 errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)
132 return
133 }
134 errCh <- nil
135 })
136
137 select {
138 case <-time.After(defaultTestTimeout):
139 t.Fatal("timeout when waiting for lookup callback to be invoked")
140 case err := <-errCh:
141 if err != nil {
142 t.Fatal(err)
143 }
144 }
145 }
146
147
148 type testCredsBundle struct {
149 transportCreds credentials.TransportCredentials
150 callCreds credentials.PerRPCCredentials
151 }
152
153 func (f *testCredsBundle) TransportCredentials() credentials.TransportCredentials {
154 return f.transportCreds
155 }
156
157 func (f *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
158 return f.callCreds
159 }
160
161 func (f *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
162 if mode != internal.CredsBundleModeFallback {
163 return nil, fmt.Errorf("unsupported mode: %v", mode)
164 }
165 return &testCredsBundle{
166 transportCreds: f.transportCreds,
167 callCreds: f.callCreds,
168 }, nil
169 }
170
171 var (
172
173
174 perRPCCredsData = map[string]string{
175 "test-key": "test-value",
176 "test-key-bin": string([]byte{1, 2, 3}),
177 }
178 )
179
180 type testPerRPCCredentials struct {
181 callCreds map[string]string
182 }
183
184 func (f *testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
185 return f.callCreds, nil
186 }
187
188 func (f *testPerRPCCredentials) RequireTransportSecurity() bool {
189 return true
190 }
191
192
193
194 func callCredsValidatingServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
195 md, ok := metadata.FromIncomingContext(ctx)
196 if !ok {
197 return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context")
198 }
199 for k, want := range perRPCCredsData {
200 got, ok := md[k]
201 if !ok {
202 return ctx, status.Errorf(codes.PermissionDenied, "didn't find call creds key %v in context", k)
203 }
204 if got[0] != want {
205 return ctx, status.Errorf(codes.PermissionDenied, "for key %v, got value %v, want %v", k, got, want)
206 }
207 }
208 return handler(ctx, req)
209 }
210
211
212
213 func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials {
214 cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
215 if err != nil {
216 t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err)
217 }
218 b, err := os.ReadFile(testdata.Path(rootsPath))
219 if err != nil {
220 t.Fatalf("os.ReadFile(%q) failed: %v", rootsPath, err)
221 }
222 roots := x509.NewCertPool()
223 if !roots.AppendCertsFromPEM(b) {
224 t.Fatal("failed to append certificates")
225 }
226 return credentials.NewTLS(&tls.Config{
227 Certificates: []tls.Certificate{cert},
228 RootCAs: roots,
229 })
230 }
231
232 const (
233 wantHeaderData = "headerData"
234 staleHeaderData = "staleHeaderData"
235 )
236
237 var (
238 keyMap = map[string]string{
239 "k1": "v1",
240 "k2": "v2",
241 }
242 wantTargets = []string{"us_east_1.firestore.googleapis.com"}
243 lookupRequest = &rlspb.RouteLookupRequest{
244 TargetType: "grpc",
245 KeyMap: keyMap,
246 Reason: rlspb.RouteLookupRequest_REASON_MISS,
247 StaleHeaderData: staleHeaderData,
248 }
249 lookupResponse = &rlstest.RouteLookupResponse{
250 Resp: &rlspb.RouteLookupResponse{
251 Targets: wantTargets,
252 HeaderData: wantHeaderData,
253 },
254 }
255 )
256
257 func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions) {
258
259 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, sopts...)
260 overrideAdaptiveThrottler(t, neverThrottlingThrottler())
261
262
263 rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
264 return lookupResponse
265 })
266
267
268 rlsServer.SetRequestCallback(func(got *rlspb.RouteLookupRequest) {
269 if diff := cmp.Diff(lookupRequest, got, cmp.Comparer(proto.Equal)); diff != "" {
270 t.Errorf("RouteLookupRequest diff (-want, +got):\n%s", diff)
271 }
272 })
273
274
275 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, bopts, nil)
276 if err != nil {
277 t.Fatalf("Failed to create control channel to RLS server: %v", err)
278 }
279 defer ctrlCh.close()
280
281
282 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
283 defer cancel()
284 errCh := make(chan error, 1)
285 ctrlCh.lookup(keyMap, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(targets []string, headerData string, err error) {
286 if err != nil {
287 errCh <- fmt.Errorf("rlsClient.lookup() failed with err: %v", err)
288 return
289 }
290 if !cmp.Equal(targets, wantTargets) || headerData != wantHeaderData {
291 errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, headerData, wantTargets, wantHeaderData)
292 return
293 }
294 errCh <- nil
295 })
296
297 select {
298 case <-ctx.Done():
299 t.Fatal("timeout when waiting for lookup callback to be invoked")
300 case err := <-errCh:
301 if err != nil {
302 t.Fatal(err)
303 }
304 }
305 }
306
307
308
309 func (s) TestControlChannelCredsSuccess(t *testing.T) {
310 serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
311 clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
312
313 tests := []struct {
314 name string
315 sopts []grpc.ServerOption
316 bopts balancer.BuildOptions
317 }{
318 {
319 name: "insecure",
320 sopts: nil,
321 bopts: balancer.BuildOptions{},
322 },
323 {
324 name: "transport creds only",
325 sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
326 bopts: balancer.BuildOptions{
327 DialCreds: clientCreds,
328 Authority: "x.test.example.com",
329 },
330 },
331 {
332 name: "creds bundle",
333 sopts: []grpc.ServerOption{
334 grpc.Creds(serverCreds),
335 grpc.UnaryInterceptor(callCredsValidatingServerInterceptor),
336 },
337 bopts: balancer.BuildOptions{
338 CredsBundle: &testCredsBundle{
339 transportCreds: clientCreds,
340 callCreds: &testPerRPCCredentials{callCreds: perRPCCredsData},
341 },
342 Authority: "x.test.example.com",
343 },
344 },
345 }
346 for _, test := range tests {
347 t.Run(test.name, func(t *testing.T) {
348 testControlChannelCredsSuccess(t, test.sopts, test.bopts)
349 })
350 }
351 }
352
353 func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErrRegex *regexp.Regexp) {
354
355
356
357
358
359 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, sopts...)
360 overrideAdaptiveThrottler(t, neverThrottlingThrottler())
361
362
363 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, bopts, nil)
364 if err != nil {
365 t.Fatalf("Failed to create control channel to RLS server: %v", err)
366 }
367 defer ctrlCh.close()
368
369
370 errCh := make(chan error)
371 ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
372 if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !wantErrRegex.MatchString(st.String()) {
373 errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErrRegex.String())
374 return
375 }
376 errCh <- nil
377 })
378
379 select {
380 case <-time.After(defaultTestTimeout):
381 t.Fatal("timeout when waiting for lookup callback to be invoked")
382 case err := <-errCh:
383 if err != nil {
384 t.Fatal(err)
385 }
386 }
387 }
388
389
390
391 func (s) TestControlChannelCredsFailure(t *testing.T) {
392 serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
393 clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
394
395 tests := []struct {
396 name string
397 sopts []grpc.ServerOption
398 bopts balancer.BuildOptions
399 wantCode codes.Code
400 wantErrRegex *regexp.Regexp
401 }{
402 {
403 name: "transport creds authority mismatch",
404 sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
405 bopts: balancer.BuildOptions{
406 DialCreds: clientCreds,
407 Authority: "authority-mismatch",
408 },
409 wantCode: codes.Unavailable,
410 wantErrRegex: regexp.MustCompile(`transport: authentication handshake failed: .* \*\.test\.example\.com.*authority-mismatch`),
411 },
412 {
413 name: "transport creds handshake failure",
414 sopts: nil,
415 bopts: balancer.BuildOptions{
416 DialCreds: clientCreds,
417 Authority: "x.test.example.com",
418 },
419 wantCode: codes.Unavailable,
420 wantErrRegex: regexp.MustCompile("transport: authentication handshake failed: .*"),
421 },
422 {
423 name: "call creds mismatch",
424 sopts: []grpc.ServerOption{
425 grpc.Creds(serverCreds),
426 grpc.UnaryInterceptor(callCredsValidatingServerInterceptor),
427 },
428 bopts: balancer.BuildOptions{
429 CredsBundle: &testCredsBundle{
430 transportCreds: clientCreds,
431 callCreds: &testPerRPCCredentials{},
432 },
433 Authority: "x.test.example.com",
434 },
435 wantCode: codes.PermissionDenied,
436 wantErrRegex: regexp.MustCompile("didn't find call creds"),
437 },
438 }
439 for _, test := range tests {
440 t.Run(test.name, func(t *testing.T) {
441 testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErrRegex)
442 })
443 }
444 }
445
446 type unsupportedCredsBundle struct {
447 credentials.Bundle
448 }
449
450 func (*unsupportedCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
451 return nil, fmt.Errorf("unsupported mode: %v", mode)
452 }
453
454
455
456 func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) {
457 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
458
459
460 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{CredsBundle: &unsupportedCredsBundle{}}, nil)
461 if err == nil {
462 ctrlCh.close()
463 t.Fatal("newControlChannel succeeded when expected to fail")
464 }
465 }
466
View as plain text