1 package notmain
2
3 import (
4 "context"
5 "fmt"
6 "math/rand"
7 "os"
8 "sync/atomic"
9 "time"
10
11 "github.com/jmhodges/clock"
12 capb "github.com/letsencrypt/boulder/ca/proto"
13 "github.com/letsencrypt/boulder/core"
14 "github.com/letsencrypt/boulder/db"
15 blog "github.com/letsencrypt/boulder/log"
16 "github.com/letsencrypt/boulder/rocsp"
17 "github.com/letsencrypt/boulder/sa"
18 "github.com/letsencrypt/boulder/test/ocsp/helper"
19 "golang.org/x/crypto/ocsp"
20 "google.golang.org/protobuf/types/known/timestamppb"
21 )
22
23 type client struct {
24 redis *rocsp.RWClient
25 db *db.WrappedMap
26 ocspGenerator capb.OCSPGeneratorClient
27 clk clock.Clock
28 scanBatchSize int
29 logger blog.Logger
30 }
31
32
33
34
35 type processResult struct {
36 id uint64
37 err error
38 }
39
40 func getStartingID(ctx context.Context, clk clock.Clock, db *db.WrappedMap) (int64, error) {
41
42
43
44
45
46 startTime := clk.Now().Add(-24 * time.Hour)
47 var minID *int64
48 err := db.QueryRowContext(
49 ctx,
50 "SELECT MIN(id) FROM certificateStatus WHERE notAfter >= ?",
51 startTime,
52 ).Scan(&minID)
53 if err != nil {
54 return 0, fmt.Errorf("selecting minID: %w", err)
55 }
56 if minID == nil {
57 return 0, fmt.Errorf("no entries in certificateStatus (where notAfter >= %s)", startTime)
58 }
59 return *minID, nil
60 }
61
62 func (cl *client) loadFromDB(ctx context.Context, speed ProcessingSpeed, startFromID int64) error {
63 prevID := startFromID
64 var err error
65 if prevID == 0 {
66 prevID, err = getStartingID(ctx, cl.clk, cl.db)
67 if err != nil {
68 return fmt.Errorf("getting starting ID: %w", err)
69 }
70 }
71
72
73
74 var maxID *int64
75 err = cl.db.QueryRowContext(
76 ctx,
77 "SELECT MAX(id) FROM certificateStatus",
78 ).Scan(&maxID)
79 if err != nil {
80 return fmt.Errorf("selecting maxID: %w", err)
81 }
82 if maxID == nil {
83 return fmt.Errorf("no entries in certificateStatus")
84 }
85
86
87 frequency := time.Duration(float64(time.Second) / float64(time.Duration(speed.RowsPerSecond)))
88
89 inflightIDs := newInflight()
90 statusesToSign := cl.scanFromDB(ctx, prevID, *maxID, frequency, inflightIDs)
91
92 results := make(chan processResult, speed.ParallelSigns)
93 var runningSigners int32
94 for i := 0; i < speed.ParallelSigns; i++ {
95 atomic.AddInt32(&runningSigners, 1)
96 go cl.signAndStoreResponses(ctx, statusesToSign, results, &runningSigners)
97 }
98
99 var successCount, errorCount int64
100
101 for result := range results {
102 inflightIDs.remove(result.id)
103 if result.err != nil {
104 errorCount++
105 if errorCount < 10 ||
106 (errorCount < 1000 && rand.Intn(1000) < 100) ||
107 (errorCount < 100000 && rand.Intn(1000) < 10) ||
108 (rand.Intn(1000) < 1) {
109 cl.logger.Errf("error: %s", result.err)
110 }
111 } else {
112 successCount++
113 }
114
115 total := successCount + errorCount
116 if total < 10 ||
117 (total < 1000 && rand.Intn(1000) < 100) ||
118 (total < 100000 && rand.Intn(1000) < 10) ||
119 (rand.Intn(1000) < 1) {
120 cl.logger.Infof("stored %d responses, %d errors", successCount, errorCount)
121 }
122 }
123
124 cl.logger.Infof("done. processed %d successes and %d errors\n", successCount, errorCount)
125 if inflightIDs.len() != 0 {
126 return fmt.Errorf("inflightIDs non-empty! has %d items, lowest %d", inflightIDs.len(), inflightIDs.min())
127 }
128
129 return nil
130 }
131
132
133
134
135
136 func (cl *client) scanFromDB(ctx context.Context, prevID int64, maxID int64, frequency time.Duration, inflightIDs *inflight) <-chan *sa.CertStatusMetadata {
137 statusesToSign := make(chan *sa.CertStatusMetadata)
138 go func() {
139 defer close(statusesToSign)
140
141 var err error
142 currentMin := prevID
143 for currentMin < maxID {
144 currentMin, err = cl.scanFromDBOneBatch(ctx, currentMin, frequency, statusesToSign, inflightIDs)
145 if err != nil {
146 cl.logger.Infof("error scanning rows: %s", err)
147 }
148 }
149 }()
150 return statusesToSign
151 }
152
153
154
155
156
157 func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequency time.Duration, output chan<- *sa.CertStatusMetadata, inflightIDs *inflight) (int64, error) {
158 rowTicker := time.NewTicker(frequency)
159
160 clauses := "WHERE id > ? ORDER BY id LIMIT ?"
161 params := []interface{}{prevID, cl.scanBatchSize}
162
163 selector, err := db.NewMappedSelector[sa.CertStatusMetadata](cl.db)
164 if err != nil {
165 return -1, fmt.Errorf("initializing db map: %w", err)
166 }
167
168 rows, err := selector.QueryContext(ctx, clauses, params...)
169 if err != nil {
170 return -1, fmt.Errorf("scanning certificateStatus: %w", err)
171 }
172 defer func() {
173 rerr := rows.Close()
174 if rerr != nil {
175 cl.logger.Infof("closing rows: %s", rerr)
176 }
177 }()
178
179 var scanned int
180 var previousID int64
181 for rows.Next() {
182 <-rowTicker.C
183
184 status, err := rows.Get()
185 if err != nil {
186 return -1, fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err)
187 }
188 scanned++
189 inflightIDs.add(uint64(status.ID))
190
191
192
193 if scanned%100000 == 0 {
194 cl.logger.Infof("scanned %d certificateStatus rows. minimum inflight ID %d", scanned, inflightIDs.min())
195 }
196 output <- status
197 previousID = status.ID
198 }
199 return previousID, nil
200 }
201
202
203
204
205 func (cl *client) signAndStoreResponses(ctx context.Context, input <-chan *sa.CertStatusMetadata, output chan processResult, runningSigners *int32) {
206 defer func() {
207 if atomic.AddInt32(runningSigners, -1) <= 0 {
208 close(output)
209 }
210 }()
211 for status := range input {
212 ocspReq := &capb.GenerateOCSPRequest{
213 Serial: status.Serial,
214 IssuerID: status.IssuerID,
215 Status: string(status.Status),
216 Reason: int32(status.RevokedReason),
217 RevokedAtNS: status.RevokedDate.UnixNano(),
218 RevokedAt: timestamppb.New(status.RevokedDate),
219 }
220 result, err := cl.ocspGenerator.GenerateOCSP(ctx, ocspReq)
221 if err != nil {
222 output <- processResult{id: uint64(status.ID), err: err}
223 continue
224 }
225 resp, err := ocsp.ParseResponse(result.Response, nil)
226 if err != nil {
227 output <- processResult{id: uint64(status.ID), err: err}
228 continue
229 }
230
231 err = cl.redis.StoreResponse(ctx, resp)
232 if err != nil {
233 output <- processResult{id: uint64(status.ID), err: err}
234 } else {
235 output <- processResult{id: uint64(status.ID), err: nil}
236 }
237 }
238 }
239
240 type expiredError struct {
241 serial string
242 ago time.Duration
243 }
244
245 func (e expiredError) Error() string {
246 return fmt.Sprintf("response for %s expired %s ago", e.serial, e.ago)
247 }
248
249 func (cl *client) storeResponsesFromFiles(ctx context.Context, files []string) error {
250 for _, respFile := range files {
251 respBytes, err := os.ReadFile(respFile)
252 if err != nil {
253 return fmt.Errorf("reading response file %q: %w", respFile, err)
254 }
255 err = cl.storeResponse(ctx, respBytes)
256 if err != nil {
257 return err
258 }
259 }
260 return nil
261 }
262
263 func (cl *client) storeResponse(ctx context.Context, respBytes []byte) error {
264 resp, err := ocsp.ParseResponse(respBytes, nil)
265 if err != nil {
266 return fmt.Errorf("parsing response: %w", err)
267 }
268
269 serial := core.SerialToString(resp.SerialNumber)
270
271 if resp.NextUpdate.Before(cl.clk.Now()) {
272 return expiredError{
273 serial: serial,
274 ago: cl.clk.Now().Sub(resp.NextUpdate),
275 }
276 }
277
278 cl.logger.Infof("storing response for %s, generated %s, ttl %g hours",
279 serial,
280 resp.ThisUpdate,
281 time.Until(resp.NextUpdate).Hours(),
282 )
283
284 err = cl.redis.StoreResponse(ctx, resp)
285 if err != nil {
286 return fmt.Errorf("storing response: %w", err)
287 }
288
289 retrievedResponse, err := cl.redis.GetResponse(ctx, serial)
290 if err != nil {
291 return fmt.Errorf("getting response: %w", err)
292 }
293
294 parsedRetrievedResponse, err := ocsp.ParseResponse(retrievedResponse, nil)
295 if err != nil {
296 return fmt.Errorf("parsing retrieved response: %w", err)
297 }
298 cl.logger.Infof("retrieved %s", helper.PrettyResponse(parsedRetrievedResponse))
299 return nil
300 }
301
View as plain text