1
18
19 package certprovider
20
21 import (
22 "context"
23 "crypto/tls"
24 "crypto/x509"
25 "errors"
26 "fmt"
27 "os"
28 "testing"
29 "time"
30
31 "google.golang.org/grpc/internal/grpctest"
32 "google.golang.org/grpc/internal/testutils"
33 "google.golang.org/grpc/testdata"
34 )
35
36 const (
37 fakeProvider1Name = "fake-certificate-provider-1"
38 fakeProvider2Name = "fake-certificate-provider-2"
39 fakeConfig = "my fake config"
40 defaultTestTimeout = 5 * time.Second
41 defaultTestShortTimeout = 10 * time.Millisecond
42 )
43
44 var fpb1, fpb2 *fakeProviderBuilder
45
46 func init() {
47 fpb1 = &fakeProviderBuilder{
48 name: fakeProvider1Name,
49 providerChan: testutils.NewChannel(),
50 }
51 fpb2 = &fakeProviderBuilder{
52 name: fakeProvider2Name,
53 providerChan: testutils.NewChannel(),
54 }
55 Register(fpb1)
56 Register(fpb2)
57 }
58
59 type s struct {
60 grpctest.Tester
61 }
62
63 func Test(t *testing.T) {
64 grpctest.RunSubTests(t, s{})
65 }
66
67
68
69 type fakeProviderBuilder struct {
70 name string
71 providerChan *testutils.Channel
72 }
73
74 func (b *fakeProviderBuilder) ParseConfig(config any) (*BuildableConfig, error) {
75 s, ok := config.(string)
76 if !ok {
77 return nil, fmt.Errorf("providerBuilder %s received config of type %T, want string", b.name, config)
78 }
79 return NewBuildableConfig(b.name, []byte(s), func(BuildOptions) Provider {
80 fp := &fakeProvider{
81 Distributor: NewDistributor(),
82 config: s,
83 }
84 b.providerChan.Send(fp)
85 return fp
86 }), nil
87 }
88
89 func (b *fakeProviderBuilder) Name() string {
90 return b.name
91 }
92
93
94
95 type fakeProvider struct {
96 *Distributor
97 config string
98 }
99
100 func (p *fakeProvider) Start(BuildOptions) Provider {
101
102
103 return p
104 }
105
106
107
108 func (p *fakeProvider) newKeyMaterial(km *KeyMaterial, err error) {
109 p.Distributor.Set(km, err)
110 }
111
112
113 func (p *fakeProvider) Close() {
114 p.Distributor.Stop()
115 }
116
117
118
119 func loadKeyMaterials(t *testing.T, cert, key, ca string) *KeyMaterial {
120 t.Helper()
121
122 certs, err := tls.LoadX509KeyPair(testdata.Path(cert), testdata.Path(key))
123 if err != nil {
124 t.Fatalf("Failed to load keyPair: %v", err)
125 }
126
127 pemData, err := os.ReadFile(testdata.Path(ca))
128 if err != nil {
129 t.Fatal(err)
130 }
131 roots := x509.NewCertPool()
132 roots.AppendCertsFromPEM(pemData)
133 return &KeyMaterial{Certs: []tls.Certificate{certs}, Roots: roots}
134 }
135
136
137
138
139 type kmReader interface {
140 KeyMaterial(context.Context) (*KeyMaterial, error)
141 }
142
143
144
145 func readAndVerifyKeyMaterial(ctx context.Context, kmr kmReader, wantKM *KeyMaterial) error {
146 gotKM, err := kmr.KeyMaterial(ctx)
147 if err != nil {
148 return fmt.Errorf("KeyMaterial(ctx) failed: %w", err)
149 }
150 return compareKeyMaterial(gotKM, wantKM)
151 }
152
153 func compareKeyMaterial(got, want *KeyMaterial) error {
154 if len(got.Certs) != len(want.Certs) {
155 return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
156 }
157 for i := 0; i < len(got.Certs); i++ {
158 if !got.Certs[i].Leaf.Equal(want.Certs[i].Leaf) {
159 return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
160 }
161 }
162
163 if gotR, wantR := got.Roots, want.Roots; !gotR.Equal(wantR) {
164 return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR)
165 }
166
167 return nil
168 }
169
170 func createProvider(t *testing.T, name, config string, opts BuildOptions) Provider {
171 t.Helper()
172 prov, err := GetProvider(name, config, opts)
173 if err != nil {
174 t.Fatalf("GetProvider(%s, %s, %v) failed: %v", name, config, opts, err)
175 }
176 return prov
177 }
178
179
180
181 func (s) TestStoreSingleProvider(t *testing.T) {
182 prov := createProvider(t, fakeProvider1Name, fakeConfig, BuildOptions{CertName: "default"})
183 defer prov.Close()
184
185
186
187 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
188 defer cancel()
189 p, err := fpb1.providerChan.Receive(ctx)
190 if err != nil {
191 t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
192 }
193 fakeProv := p.(*fakeProvider)
194
195
196
197
198 sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
199 defer sCancel()
200 if err := readAndVerifyKeyMaterial(sCtx, prov, nil); !errors.Is(err, context.DeadlineExceeded) {
201 t.Fatal(err)
202 }
203
204
205
206 testKM1 := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
207 fakeProv.newKeyMaterial(testKM1, nil)
208 if err := readAndVerifyKeyMaterial(ctx, prov, testKM1); err != nil {
209 t.Fatal(err)
210 }
211
212
213
214 testKM2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem")
215 fakeProv.newKeyMaterial(testKM2, nil)
216 if err := readAndVerifyKeyMaterial(ctx, prov, testKM2); err != nil {
217 t.Fatal(err)
218 }
219 }
220
221
222
223
224
225 func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) {
226
227
228 optsFoo := BuildOptions{CertName: "foo"}
229 provFoo1 := createProvider(t, fakeProvider1Name, fakeConfig, optsFoo)
230 provFoo2 := createProvider(t, fakeProvider1Name, fakeConfig, optsFoo)
231 defer func() {
232 provFoo1.Close()
233 provFoo2.Close()
234 }()
235
236
237
238 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
239 defer cancel()
240 p, err := fpb1.providerChan.Receive(ctx)
241 if err != nil {
242 t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
243 }
244 fakeProvFoo := p.(*fakeProvider)
245
246
247
248 sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
249 defer sCancel()
250 if _, err := fpb1.providerChan.Receive(sCtx); !errors.Is(err, context.DeadlineExceeded) {
251 t.Fatalf("A second provider created when expected to be shared by the store")
252 }
253
254 optsBar := BuildOptions{CertName: "bar"}
255 provBar1 := createProvider(t, fakeProvider1Name, fakeConfig, optsBar)
256 defer provBar1.Close()
257
258
259 p, err = fpb1.providerChan.Receive(ctx)
260 if err != nil {
261 t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
262 }
263 fakeProvBar := p.(*fakeProvider)
264
265
266
267 fooKM := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
268 fakeProvFoo.newKeyMaterial(fooKM, nil)
269 if err := readAndVerifyKeyMaterial(ctx, provFoo1, fooKM); err != nil {
270 t.Fatal(err)
271 }
272 if err := readAndVerifyKeyMaterial(ctx, provFoo2, fooKM); err != nil {
273 t.Fatal(err)
274 }
275 sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
276 defer sCancel()
277 if err := readAndVerifyKeyMaterial(sCtx, provBar1, nil); !errors.Is(err, context.DeadlineExceeded) {
278 t.Fatal(err)
279 }
280
281
282
283 barKM := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem")
284 fakeProvBar.newKeyMaterial(barKM, nil)
285 if err := readAndVerifyKeyMaterial(ctx, provBar1, barKM); err != nil {
286 t.Fatal(err)
287 }
288
289
290 if err := readAndVerifyKeyMaterial(ctx, provFoo1, fooKM); err != nil {
291 t.Fatal(err)
292 }
293 }
294
295
296
297
298
299 func (s) TestStoreSingleProviderDifferentConfigs(t *testing.T) {
300
301 opts := BuildOptions{CertName: "foo"}
302 cfg1 := fakeConfig + "1111"
303 cfg2 := fakeConfig + "2222"
304
305 prov1 := createProvider(t, fakeProvider1Name, cfg1, opts)
306 defer prov1.Close()
307
308
309 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
310 defer cancel()
311 p1, err := fpb1.providerChan.Receive(ctx)
312 if err != nil {
313 t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
314 }
315 fakeProv1 := p1.(*fakeProvider)
316
317 prov2 := createProvider(t, fakeProvider1Name, cfg2, opts)
318 defer prov2.Close()
319
320 p2, err := fpb1.providerChan.Receive(ctx)
321 if err != nil {
322 t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
323 }
324 fakeProv2 := p2.(*fakeProvider)
325
326
327
328 km1 := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
329 fakeProv1.newKeyMaterial(km1, nil)
330 fakeProv2.newKeyMaterial(km1, nil)
331 if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil {
332 t.Fatal(err)
333 }
334 if err := readAndVerifyKeyMaterial(ctx, prov2, km1); err != nil {
335 t.Fatal(err)
336 }
337
338
339
340
341 km2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem")
342 fakeProv2.newKeyMaterial(km2, nil)
343 if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil {
344 t.Fatal(err)
345 }
346 if err := readAndVerifyKeyMaterial(ctx, prov2, km2); err != nil {
347 t.Fatal(err)
348 }
349
350
351 prov1.Close()
352 if err := readAndVerifyKeyMaterial(ctx, prov2, km2); err != nil {
353 t.Fatal(err)
354 }
355 }
356
357
358
359 func (s) TestStoreMultipleProviders(t *testing.T) {
360 opts := BuildOptions{CertName: "foo"}
361 prov1 := createProvider(t, fakeProvider1Name, fakeConfig, opts)
362 defer prov1.Close()
363
364
365 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
366 defer cancel()
367 p1, err := fpb1.providerChan.Receive(ctx)
368 if err != nil {
369 t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
370 }
371 fakeProv1 := p1.(*fakeProvider)
372
373 prov2 := createProvider(t, fakeProvider2Name, fakeConfig, opts)
374 defer prov2.Close()
375
376 p2, err := fpb2.providerChan.Receive(ctx)
377 if err != nil {
378 t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider2Name)
379 }
380 fakeProv2 := p2.(*fakeProvider)
381
382
383
384 km1 := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
385 fakeProv1.newKeyMaterial(km1, nil)
386 km2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem")
387 fakeProv2.newKeyMaterial(km2, nil)
388 if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil {
389 t.Fatal(err)
390 }
391 if err := readAndVerifyKeyMaterial(ctx, prov2, km2); err != nil {
392 t.Fatal(err)
393 }
394
395
396 prov1.Close()
397 if err := readAndVerifyKeyMaterial(ctx, prov2, km2); err != nil {
398 t.Fatal(err)
399 }
400 }
401
View as plain text