...

Source file src/github.com/letsencrypt/boulder/cmd/rocsp-tool/client.go

Documentation: github.com/letsencrypt/boulder/cmd/rocsp-tool

     1  package notmain
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"os"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"github.com/jmhodges/clock"
    12  	capb "github.com/letsencrypt/boulder/ca/proto"
    13  	"github.com/letsencrypt/boulder/core"
    14  	"github.com/letsencrypt/boulder/db"
    15  	blog "github.com/letsencrypt/boulder/log"
    16  	"github.com/letsencrypt/boulder/rocsp"
    17  	"github.com/letsencrypt/boulder/sa"
    18  	"github.com/letsencrypt/boulder/test/ocsp/helper"
    19  	"golang.org/x/crypto/ocsp"
    20  	"google.golang.org/protobuf/types/known/timestamppb"
    21  )
    22  
    23  type client struct {
    24  	redis         *rocsp.RWClient
    25  	db            *db.WrappedMap // optional
    26  	ocspGenerator capb.OCSPGeneratorClient
    27  	clk           clock.Clock
    28  	scanBatchSize int
    29  	logger        blog.Logger
    30  }
    31  
    32  // processResult represents the result of attempting to sign and store status
    33  // for a single certificateStatus ID. If `err` is non-nil, it indicates the
    34  // attempt failed.
    35  type processResult struct {
    36  	id  uint64
    37  	err error
    38  }
    39  
    40  func getStartingID(ctx context.Context, clk clock.Clock, db *db.WrappedMap) (int64, error) {
    41  	// To scan the DB efficiently, we want to select only currently-valid certificates. There's a
    42  	// handy expires index, but for selecting a large set of rows, using the primary key will be
    43  	// more efficient. So first we find a good id to start with, then scan from there. Note: since
    44  	// AUTO_INCREMENT can skip around a bit, we add padding to ensure we get all currently-valid
    45  	// certificates.
    46  	startTime := clk.Now().Add(-24 * time.Hour)
    47  	var minID *int64
    48  	err := db.QueryRowContext(
    49  		ctx,
    50  		"SELECT MIN(id) FROM certificateStatus WHERE notAfter >= ?",
    51  		startTime,
    52  	).Scan(&minID)
    53  	if err != nil {
    54  		return 0, fmt.Errorf("selecting minID: %w", err)
    55  	}
    56  	if minID == nil {
    57  		return 0, fmt.Errorf("no entries in certificateStatus (where notAfter >= %s)", startTime)
    58  	}
    59  	return *minID, nil
    60  }
    61  
    62  func (cl *client) loadFromDB(ctx context.Context, speed ProcessingSpeed, startFromID int64) error {
    63  	prevID := startFromID
    64  	var err error
    65  	if prevID == 0 {
    66  		prevID, err = getStartingID(ctx, cl.clk, cl.db)
    67  		if err != nil {
    68  			return fmt.Errorf("getting starting ID: %w", err)
    69  		}
    70  	}
    71  
    72  	// Find the current maximum id in certificateStatus. We do this because the table is always
    73  	// growing. If we scanned until we saw a batch with no rows, we would scan forever.
    74  	var maxID *int64
    75  	err = cl.db.QueryRowContext(
    76  		ctx,
    77  		"SELECT MAX(id) FROM certificateStatus",
    78  	).Scan(&maxID)
    79  	if err != nil {
    80  		return fmt.Errorf("selecting maxID: %w", err)
    81  	}
    82  	if maxID == nil {
    83  		return fmt.Errorf("no entries in certificateStatus")
    84  	}
    85  
    86  	// Limit the rate of reading rows.
    87  	frequency := time.Duration(float64(time.Second) / float64(time.Duration(speed.RowsPerSecond)))
    88  	// a set of all inflight certificate statuses, indexed by their `ID`.
    89  	inflightIDs := newInflight()
    90  	statusesToSign := cl.scanFromDB(ctx, prevID, *maxID, frequency, inflightIDs)
    91  
    92  	results := make(chan processResult, speed.ParallelSigns)
    93  	var runningSigners int32
    94  	for i := 0; i < speed.ParallelSigns; i++ {
    95  		atomic.AddInt32(&runningSigners, 1)
    96  		go cl.signAndStoreResponses(ctx, statusesToSign, results, &runningSigners)
    97  	}
    98  
    99  	var successCount, errorCount int64
   100  
   101  	for result := range results {
   102  		inflightIDs.remove(result.id)
   103  		if result.err != nil {
   104  			errorCount++
   105  			if errorCount < 10 ||
   106  				(errorCount < 1000 && rand.Intn(1000) < 100) ||
   107  				(errorCount < 100000 && rand.Intn(1000) < 10) ||
   108  				(rand.Intn(1000) < 1) {
   109  				cl.logger.Errf("error: %s", result.err)
   110  			}
   111  		} else {
   112  			successCount++
   113  		}
   114  
   115  		total := successCount + errorCount
   116  		if total < 10 ||
   117  			(total < 1000 && rand.Intn(1000) < 100) ||
   118  			(total < 100000 && rand.Intn(1000) < 10) ||
   119  			(rand.Intn(1000) < 1) {
   120  			cl.logger.Infof("stored %d responses, %d errors", successCount, errorCount)
   121  		}
   122  	}
   123  
   124  	cl.logger.Infof("done. processed %d successes and %d errors\n", successCount, errorCount)
   125  	if inflightIDs.len() != 0 {
   126  		return fmt.Errorf("inflightIDs non-empty! has %d items, lowest %d", inflightIDs.len(), inflightIDs.min())
   127  	}
   128  
   129  	return nil
   130  }
   131  
   132  // scanFromDB scans certificateStatus rows from the DB, starting with `minID`, and writes them to
   133  // its output channel at a maximum frequency of `frequency`. When it's read all available rows, it
   134  // closes its output channel and exits.
   135  // If there is an error, it logs the error, closes its output channel, and exits.
   136  func (cl *client) scanFromDB(ctx context.Context, prevID int64, maxID int64, frequency time.Duration, inflightIDs *inflight) <-chan *sa.CertStatusMetadata {
   137  	statusesToSign := make(chan *sa.CertStatusMetadata)
   138  	go func() {
   139  		defer close(statusesToSign)
   140  
   141  		var err error
   142  		currentMin := prevID
   143  		for currentMin < maxID {
   144  			currentMin, err = cl.scanFromDBOneBatch(ctx, currentMin, frequency, statusesToSign, inflightIDs)
   145  			if err != nil {
   146  				cl.logger.Infof("error scanning rows: %s", err)
   147  			}
   148  		}
   149  	}()
   150  	return statusesToSign
   151  }
   152  
   153  // scanFromDBOneBatch scans up to `cl.scanBatchSize` rows from certificateStatus, in order, and
   154  // writes them to `output`. When done, it returns the highest `id` it saw during the scan.
   155  // We do this in batches because if we tried to scan the whole table in a single query, MariaDB
   156  // would terminate the query after a certain amount of data transferred.
   157  func (cl *client) scanFromDBOneBatch(ctx context.Context, prevID int64, frequency time.Duration, output chan<- *sa.CertStatusMetadata, inflightIDs *inflight) (int64, error) {
   158  	rowTicker := time.NewTicker(frequency)
   159  
   160  	clauses := "WHERE id > ? ORDER BY id LIMIT ?"
   161  	params := []interface{}{prevID, cl.scanBatchSize}
   162  
   163  	selector, err := db.NewMappedSelector[sa.CertStatusMetadata](cl.db)
   164  	if err != nil {
   165  		return -1, fmt.Errorf("initializing db map: %w", err)
   166  	}
   167  
   168  	rows, err := selector.QueryContext(ctx, clauses, params...)
   169  	if err != nil {
   170  		return -1, fmt.Errorf("scanning certificateStatus: %w", err)
   171  	}
   172  	defer func() {
   173  		rerr := rows.Close()
   174  		if rerr != nil {
   175  			cl.logger.Infof("closing rows: %s", rerr)
   176  		}
   177  	}()
   178  
   179  	var scanned int
   180  	var previousID int64
   181  	for rows.Next() {
   182  		<-rowTicker.C
   183  
   184  		status, err := rows.Get()
   185  		if err != nil {
   186  			return -1, fmt.Errorf("scanning row %d (previous ID %d): %w", scanned, previousID, err)
   187  		}
   188  		scanned++
   189  		inflightIDs.add(uint64(status.ID))
   190  		// Emit a log line every 100000 rows. For our current ~215M rows, that
   191  		// will emit about 2150 log lines. This probably strikes a good balance
   192  		// between too spammy and having a reasonably frequent checkpoint.
   193  		if scanned%100000 == 0 {
   194  			cl.logger.Infof("scanned %d certificateStatus rows. minimum inflight ID %d", scanned, inflightIDs.min())
   195  		}
   196  		output <- status
   197  		previousID = status.ID
   198  	}
   199  	return previousID, nil
   200  }
   201  
   202  // signAndStoreResponses consumes cert statuses on its input channel and writes them to its output
   203  // channel. Before returning, it atomically decrements the provided runningSigners int. If the
   204  // result is 0, indicating this was the last running signer, it closes its output channel.
   205  func (cl *client) signAndStoreResponses(ctx context.Context, input <-chan *sa.CertStatusMetadata, output chan processResult, runningSigners *int32) {
   206  	defer func() {
   207  		if atomic.AddInt32(runningSigners, -1) <= 0 {
   208  			close(output)
   209  		}
   210  	}()
   211  	for status := range input {
   212  		ocspReq := &capb.GenerateOCSPRequest{
   213  			Serial:      status.Serial,
   214  			IssuerID:    status.IssuerID,
   215  			Status:      string(status.Status),
   216  			Reason:      int32(status.RevokedReason),
   217  			RevokedAtNS: status.RevokedDate.UnixNano(),
   218  			RevokedAt:   timestamppb.New(status.RevokedDate),
   219  		}
   220  		result, err := cl.ocspGenerator.GenerateOCSP(ctx, ocspReq)
   221  		if err != nil {
   222  			output <- processResult{id: uint64(status.ID), err: err}
   223  			continue
   224  		}
   225  		resp, err := ocsp.ParseResponse(result.Response, nil)
   226  		if err != nil {
   227  			output <- processResult{id: uint64(status.ID), err: err}
   228  			continue
   229  		}
   230  
   231  		err = cl.redis.StoreResponse(ctx, resp)
   232  		if err != nil {
   233  			output <- processResult{id: uint64(status.ID), err: err}
   234  		} else {
   235  			output <- processResult{id: uint64(status.ID), err: nil}
   236  		}
   237  	}
   238  }
   239  
   240  type expiredError struct {
   241  	serial string
   242  	ago    time.Duration
   243  }
   244  
   245  func (e expiredError) Error() string {
   246  	return fmt.Sprintf("response for %s expired %s ago", e.serial, e.ago)
   247  }
   248  
   249  func (cl *client) storeResponsesFromFiles(ctx context.Context, files []string) error {
   250  	for _, respFile := range files {
   251  		respBytes, err := os.ReadFile(respFile)
   252  		if err != nil {
   253  			return fmt.Errorf("reading response file %q: %w", respFile, err)
   254  		}
   255  		err = cl.storeResponse(ctx, respBytes)
   256  		if err != nil {
   257  			return err
   258  		}
   259  	}
   260  	return nil
   261  }
   262  
   263  func (cl *client) storeResponse(ctx context.Context, respBytes []byte) error {
   264  	resp, err := ocsp.ParseResponse(respBytes, nil)
   265  	if err != nil {
   266  		return fmt.Errorf("parsing response: %w", err)
   267  	}
   268  
   269  	serial := core.SerialToString(resp.SerialNumber)
   270  
   271  	if resp.NextUpdate.Before(cl.clk.Now()) {
   272  		return expiredError{
   273  			serial: serial,
   274  			ago:    cl.clk.Now().Sub(resp.NextUpdate),
   275  		}
   276  	}
   277  
   278  	cl.logger.Infof("storing response for %s, generated %s, ttl %g hours",
   279  		serial,
   280  		resp.ThisUpdate,
   281  		time.Until(resp.NextUpdate).Hours(),
   282  	)
   283  
   284  	err = cl.redis.StoreResponse(ctx, resp)
   285  	if err != nil {
   286  		return fmt.Errorf("storing response: %w", err)
   287  	}
   288  
   289  	retrievedResponse, err := cl.redis.GetResponse(ctx, serial)
   290  	if err != nil {
   291  		return fmt.Errorf("getting response: %w", err)
   292  	}
   293  
   294  	parsedRetrievedResponse, err := ocsp.ParseResponse(retrievedResponse, nil)
   295  	if err != nil {
   296  		return fmt.Errorf("parsing retrieved response: %w", err)
   297  	}
   298  	cl.logger.Infof("retrieved %s", helper.PrettyResponse(parsedRetrievedResponse))
   299  	return nil
   300  }
   301  

View as plain text