1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package fixchain
16
17 import (
18 "bytes"
19 "context"
20 "encoding/json"
21 "fmt"
22 "log"
23 "strings"
24 "testing"
25
26 "github.com/google/certificate-transparency-go/x509"
27 "github.com/google/certificate-transparency-go/x509/pkix"
28 "github.com/google/certificate-transparency-go/x509util"
29 )
30
31 type nilLimiter struct{}
32
33 func (l *nilLimiter) Wait(ctx context.Context) error {
34 return nil
35 }
36
37 func newNilLimiter() *nilLimiter {
38 return &nilLimiter{}
39 }
40
41 type bytesReadCloser struct {
42 *bytes.Reader
43 }
44
45 func (rc bytesReadCloser) Close() error {
46 return nil
47 }
48
49
50
51
52 func GetTestCertificateFromPEM(t *testing.T, pemBytes string) *x509.Certificate {
53 cert, err := x509util.CertificateFromPEM([]byte(pemBytes))
54 if x509.IsFatal(err) {
55 t.Errorf("Failed to parse leaf: %s", err)
56 }
57 return cert
58 }
59
60 func nameToKey(name *pkix.Name) string {
61 return fmt.Sprintf("%s/%s/%s/%s", strings.Join(name.Country, ","),
62 strings.Join(name.Organization, ","),
63 strings.Join(name.OrganizationalUnit, ","), name.CommonName)
64 }
65
66 func chainToDebugString(chain []*x509.Certificate) string {
67 var chainStr string
68 for _, cert := range chain {
69 if len(chainStr) > 0 {
70 chainStr += " -> "
71 }
72 chainStr += nameToKey(&cert.Subject)
73 }
74 return chainStr
75 }
76
77 func matchTestChainList(t *testing.T, i int, want [][]string, got [][]*x509.Certificate) {
78 if len(want) != len(got) {
79 t.Errorf("#%d: Wanted %d chains, got back %d", i, len(want), len(got))
80 }
81
82 seen := make([]bool, len(want))
83 NextOutputChain:
84 for _, chain := range got {
85 TryNextExpected:
86 for j, expChain := range want {
87 if seen[j] {
88 continue
89 }
90 if len(chain) != len(expChain) {
91 continue
92 }
93 for k, cert := range chain {
94 if !strings.Contains(nameToKey(&cert.Subject), expChain[k]) {
95 continue TryNextExpected
96 }
97 }
98 seen[j] = true
99 continue NextOutputChain
100 }
101 t.Errorf("#%d: No expected chain matched output chain %s", i,
102 chainToDebugString(chain))
103 }
104
105 for j, val := range seen {
106 if !val {
107 t.Errorf("#%d: No output chain matched expected chain %s", i,
108 strings.Join(want[j], " -> "))
109 }
110 }
111 }
112
113 func matchTestErrorList(t *testing.T, i int, want []errorType, got []*FixError) {
114 if len(want) != len(got) {
115 t.Errorf("#%d: Wanted %d errors, got back %d", i, len(want), len(got))
116 }
117
118 seen := make([]bool, len(want))
119 NextOutputErr:
120 for _, err := range got {
121 for j, expErr := range want {
122 if seen[j] {
123 continue
124 }
125 if err.Type == expErr {
126 seen[j] = true
127 continue NextOutputErr
128 }
129 }
130 t.Errorf("#%d: No expected error matched output error %s", i, err.TypeString())
131 }
132
133 for j, val := range seen {
134 if !val {
135 t.Errorf("#%d: No output error matched expected error %s", i,
136 FixError{Type: want[j]}.TypeString())
137 }
138 }
139 }
140
141 func matchTestChain(t *testing.T, i int, want []string, got []*x509.Certificate) {
142 if len(got) != len(want) {
143 t.Errorf("#%d: Expected a chain of length %d, got one of length %d",
144 i, len(want), len(got))
145 return
146 }
147
148 if want != nil {
149 for j, cert := range got {
150 if !strings.Contains(nameToKey(&cert.Subject), want[j]) {
151 t.Errorf("#%d: Chain does not match expected chain at position %d", i, j)
152 }
153 }
154 }
155 }
156
157 func matchTestRoots(t *testing.T, i int, want []string, got *x509.CertPool) {
158 if len(got.Subjects()) != len(want) {
159 t.Errorf("#%d: received %d roots, expected %d", i, len(got.Subjects()), len(want))
160 }
161 testRoots := extractTestChain(t, i, want)
162 seen := make([]bool, len(testRoots))
163 NextRoot:
164 for _, rootSub := range got.Subjects() {
165 for j, testRoot := range testRoots {
166 if seen[j] {
167 continue
168 }
169 if bytes.Equal(rootSub, testRoot.RawSubject) {
170 seen[j] = true
171 continue NextRoot
172 }
173 }
174 t.Errorf("#%d: No expected root matches one of the output roots", i)
175 }
176
177 for j, val := range seen {
178 if !val {
179 t.Errorf("#%d: No output root matches expected root %s", i, nameToKey(&testRoots[j].Subject))
180 }
181 }
182 }
183
184 func extractTestChain(t *testing.T, _ int, testChain []string) []*x509.Certificate {
185 var chain []*x509.Certificate
186 for _, cert := range testChain {
187 chain = append(chain, GetTestCertificateFromPEM(t, cert))
188 }
189 return chain
190
191 }
192
193 func extractTestRoots(t *testing.T, i int, testRoots []string) *x509.CertPool {
194 roots := x509.NewCertPool()
195 for j, cert := range testRoots {
196 ok := roots.AppendCertsFromPEM([]byte(cert))
197 if !ok {
198 t.Errorf("#%d: Failed to parse root #%d", i, j)
199 }
200 }
201 return roots
202 }
203
204 func testChains(t *testing.T, i int, expectedChains [][]string, chains chan []*x509.Certificate) {
205 var allChains [][]*x509.Certificate
206 for chain := range chains {
207 allChains = append(allChains, chain)
208 }
209 matchTestChainList(t, i, expectedChains, allChains)
210 }
211
212 func testErrors(t *testing.T, i int, expectedErrs []errorType, errors chan *FixError) {
213 var allFerrs []*FixError
214 for ferr := range errors {
215 allFerrs = append(allFerrs, ferr)
216 }
217 matchTestErrorList(t, i, expectedErrs, allFerrs)
218 }
219
220 func stringRootsToJSON(roots []string) []byte {
221 type Roots struct {
222 Certs [][]byte `json:"certificates"`
223 }
224 var r Roots
225 for _, root := range roots {
226 cert, err := x509util.CertificateFromPEM([]byte(root))
227 if err != nil {
228 log.Fatalf("Failed to parse certificate: %s", err)
229 }
230 r.Certs = append(r.Certs, cert.Raw)
231 }
232 b, err := json.Marshal(r)
233 if err != nil {
234 log.Fatalf("Can't marshal JSON: %s", err)
235 }
236 return b
237 }
238
View as plain text