...

Source file src/github.com/letsencrypt/boulder/grpc/interceptors_test.go

Documentation: github.com/letsencrypt/boulder/grpc

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"errors"
     8  	"fmt"
     9  	"log"
    10  	"net"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/jmhodges/clock"
    18  	"github.com/prometheus/client_golang/prometheus"
    19  	"google.golang.org/grpc"
    20  	"google.golang.org/grpc/balancer/roundrobin"
    21  	"google.golang.org/grpc/credentials"
    22  	"google.golang.org/grpc/credentials/insecure"
    23  	"google.golang.org/grpc/metadata"
    24  	"google.golang.org/grpc/peer"
    25  	"google.golang.org/grpc/status"
    26  	"google.golang.org/protobuf/types/known/durationpb"
    27  
    28  	"github.com/letsencrypt/boulder/grpc/test_proto"
    29  	"github.com/letsencrypt/boulder/metrics"
    30  	"github.com/letsencrypt/boulder/test"
    31  )
    32  
    33  var fc = clock.NewFake()
    34  
    35  func testHandler(_ context.Context, i interface{}) (interface{}, error) {
    36  	if i != nil {
    37  		return nil, errors.New("")
    38  	}
    39  	fc.Sleep(time.Second)
    40  	return nil, nil
    41  }
    42  
    43  func testInvoker(_ context.Context, method string, _, _ interface{}, _ *grpc.ClientConn, opts ...grpc.CallOption) error {
    44  	switch method {
    45  	case "-service-brokeTest":
    46  		return errors.New("")
    47  	case "-service-requesterCanceledTest":
    48  		return status.Error(1, context.Canceled.Error())
    49  	}
    50  	fc.Sleep(time.Second)
    51  	return nil
    52  }
    53  
    54  func TestServerInterceptor(t *testing.T) {
    55  	serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
    56  	test.AssertNotError(t, err, "creating server metrics")
    57  	si := newServerMetadataInterceptor(serverMetrics, clock.NewFake())
    58  
    59  	md := metadata.New(map[string]string{clientRequestTimeKey: "0"})
    60  	ctxWithMetadata := metadata.NewIncomingContext(context.Background(), md)
    61  
    62  	_, err = si.Unary(context.Background(), nil, nil, testHandler)
    63  	test.AssertError(t, err, "si.intercept didn't fail with a context missing metadata")
    64  
    65  	_, err = si.Unary(ctxWithMetadata, nil, nil, testHandler)
    66  	test.AssertError(t, err, "si.intercept didn't fail with a nil grpc.UnaryServerInfo")
    67  
    68  	_, err = si.Unary(ctxWithMetadata, nil, &grpc.UnaryServerInfo{FullMethod: "-service-test"}, testHandler)
    69  	test.AssertNotError(t, err, "si.intercept failed with a non-nil grpc.UnaryServerInfo")
    70  
    71  	_, err = si.Unary(ctxWithMetadata, 0, &grpc.UnaryServerInfo{FullMethod: "brokeTest"}, testHandler)
    72  	test.AssertError(t, err, "si.intercept didn't fail when handler returned a error")
    73  }
    74  
    75  func TestClientInterceptor(t *testing.T) {
    76  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
    77  	test.AssertNotError(t, err, "creating client metrics")
    78  	ci := clientMetadataInterceptor{
    79  		timeout: time.Second,
    80  		metrics: clientMetrics,
    81  		clk:     clock.NewFake(),
    82  	}
    83  
    84  	err = ci.Unary(context.Background(), "-service-test", nil, nil, nil, testInvoker)
    85  	test.AssertNotError(t, err, "ci.intercept failed with a non-nil grpc.UnaryServerInfo")
    86  
    87  	err = ci.Unary(context.Background(), "-service-brokeTest", nil, nil, nil, testInvoker)
    88  	test.AssertError(t, err, "ci.intercept didn't fail when handler returned a error")
    89  }
    90  
    91  // TestWaitForReadyTrue configures a gRPC client with waitForReady: true and
    92  // sends a request to a backend that is unavailable. It ensures that the
    93  // request doesn't error out until the timeout is reached, i.e. that
    94  // FailFast is set to false.
    95  // https://github.com/grpc/grpc/blob/main/doc/wait-for-ready.md
    96  func TestWaitForReadyTrue(t *testing.T) {
    97  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
    98  	test.AssertNotError(t, err, "creating client metrics")
    99  	ci := &clientMetadataInterceptor{
   100  		timeout:      100 * time.Millisecond,
   101  		metrics:      clientMetrics,
   102  		clk:          clock.NewFake(),
   103  		waitForReady: true,
   104  	}
   105  	conn, err := grpc.Dial("localhost:19876", // random, probably unused port
   106  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)),
   107  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   108  		grpc.WithUnaryInterceptor(ci.Unary))
   109  	if err != nil {
   110  		t.Fatalf("did not connect: %v", err)
   111  	}
   112  	defer conn.Close()
   113  	c := test_proto.NewChillerClient(conn)
   114  
   115  	start := time.Now()
   116  	_, err = c.Chill(context.Background(), &test_proto.Time{Duration: durationpb.New(time.Second)})
   117  	if err == nil {
   118  		t.Errorf("Successful Chill when we expected failure.")
   119  	}
   120  	if time.Since(start) < 90*time.Millisecond {
   121  		t.Errorf("Chill failed fast, when WaitForReady should be enabled.")
   122  	}
   123  }
   124  
   125  // TestWaitForReadyFalse configures a gRPC client with waitForReady: false and
   126  // sends a request to a backend that is unavailable, and ensures that the request
   127  // errors out promptly.
   128  func TestWaitForReadyFalse(t *testing.T) {
   129  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
   130  	test.AssertNotError(t, err, "creating client metrics")
   131  	ci := &clientMetadataInterceptor{
   132  		timeout:      time.Second,
   133  		metrics:      clientMetrics,
   134  		clk:          clock.NewFake(),
   135  		waitForReady: false,
   136  	}
   137  	conn, err := grpc.Dial("localhost:19876", // random, probably unused port
   138  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)),
   139  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   140  		grpc.WithUnaryInterceptor(ci.Unary))
   141  	if err != nil {
   142  		t.Fatalf("did not connect: %v", err)
   143  	}
   144  	defer conn.Close()
   145  	c := test_proto.NewChillerClient(conn)
   146  
   147  	start := time.Now()
   148  	_, err = c.Chill(context.Background(), &test_proto.Time{Duration: durationpb.New(time.Second)})
   149  	if err == nil {
   150  		t.Errorf("Successful Chill when we expected failure.")
   151  	}
   152  	if time.Since(start) > 200*time.Millisecond {
   153  		t.Errorf("Chill failed slow, when WaitForReady should be disabled.")
   154  	}
   155  }
   156  
   157  // testServer is used to implement TestTimeouts, and will attempt to sleep for
   158  // the given amount of time (unless it hits a timeout or cancel).
   159  type testServer struct {
   160  	test_proto.UnimplementedChillerServer
   161  }
   162  
   163  // Chill implements ChillerServer.Chill
   164  func (s *testServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) {
   165  	start := time.Now()
   166  	// Sleep for either the requested amount of time, or the context times out or
   167  	// is canceled.
   168  	select {
   169  	case <-time.After(in.Duration.AsDuration() * time.Nanosecond):
   170  		spent := time.Since(start) / time.Nanosecond
   171  		return &test_proto.Time{Duration: durationpb.New(spent)}, nil
   172  	case <-ctx.Done():
   173  		return nil, errors.New("unique error indicating that the server's shortened context timed itself out")
   174  	}
   175  }
   176  
   177  func TestTimeouts(t *testing.T) {
   178  	// start server
   179  	lis, err := net.Listen("tcp", ":0")
   180  	if err != nil {
   181  		log.Fatalf("failed to listen: %v", err)
   182  	}
   183  	port := lis.Addr().(*net.TCPAddr).Port
   184  
   185  	serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
   186  	test.AssertNotError(t, err, "creating server metrics")
   187  	si := newServerMetadataInterceptor(serverMetrics, clock.NewFake())
   188  	s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
   189  	test_proto.RegisterChillerServer(s, &testServer{})
   190  	go func() {
   191  		start := time.Now()
   192  		err := s.Serve(lis)
   193  		if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") {
   194  			t.Logf("s.Serve: %v after %s", err, time.Since(start))
   195  		}
   196  	}()
   197  	defer s.Stop()
   198  
   199  	// make client
   200  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
   201  	test.AssertNotError(t, err, "creating client metrics")
   202  	ci := &clientMetadataInterceptor{
   203  		timeout: 30 * time.Second,
   204  		metrics: clientMetrics,
   205  		clk:     clock.NewFake(),
   206  	}
   207  	conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
   208  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   209  		grpc.WithUnaryInterceptor(ci.Unary))
   210  	if err != nil {
   211  		t.Fatalf("did not connect: %v", err)
   212  	}
   213  	c := test_proto.NewChillerClient(conn)
   214  
   215  	testCases := []struct {
   216  		timeout             time.Duration
   217  		expectedErrorPrefix string
   218  	}{
   219  		{250 * time.Millisecond, "rpc error: code = Unknown desc = unique error indicating that the server's shortened context timed itself out"},
   220  		{100 * time.Millisecond, "Chiller.Chill timed out after 0 ms"},
   221  		{10 * time.Millisecond, "Chiller.Chill timed out after 0 ms"},
   222  	}
   223  	for _, tc := range testCases {
   224  		t.Run(tc.timeout.String(), func(t *testing.T) {
   225  			ctx, cancel := context.WithTimeout(context.Background(), tc.timeout)
   226  			defer cancel()
   227  			_, err := c.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second)})
   228  			if err == nil {
   229  				t.Fatal("Got no error, expected a timeout")
   230  			}
   231  			if !strings.HasPrefix(err.Error(), tc.expectedErrorPrefix) {
   232  				t.Errorf("Wrong error. Got %s, expected %s", err.Error(), tc.expectedErrorPrefix)
   233  			}
   234  		})
   235  	}
   236  }
   237  
   238  func TestRequestTimeTagging(t *testing.T) {
   239  	clk := clock.NewFake()
   240  	// Listen for TCP requests on a random system assigned port number
   241  	lis, err := net.Listen("tcp", ":0")
   242  	if err != nil {
   243  		log.Fatalf("failed to listen: %v", err)
   244  	}
   245  	// Retrieve the concrete port numberthe system assigned our listener
   246  	port := lis.Addr().(*net.TCPAddr).Port
   247  
   248  	// Create a new ChillerServer
   249  	serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
   250  	test.AssertNotError(t, err, "creating server metrics")
   251  	si := newServerMetadataInterceptor(serverMetrics, clk)
   252  	s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
   253  	test_proto.RegisterChillerServer(s, &testServer{})
   254  	// Chill until ill
   255  	go func() {
   256  		start := time.Now()
   257  		err := s.Serve(lis)
   258  		if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") {
   259  			t.Logf("s.Serve: %v after %s", err, time.Since(start))
   260  		}
   261  	}()
   262  	defer s.Stop()
   263  
   264  	// Dial the ChillerServer
   265  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
   266  	test.AssertNotError(t, err, "creating client metrics")
   267  	ci := &clientMetadataInterceptor{
   268  		timeout: 30 * time.Second,
   269  		metrics: clientMetrics,
   270  		clk:     clk,
   271  	}
   272  	conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
   273  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   274  		grpc.WithUnaryInterceptor(ci.Unary))
   275  	if err != nil {
   276  		t.Fatalf("did not connect: %v", err)
   277  	}
   278  	// Create a ChillerClient with the connection to the ChillerServer
   279  	c := test_proto.NewChillerClient(conn)
   280  
   281  	// Make an RPC request with the ChillerClient with a timeout higher than the
   282  	// requested ChillerServer delay so that the RPC completes normally
   283  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   284  	defer cancel()
   285  	if _, err := c.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second * 5)}); err != nil {
   286  		t.Fatalf("Unexpected error calling Chill RPC: %s", err)
   287  	}
   288  
   289  	// There should be one histogram sample in the serverInterceptor rpcLag stat
   290  	test.AssertMetricWithLabelsEquals(t, si.metrics.rpcLag, prometheus.Labels{}, 1)
   291  }
   292  
   293  // blockedServer implements a ChillerServer with a Chill method that:
   294  //  1. Calls Done() on the received waitgroup when receiving an RPC
   295  //  2. Blocks the RPC on the roadblock waitgroup
   296  //
   297  // This is used by TestInFlightRPCStat to test that the gauge for in-flight RPCs
   298  // is incremented and decremented as expected.
   299  type blockedServer struct {
   300  	test_proto.UnimplementedChillerServer
   301  	roadblock, received sync.WaitGroup
   302  }
   303  
   304  // Chill implements ChillerServer.Chill
   305  func (s *blockedServer) Chill(_ context.Context, _ *test_proto.Time) (*test_proto.Time, error) {
   306  	// Note that a client RPC arrived
   307  	s.received.Done()
   308  	// Wait for the roadblock to be cleared
   309  	s.roadblock.Wait()
   310  	// Return a dummy spent value to adhere to the chiller protocol
   311  	return &test_proto.Time{Duration: durationpb.New(time.Millisecond)}, nil
   312  }
   313  
   314  func TestInFlightRPCStat(t *testing.T) {
   315  	clk := clock.NewFake()
   316  	// Listen for TCP requests on a random system assigned port number
   317  	lis, err := net.Listen("tcp", ":0")
   318  	if err != nil {
   319  		log.Fatalf("failed to listen: %v", err)
   320  	}
   321  	// Retrieve the concrete port numberthe system assigned our listener
   322  	port := lis.Addr().(*net.TCPAddr).Port
   323  
   324  	// Create a new blockedServer to act as a ChillerServer
   325  	server := &blockedServer{}
   326  
   327  	// Increment the roadblock waitgroup - this will cause all chill RPCs to
   328  	// the server to block until we call Done()!
   329  	server.roadblock.Add(1)
   330  
   331  	// Increment the sentRPCs waitgroup - we use this to find out when all the
   332  	// RPCs we want to send have been received and we can count the in-flight
   333  	// gauge
   334  	numRPCs := 5
   335  	server.received.Add(numRPCs)
   336  
   337  	serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
   338  	test.AssertNotError(t, err, "creating server metrics")
   339  	si := newServerMetadataInterceptor(serverMetrics, clk)
   340  	s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
   341  	test_proto.RegisterChillerServer(s, server)
   342  	// Chill until ill
   343  	go func() {
   344  		start := time.Now()
   345  		err := s.Serve(lis)
   346  		if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") {
   347  			t.Logf("s.Serve: %v after %s", err, time.Since(start))
   348  		}
   349  	}()
   350  	defer s.Stop()
   351  
   352  	// Dial the ChillerServer
   353  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
   354  	test.AssertNotError(t, err, "creating client metrics")
   355  	ci := &clientMetadataInterceptor{
   356  		timeout: 30 * time.Second,
   357  		metrics: clientMetrics,
   358  		clk:     clk,
   359  	}
   360  	conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
   361  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   362  		grpc.WithUnaryInterceptor(ci.Unary))
   363  	if err != nil {
   364  		t.Fatalf("did not connect: %v", err)
   365  	}
   366  	// Create a ChillerClient with the connection to the ChillerServer
   367  	c := test_proto.NewChillerClient(conn)
   368  
   369  	// Fire off a few RPCs. They will block on the blockedServer's roadblock wg
   370  	for i := 0; i < numRPCs; i++ {
   371  		go func() {
   372  			// Ignore errors, just chilllll.
   373  			_, _ = c.Chill(context.Background(), &test_proto.Time{})
   374  		}()
   375  	}
   376  
   377  	// wait until all of the client RPCs have been sent and are blocking. We can
   378  	// now check the gauge.
   379  	server.received.Wait()
   380  
   381  	// Specify the labels for the RPCs we're interested in
   382  	labels := prometheus.Labels{
   383  		"service": "Chiller",
   384  		"method":  "Chill",
   385  	}
   386  
   387  	// We expect the inFlightRPCs gauge for the Chiller.Chill RPCs to be equal to numRPCs.
   388  	test.AssertMetricWithLabelsEquals(t, ci.metrics.inFlightRPCs, labels, float64(numRPCs))
   389  
   390  	// Unblock the blockedServer to let all of the Chiller.Chill RPCs complete
   391  	server.roadblock.Done()
   392  	// Sleep for a little bit to let all the RPCs complete
   393  	time.Sleep(1 * time.Second)
   394  
   395  	// Check the gauge value again
   396  	test.AssertMetricWithLabelsEquals(t, ci.metrics.inFlightRPCs, labels, 0)
   397  }
   398  
   399  func TestServiceAuthChecker(t *testing.T) {
   400  	ac := authInterceptor{
   401  		map[string]map[string]struct{}{
   402  			"package.ServiceName": {
   403  				"allowed.client": {},
   404  				"also.allowed":   {},
   405  			},
   406  		},
   407  	}
   408  
   409  	// No allowlist is a bad configuration.
   410  	ctx := context.Background()
   411  	err := ac.checkContextAuth(ctx, "/package.OtherService/Method/")
   412  	test.AssertError(t, err, "checking empty allowlist")
   413  
   414  	// Context with no peering information is disallowed.
   415  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   416  	test.AssertError(t, err, "checking un-peered context")
   417  
   418  	// Context with no auth info is disallowed.
   419  	ctx = peer.NewContext(ctx, &peer.Peer{})
   420  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   421  	test.AssertError(t, err, "checking peer with no auth")
   422  
   423  	// Context with no verified chains is disallowed.
   424  	ctx = peer.NewContext(ctx, &peer.Peer{
   425  		AuthInfo: credentials.TLSInfo{
   426  			State: tls.ConnectionState{},
   427  		},
   428  	})
   429  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   430  	test.AssertError(t, err, "checking TLS with no valid chains")
   431  
   432  	// Context with cert with wrong name is disallowed.
   433  	ctx = peer.NewContext(ctx, &peer.Peer{
   434  		AuthInfo: credentials.TLSInfo{
   435  			State: tls.ConnectionState{
   436  				VerifiedChains: [][]*x509.Certificate{
   437  					{
   438  						&x509.Certificate{
   439  							DNSNames: []string{
   440  								"disallowed.client",
   441  							},
   442  						},
   443  					},
   444  				},
   445  			},
   446  		},
   447  	})
   448  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   449  	test.AssertError(t, err, "checking disallowed cert")
   450  
   451  	// Context with cert with good name is allowed.
   452  	ctx = peer.NewContext(ctx, &peer.Peer{
   453  		AuthInfo: credentials.TLSInfo{
   454  			State: tls.ConnectionState{
   455  				VerifiedChains: [][]*x509.Certificate{
   456  					{
   457  						&x509.Certificate{
   458  							DNSNames: []string{
   459  								"disallowed.client",
   460  								"also.allowed",
   461  							},
   462  						},
   463  					},
   464  				},
   465  			},
   466  		},
   467  	})
   468  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   469  	test.AssertNotError(t, err, "checking allowed cert")
   470  }
   471  

View as plain text