1
18
19 package pemfile
20
21 import (
22 "context"
23 "fmt"
24 "os"
25 "path"
26 "testing"
27 "time"
28
29 "google.golang.org/grpc/credentials/tls/certprovider"
30 "google.golang.org/grpc/internal/grpctest"
31 "google.golang.org/grpc/internal/testutils"
32 "google.golang.org/grpc/testdata"
33 )
34
35 const (
36
37
38 certFile = "cert.pem"
39 keyFile = "key.pem"
40 rootFile = "ca.pem"
41
42 defaultTestRefreshDuration = 100 * time.Millisecond
43 defaultTestTimeout = 5 * time.Second
44 )
45
46 type s struct {
47 grpctest.Tester
48 }
49
50 func Test(t *testing.T) {
51 grpctest.RunSubTests(t, s{})
52 }
53
54 func compareKeyMaterial(got, want *certprovider.KeyMaterial) error {
55 if len(got.Certs) != len(want.Certs) {
56 return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
57 }
58 for i := 0; i < len(got.Certs); i++ {
59 if !got.Certs[i].Leaf.Equal(want.Certs[i].Leaf) {
60 return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
61 }
62 }
63
64 if gotR, wantR := got.Roots, want.Roots; !gotR.Equal(wantR) {
65 return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR)
66 }
67
68 return nil
69 }
70
71
72 func (s) TestNewProvider(t *testing.T) {
73 tests := []struct {
74 desc string
75 options Options
76 wantError bool
77 }{
78 {
79 desc: "No credential files specified",
80 options: Options{},
81 wantError: true,
82 },
83 {
84 desc: "Only identity cert is specified",
85 options: Options{
86 CertFile: testdata.Path("x509/client1_cert.pem"),
87 },
88 wantError: true,
89 },
90 {
91 desc: "Only identity key is specified",
92 options: Options{
93 KeyFile: testdata.Path("x509/client1_key.pem"),
94 },
95 wantError: true,
96 },
97 {
98 desc: "Identity cert/key pair is specified",
99 options: Options{
100 KeyFile: testdata.Path("x509/client1_key.pem"),
101 CertFile: testdata.Path("x509/client1_cert.pem"),
102 },
103 },
104 {
105 desc: "Only root certs are specified",
106 options: Options{
107 RootFile: testdata.Path("x509/client_ca_cert.pem"),
108 },
109 },
110 {
111 desc: "Everything is specified",
112 options: Options{
113 KeyFile: testdata.Path("x509/client1_key.pem"),
114 CertFile: testdata.Path("x509/client1_cert.pem"),
115 RootFile: testdata.Path("x509/client_ca_cert.pem"),
116 },
117 wantError: false,
118 },
119 }
120 for _, test := range tests {
121 t.Run(test.desc, func(t *testing.T) {
122 provider, err := NewProvider(test.options)
123 if (err != nil) != test.wantError {
124 t.Fatalf("NewProvider(%v) = %v, want %v", test.options, err, test.wantError)
125 }
126 if err != nil {
127 return
128 }
129 provider.Close()
130 })
131 }
132 }
133
134
135
136 type wrappedDistributor struct {
137 *certprovider.Distributor
138 distCh *testutils.Channel
139 }
140
141 func newWrappedDistributor(distCh *testutils.Channel) *wrappedDistributor {
142 return &wrappedDistributor{
143 distCh: distCh,
144 Distributor: certprovider.NewDistributor(),
145 }
146 }
147
148 func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) {
149 wd.Distributor.Set(km, err)
150 wd.distCh.Send(nil)
151 }
152
153 func createTmpFile(t *testing.T, src, dst string) {
154 t.Helper()
155
156 data, err := os.ReadFile(src)
157 if err != nil {
158 t.Fatalf("os.ReadFile(%q) failed: %v", src, err)
159 }
160 if err := os.WriteFile(dst, data, os.ModePerm); err != nil {
161 t.Fatalf("os.WriteFile(%q) failed: %v", dst, err)
162 }
163 t.Logf("Wrote file at: %s", dst)
164 t.Logf("%s", string(data))
165 }
166
167
168
169
170
171 func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string {
172 t.Helper()
173
174
175
176 dir, err := os.MkdirTemp("", dirSuffix)
177 if err != nil {
178 t.Fatalf("os.MkdirTemp() failed: %v", err)
179 }
180 t.Logf("Using tmpdir: %s", dir)
181
182 createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile))
183 createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile))
184 createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile))
185 return dir
186 }
187
188
189
190 func initializeProvider(t *testing.T, testName string) (string, certprovider.Provider, *testutils.Channel, func()) {
191 t.Helper()
192
193
194
195 origDistributorFunc := newDistributor
196 distCh := testutils.NewChannel()
197 d := newWrappedDistributor(distCh)
198 newDistributor = func() distributor { return d }
199
200
201 dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
202 opts := Options{
203 CertFile: path.Join(dir, certFile),
204 KeyFile: path.Join(dir, keyFile),
205 RootFile: path.Join(dir, rootFile),
206 RefreshDuration: defaultTestRefreshDuration,
207 }
208 prov, err := NewProvider(opts)
209 if err != nil {
210 t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
211 }
212
213
214
215 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
216 defer cancel()
217 for i := 0; i < 2; i++ {
218
219
220 if _, err := distCh.Receive(ctx); err != nil {
221 t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
222 }
223 }
224
225 return dir, prov, distCh, func() {
226 newDistributor = origDistributorFunc
227 prov.Close()
228 }
229 }
230
231
232
233
234 func (s) TestProvider_NoUpdate(t *testing.T) {
235 _, prov, distCh, cancel := initializeProvider(t, "no_update")
236 defer cancel()
237
238
239 ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
240 defer cc()
241 if _, err := prov.KeyMaterial(ctx); err != nil {
242 t.Fatalf("provider.KeyMaterial() failed: %v", err)
243 }
244
245
246 sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
247 defer sc()
248 if _, err := distCh.Receive(sCtx); err == nil {
249 t.Fatal("new key material pushed to distributor when underlying files did not change")
250 }
251 }
252
253
254
255
256 func (s) TestProvider_UpdateSuccess(t *testing.T) {
257 dir, prov, distCh, cancel := initializeProvider(t, "update_success")
258 defer cancel()
259
260
261 ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
262 defer cc()
263 km1, err := prov.KeyMaterial(ctx)
264 if err != nil {
265 t.Fatalf("provider.KeyMaterial() failed: %v", err)
266 }
267
268
269 createTmpFile(t, testdata.Path("x509/server_ca_cert.pem"), path.Join(dir, rootFile))
270 if _, err := distCh.Receive(ctx); err != nil {
271 t.Fatal("timeout waiting for new key material to be pushed to the distributor")
272 }
273
274
275 km2, err := prov.KeyMaterial(ctx)
276 if err != nil {
277 t.Fatalf("provider.KeyMaterial() failed: %v", err)
278 }
279 if err := compareKeyMaterial(km1, km2); err == nil {
280 t.Fatal("expected provider to return new key material after update to underlying file")
281 }
282
283
284 createTmpFile(t, testdata.Path("x509/client2_cert.pem"), path.Join(dir, certFile))
285 createTmpFile(t, testdata.Path("x509/client2_key.pem"), path.Join(dir, keyFile))
286 if _, err := distCh.Receive(ctx); err != nil {
287 t.Fatal("timeout waiting for new key material to be pushed to the distributor")
288 }
289
290
291 km3, err := prov.KeyMaterial(ctx)
292 if err != nil {
293 t.Fatalf("provider.KeyMaterial() failed: %v", err)
294 }
295 if err := compareKeyMaterial(km2, km3); err == nil {
296 t.Fatal("expected provider to return new key material after update to underlying file")
297 }
298 }
299
300
301
302
303
304 func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) {
305
306
307 origDistributorFunc := newDistributor
308 distCh := testutils.NewChannel()
309 d := newWrappedDistributor(distCh)
310 newDistributor = func() distributor { return d }
311 defer func() { newDistributor = origDistributorFunc }()
312
313
314 dir1 := createTmpDirWithFiles(t, "update_with_symlink1_*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
315 dir2 := createTmpDirWithFiles(t, "update_with_symlink2_*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/server_ca_cert.pem")
316
317
318 tmpdir, err := os.MkdirTemp("", "test_symlink_*")
319 if err != nil {
320 t.Fatalf("os.MkdirTemp() failed: %v", err)
321 }
322 symLinkName := path.Join(tmpdir, "test_symlink")
323 if err := os.Symlink(dir1, symLinkName); err != nil {
324 t.Fatalf("failed to create symlink to %q: %v", dir1, err)
325 }
326
327
328 opts := Options{
329 CertFile: path.Join(symLinkName, certFile),
330 KeyFile: path.Join(symLinkName, keyFile),
331 RootFile: path.Join(symLinkName, rootFile),
332 RefreshDuration: defaultTestRefreshDuration,
333 }
334 prov, err := NewProvider(opts)
335 if err != nil {
336 t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
337 }
338 defer prov.Close()
339
340
341
342 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
343 defer cancel()
344 for i := 0; i < 2; i++ {
345
346
347 if _, err := distCh.Receive(ctx); err != nil {
348 t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
349 }
350 }
351 km1, err := prov.KeyMaterial(ctx)
352 if err != nil {
353 t.Fatalf("provider.KeyMaterial() failed: %v", err)
354 }
355
356
357 symLinkTmpName := path.Join(tmpdir, "test_symlink.tmp")
358 if err := os.Symlink(dir2, symLinkTmpName); err != nil {
359 t.Fatalf("failed to create symlink to %q: %v", dir2, err)
360 }
361 if err := os.Rename(symLinkTmpName, symLinkName); err != nil {
362 t.Fatalf("failed to update symlink: %v", err)
363 }
364
365
366
367 for i := 0; i < 2; i++ {
368
369
370 if _, err := distCh.Receive(ctx); err != nil {
371 t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
372 }
373 }
374 km2, err := prov.KeyMaterial(ctx)
375 if err != nil {
376 t.Fatalf("provider.KeyMaterial() failed: %v", err)
377 }
378
379 if err := compareKeyMaterial(km1, km2); err == nil {
380 t.Fatal("expected provider to return new key material after symlink update")
381 }
382 }
383
384
385
386
387
388 func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) {
389 dir, prov, distCh, cancel := initializeProvider(t, "update_failure")
390 defer cancel()
391
392
393 ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
394 defer cc()
395 km1, err := prov.KeyMaterial(ctx)
396 if err != nil {
397 t.Fatalf("provider.KeyMaterial() failed: %v", err)
398 }
399
400
401
402
403
404 createTmpFile(t, testdata.Path("x509/server1_cert.pem"), path.Join(dir, certFile))
405
406
407
408 sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
409 defer sc()
410 if _, err := distCh.Receive(sCtx); err == nil {
411 t.Fatal("new key material pushed to distributor when underlying files did not change")
412 }
413
414
415 km2, err := prov.KeyMaterial(ctx)
416 if err != nil {
417 t.Fatalf("provider.KeyMaterial() failed: %v", err)
418 }
419 if err := compareKeyMaterial(km1, km2); err != nil {
420 t.Fatalf("expected provider to not update key material: %v", err)
421 }
422
423
424 createTmpFile(t, testdata.Path("x509/server1_key.pem"), path.Join(dir, keyFile))
425
426
427 if _, err := distCh.Receive(ctx); err != nil {
428 t.Fatal("timeout waiting for new key material to be pushed to the distributor")
429 }
430 km3, err := prov.KeyMaterial(ctx)
431 if err != nil {
432 t.Fatalf("provider.KeyMaterial() failed: %v", err)
433 }
434 if err := compareKeyMaterial(km2, km3); err == nil {
435 t.Fatal("expected provider to return new key material after update to underlying file")
436 }
437 }
438
View as plain text