1
18
19 package tlscreds_test
20
21 import (
22 "context"
23 "crypto/tls"
24 "encoding/json"
25 "fmt"
26 "os"
27 "strings"
28 "testing"
29 "time"
30
31 "google.golang.org/grpc"
32 "google.golang.org/grpc/codes"
33 "google.golang.org/grpc/internal/grpctest"
34 "google.golang.org/grpc/internal/stubserver"
35 "google.golang.org/grpc/internal/testutils/xds/e2e"
36 "google.golang.org/grpc/internal/xds/bootstrap/tlscreds"
37 testgrpc "google.golang.org/grpc/interop/grpc_testing"
38 testpb "google.golang.org/grpc/interop/grpc_testing"
39 "google.golang.org/grpc/status"
40 "google.golang.org/grpc/testdata"
41 )
42
43 const defaultTestTimeout = 5 * time.Second
44
45 type s struct {
46 grpctest.Tester
47 }
48
49 func Test(t *testing.T) {
50 grpctest.RunSubTests(t, s{})
51 }
52
53 type Closable interface {
54 Close()
55 }
56
57 func (s) TestValidTlsBuilder(t *testing.T) {
58 caCert := testdata.Path("x509/server_ca_cert.pem")
59 clientCert := testdata.Path("x509/client1_cert.pem")
60 clientKey := testdata.Path("x509/client1_key.pem")
61 tests := []struct {
62 name string
63 jd string
64 }{
65 {
66 name: "Absent configuration",
67 jd: `null`,
68 },
69 {
70 name: "Empty configuration",
71 jd: `{}`,
72 },
73 {
74 name: "Only CA certificate chain",
75 jd: fmt.Sprintf(`{"ca_certificate_file": "%s"}`, caCert),
76 },
77 {
78 name: "Only private key and certificate chain",
79 jd: fmt.Sprintf(`{"certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey),
80 },
81 {
82 name: "CA chain, private key and certificate chain",
83 jd: fmt.Sprintf(`{"ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey),
84 },
85 {
86 name: "Only refresh interval", jd: `{"refresh_interval": "1s"}`,
87 },
88 {
89 name: "Refresh interval and CA certificate chain",
90 jd: fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file": "%s"}`, caCert),
91 },
92 {
93 name: "Refresh interval, private key and certificate chain",
94 jd: fmt.Sprintf(`{"refresh_interval": "1s","certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey),
95 },
96 {
97 name: "Refresh interval, CA chain, private key and certificate chain",
98 jd: fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey),
99 },
100 {
101 name: "Unknown field",
102 jd: `{"unknown_field": "foo"}`,
103 },
104 }
105
106 for _, test := range tests {
107 t.Run(test.name, func(t *testing.T) {
108 msg := json.RawMessage(test.jd)
109 _, stop, err := tlscreds.NewBundle(msg)
110 if err != nil {
111 t.Fatalf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err)
112 }
113 stop()
114 })
115 }
116 }
117
118 func (s) TestInvalidTlsBuilder(t *testing.T) {
119 tests := []struct {
120 name, jd, wantErrPrefix string
121 }{
122 {
123 name: "Wrong type in json",
124 jd: `{"ca_certificate_file": 1}`,
125 wantErrPrefix: "failed to unmarshal config:"},
126 {
127 name: "Missing private key",
128 jd: fmt.Sprintf(`{"certificate_file":"%s"}`, testdata.Path("x509/server_cert.pem")),
129 wantErrPrefix: "pemfile: private key file and identity cert file should be both specified or not specified",
130 },
131 }
132
133 for _, test := range tests {
134 t.Run(test.name, func(t *testing.T) {
135 msg := json.RawMessage(test.jd)
136 _, stop, err := tlscreds.NewBundle(msg)
137 if err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) {
138 if stop != nil {
139 stop()
140 }
141 t.Fatalf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix)
142 }
143 })
144 }
145 }
146
147 func (s) TestCaReloading(t *testing.T) {
148 serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
149 if err != nil {
150 t.Fatalf("Failed to read test CA cert: %s", err)
151 }
152
153
154 caPath := t.TempDir() + "/ca.pem"
155 if err = os.WriteFile(caPath, serverCa, 0644); err != nil {
156 t.Fatalf("Failed to write test CA cert: %v", err)
157 }
158 cfg := fmt.Sprintf(`{
159 "ca_certificate_file": "%s",
160 "refresh_interval": ".01s"
161 }`, caPath)
162 tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg))
163 if err != nil {
164 t.Fatalf("Failed to create TLS bundle: %v", err)
165 }
166 defer stop()
167
168 serverCredentials := grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.NoClientCert))
169 server := stubserver.StartTestService(t, nil, serverCredentials)
170
171 conn, err := grpc.NewClient(
172 server.Address,
173 grpc.WithCredentialsBundle(tlsBundle),
174 grpc.WithAuthority("x.test.example.com"),
175 )
176 if err != nil {
177 t.Fatalf("Error dialing: %v", err)
178 }
179 defer conn.Close()
180
181 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
182 defer cancel()
183
184 client := testgrpc.NewTestServiceClient(conn)
185 if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
186 t.Errorf("Error calling EmptyCall: %v", err)
187 }
188
189
190 server.Stop()
191
192 invalidCa, err := os.ReadFile(testdata.Path("ca.pem"))
193 if err != nil {
194 t.Fatalf("Failed to read test CA cert: %v", err)
195 }
196
197 err = os.WriteFile(caPath, invalidCa, 0644)
198 if err != nil {
199 t.Fatalf("Failed to write test CA cert: %v", err)
200 }
201
202 for ; ctx.Err() == nil; <-time.After(10 * time.Millisecond) {
203 ss := stubserver.StubServer{
204 Address: server.Address,
205 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
206 }
207 server = stubserver.StartTestService(t, &ss, serverCredentials)
208
209
210
211 t.Log(server)
212 _, err = client.EmptyCall(ctx, &testpb.Empty{})
213 const wantErr = "certificate signed by unknown authority"
214 if status.Code(err) == codes.Unavailable && strings.Contains(err.Error(), wantErr) {
215
216 server.Stop()
217 break
218 }
219 t.Logf("EmptyCall() got err: %s, want code: %s, want err: %s", err, codes.Unavailable, wantErr)
220 server.Stop()
221 }
222 if ctx.Err() != nil {
223 t.Errorf("Timed out waiting for CA certs reloading")
224 }
225 }
226
227 func (s) TestMTLS(t *testing.T) {
228 s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert)))
229 defer s.Stop()
230
231 cfg := fmt.Sprintf(`{
232 "ca_certificate_file": "%s",
233 "certificate_file": "%s",
234 "private_key_file": "%s"
235 }`,
236 testdata.Path("x509/server_ca_cert.pem"),
237 testdata.Path("x509/client1_cert.pem"),
238 testdata.Path("x509/client1_key.pem"))
239 tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg))
240 if err != nil {
241 t.Fatalf("Failed to create TLS bundle: %v", err)
242 }
243 defer stop()
244 conn, err := grpc.NewClient(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com"))
245 if err != nil {
246 t.Fatalf("Error dialing: %v", err)
247 }
248 defer conn.Close()
249 client := testgrpc.NewTestServiceClient(conn)
250 if _, err = client.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
251 t.Errorf("EmptyCall(): got error %v when expected to succeed", err)
252 }
253 }
254
View as plain text