...

Source file src/github.com/grpc-ecosystem/go-grpc-middleware/testing/interceptor_suite.go

Documentation: github.com/grpc-ecosystem/go-grpc-middleware/testing

     1  // Copyright 2016 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_testing
     5  
     6  import (
     7  	"context"
     8  	"crypto/rand"
     9  	"crypto/rsa"
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"crypto/x509/pkix"
    13  	"encoding/pem"
    14  	"flag"
    15  	"math/big"
    16  	"net"
    17  	"time"
    18  
    19  	pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/stretchr/testify/suite"
    22  	"google.golang.org/grpc"
    23  	"google.golang.org/grpc/credentials"
    24  )
    25  
    26  var (
    27  	flagTls = flag.Bool("use_tls", true, "whether all gRPC middleware tests should use tls")
    28  
    29  	certPEM []byte
    30  	keyPEM  []byte
    31  )
    32  
    33  // InterceptorTestSuite is a testify/Suite that starts a gRPC PingService server and a client.
    34  type InterceptorTestSuite struct {
    35  	suite.Suite
    36  
    37  	TestService pb_testproto.TestServiceServer
    38  	ServerOpts  []grpc.ServerOption
    39  	ClientOpts  []grpc.DialOption
    40  
    41  	serverAddr     string
    42  	ServerListener net.Listener
    43  	Server         *grpc.Server
    44  	clientConn     *grpc.ClientConn
    45  	Client         pb_testproto.TestServiceClient
    46  
    47  	restartServerWithDelayedStart chan time.Duration
    48  	serverRunning                 chan bool
    49  }
    50  
    51  func (s *InterceptorTestSuite) SetupSuite() {
    52  	s.restartServerWithDelayedStart = make(chan time.Duration)
    53  	s.serverRunning = make(chan bool)
    54  
    55  	s.serverAddr = "127.0.0.1:0"
    56  	var err error
    57  	certPEM, keyPEM, err = generateCertAndKey([]string{"localhost", "example.com"})
    58  	if err != nil {
    59  		s.T().Fatalf("unable to generate test certificate/key: " + err.Error())
    60  	}
    61  	go func() {
    62  		for {
    63  			var err error
    64  			s.ServerListener, err = net.Listen("tcp", s.serverAddr)
    65  			if err != nil {
    66  				s.T().Fatalf("unable to listen on address %s: %v", s.serverAddr, err)
    67  			}
    68  			s.serverAddr = s.ServerListener.Addr().String()
    69  			require.NoError(s.T(), err, "must be able to allocate a port for serverListener")
    70  			if *flagTls {
    71  				cert, err := tls.X509KeyPair(certPEM, keyPEM)
    72  				if err != nil {
    73  					s.T().Fatalf("unable to load test TLS certificate: %v", err)
    74  				}
    75  				creds := credentials.NewServerTLSFromCert(&cert)
    76  				s.ServerOpts = append(s.ServerOpts, grpc.Creds(creds))
    77  			}
    78  			// This is the point where we hook up the interceptor
    79  			s.Server = grpc.NewServer(s.ServerOpts...)
    80  			// Create a service of the instantiator hasn't provided one.
    81  			if s.TestService == nil {
    82  				s.TestService = &TestPingService{T: s.T()}
    83  			}
    84  			pb_testproto.RegisterTestServiceServer(s.Server, s.TestService)
    85  
    86  			go func() {
    87  				s.Server.Serve(s.ServerListener)
    88  			}()
    89  			if s.Client == nil {
    90  				s.Client = s.NewClient(s.ClientOpts...)
    91  			}
    92  
    93  			s.serverRunning <- true
    94  
    95  			d := <-s.restartServerWithDelayedStart
    96  			s.Server.Stop()
    97  			time.Sleep(d)
    98  		}
    99  	}()
   100  
   101  	select {
   102  	case <-s.serverRunning:
   103  	case <-time.After(2 * time.Second):
   104  		s.T().Fatal("server failed to start before deadline")
   105  	}
   106  }
   107  
   108  func (s *InterceptorTestSuite) RestartServer(delayedStart time.Duration) <-chan bool {
   109  	s.restartServerWithDelayedStart <- delayedStart
   110  	time.Sleep(10 * time.Millisecond)
   111  	return s.serverRunning
   112  }
   113  
   114  func (s *InterceptorTestSuite) NewClient(dialOpts ...grpc.DialOption) pb_testproto.TestServiceClient {
   115  	newDialOpts := append(dialOpts, grpc.WithBlock())
   116  	if *flagTls {
   117  		cp := x509.NewCertPool()
   118  		if !cp.AppendCertsFromPEM(certPEM) {
   119  			s.T().Fatal("failed to append certificate")
   120  		}
   121  		creds := credentials.NewTLS(&tls.Config{ServerName: "localhost", RootCAs: cp})
   122  		newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(creds))
   123  	} else {
   124  		newDialOpts = append(newDialOpts, grpc.WithInsecure())
   125  	}
   126  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   127  	defer cancel()
   128  	clientConn, err := grpc.DialContext(ctx, s.ServerAddr(), newDialOpts...)
   129  	require.NoError(s.T(), err, "must not error on client Dial")
   130  	return pb_testproto.NewTestServiceClient(clientConn)
   131  }
   132  
   133  func (s *InterceptorTestSuite) ServerAddr() string {
   134  	return s.serverAddr
   135  }
   136  
   137  func (s *InterceptorTestSuite) SimpleCtx() context.Context {
   138  	ctx, _ := context.WithTimeout(context.TODO(), 2*time.Second)
   139  	return ctx
   140  }
   141  
   142  func (s *InterceptorTestSuite) DeadlineCtx(deadline time.Time) context.Context {
   143  	ctx, _ := context.WithDeadline(context.TODO(), deadline)
   144  	return ctx
   145  }
   146  
   147  func (s *InterceptorTestSuite) TearDownSuite() {
   148  	time.Sleep(10 * time.Millisecond)
   149  	if s.ServerListener != nil {
   150  		s.Server.GracefulStop()
   151  		s.T().Logf("stopped grpc.Server at: %v", s.ServerAddr())
   152  		s.ServerListener.Close()
   153  	}
   154  	if s.clientConn != nil {
   155  		s.clientConn.Close()
   156  	}
   157  }
   158  
   159  // generateCertAndKey copied from https://github.com/johanbrandhorst/certify/blob/master/issuers/vault/vault_suite_test.go#L255
   160  // with minor modifications.
   161  func generateCertAndKey(san []string) ([]byte, []byte, error) {
   162  	priv, err := rsa.GenerateKey(rand.Reader, 2048)
   163  	if err != nil {
   164  		return nil, nil, err
   165  	}
   166  	notBefore := time.Now()
   167  	notAfter := notBefore.Add(time.Hour)
   168  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
   169  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
   170  	if err != nil {
   171  		return nil, nil, err
   172  	}
   173  	template := x509.Certificate{
   174  		SerialNumber: serialNumber,
   175  		Subject: pkix.Name{
   176  			CommonName: "example.com",
   177  		},
   178  		NotBefore:             notBefore,
   179  		NotAfter:              notAfter,
   180  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
   181  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   182  		BasicConstraintsValid: true,
   183  		DNSNames:              san,
   184  	}
   185  	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv)
   186  	if err != nil {
   187  		return nil, nil, err
   188  	}
   189  	certOut := pem.EncodeToMemory(&pem.Block{
   190  		Type:  "CERTIFICATE",
   191  		Bytes: derBytes,
   192  	})
   193  	keyOut := pem.EncodeToMemory(&pem.Block{
   194  		Type:  "RSA PRIVATE KEY",
   195  		Bytes: x509.MarshalPKCS1PrivateKey(priv),
   196  	})
   197  
   198  	return certOut, keyOut, nil
   199  }
   200  

View as plain text