...

Source file src/github.com/letsencrypt/boulder/grpc/interceptors.go

Documentation: github.com/letsencrypt/boulder/grpc

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/jmhodges/clock"
    11  	"github.com/prometheus/client_golang/prometheus"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/codes"
    14  	"google.golang.org/grpc/credentials"
    15  	"google.golang.org/grpc/metadata"
    16  	"google.golang.org/grpc/peer"
    17  	"google.golang.org/grpc/status"
    18  
    19  	"github.com/letsencrypt/boulder/cmd"
    20  	berrors "github.com/letsencrypt/boulder/errors"
    21  )
    22  
    23  const (
    24  	returnOverhead         = 20 * time.Millisecond
    25  	meaningfulWorkOverhead = 100 * time.Millisecond
    26  	clientRequestTimeKey   = "client-request-time"
    27  )
    28  
    29  type serverInterceptor interface {
    30  	Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error)
    31  	Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error
    32  }
    33  
    34  // noopServerInterceptor provides no-op interceptors. It can be substituted for
    35  // an interceptor that has been disabled.
    36  type noopServerInterceptor struct{}
    37  
    38  // Unary is a gRPC unary interceptor.
    39  func (n *noopServerInterceptor) Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    40  	return handler(ctx, req)
    41  }
    42  
    43  // Stream is a gRPC stream interceptor.
    44  func (n *noopServerInterceptor) Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    45  	return handler(srv, ss)
    46  }
    47  
    48  // Ensure noopServerInterceptor matches the serverInterceptor interface.
    49  var _ serverInterceptor = &noopServerInterceptor{}
    50  
    51  type clientInterceptor interface {
    52  	Unary(ctx context.Context, method string, req interface{}, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error
    53  	Stream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error)
    54  }
    55  
    56  // serverMetadataInterceptor is a gRPC interceptor that adds Prometheus
    57  // metrics to requests handled by a gRPC server, and wraps Boulder-specific
    58  // errors for transmission in a grpc/metadata trailer (see bcodes.go).
    59  type serverMetadataInterceptor struct {
    60  	metrics serverMetrics
    61  	clk     clock.Clock
    62  }
    63  
    64  func newServerMetadataInterceptor(metrics serverMetrics, clk clock.Clock) serverMetadataInterceptor {
    65  	return serverMetadataInterceptor{
    66  		metrics: metrics,
    67  		clk:     clk,
    68  	}
    69  }
    70  
    71  // Unary implements the grpc.UnaryServerInterceptor interface.
    72  func (smi *serverMetadataInterceptor) Unary(
    73  	ctx context.Context,
    74  	req interface{},
    75  	info *grpc.UnaryServerInfo,
    76  	handler grpc.UnaryHandler) (interface{}, error) {
    77  	if info == nil {
    78  		return nil, berrors.InternalServerError("passed nil *grpc.UnaryServerInfo")
    79  	}
    80  
    81  	// Extract the grpc metadata from the context. If the context has
    82  	// a `clientRequestTimeKey` field, and it has a value, then observe the RPC
    83  	// latency with Prometheus.
    84  	if md, ok := metadata.FromIncomingContext(ctx); ok && len(md[clientRequestTimeKey]) > 0 {
    85  		err := smi.observeLatency(md[clientRequestTimeKey][0])
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  	}
    90  
    91  	// Shave 20 milliseconds off the deadline to ensure that if the RPC server times
    92  	// out any sub-calls it makes (like DNS lookups, or onwards RPCs), it has a
    93  	// chance to report that timeout to the client. This allows for more specific
    94  	// errors, e.g "the VA timed out looking up CAA for example.com" (when called
    95  	// from RA.NewCertificate, which was called from WFE.NewCertificate), as
    96  	// opposed to "RA.NewCertificate timed out" (causing a 500).
    97  	// Once we've shaved the deadline, we ensure we have we have at least another
    98  	// 100ms left to do work; otherwise we abort early.
    99  	deadline, ok := ctx.Deadline()
   100  	// Should never happen: there was no deadline.
   101  	if !ok {
   102  		deadline = time.Now().Add(100 * time.Second)
   103  	}
   104  	deadline = deadline.Add(-returnOverhead)
   105  	remaining := time.Until(deadline)
   106  	if remaining < meaningfulWorkOverhead {
   107  		return nil, status.Errorf(codes.DeadlineExceeded, "not enough time left on clock: %s", remaining)
   108  	}
   109  
   110  	localCtx, cancel := context.WithDeadline(ctx, deadline)
   111  	defer cancel()
   112  
   113  	resp, err := handler(localCtx, req)
   114  	if err != nil {
   115  		err = wrapError(localCtx, err)
   116  	}
   117  	return resp, err
   118  }
   119  
   120  // interceptedServerStream wraps an existing server stream, but replaces its
   121  // context with its own.
   122  type interceptedServerStream struct {
   123  	grpc.ServerStream
   124  	ctx context.Context
   125  }
   126  
   127  // Context implements part of the grpc.ServerStream interface.
   128  func (iss interceptedServerStream) Context() context.Context {
   129  	return iss.ctx
   130  }
   131  
   132  // Stream implements the grpc.StreamServerInterceptor interface.
   133  func (smi *serverMetadataInterceptor) Stream(
   134  	srv interface{},
   135  	ss grpc.ServerStream,
   136  	info *grpc.StreamServerInfo,
   137  	handler grpc.StreamHandler) error {
   138  	ctx := ss.Context()
   139  
   140  	// Extract the grpc metadata from the context. If the context has
   141  	// a `clientRequestTimeKey` field, and it has a value, then observe the RPC
   142  	// latency with Prometheus.
   143  	if md, ok := metadata.FromIncomingContext(ctx); ok && len(md[clientRequestTimeKey]) > 0 {
   144  		err := smi.observeLatency(md[clientRequestTimeKey][0])
   145  		if err != nil {
   146  			return err
   147  		}
   148  	}
   149  
   150  	// Shave 20 milliseconds off the deadline to ensure that if the RPC server times
   151  	// out any sub-calls it makes (like DNS lookups, or onwards RPCs), it has a
   152  	// chance to report that timeout to the client. This allows for more specific
   153  	// errors, e.g "the VA timed out looking up CAA for example.com" (when called
   154  	// from RA.NewCertificate, which was called from WFE.NewCertificate), as
   155  	// opposed to "RA.NewCertificate timed out" (causing a 500).
   156  	// Once we've shaved the deadline, we ensure we have we have at least another
   157  	// 100ms left to do work; otherwise we abort early.
   158  	deadline, ok := ctx.Deadline()
   159  	// Should never happen: there was no deadline.
   160  	if !ok {
   161  		deadline = time.Now().Add(100 * time.Second)
   162  	}
   163  	deadline = deadline.Add(-returnOverhead)
   164  	remaining := time.Until(deadline)
   165  	if remaining < meaningfulWorkOverhead {
   166  		return status.Errorf(codes.DeadlineExceeded, "not enough time left on clock: %s", remaining)
   167  	}
   168  
   169  	// Server stream interceptors are synchronous (they return their error, if
   170  	// any, when the stream is done) so defer cancel() is safe here.
   171  	localCtx, cancel := context.WithDeadline(ctx, deadline)
   172  	defer cancel()
   173  
   174  	err := handler(srv, interceptedServerStream{ss, localCtx})
   175  	if err != nil {
   176  		err = wrapError(localCtx, err)
   177  	}
   178  	return err
   179  }
   180  
   181  // splitMethodName is borrowed directly from
   182  // `grpc-ecosystem/go-grpc-prometheus/util.go` and is used to extract the
   183  // service and method name from the `method` argument to
   184  // a `UnaryClientInterceptor`.
   185  func splitMethodName(fullMethodName string) (string, string) {
   186  	fullMethodName = strings.TrimPrefix(fullMethodName, "/") // remove leading slash
   187  	if i := strings.Index(fullMethodName, "/"); i >= 0 {
   188  		return fullMethodName[:i], fullMethodName[i+1:]
   189  	}
   190  	return "unknown", "unknown"
   191  }
   192  
   193  // observeLatency is called with the `clientRequestTimeKey` value from
   194  // a request's gRPC metadata. This string value is converted to a timestamp and
   195  // used to calculate the latency between send and receive time. The latency is
   196  // published to the server interceptor's rpcLag prometheus histogram. An error
   197  // is returned if the `clientReqTime` string is not a valid timestamp.
   198  func (smi *serverMetadataInterceptor) observeLatency(clientReqTime string) error {
   199  	// Convert the metadata request time into an int64
   200  	reqTimeUnixNanos, err := strconv.ParseInt(clientReqTime, 10, 64)
   201  	if err != nil {
   202  		return berrors.InternalServerError("grpc metadata had illegal %s value: %q - %s",
   203  			clientRequestTimeKey, clientReqTime, err)
   204  	}
   205  	// Calculate the elapsed time since the client sent the RPC
   206  	reqTime := time.Unix(0, reqTimeUnixNanos)
   207  	elapsed := smi.clk.Since(reqTime)
   208  	// Publish an RPC latency observation to the histogram
   209  	smi.metrics.rpcLag.Observe(elapsed.Seconds())
   210  	return nil
   211  }
   212  
   213  // Ensure serverMetadataInterceptor matches the serverInterceptor interface.
   214  var _ serverInterceptor = (*serverMetadataInterceptor)(nil)
   215  
   216  // clientMetadataInterceptor is a gRPC interceptor that adds Prometheus
   217  // metrics to sent requests, and disables FailFast. We disable FailFast because
   218  // non-FailFast mode is most similar to the old AMQP RPC layer: If a client
   219  // makes a request while all backends are briefly down (e.g. for a restart), the
   220  // request doesn't necessarily fail. A backend can service the request if it
   221  // comes back up within the timeout. Under gRPC the same effect is achieved by
   222  // retries up to the Context deadline.
   223  type clientMetadataInterceptor struct {
   224  	timeout time.Duration
   225  	metrics clientMetrics
   226  	clk     clock.Clock
   227  
   228  	waitForReady bool
   229  }
   230  
   231  // Unary implements the grpc.UnaryClientInterceptor interface.
   232  func (cmi *clientMetadataInterceptor) Unary(
   233  	ctx context.Context,
   234  	fullMethod string,
   235  	req,
   236  	reply interface{},
   237  	cc *grpc.ClientConn,
   238  	invoker grpc.UnaryInvoker,
   239  	opts ...grpc.CallOption) error {
   240  	// This should not occur but fail fast with a clear error if it does (e.g.
   241  	// because of buggy unit test code) instead of a generic nil panic later!
   242  	if cmi.metrics.inFlightRPCs == nil {
   243  		return berrors.InternalServerError("clientInterceptor has nil inFlightRPCs gauge")
   244  	}
   245  
   246  	// Ensure that the context has a deadline set.
   247  	localCtx, cancel := context.WithTimeout(ctx, cmi.timeout)
   248  	defer cancel()
   249  
   250  	// Convert the current unix nano timestamp to a string for embedding in the grpc metadata
   251  	nowTS := strconv.FormatInt(cmi.clk.Now().UnixNano(), 10)
   252  	// Create a grpc/metadata.Metadata instance for the request metadata.
   253  	// Initialize it with the request time.
   254  	reqMD := metadata.New(map[string]string{clientRequestTimeKey: nowTS})
   255  	// Configure the localCtx with the metadata so it gets sent along in the request
   256  	localCtx = metadata.NewOutgoingContext(localCtx, reqMD)
   257  
   258  	// Disable fail-fast so RPCs will retry until deadline, even if all backends
   259  	// are down.
   260  	opts = append(opts, grpc.WaitForReady(cmi.waitForReady))
   261  
   262  	// Create a grpc/metadata.Metadata instance for a grpc.Trailer.
   263  	respMD := metadata.New(nil)
   264  	// Configure a grpc Trailer with respMD. This allows us to wrap error
   265  	// types in the server interceptor later on.
   266  	opts = append(opts, grpc.Trailer(&respMD))
   267  
   268  	// Split the method and service name from the fullMethod.
   269  	// UnaryClientInterceptor's receive a `method` arg of the form
   270  	// "/ServiceName/MethodName"
   271  	service, method := splitMethodName(fullMethod)
   272  	// Slice the inFlightRPC inc/dec calls by method and service
   273  	labels := prometheus.Labels{
   274  		"method":  method,
   275  		"service": service,
   276  	}
   277  	// Increment the inFlightRPCs gauge for this method/service
   278  	cmi.metrics.inFlightRPCs.With(labels).Inc()
   279  	// And defer decrementing it when we're done
   280  	defer cmi.metrics.inFlightRPCs.With(labels).Dec()
   281  
   282  	// Handle the RPC
   283  	begin := cmi.clk.Now()
   284  	err := invoker(localCtx, fullMethod, req, reply, cc, opts...)
   285  	if err != nil {
   286  		err = unwrapError(err, respMD)
   287  		if status.Code(err) == codes.DeadlineExceeded {
   288  			return deadlineDetails{
   289  				service: service,
   290  				method:  method,
   291  				latency: cmi.clk.Since(begin),
   292  			}
   293  		}
   294  	}
   295  	return err
   296  }
   297  
   298  // interceptedClientStream wraps an existing client stream, and calls finish
   299  // when the stream ends or any operation on it fails.
   300  type interceptedClientStream struct {
   301  	grpc.ClientStream
   302  	finish func(error) error
   303  }
   304  
   305  // Header implements part of the grpc.ClientStream interface.
   306  func (ics interceptedClientStream) Header() (metadata.MD, error) {
   307  	md, err := ics.ClientStream.Header()
   308  	if err != nil {
   309  		err = ics.finish(err)
   310  	}
   311  	return md, err
   312  }
   313  
   314  // SendMsg implements part of the grpc.ClientStream interface.
   315  func (ics interceptedClientStream) SendMsg(m interface{}) error {
   316  	err := ics.ClientStream.SendMsg(m)
   317  	if err != nil {
   318  		err = ics.finish(err)
   319  	}
   320  	return err
   321  }
   322  
   323  // RecvMsg implements part of the grpc.ClientStream interface.
   324  func (ics interceptedClientStream) RecvMsg(m interface{}) error {
   325  	err := ics.ClientStream.RecvMsg(m)
   326  	if err != nil {
   327  		err = ics.finish(err)
   328  	}
   329  	return err
   330  }
   331  
   332  // CloseSend implements part of the grpc.ClientStream interface.
   333  func (ics interceptedClientStream) CloseSend() error {
   334  	err := ics.ClientStream.CloseSend()
   335  	if err != nil {
   336  		err = ics.finish(err)
   337  	}
   338  	return err
   339  }
   340  
   341  // Stream implements the grpc.StreamClientInterceptor interface.
   342  func (cmi *clientMetadataInterceptor) Stream(
   343  	ctx context.Context,
   344  	desc *grpc.StreamDesc,
   345  	cc *grpc.ClientConn,
   346  	fullMethod string,
   347  	streamer grpc.Streamer,
   348  	opts ...grpc.CallOption) (grpc.ClientStream, error) {
   349  	// This should not occur but fail fast with a clear error if it does (e.g.
   350  	// because of buggy unit test code) instead of a generic nil panic later!
   351  	if cmi.metrics.inFlightRPCs == nil {
   352  		return nil, berrors.InternalServerError("clientInterceptor has nil inFlightRPCs gauge")
   353  	}
   354  
   355  	// We don't defer cancel() here, because this function is going to return
   356  	// immediately. Instead we store it in the interceptedClientStream.
   357  	localCtx, cancel := context.WithTimeout(ctx, cmi.timeout)
   358  
   359  	// Convert the current unix nano timestamp to a string for embedding in the grpc metadata
   360  	nowTS := strconv.FormatInt(cmi.clk.Now().UnixNano(), 10)
   361  	// Create a grpc/metadata.Metadata instance for the request metadata.
   362  	// Initialize it with the request time.
   363  	reqMD := metadata.New(map[string]string{clientRequestTimeKey: nowTS})
   364  	// Configure the localCtx with the metadata so it gets sent along in the request
   365  	localCtx = metadata.NewOutgoingContext(localCtx, reqMD)
   366  
   367  	// Disable fail-fast so RPCs will retry until deadline, even if all backends
   368  	// are down.
   369  	opts = append(opts, grpc.WaitForReady(cmi.waitForReady))
   370  
   371  	// Create a grpc/metadata.Metadata instance for a grpc.Trailer.
   372  	respMD := metadata.New(nil)
   373  	// Configure a grpc Trailer with respMD. This allows us to wrap error
   374  	// types in the server interceptor later on.
   375  	opts = append(opts, grpc.Trailer(&respMD))
   376  
   377  	// Split the method and service name from the fullMethod.
   378  	// UnaryClientInterceptor's receive a `method` arg of the form
   379  	// "/ServiceName/MethodName"
   380  	service, method := splitMethodName(fullMethod)
   381  	// Slice the inFlightRPC inc/dec calls by method and service
   382  	labels := prometheus.Labels{
   383  		"method":  method,
   384  		"service": service,
   385  	}
   386  	// Increment the inFlightRPCs gauge for this method/service
   387  	cmi.metrics.inFlightRPCs.With(labels).Inc()
   388  	begin := cmi.clk.Now()
   389  
   390  	// Cancel the local context and decrement the metric when we're done. Also
   391  	// transform the error into a more usable form, if necessary.
   392  	finish := func(err error) error {
   393  		cancel()
   394  		cmi.metrics.inFlightRPCs.With(labels).Dec()
   395  		if err != nil {
   396  			err = unwrapError(err, respMD)
   397  			if status.Code(err) == codes.DeadlineExceeded {
   398  				return deadlineDetails{
   399  					service: service,
   400  					method:  method,
   401  					latency: cmi.clk.Since(begin),
   402  				}
   403  			}
   404  		}
   405  		return err
   406  	}
   407  
   408  	// Handle the RPC
   409  	cs, err := streamer(localCtx, desc, cc, fullMethod, opts...)
   410  	ics := interceptedClientStream{cs, finish}
   411  	return ics, err
   412  }
   413  
   414  var _ clientInterceptor = (*clientMetadataInterceptor)(nil)
   415  
   416  // deadlineDetails is an error type that we use in place of gRPC's
   417  // DeadlineExceeded errors in order to add more detail for debugging.
   418  type deadlineDetails struct {
   419  	service string
   420  	method  string
   421  	latency time.Duration
   422  }
   423  
   424  func (dd deadlineDetails) Error() string {
   425  	return fmt.Sprintf("%s.%s timed out after %d ms",
   426  		dd.service, dd.method, int64(dd.latency/time.Millisecond))
   427  }
   428  
   429  // authInterceptor provides two server interceptors (Unary and Stream) which can
   430  // check that every request for a given gRPC service is being made over an mTLS
   431  // connection from a client which is allow-listed for that particular service.
   432  type authInterceptor struct {
   433  	// serviceClientNames is a map of gRPC service names (e.g. "ca.CertificateAuthority")
   434  	// to allowed client certificate SANs (e.g. "ra.boulder") which are allowed to
   435  	// make RPCs to that service. The set of client names is implemented as a map
   436  	// of names to empty structs for easy lookup.
   437  	serviceClientNames map[string]map[string]struct{}
   438  }
   439  
   440  // newServiceAuthChecker takes a GRPCServerConfig and uses its Service stanzas
   441  // to construct a serviceAuthChecker which enforces the service/client mappings
   442  // contained in the config.
   443  func newServiceAuthChecker(c *cmd.GRPCServerConfig) *authInterceptor {
   444  	names := make(map[string]map[string]struct{})
   445  	for serviceName, service := range c.Services {
   446  		names[serviceName] = make(map[string]struct{})
   447  		for _, clientName := range service.ClientNames {
   448  			names[serviceName][clientName] = struct{}{}
   449  		}
   450  	}
   451  	return &authInterceptor{names}
   452  }
   453  
   454  // Unary is a gRPC unary interceptor.
   455  func (ac *authInterceptor) Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   456  	err := ac.checkContextAuth(ctx, info.FullMethod)
   457  	if err != nil {
   458  		return nil, err
   459  	}
   460  	return handler(ctx, req)
   461  }
   462  
   463  // Stream is a gRPC stream interceptor.
   464  func (ac *authInterceptor) Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   465  	err := ac.checkContextAuth(ss.Context(), info.FullMethod)
   466  	if err != nil {
   467  		return err
   468  	}
   469  	return handler(srv, ss)
   470  }
   471  
   472  // checkContextAuth does most of the heavy lifting. It extracts TLS information
   473  // from the incoming context, gets the set of DNS names contained in the client
   474  // mTLS cert, and returns nil if at least one of those names appears in the set
   475  // of allowed client names for given service (or if the set of allowed client
   476  // names is empty).
   477  func (ac *authInterceptor) checkContextAuth(ctx context.Context, fullMethod string) error {
   478  	serviceName, _ := splitMethodName(fullMethod)
   479  
   480  	allowedClientNames, ok := ac.serviceClientNames[serviceName]
   481  	if !ok || len(allowedClientNames) == 0 {
   482  		return fmt.Errorf("service %q has no allowed client names", serviceName)
   483  	}
   484  
   485  	p, ok := peer.FromContext(ctx)
   486  	if !ok {
   487  		return fmt.Errorf("unable to fetch peer info from grpc context")
   488  	}
   489  
   490  	if p.AuthInfo == nil {
   491  		return fmt.Errorf("grpc connection appears to be plaintext")
   492  	}
   493  
   494  	tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo)
   495  	if !ok {
   496  		return fmt.Errorf("connection is not TLS authed")
   497  	}
   498  
   499  	if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 {
   500  		return fmt.Errorf("connection auth not verified")
   501  	}
   502  
   503  	cert := tlsAuth.State.VerifiedChains[0][0]
   504  
   505  	for _, clientName := range cert.DNSNames {
   506  		_, ok := allowedClientNames[clientName]
   507  		if ok {
   508  			return nil
   509  		}
   510  	}
   511  
   512  	return fmt.Errorf(
   513  		"client names %v are not authorized for service %q (%v)",
   514  		cert.DNSNames, serviceName, allowedClientNames)
   515  }
   516  
   517  // Ensure authInterceptor matches the serverInterceptor interface.
   518  var _ serverInterceptor = (*authInterceptor)(nil)
   519  

View as plain text