1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package submission
16
17 import (
18 "context"
19 "encoding/json"
20 "errors"
21 "fmt"
22 "regexp"
23 "testing"
24 "time"
25
26 "github.com/google/certificate-transparency-go/client"
27 "github.com/google/certificate-transparency-go/ctpolicy"
28 "github.com/google/certificate-transparency-go/loglist3"
29 "github.com/google/certificate-transparency-go/schedule"
30 "github.com/google/certificate-transparency-go/testdata"
31 "github.com/google/certificate-transparency-go/x509"
32 "github.com/google/certificate-transparency-go/x509util"
33 "github.com/google/go-cmp/cmp"
34 "github.com/google/go-cmp/cmp/cmpopts"
35 "github.com/google/trillian/monitoring"
36 "k8s.io/klog/v2"
37 )
38
39 func newLocalStubLogClient(log *loglist3.Log) (client.AddLogClient, error) {
40 return newRootedStubLogClient(log, RootsCerts)
41 }
42
43 func ExampleDistributor() {
44 ctx, cancel := context.WithCancel(context.Background())
45 defer cancel()
46
47 d, err := NewDistributor(sampleValidLogList(), buildStubCTPolicy(1), newLocalStubLogClient, monitoring.InertMetricFactory{})
48 if err != nil {
49 panic(err)
50 }
51
52
53
54 refresh := make(chan struct{})
55 go schedule.Every(ctx, time.Hour, func(ctx context.Context) {
56 if errs := d.RefreshRoots(ctx); len(errs) > 0 {
57 klog.Error(errs)
58 }
59 refresh <- struct{}{}
60 })
61
62 select {
63 case <-refresh:
64 break
65 case <-ctx.Done():
66 panic("Context expired")
67 }
68
69 scts, err := d.AddPreChain(ctx, pemFileToDERChain("../trillian/testdata/subleaf-pre.chain"), false )
70 if err != nil {
71 panic(err)
72 }
73 for _, sct := range scts {
74 fmt.Printf("%s\n", *sct)
75 }
76
77
78 }
79
80 var (
81 RootsCerts = map[string][]rootInfo{
82 "https://ct.googleapis.com/aviator/": {
83 rootInfo{filename: "../trillian/testdata/fake-ca-1.cert"},
84 rootInfo{filename: "testdata/some.cert"},
85 },
86 "https://ct.googleapis.com/rocketeer/": {
87 rootInfo{filename: "../trillian/testdata/fake-ca.cert"},
88 rootInfo{filename: "../trillian/testdata/fake-ca-1.cert"},
89 rootInfo{filename: "testdata/some.cert"},
90 rootInfo{filename: "testdata/another.cert"},
91 },
92 "https://ct.googleapis.com/icarus/": {
93 rootInfo{raw: []byte("invalid000")},
94 rootInfo{filename: "testdata/another.cert"},
95 },
96 "uncollectable-roots/log/": {
97 rootInfo{raw: []byte("invalid")},
98 },
99 }
100 )
101
102
103 func newNoLogClient(_ *loglist3.Log) (client.AddLogClient, error) {
104 return nil, errors.New("bad log-client builder")
105 }
106
107 func sampleLogList() *loglist3.LogList {
108 var ll loglist3.LogList
109 if err := json.Unmarshal([]byte(testdata.SampleLogList3), &ll); err != nil {
110 panic(fmt.Errorf("unable to Unmarshal testdata.SampleLogList3: %v", err))
111 }
112 return &ll
113 }
114
115 func sampleValidLogList() *loglist3.LogList {
116 ll := sampleLogList()
117
118 inval := 2
119 ll.Operators[0].Logs = append(ll.Operators[0].Logs[:inval], ll.Operators[0].Logs[inval+1:]...)
120 return ll
121 }
122
123 func sampleUncollectableLogList() *loglist3.LogList {
124 ll := sampleValidLogList()
125
126 ll.Operators[0].Logs = append(ll.Operators[0].Logs, &loglist3.Log{
127 Description: "Does not return roots", Key: []byte("VW5jb2xsZWN0YWJsZUxvZ0xpc3Q="),
128 URL: "uncollectable-roots/log/",
129 DNS: "uncollectable.ct.googleapis.com",
130 MMD: 123,
131 State: &loglist3.LogStates{Usable: &loglist3.LogState{}},
132 })
133 return ll
134 }
135
136 func TestNewDistributorLogClients(t *testing.T) {
137 testCases := []struct {
138 name string
139 ll *loglist3.LogList
140 lcBuilder LogClientBuilder
141 errRegexp *regexp.Regexp
142 }{
143 {
144 name: "ValidLogClients",
145 ll: sampleValidLogList(),
146 lcBuilder: newEmptyStubLogClient,
147 },
148 {
149 name: "NoLogClients",
150 ll: sampleValidLogList(),
151 lcBuilder: newNoLogClient,
152 errRegexp: regexp.MustCompile("failed to create log client"),
153 },
154 {
155 name: "NoLogClientsEmptyLogList",
156 ll: &loglist3.LogList{},
157 lcBuilder: newNoLogClient,
158 },
159 }
160
161 for _, tc := range testCases {
162 t.Run(tc.name, func(t *testing.T) {
163 _, err := NewDistributor(tc.ll, ctpolicy.ChromeCTPolicy{}, tc.lcBuilder, monitoring.InertMetricFactory{})
164 if gotErr, wantErr := err != nil, tc.errRegexp != nil; gotErr != wantErr {
165 var unwantedErr string
166 if gotErr {
167 unwantedErr = fmt.Sprintf(" (%q)", err)
168 }
169 t.Errorf("Got error = %v%s, expected error = %v", gotErr, unwantedErr, wantErr)
170 } else if tc.errRegexp != nil && !tc.errRegexp.MatchString(err.Error()) {
171 t.Errorf("Error %q did not match expected regexp %q", err, tc.errRegexp)
172 }
173 })
174 }
175 }
176
177 func TestNewDistributorRootPools(t *testing.T) {
178 testCases := []struct {
179 name string
180 ll *loglist3.LogList
181 rootNum map[string]int
182 wantErrs int
183 }{
184 {
185 name: "InactiveZeroRoots",
186 ll: sampleValidLogList(),
187
188 rootNum: map[string]int{"https://ct.googleapis.com/aviator/": 0, "https://ct.googleapis.com/rocketeer/": 4, "https://ct.googleapis.com/icarus/": 1},
189 wantErrs: 1,
190 },
191 {
192 name: "CouldNotCollect",
193 ll: sampleUncollectableLogList(),
194
195 rootNum: map[string]int{"https://ct.googleapis.com/aviator/": 0, "https://ct.googleapis.com/rocketeer/": 4, "https://ct.googleapis.com/icarus/": 1, "uncollectable-roots/log/": 0},
196 wantErrs: 2,
197 },
198 }
199
200 for _, tc := range testCases {
201 t.Run(tc.name, func(t *testing.T) {
202 ctx := context.Background()
203 dist, _ := NewDistributor(tc.ll, ctpolicy.ChromeCTPolicy{}, newLocalStubLogClient, monitoring.InertMetricFactory{})
204
205 if errs := dist.RefreshRoots(ctx); len(errs) != tc.wantErrs {
206 t.Errorf("dist.RefreshRoots() = %v, want %d errors", errs, tc.wantErrs)
207 }
208
209 for logURL, wantNum := range tc.rootNum {
210 gotNum := 0
211 if roots, ok := dist.logRoots[logURL]; ok {
212 gotNum = len(roots.RawCertificates())
213 }
214 if wantNum != gotNum {
215 t.Errorf("Expected %d root(s) for Log %s, got %d", wantNum, logURL, gotNum)
216 }
217 }
218 })
219 }
220 }
221
222 func pemFileToDERChain(filename string) [][]byte {
223 if len(filename) == 0 {
224 return nil
225 }
226 rawChain, err := x509util.ReadPossiblePEMFile(filename, "CERTIFICATE")
227 if err != nil {
228 panic(err)
229 }
230 return rawChain
231 }
232
233
234 type stubCTPolicy struct {
235 baseNum int
236 }
237
238
239 func buildStubCTPolicy(n int) stubCTPolicy {
240 return stubCTPolicy{baseNum: n}
241 }
242
243 func (stubP stubCTPolicy) LogsByGroup(cert *x509.Certificate, approved *loglist3.LogList) (ctpolicy.LogPolicyData, error) {
244 baseGroup, err := ctpolicy.BaseGroupFor(approved, stubP.baseNum)
245 groups := ctpolicy.LogPolicyData{baseGroup.Name: baseGroup}
246 return groups, err
247 }
248
249 func (stubP stubCTPolicy) Name() string {
250 return "stub"
251 }
252
253 func TestDistributorAddChain(t *testing.T) {
254 testCases := []struct {
255 name string
256 ll *loglist3.LogList
257 plc ctpolicy.CTPolicy
258 pemChainFile string
259 getRoots bool
260 scts []*AssignedSCT
261 wantErr bool
262 }{
263 {
264 name: "MalformedChainRequest with log roots available",
265 ll: sampleValidLogList(),
266 plc: ctpolicy.ChromeCTPolicy{},
267 pemChainFile: "../trillian/testdata/subleaf.misordered.chain",
268 getRoots: true,
269 scts: nil,
270 wantErr: true,
271 },
272 {
273 name: "MalformedChainRequest without log roots available",
274 ll: sampleValidLogList(),
275 plc: ctpolicy.ChromeCTPolicy{},
276 pemChainFile: "../trillian/testdata/subleaf.misordered.chain",
277 getRoots: false,
278 scts: nil,
279 wantErr: true,
280 },
281 {
282 name: "CallBeforeInit",
283 ll: sampleValidLogList(),
284 plc: ctpolicy.ChromeCTPolicy{},
285 pemChainFile: "",
286 scts: nil,
287 wantErr: true,
288 },
289 {
290 name: "InsufficientSCTsForPolicy",
291 ll: sampleValidLogList(),
292 plc: ctpolicy.AppleCTPolicy{},
293 pemChainFile: "../trillian/testdata/subleaf.chain",
294 getRoots: true,
295 scts: []*AssignedSCT{},
296 wantErr: true,
297 },
298 {
299 name: "FullChain1Policy",
300 ll: sampleValidLogList(),
301 plc: buildStubCTPolicy(1),
302 pemChainFile: "../trillian/testdata/subleaf.chain",
303 getRoots: true,
304 scts: []*AssignedSCT{
305 {
306 LogURL: "https://ct.googleapis.com/rocketeer/",
307 SCT: testSCT("https://ct.googleapis.com/rocketeer/"),
308 },
309 },
310 wantErr: false,
311 },
312
313 }
314
315 for _, tc := range testCases {
316 t.Run(tc.name, func(t *testing.T) {
317 dist, _ := NewDistributor(tc.ll, tc.plc, newLocalStubLogClient, monitoring.InertMetricFactory{})
318 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
319 defer cancel()
320
321 if tc.getRoots {
322 if errs := dist.RefreshRoots(ctx); len(errs) != 1 || errs["https://ct.googleapis.com/icarus/"] == nil {
323
324 t.Fatalf("dist.RefreshRoots() = %v, want 1 error for 'https://ct.googleapis.com/icarus/'", errs)
325 }
326 }
327
328 scts, err := dist.AddChain(context.Background(), pemFileToDERChain(tc.pemChainFile), false )
329
330 if gotErr := err != nil; gotErr != tc.wantErr {
331 t.Fatalf("dist.AddChain(from %q) = (_, error: %v), want err? %t", tc.pemChainFile, err, tc.wantErr)
332 } else if gotErr {
333 return
334 }
335
336 if got, want := len(scts), len(tc.scts); got != want {
337 t.Errorf("dist.AddChain(from %q) = %d SCTs, want %d SCTs", tc.pemChainFile, got, want)
338 }
339 if diff := cmp.Diff(scts, tc.scts, cmpopts.SortSlices(func(x, y *AssignedSCT) bool {
340 return x.LogURL < y.LogURL
341 })); diff != "" {
342 t.Errorf("dist.AddChain(from %q): diff -want +got\n%s", tc.pemChainFile, diff)
343 }
344 })
345 }
346 }
347
348
349 func TestDistributorAddPreChain(t *testing.T) {
350 testCases := []struct {
351 name string
352 ll *loglist3.LogList
353 plc ctpolicy.CTPolicy
354 pemChainFile string
355 getRoots bool
356 scts []*AssignedSCT
357 wantErr bool
358 }{
359 {
360 name: "MalformedChainRequest with log roots available",
361 ll: sampleValidLogList(),
362 plc: ctpolicy.ChromeCTPolicy{},
363 pemChainFile: "../trillian/testdata/subleaf-pre.misordered.chain",
364 getRoots: true,
365 scts: nil,
366 wantErr: true,
367 },
368 {
369 name: "MalformedChainRequest without log roots available",
370 ll: sampleValidLogList(),
371 plc: ctpolicy.ChromeCTPolicy{},
372 pemChainFile: "../trillian/testdata/subleaf-pre.misordered.chain",
373 getRoots: false,
374 scts: nil,
375 wantErr: true,
376 },
377 {
378 name: "CallBeforeInit",
379 ll: sampleValidLogList(),
380 plc: ctpolicy.ChromeCTPolicy{},
381 pemChainFile: "",
382 scts: nil,
383 wantErr: true,
384 },
385 {
386 name: "InsufficientSCTsForPolicy",
387 ll: sampleValidLogList(),
388 plc: ctpolicy.AppleCTPolicy{},
389 pemChainFile: "../trillian/testdata/subleaf-pre.chain",
390 getRoots: true,
391 scts: []*AssignedSCT{},
392 wantErr: true,
393 },
394 {
395 name: "FullChain1Policy",
396 ll: sampleValidLogList(),
397 plc: buildStubCTPolicy(1),
398 pemChainFile: "../trillian/testdata/subleaf-pre.chain",
399 getRoots: true,
400 scts: []*AssignedSCT{
401 {
402 LogURL: "https://ct.googleapis.com/rocketeer/",
403 SCT: testSCT("https://ct.googleapis.com/rocketeer/"),
404 },
405 },
406 wantErr: false,
407 },
408
409 }
410
411 for _, tc := range testCases {
412 t.Run(tc.name, func(t *testing.T) {
413 dist, _ := NewDistributor(tc.ll, tc.plc, newLocalStubLogClient, monitoring.InertMetricFactory{})
414 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
415 defer cancel()
416
417 if tc.getRoots {
418 if errs := dist.RefreshRoots(ctx); len(errs) != 1 || errs["https://ct.googleapis.com/icarus/"] == nil {
419
420 t.Fatalf("dist.RefreshRoots() = %v, want 1 error for 'https://ct.googleapis.com/icarus/'", errs)
421 }
422 }
423
424 scts, err := dist.AddPreChain(context.Background(), pemFileToDERChain(tc.pemChainFile), true )
425
426 if gotErr := err != nil; gotErr != tc.wantErr {
427 t.Fatalf("dist.AddPreChain(from %q) = (_, error: %v), want err? %t", tc.pemChainFile, err, tc.wantErr)
428 } else if gotErr {
429 return
430 }
431
432 if got, want := len(scts), len(tc.scts); got != want {
433 t.Errorf("dist.AddPreChain(from %q) = %d SCTs, want %d SCTs", tc.pemChainFile, got, want)
434 }
435 if diff := cmp.Diff(scts, tc.scts, cmpopts.SortSlices(func(x, y *AssignedSCT) bool {
436 return x.LogURL < y.LogURL
437 })); diff != "" {
438 t.Errorf("dist.AddPreChain(from %q): diff -want +got\n%s", tc.pemChainFile, diff)
439 }
440 })
441 }
442 }
443
444 func TestDistributorAddTypeMismatch(t *testing.T) {
445 testCases := []struct {
446 name string
447 asPreChain bool
448 pemChainFile string
449 scts []*AssignedSCT
450 wantErr bool
451 }{
452 {
453 name: "FullChain1PolicyCertToPreAdd",
454 asPreChain: true,
455 pemChainFile: "../trillian/testdata/subleaf.chain",
456 scts: nil,
457 wantErr: true,
458 },
459 {
460 name: "FullChain1PolicyPreCertToAdd",
461 asPreChain: false,
462 pemChainFile: "../trillian/testdata/subleaf-pre.chain",
463 scts: nil,
464 wantErr: true,
465 },
466 }
467
468 for _, tc := range testCases {
469 t.Run(tc.name, func(t *testing.T) {
470 dist, _ := NewDistributor(sampleValidLogList(), buildStubCTPolicy(1), newLocalStubLogClient, monitoring.InertMetricFactory{})
471 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
472 defer cancel()
473
474 if errs := dist.RefreshRoots(ctx); len(errs) != 1 || errs["https://ct.googleapis.com/icarus/"] == nil {
475
476 t.Fatalf("dist.RefreshRoots() = %v, want 1 error for 'https://ct.googleapis.com/icarus/'", errs)
477 }
478
479 var scts []*AssignedSCT
480 var err error
481 if tc.asPreChain {
482 scts, err = dist.AddPreChain(context.Background(), pemFileToDERChain(tc.pemChainFile), false )
483 } else {
484 scts, err = dist.AddChain(context.Background(), pemFileToDERChain(tc.pemChainFile), false )
485 }
486
487 pre := ""
488 if tc.asPreChain {
489 pre = "Pre"
490 }
491 if gotErr := err != nil; gotErr != tc.wantErr {
492 t.Fatalf("dist.Add%sChain(from %q) = (_, error: %v), want err? %t", pre, tc.pemChainFile, err, tc.wantErr)
493 } else if gotErr {
494 return
495 }
496
497 if got, want := len(scts), len(tc.scts); got != want {
498 t.Errorf("dist.Add%sChain(from %q) = %d SCTs, want %d SCTs", pre, tc.pemChainFile, got, want)
499 }
500 if diff := cmp.Diff(scts, tc.scts, cmpopts.SortSlices(func(x, y *AssignedSCT) bool {
501 return x.LogURL < y.LogURL
502 })); diff != "" {
503 t.Errorf("dist.Add%sChain(from %q): diff -want +got\n%s", pre, tc.pemChainFile, diff)
504 }
505 })
506 }
507 }
508
View as plain text