1 package redis
2
3 import (
4 "context"
5 "errors"
6 "reflect"
7 "sync"
8
9 "github.com/prometheus/client_golang/prometheus"
10 "golang.org/x/crypto/ocsp"
11
12 "github.com/letsencrypt/boulder/core"
13 "github.com/letsencrypt/boulder/db"
14 berrors "github.com/letsencrypt/boulder/errors"
15 blog "github.com/letsencrypt/boulder/log"
16 "github.com/letsencrypt/boulder/ocsp/responder"
17 "github.com/letsencrypt/boulder/sa"
18 sapb "github.com/letsencrypt/boulder/sa/proto"
19 )
20
21
22
23 type dbSelector interface {
24 SelectOne(ctx context.Context, holder interface{}, query string, args ...interface{}) error
25 }
26
27
28
29 type rocspSourceInterface interface {
30 Response(ctx context.Context, req *ocsp.Request) (*responder.Response, error)
31 signAndSave(ctx context.Context, req *ocsp.Request, cause signAndSaveCause) (*responder.Response, error)
32 }
33
34
35
36
37
38
39
40
41 type checkedRedisSource struct {
42 base rocspSourceInterface
43 dbMap dbSelector
44 sac sapb.StorageAuthorityReadOnlyClient
45 counter *prometheus.CounterVec
46 log blog.Logger
47 }
48
49
50
51 func NewCheckedRedisSource(base *redisSource, dbMap dbSelector, sac sapb.StorageAuthorityReadOnlyClient, stats prometheus.Registerer, log blog.Logger) (*checkedRedisSource, error) {
52 if base == nil {
53 return nil, errors.New("base was nil")
54 }
55
56
57
58
59
60
61 if (reflect.TypeOf(sac) == nil || reflect.ValueOf(sac).IsNil()) &&
62 (reflect.TypeOf(dbMap) == nil || reflect.ValueOf(dbMap).IsNil()) {
63 return nil, errors.New("either SA gRPC or direct DB connection must be provided")
64 }
65
66 return newCheckedRedisSource(base, dbMap, sac, stats, log), nil
67 }
68
69
70
71 func newCheckedRedisSource(base rocspSourceInterface, dbMap dbSelector, sac sapb.StorageAuthorityReadOnlyClient, stats prometheus.Registerer, log blog.Logger) *checkedRedisSource {
72 counter := prometheus.NewCounterVec(prometheus.CounterOpts{
73 Name: "checked_rocsp_responses",
74 Help: "Count of OCSP requests/responses from checkedRedisSource, by result",
75 }, []string{"result"})
76 stats.MustRegister(counter)
77
78 return &checkedRedisSource{
79 base: base,
80 dbMap: dbMap,
81 sac: sac,
82 counter: counter,
83 log: log,
84 }
85 }
86
87
88
89
90 func (src *checkedRedisSource) Response(ctx context.Context, req *ocsp.Request) (*responder.Response, error) {
91 serialString := core.SerialToString(req.SerialNumber)
92
93 var wg sync.WaitGroup
94 wg.Add(2)
95 var dbStatus *sapb.RevocationStatus
96 var redisResult *responder.Response
97 var redisErr, dbErr error
98 go func() {
99 defer wg.Done()
100 if src.sac != nil {
101 dbStatus, dbErr = src.sac.GetRevocationStatus(ctx, &sapb.Serial{Serial: serialString})
102 } else {
103 dbStatus, dbErr = sa.SelectRevocationStatus(ctx, src.dbMap, serialString)
104 }
105 }()
106 go func() {
107 defer wg.Done()
108 redisResult, redisErr = src.base.Response(ctx, req)
109 }()
110 wg.Wait()
111
112 if dbErr != nil {
113
114
115 if db.IsNoRows(dbErr) || errors.Is(dbErr, berrors.NotFound) {
116 src.counter.WithLabelValues("not_found").Inc()
117 return nil, responder.ErrNotFound
118 }
119
120 src.counter.WithLabelValues("db_error").Inc()
121 return nil, dbErr
122 }
123
124 if redisErr != nil {
125 src.counter.WithLabelValues("redis_error").Inc()
126 return nil, redisErr
127 }
128
129
130 if agree(dbStatus, redisResult.Response) {
131 src.counter.WithLabelValues("success").Inc()
132 return redisResult, nil
133 }
134
135
136 freshResult, err := src.base.signAndSave(ctx, req, causeMismatch)
137 if err != nil {
138 src.counter.WithLabelValues("revocation_re_sign_error").Inc()
139 return nil, err
140 }
141
142 if agree(dbStatus, freshResult.Response) {
143 src.counter.WithLabelValues("revocation_re_sign_success").Inc()
144 return freshResult, nil
145 }
146
147
148
149 src.counter.WithLabelValues("revocation_re_sign_mismatch").Inc()
150 return nil, errors.New("freshly signed status did not match DB")
151
152 }
153
154
155 func agree(dbStatus *sapb.RevocationStatus, redisResult *ocsp.Response) bool {
156 return dbStatus.Status == int64(redisResult.Status) &&
157 dbStatus.RevokedReason == int64(redisResult.RevocationReason) &&
158 dbStatus.RevokedDate.AsTime().Equal(redisResult.RevokedAt)
159 }
160
View as plain text