...

Source file src/cloud.google.com/go/internal/testutil/headers_enforcer.go

Documentation: cloud.google.com/go/internal/testutil

     1  // Copyright 2019 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testutil
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"errors"
    21  	"fmt"
    22  	"log"
    23  	"os"
    24  	"strings"
    25  
    26  	"google.golang.org/api/option"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/metadata"
    29  )
    30  
    31  // HeaderChecker defines header checking and validation rules for any outgoing metadata.
    32  type HeaderChecker struct {
    33  	// Key is the header name to be checked against e.g. "x-goog-api-client".
    34  	Key string
    35  
    36  	// ValuesValidator validates the header values retrieved from mapping against
    37  	// Key in the Headers.
    38  	ValuesValidator func(values ...string) error
    39  }
    40  
    41  // HeadersEnforcer asserts that outgoing RPC headers
    42  // are present and match expectations. If the expected headers
    43  // are not present or don't match expectations, it'll invoke OnFailure
    44  // with the validation error, or instead log.Fatal if OnFailure is nil.
    45  //
    46  // It expects that every declared key will be present in the outgoing
    47  // RPC header and each value will be validated by the validation function.
    48  type HeadersEnforcer struct {
    49  	// Checkers maps header keys that are expected to be sent in the metadata
    50  	// of outgoing gRPC requests, against the values passed into the custom
    51  	// validation functions.
    52  	//
    53  	// If Checkers is nil or empty, only the default header "x-goog-api-client"
    54  	// will be checked for.
    55  	// Otherwise, if you supply Matchers, those keys and their respective
    56  	// validation functions will be checked.
    57  	Checkers []*HeaderChecker
    58  
    59  	// OnFailure is the function that will be invoked after all validation
    60  	// failures have been composed. If OnFailure is nil, log.Fatal will be
    61  	// invoked instead.
    62  	OnFailure func(fmt_ string, args ...interface{})
    63  }
    64  
    65  // StreamInterceptors returns a list of StreamClientInterceptor functions which
    66  // enforce the presence and validity of expected headers during streaming RPCs.
    67  //
    68  // For client implementations which provide their own StreamClientInterceptor(s)
    69  // these interceptors should be specified as the final elements to
    70  // WithChainStreamInterceptor.
    71  //
    72  // Alternatively, users may apply gPRC options produced from DialOptions to
    73  // apply all applicable gRPC interceptors.
    74  func (h *HeadersEnforcer) StreamInterceptors() []grpc.StreamClientInterceptor {
    75  	return []grpc.StreamClientInterceptor{h.interceptStream}
    76  }
    77  
    78  // UnaryInterceptors returns a list of UnaryClientInterceptor functions which
    79  // enforce the presence and validity of expected headers during unary RPCs.
    80  //
    81  // For client implementations which provide their own UnaryClientInterceptor(s)
    82  // these interceptors should be specified as the final elements to
    83  // WithChainUnaryInterceptor.
    84  //
    85  // Alternatively, users may apply gPRC options produced from DialOptions to
    86  // apply all applicable gRPC interceptors.
    87  func (h *HeadersEnforcer) UnaryInterceptors() []grpc.UnaryClientInterceptor {
    88  	return []grpc.UnaryClientInterceptor{h.interceptUnary}
    89  }
    90  
    91  // DialOptions returns gRPC DialOptions consisting of unary and stream interceptors
    92  // to enforce the presence and validity of expected headers.
    93  func (h *HeadersEnforcer) DialOptions() []grpc.DialOption {
    94  	return []grpc.DialOption{
    95  		grpc.WithChainStreamInterceptor(h.interceptStream),
    96  		grpc.WithChainUnaryInterceptor(h.interceptUnary),
    97  	}
    98  }
    99  
   100  // CallOptions returns ClientOptions consisting of unary and stream interceptors
   101  // to enforce the presence and validity of expected headers.
   102  func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) {
   103  	dopts := h.DialOptions()
   104  	for _, dopt := range dopts {
   105  		copts = append(copts, option.WithGRPCDialOption(dopt))
   106  	}
   107  	return
   108  }
   109  
   110  func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   111  	h.checkMetadata(ctx, method)
   112  	return invoker(ctx, method, req, res, cc, opts...)
   113  }
   114  
   115  func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   116  	h.checkMetadata(ctx, method)
   117  	return streamer(ctx, desc, cc, method, opts...)
   118  }
   119  
   120  // XGoogClientHeaderChecker is a HeaderChecker that ensures that the "x-goog-api-client"
   121  // header is present on outgoing metadata.
   122  var XGoogClientHeaderChecker = &HeaderChecker{
   123  	Key: "x-goog-api-client",
   124  	ValuesValidator: func(values ...string) error {
   125  		if len(values) == 0 {
   126  			return errors.New("expecting values")
   127  		}
   128  		for _, value := range values {
   129  			switch {
   130  			case strings.Contains(value, "gl-go/"):
   131  				// TODO: check for exact version strings.
   132  				return nil
   133  
   134  			default: // Add others here.
   135  			}
   136  		}
   137  		return errors.New("unmatched values")
   138  	},
   139  }
   140  
   141  // DefaultHeadersEnforcer returns a HeadersEnforcer that at bare minimum checks that
   142  // the "x-goog-api-client" key is present in the outgoing metadata headers. On any
   143  // validation failure, it will invoke log.Fatalf with the error message.
   144  func DefaultHeadersEnforcer() *HeadersEnforcer {
   145  	return &HeadersEnforcer{
   146  		Checkers: []*HeaderChecker{XGoogClientHeaderChecker},
   147  	}
   148  }
   149  
   150  func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) {
   151  	onFailure := h.OnFailure
   152  	if onFailure == nil {
   153  		lgr := log.New(os.Stderr, "", 0) // Do not log the time prefix, it is noisy in test failure logs.
   154  		onFailure = func(fmt_ string, args ...interface{}) {
   155  			lgr.Fatalf(fmt_, args...)
   156  		}
   157  	}
   158  
   159  	md, ok := metadata.FromOutgoingContext(ctx)
   160  	if !ok {
   161  		onFailure("Missing metadata for method %q", method)
   162  		return
   163  	}
   164  	checkers := h.Checkers
   165  	if len(checkers) == 0 {
   166  		// Instead use the default HeaderChecker.
   167  		checkers = append(checkers, XGoogClientHeaderChecker)
   168  	}
   169  
   170  	errBuf := new(bytes.Buffer)
   171  	for _, checker := range checkers {
   172  		hdrKey := checker.Key
   173  		outHdrValues, ok := md[hdrKey]
   174  		if !ok {
   175  			fmt.Fprintf(errBuf, "missing header %q\n", hdrKey)
   176  			continue
   177  		}
   178  		if err := checker.ValuesValidator(outHdrValues...); err != nil {
   179  			fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err)
   180  		}
   181  	}
   182  
   183  	if errBuf.Len() != 0 {
   184  		onFailure("For method %q, errors:\n%s", method, errBuf)
   185  		return
   186  	}
   187  }
   188  

View as plain text