...

Source file src/google.golang.org/grpc/balancer/rls/control_channel_test.go

Documentation: google.golang.org/grpc/balancer/rls

     1  /*
     2   *
     3   * Copyright 2021 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package rls
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"errors"
    26  	"fmt"
    27  	"os"
    28  	"regexp"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/google/go-cmp/cmp"
    33  	"google.golang.org/grpc"
    34  	"google.golang.org/grpc/balancer"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/credentials"
    37  	"google.golang.org/grpc/internal"
    38  	rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
    39  	rlstest "google.golang.org/grpc/internal/testutils/rls"
    40  	"google.golang.org/grpc/metadata"
    41  	"google.golang.org/grpc/status"
    42  	"google.golang.org/grpc/testdata"
    43  	"google.golang.org/protobuf/proto"
    44  )
    45  
    46  // TestControlChannelThrottled tests the case where the adaptive throttler
    47  // indicates that the control channel needs to be throttled.
    48  func (s) TestControlChannelThrottled(t *testing.T) {
    49  	// Start an RLS server and set the throttler to always throttle requests.
    50  	rlsServer, rlsReqCh := rlstest.SetupFakeRLSServer(t, nil)
    51  	overrideAdaptiveThrottler(t, alwaysThrottlingThrottler())
    52  
    53  	// Create a control channel to the fake RLS server.
    54  	ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, nil)
    55  	if err != nil {
    56  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
    57  	}
    58  	defer ctrlCh.close()
    59  
    60  	// Perform the lookup and expect the attempt to be throttled.
    61  	ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, nil)
    62  
    63  	select {
    64  	case <-rlsReqCh:
    65  		t.Fatal("RouteLookup RPC invoked when control channel is throtlled")
    66  	case <-time.After(defaultTestShortTimeout):
    67  	}
    68  }
    69  
    70  // TestLookupFailure tests the case where the RLS server responds with an error.
    71  func (s) TestLookupFailure(t *testing.T) {
    72  	// Start an RLS server and set the throttler to never throttle requests.
    73  	rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
    74  	overrideAdaptiveThrottler(t, neverThrottlingThrottler())
    75  
    76  	// Setup the RLS server to respond with errors.
    77  	rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
    78  		return &rlstest.RouteLookupResponse{Err: errors.New("rls failure")}
    79  	})
    80  
    81  	// Create a control channel to the fake RLS server.
    82  	ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, nil)
    83  	if err != nil {
    84  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
    85  	}
    86  	defer ctrlCh.close()
    87  
    88  	// Perform the lookup and expect the callback to be invoked with an error.
    89  	errCh := make(chan error, 1)
    90  	ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
    91  		if err == nil {
    92  			errCh <- errors.New("rlsClient.lookup() succeeded, should have failed")
    93  			return
    94  		}
    95  		errCh <- nil
    96  	})
    97  
    98  	select {
    99  	case <-time.After(defaultTestTimeout):
   100  		t.Fatal("timeout when waiting for lookup callback to be invoked")
   101  	case err := <-errCh:
   102  		if err != nil {
   103  			t.Fatal(err)
   104  		}
   105  	}
   106  }
   107  
   108  // TestLookupDeadlineExceeded tests the case where the RLS server does not
   109  // respond within the configured rpc timeout.
   110  func (s) TestLookupDeadlineExceeded(t *testing.T) {
   111  	// A unary interceptor which returns a status error with DeadlineExceeded.
   112  	interceptor := func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
   113  		return nil, status.Error(codes.DeadlineExceeded, "deadline exceeded")
   114  	}
   115  
   116  	// Start an RLS server and set the throttler to never throttle.
   117  	rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, grpc.UnaryInterceptor(interceptor))
   118  	overrideAdaptiveThrottler(t, neverThrottlingThrottler())
   119  
   120  	// Create a control channel with a small deadline.
   121  	ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestShortTimeout, balancer.BuildOptions{}, nil)
   122  	if err != nil {
   123  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
   124  	}
   125  	defer ctrlCh.close()
   126  
   127  	// Perform the lookup and expect the callback to be invoked with an error.
   128  	errCh := make(chan error)
   129  	ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
   130  		if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
   131  			errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)
   132  			return
   133  		}
   134  		errCh <- nil
   135  	})
   136  
   137  	select {
   138  	case <-time.After(defaultTestTimeout):
   139  		t.Fatal("timeout when waiting for lookup callback to be invoked")
   140  	case err := <-errCh:
   141  		if err != nil {
   142  			t.Fatal(err)
   143  		}
   144  	}
   145  }
   146  
   147  // testCredsBundle wraps a test call creds and real transport creds.
   148  type testCredsBundle struct {
   149  	transportCreds credentials.TransportCredentials
   150  	callCreds      credentials.PerRPCCredentials
   151  }
   152  
   153  func (f *testCredsBundle) TransportCredentials() credentials.TransportCredentials {
   154  	return f.transportCreds
   155  }
   156  
   157  func (f *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
   158  	return f.callCreds
   159  }
   160  
   161  func (f *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
   162  	if mode != internal.CredsBundleModeFallback {
   163  		return nil, fmt.Errorf("unsupported mode: %v", mode)
   164  	}
   165  	return &testCredsBundle{
   166  		transportCreds: f.transportCreds,
   167  		callCreds:      f.callCreds,
   168  	}, nil
   169  }
   170  
   171  var (
   172  	// Call creds sent by the testPerRPCCredentials on the client, and verified
   173  	// by an interceptor on the server.
   174  	perRPCCredsData = map[string]string{
   175  		"test-key":     "test-value",
   176  		"test-key-bin": string([]byte{1, 2, 3}),
   177  	}
   178  )
   179  
   180  type testPerRPCCredentials struct {
   181  	callCreds map[string]string
   182  }
   183  
   184  func (f *testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
   185  	return f.callCreds, nil
   186  }
   187  
   188  func (f *testPerRPCCredentials) RequireTransportSecurity() bool {
   189  	return true
   190  }
   191  
   192  // Unary server interceptor which validates if the RPC contains call credentials
   193  // which match `perRPCCredsData
   194  func callCredsValidatingServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
   195  	md, ok := metadata.FromIncomingContext(ctx)
   196  	if !ok {
   197  		return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context")
   198  	}
   199  	for k, want := range perRPCCredsData {
   200  		got, ok := md[k]
   201  		if !ok {
   202  			return ctx, status.Errorf(codes.PermissionDenied, "didn't find call creds key %v in context", k)
   203  		}
   204  		if got[0] != want {
   205  			return ctx, status.Errorf(codes.PermissionDenied, "for key %v, got value %v, want %v", k, got, want)
   206  		}
   207  	}
   208  	return handler(ctx, req)
   209  }
   210  
   211  // makeTLSCreds is a test helper which creates a TLS based transport credentials
   212  // from files specified in the arguments.
   213  func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials {
   214  	cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
   215  	if err != nil {
   216  		t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err)
   217  	}
   218  	b, err := os.ReadFile(testdata.Path(rootsPath))
   219  	if err != nil {
   220  		t.Fatalf("os.ReadFile(%q) failed: %v", rootsPath, err)
   221  	}
   222  	roots := x509.NewCertPool()
   223  	if !roots.AppendCertsFromPEM(b) {
   224  		t.Fatal("failed to append certificates")
   225  	}
   226  	return credentials.NewTLS(&tls.Config{
   227  		Certificates: []tls.Certificate{cert},
   228  		RootCAs:      roots,
   229  	})
   230  }
   231  
   232  const (
   233  	wantHeaderData  = "headerData"
   234  	staleHeaderData = "staleHeaderData"
   235  )
   236  
   237  var (
   238  	keyMap = map[string]string{
   239  		"k1": "v1",
   240  		"k2": "v2",
   241  	}
   242  	wantTargets   = []string{"us_east_1.firestore.googleapis.com"}
   243  	lookupRequest = &rlspb.RouteLookupRequest{
   244  		TargetType:      "grpc",
   245  		KeyMap:          keyMap,
   246  		Reason:          rlspb.RouteLookupRequest_REASON_MISS,
   247  		StaleHeaderData: staleHeaderData,
   248  	}
   249  	lookupResponse = &rlstest.RouteLookupResponse{
   250  		Resp: &rlspb.RouteLookupResponse{
   251  			Targets:    wantTargets,
   252  			HeaderData: wantHeaderData,
   253  		},
   254  	}
   255  )
   256  
   257  func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions) {
   258  	// Start an RLS server and set the throttler to never throttle requests.
   259  	rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, sopts...)
   260  	overrideAdaptiveThrottler(t, neverThrottlingThrottler())
   261  
   262  	// Setup the RLS server to respond with a valid response.
   263  	rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse {
   264  		return lookupResponse
   265  	})
   266  
   267  	// Verify that the request received by the RLS matches the expected one.
   268  	rlsServer.SetRequestCallback(func(got *rlspb.RouteLookupRequest) {
   269  		if diff := cmp.Diff(lookupRequest, got, cmp.Comparer(proto.Equal)); diff != "" {
   270  			t.Errorf("RouteLookupRequest diff (-want, +got):\n%s", diff)
   271  		}
   272  	})
   273  
   274  	// Create a control channel to the fake server.
   275  	ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, bopts, nil)
   276  	if err != nil {
   277  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
   278  	}
   279  	defer ctrlCh.close()
   280  
   281  	// Perform the lookup and expect a successful callback invocation.
   282  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   283  	defer cancel()
   284  	errCh := make(chan error, 1)
   285  	ctrlCh.lookup(keyMap, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(targets []string, headerData string, err error) {
   286  		if err != nil {
   287  			errCh <- fmt.Errorf("rlsClient.lookup() failed with err: %v", err)
   288  			return
   289  		}
   290  		if !cmp.Equal(targets, wantTargets) || headerData != wantHeaderData {
   291  			errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, headerData, wantTargets, wantHeaderData)
   292  			return
   293  		}
   294  		errCh <- nil
   295  	})
   296  
   297  	select {
   298  	case <-ctx.Done():
   299  		t.Fatal("timeout when waiting for lookup callback to be invoked")
   300  	case err := <-errCh:
   301  		if err != nil {
   302  			t.Fatal(err)
   303  		}
   304  	}
   305  }
   306  
   307  // TestControlChannelCredsSuccess tests creation of the control channel with
   308  // different credentials, which are expected to succeed.
   309  func (s) TestControlChannelCredsSuccess(t *testing.T) {
   310  	serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
   311  	clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
   312  
   313  	tests := []struct {
   314  		name  string
   315  		sopts []grpc.ServerOption
   316  		bopts balancer.BuildOptions
   317  	}{
   318  		{
   319  			name:  "insecure",
   320  			sopts: nil,
   321  			bopts: balancer.BuildOptions{},
   322  		},
   323  		{
   324  			name:  "transport creds only",
   325  			sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
   326  			bopts: balancer.BuildOptions{
   327  				DialCreds: clientCreds,
   328  				Authority: "x.test.example.com",
   329  			},
   330  		},
   331  		{
   332  			name: "creds bundle",
   333  			sopts: []grpc.ServerOption{
   334  				grpc.Creds(serverCreds),
   335  				grpc.UnaryInterceptor(callCredsValidatingServerInterceptor),
   336  			},
   337  			bopts: balancer.BuildOptions{
   338  				CredsBundle: &testCredsBundle{
   339  					transportCreds: clientCreds,
   340  					callCreds:      &testPerRPCCredentials{callCreds: perRPCCredsData},
   341  				},
   342  				Authority: "x.test.example.com",
   343  			},
   344  		},
   345  	}
   346  	for _, test := range tests {
   347  		t.Run(test.name, func(t *testing.T) {
   348  			testControlChannelCredsSuccess(t, test.sopts, test.bopts)
   349  		})
   350  	}
   351  }
   352  
   353  func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErrRegex *regexp.Regexp) {
   354  	// StartFakeRouteLookupServer a fake server.
   355  	//
   356  	// Start an RLS server and set the throttler to never throttle requests. The
   357  	// creds failures happen before the RPC handler on the server is invoked.
   358  	// So, there is need to setup the request and responses on the fake server.
   359  	rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, sopts...)
   360  	overrideAdaptiveThrottler(t, neverThrottlingThrottler())
   361  
   362  	// Create the control channel to the fake server.
   363  	ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, bopts, nil)
   364  	if err != nil {
   365  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
   366  	}
   367  	defer ctrlCh.close()
   368  
   369  	// Perform the lookup and expect the callback to be invoked with an error.
   370  	errCh := make(chan error)
   371  	ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
   372  		if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !wantErrRegex.MatchString(st.String()) {
   373  			errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErrRegex.String())
   374  			return
   375  		}
   376  		errCh <- nil
   377  	})
   378  
   379  	select {
   380  	case <-time.After(defaultTestTimeout):
   381  		t.Fatal("timeout when waiting for lookup callback to be invoked")
   382  	case err := <-errCh:
   383  		if err != nil {
   384  			t.Fatal(err)
   385  		}
   386  	}
   387  }
   388  
   389  // TestControlChannelCredsFailure tests creation of the control channel with
   390  // different credentials, which are expected to fail.
   391  func (s) TestControlChannelCredsFailure(t *testing.T) {
   392  	serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
   393  	clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
   394  
   395  	tests := []struct {
   396  		name         string
   397  		sopts        []grpc.ServerOption
   398  		bopts        balancer.BuildOptions
   399  		wantCode     codes.Code
   400  		wantErrRegex *regexp.Regexp
   401  	}{
   402  		{
   403  			name:  "transport creds authority mismatch",
   404  			sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
   405  			bopts: balancer.BuildOptions{
   406  				DialCreds: clientCreds,
   407  				Authority: "authority-mismatch",
   408  			},
   409  			wantCode:     codes.Unavailable,
   410  			wantErrRegex: regexp.MustCompile(`transport: authentication handshake failed: .* \*\.test\.example\.com.*authority-mismatch`),
   411  		},
   412  		{
   413  			name:  "transport creds handshake failure",
   414  			sopts: nil, // server expects insecure connection
   415  			bopts: balancer.BuildOptions{
   416  				DialCreds: clientCreds,
   417  				Authority: "x.test.example.com",
   418  			},
   419  			wantCode:     codes.Unavailable,
   420  			wantErrRegex: regexp.MustCompile("transport: authentication handshake failed: .*"),
   421  		},
   422  		{
   423  			name: "call creds mismatch",
   424  			sopts: []grpc.ServerOption{
   425  				grpc.Creds(serverCreds),
   426  				grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), // server expects call creds
   427  			},
   428  			bopts: balancer.BuildOptions{
   429  				CredsBundle: &testCredsBundle{
   430  					transportCreds: clientCreds,
   431  					callCreds:      &testPerRPCCredentials{}, // sends no call creds
   432  				},
   433  				Authority: "x.test.example.com",
   434  			},
   435  			wantCode:     codes.PermissionDenied,
   436  			wantErrRegex: regexp.MustCompile("didn't find call creds"),
   437  		},
   438  	}
   439  	for _, test := range tests {
   440  		t.Run(test.name, func(t *testing.T) {
   441  			testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErrRegex)
   442  		})
   443  	}
   444  }
   445  
   446  type unsupportedCredsBundle struct {
   447  	credentials.Bundle
   448  }
   449  
   450  func (*unsupportedCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
   451  	return nil, fmt.Errorf("unsupported mode: %v", mode)
   452  }
   453  
   454  // TestNewControlChannelUnsupportedCredsBundle tests the case where the control
   455  // channel is configured with a bundle which does not support the mode we use.
   456  func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) {
   457  	rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil)
   458  
   459  	// Create the control channel to the fake server.
   460  	ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{CredsBundle: &unsupportedCredsBundle{}}, nil)
   461  	if err == nil {
   462  		ctrlCh.close()
   463  		t.Fatal("newControlChannel succeeded when expected to fail")
   464  	}
   465  }
   466  

View as plain text