1
2
3
4 package grpc_testing
5
6 import (
7 "context"
8 "crypto/rand"
9 "crypto/rsa"
10 "crypto/tls"
11 "crypto/x509"
12 "crypto/x509/pkix"
13 "encoding/pem"
14 "flag"
15 "math/big"
16 "net"
17 "time"
18
19 pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
20 "github.com/stretchr/testify/require"
21 "github.com/stretchr/testify/suite"
22 "google.golang.org/grpc"
23 "google.golang.org/grpc/credentials"
24 )
25
26 var (
27 flagTls = flag.Bool("use_tls", true, "whether all gRPC middleware tests should use tls")
28
29 certPEM []byte
30 keyPEM []byte
31 )
32
33
34 type InterceptorTestSuite struct {
35 suite.Suite
36
37 TestService pb_testproto.TestServiceServer
38 ServerOpts []grpc.ServerOption
39 ClientOpts []grpc.DialOption
40
41 serverAddr string
42 ServerListener net.Listener
43 Server *grpc.Server
44 clientConn *grpc.ClientConn
45 Client pb_testproto.TestServiceClient
46
47 restartServerWithDelayedStart chan time.Duration
48 serverRunning chan bool
49 }
50
51 func (s *InterceptorTestSuite) SetupSuite() {
52 s.restartServerWithDelayedStart = make(chan time.Duration)
53 s.serverRunning = make(chan bool)
54
55 s.serverAddr = "127.0.0.1:0"
56 var err error
57 certPEM, keyPEM, err = generateCertAndKey([]string{"localhost", "example.com"})
58 if err != nil {
59 s.T().Fatalf("unable to generate test certificate/key: " + err.Error())
60 }
61 go func() {
62 for {
63 var err error
64 s.ServerListener, err = net.Listen("tcp", s.serverAddr)
65 if err != nil {
66 s.T().Fatalf("unable to listen on address %s: %v", s.serverAddr, err)
67 }
68 s.serverAddr = s.ServerListener.Addr().String()
69 require.NoError(s.T(), err, "must be able to allocate a port for serverListener")
70 if *flagTls {
71 cert, err := tls.X509KeyPair(certPEM, keyPEM)
72 if err != nil {
73 s.T().Fatalf("unable to load test TLS certificate: %v", err)
74 }
75 creds := credentials.NewServerTLSFromCert(&cert)
76 s.ServerOpts = append(s.ServerOpts, grpc.Creds(creds))
77 }
78
79 s.Server = grpc.NewServer(s.ServerOpts...)
80
81 if s.TestService == nil {
82 s.TestService = &TestPingService{T: s.T()}
83 }
84 pb_testproto.RegisterTestServiceServer(s.Server, s.TestService)
85
86 go func() {
87 s.Server.Serve(s.ServerListener)
88 }()
89 if s.Client == nil {
90 s.Client = s.NewClient(s.ClientOpts...)
91 }
92
93 s.serverRunning <- true
94
95 d := <-s.restartServerWithDelayedStart
96 s.Server.Stop()
97 time.Sleep(d)
98 }
99 }()
100
101 select {
102 case <-s.serverRunning:
103 case <-time.After(2 * time.Second):
104 s.T().Fatal("server failed to start before deadline")
105 }
106 }
107
108 func (s *InterceptorTestSuite) RestartServer(delayedStart time.Duration) <-chan bool {
109 s.restartServerWithDelayedStart <- delayedStart
110 time.Sleep(10 * time.Millisecond)
111 return s.serverRunning
112 }
113
114 func (s *InterceptorTestSuite) NewClient(dialOpts ...grpc.DialOption) pb_testproto.TestServiceClient {
115 newDialOpts := append(dialOpts, grpc.WithBlock())
116 if *flagTls {
117 cp := x509.NewCertPool()
118 if !cp.AppendCertsFromPEM(certPEM) {
119 s.T().Fatal("failed to append certificate")
120 }
121 creds := credentials.NewTLS(&tls.Config{ServerName: "localhost", RootCAs: cp})
122 newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(creds))
123 } else {
124 newDialOpts = append(newDialOpts, grpc.WithInsecure())
125 }
126 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
127 defer cancel()
128 clientConn, err := grpc.DialContext(ctx, s.ServerAddr(), newDialOpts...)
129 require.NoError(s.T(), err, "must not error on client Dial")
130 return pb_testproto.NewTestServiceClient(clientConn)
131 }
132
133 func (s *InterceptorTestSuite) ServerAddr() string {
134 return s.serverAddr
135 }
136
137 func (s *InterceptorTestSuite) SimpleCtx() context.Context {
138 ctx, _ := context.WithTimeout(context.TODO(), 2*time.Second)
139 return ctx
140 }
141
142 func (s *InterceptorTestSuite) DeadlineCtx(deadline time.Time) context.Context {
143 ctx, _ := context.WithDeadline(context.TODO(), deadline)
144 return ctx
145 }
146
147 func (s *InterceptorTestSuite) TearDownSuite() {
148 time.Sleep(10 * time.Millisecond)
149 if s.ServerListener != nil {
150 s.Server.GracefulStop()
151 s.T().Logf("stopped grpc.Server at: %v", s.ServerAddr())
152 s.ServerListener.Close()
153 }
154 if s.clientConn != nil {
155 s.clientConn.Close()
156 }
157 }
158
159
160
161 func generateCertAndKey(san []string) ([]byte, []byte, error) {
162 priv, err := rsa.GenerateKey(rand.Reader, 2048)
163 if err != nil {
164 return nil, nil, err
165 }
166 notBefore := time.Now()
167 notAfter := notBefore.Add(time.Hour)
168 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
169 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
170 if err != nil {
171 return nil, nil, err
172 }
173 template := x509.Certificate{
174 SerialNumber: serialNumber,
175 Subject: pkix.Name{
176 CommonName: "example.com",
177 },
178 NotBefore: notBefore,
179 NotAfter: notAfter,
180 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
181 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
182 BasicConstraintsValid: true,
183 DNSNames: san,
184 }
185 derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv)
186 if err != nil {
187 return nil, nil, err
188 }
189 certOut := pem.EncodeToMemory(&pem.Block{
190 Type: "CERTIFICATE",
191 Bytes: derBytes,
192 })
193 keyOut := pem.EncodeToMemory(&pem.Block{
194 Type: "RSA PRIVATE KEY",
195 Bytes: x509.MarshalPKCS1PrivateKey(priv),
196 })
197
198 return certOut, keyOut, nil
199 }
200
View as plain text