1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package ctutil
16
17 import (
18 "encoding/base64"
19 "testing"
20
21 ct "github.com/google/certificate-transparency-go"
22 "github.com/google/certificate-transparency-go/testdata"
23 "github.com/google/certificate-transparency-go/tls"
24 "github.com/google/certificate-transparency-go/x509util"
25 )
26
27 func TestLeafHash(t *testing.T) {
28 tests := []struct {
29 desc string
30 chainPEM string
31 sct []byte
32 embedded bool
33 want string
34 }{
35 {
36 desc: "cert",
37 chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
38 sct: testdata.TestCertProof,
39 want: testdata.TestCertB64LeafHash,
40 },
41 {
42 desc: "precert",
43 chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
44 sct: testdata.TestPreCertProof,
45 want: testdata.TestPreCertB64LeafHash,
46 },
47 {
48 desc: "cert with embedded SCT",
49 chainPEM: testdata.TestEmbeddedCertPEM + testdata.CACertPEM,
50 sct: testdata.TestPreCertProof,
51 embedded: true,
52 want: testdata.TestPreCertB64LeafHash,
53 },
54 }
55
56 for _, test := range tests {
57 t.Run(test.desc, func(t *testing.T) {
58
59 chain, err := x509util.CertificatesFromPEM([]byte(test.chainPEM))
60 if err != nil {
61 t.Fatalf("error parsing certificate chain: %s", err)
62 }
63
64
65 var sct ct.SignedCertificateTimestamp
66 if _, err = tls.Unmarshal(test.sct, &sct); err != nil {
67 t.Fatalf("error tls-unmarshalling sct: %s", err)
68 }
69
70
71 wantSl, err := base64.StdEncoding.DecodeString(test.want)
72 if err != nil {
73 t.Fatalf("error base64-decoding leaf hash %q: %s", test.want, err)
74 }
75 var want [32]byte
76 copy(want[:], wantSl)
77
78 got, err := LeafHash(chain, &sct, test.embedded)
79 if got != want || err != nil {
80 t.Errorf("LeafHash(_,_) = %v, %v, want %v, nil", got, err, want)
81 }
82
83
84 gotB64, err := LeafHashB64(chain, &sct, test.embedded)
85 if gotB64 != test.want || err != nil {
86 t.Errorf("LeafHashB64(_,_) = %v, %v, want %v, nil", gotB64, err, test.want)
87 }
88 })
89 }
90 }
91
92 func TestLeafHashErrors(t *testing.T) {
93 tests := []struct {
94 desc string
95 chainPEM string
96 sct []byte
97 embedded bool
98 }{
99 {
100 desc: "empty chain",
101 chainPEM: "",
102 sct: testdata.TestCertProof,
103 },
104 {
105 desc: "nil SCT",
106 chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
107 sct: nil,
108 },
109 {
110 desc: "no SCTs embedded in cert, embedded true",
111 chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
112 sct: testdata.TestInvalidProof,
113 embedded: true,
114 },
115 {
116 desc: "cert contains embedded SCTs, but not the SCT provided",
117 chainPEM: testdata.TestEmbeddedCertPEM + testdata.CACertPEM,
118 sct: testdata.TestInvalidProof,
119 embedded: true,
120 },
121 }
122
123 for _, test := range tests {
124 t.Run(test.desc, func(t *testing.T) {
125
126 chain, err := x509util.CertificatesFromPEM([]byte(test.chainPEM))
127 if err != nil {
128 t.Fatalf("error parsing certificate chain: %s", err)
129 }
130
131
132 var sct *ct.SignedCertificateTimestamp
133 if test.sct != nil {
134 sct = &ct.SignedCertificateTimestamp{}
135 if _, err = tls.Unmarshal(test.sct, sct); err != nil {
136 t.Fatalf("error tls-unmarshalling sct: %s", err)
137 }
138 }
139
140
141 got, err := LeafHash(chain, sct, test.embedded)
142 if got != emptyHash || err == nil {
143 t.Errorf("LeafHash(_,_) = %s, %v, want %v, error", got, err, emptyHash)
144 }
145
146
147 gotB64, err := LeafHashB64(chain, sct, test.embedded)
148 if gotB64 != "" || err == nil {
149 t.Errorf("LeafHashB64(_,_) = %s, %v, want \"\", error", gotB64, err)
150 }
151 })
152 }
153 }
154
155 func TestVerifySCT(t *testing.T) {
156 tests := []struct {
157 desc string
158 chainPEM string
159 sct []byte
160 embedded bool
161 wantErr bool
162 }{
163 {
164 desc: "cert",
165 chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
166 sct: testdata.TestCertProof,
167 },
168 {
169 desc: "precert",
170 chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
171 sct: testdata.TestPreCertProof,
172 },
173 {
174 desc: "invalid SCT",
175 chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
176 sct: testdata.TestCertProof,
177 wantErr: true,
178 },
179 {
180 desc: "cert with embedded SCT",
181 chainPEM: testdata.TestEmbeddedCertPEM + testdata.CACertPEM,
182 sct: testdata.TestPreCertProof,
183 embedded: true,
184 },
185 {
186 desc: "cert with invalid embedded SCT",
187 chainPEM: testdata.TestInvalidEmbeddedCertPEM + testdata.CACertPEM,
188 sct: testdata.TestInvalidProof,
189 embedded: true,
190 wantErr: true,
191 },
192 }
193
194 for _, test := range tests {
195 t.Run(test.desc, func(t *testing.T) {
196
197 chain, err := x509util.CertificatesFromPEM([]byte(test.chainPEM))
198 if err != nil {
199 t.Fatalf("error parsing certificate chain: %s", err)
200 }
201
202
203 var sct ct.SignedCertificateTimestamp
204 if _, err = tls.Unmarshal(test.sct, &sct); err != nil {
205 t.Fatalf("error tls-unmarshalling sct: %s", err)
206 }
207
208
209 pk, err := ct.PublicKeyFromB64(testdata.LogPublicKeyB64)
210 if err != nil {
211 t.Errorf("error parsing public key: %s", err)
212 }
213
214 err = VerifySCT(pk, chain, &sct, test.embedded)
215 if gotErr := err != nil; gotErr != test.wantErr {
216 t.Errorf("VerifySCT(_,_,_, %t) = %v, want error? %t", test.embedded, err, test.wantErr)
217 }
218 })
219 }
220 }
221
222 func TestVerifySCTWithVerifier(t *testing.T) {
223
224 pk, err := ct.PublicKeyFromB64(testdata.LogPublicKeyB64)
225 if err != nil {
226 t.Errorf("error parsing public key: %s", err)
227 }
228
229
230 sv, err := ct.NewSignatureVerifier(pk)
231 if err != nil {
232 t.Errorf("couldn't create signature verifier: %s", err)
233 }
234
235 tests := []struct {
236 desc string
237 sv *ct.SignatureVerifier
238 chainPEM string
239 sct []byte
240 embedded bool
241 wantErr bool
242 }{
243 {
244 desc: "nil signature verifier",
245 sv: nil,
246 chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
247 sct: testdata.TestCertProof,
248 wantErr: true,
249 },
250 {
251 desc: "cert",
252 sv: sv,
253 chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
254 sct: testdata.TestCertProof,
255 },
256 {
257 desc: "precert",
258 sv: sv,
259 chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
260 sct: testdata.TestPreCertProof,
261 },
262 {
263 desc: "invalid SCT",
264 sv: sv,
265 chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
266 sct: testdata.TestCertProof,
267 wantErr: true,
268 },
269 {
270 desc: "cert with embedded SCT",
271 sv: sv,
272 chainPEM: testdata.TestEmbeddedCertPEM + testdata.CACertPEM,
273 sct: testdata.TestPreCertProof,
274 embedded: true,
275 },
276 {
277 desc: "cert with invalid embedded SCT",
278 sv: sv,
279 chainPEM: testdata.TestInvalidEmbeddedCertPEM + testdata.CACertPEM,
280 sct: testdata.TestInvalidProof,
281 embedded: true,
282 wantErr: true,
283 },
284 }
285
286 for _, test := range tests {
287 t.Run(test.desc, func(t *testing.T) {
288
289 chain, err := x509util.CertificatesFromPEM([]byte(test.chainPEM))
290 if err != nil {
291 t.Fatalf("error parsing certificate chain: %s", err)
292 }
293
294
295 var sct ct.SignedCertificateTimestamp
296 if _, err = tls.Unmarshal(test.sct, &sct); err != nil {
297 t.Fatalf("error tls-unmarshalling sct: %s", err)
298 }
299
300
301 err = VerifySCTWithVerifier(test.sv, chain, &sct, test.embedded)
302 if gotErr := err != nil; gotErr != test.wantErr {
303 t.Errorf("VerifySCT(_,_,_, %t) = %v, want error? %t", test.embedded, err, test.wantErr)
304 }
305 })
306 }
307 }
308
309 func TestContainsSCT(t *testing.T) {
310 tests := []struct {
311 desc string
312 certPEM string
313 sct []byte
314 want bool
315 }{
316 {
317 desc: "cert doesn't contain any SCTs",
318 certPEM: testdata.TestCertPEM,
319 sct: testdata.TestPreCertProof,
320 want: false,
321 },
322 {
323 desc: "cert contains SCT but not specified SCT",
324 certPEM: testdata.TestEmbeddedCertPEM,
325 sct: testdata.TestInvalidProof,
326 want: false,
327 },
328 {
329 desc: "cert contains SCT",
330 certPEM: testdata.TestEmbeddedCertPEM,
331 sct: testdata.TestPreCertProof,
332 want: true,
333 },
334 }
335
336 for _, test := range tests {
337 t.Run(test.desc, func(t *testing.T) {
338
339 cert, err := x509util.CertificateFromPEM([]byte(test.certPEM))
340 if err != nil {
341 t.Fatalf("error parsing certificate: %s", err)
342 }
343
344
345 var sct ct.SignedCertificateTimestamp
346 if _, err = tls.Unmarshal(test.sct, &sct); err != nil {
347 t.Fatalf("error tls-unmarshalling sct: %s", err)
348 }
349
350
351 got, err := ContainsSCT(cert, &sct)
352 if err != nil {
353 t.Fatalf("ContainsSCT(_,_) = false, %s, want no error", err)
354 }
355
356 if got != test.want {
357 t.Errorf("ContainsSCT(_,_) = %t, nil, want %t, nil", got, test.want)
358 }
359 })
360 }
361 }
362
View as plain text