1 package rocsp
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "time"
8
9 "github.com/letsencrypt/boulder/core"
10
11 "github.com/jmhodges/clock"
12 "github.com/prometheus/client_golang/prometheus"
13 "github.com/redis/go-redis/v9"
14 "golang.org/x/crypto/ocsp"
15 )
16
17 var ErrRedisNotFound = errors.New("redis key not found")
18
19
20 type ROClient struct {
21 rdb *redis.Ring
22 timeout time.Duration
23 clk clock.Clock
24 getLatency *prometheus.HistogramVec
25 }
26
27
28
29
30 func NewReadingClient(rdb *redis.Ring, timeout time.Duration, clk clock.Clock, stats prometheus.Registerer) *ROClient {
31 getLatency := prometheus.NewHistogramVec(
32 prometheus.HistogramOpts{
33 Name: "rocsp_get_latency",
34 Help: "Histogram of latencies of rocsp.GetResponse calls with result",
35
36 Buckets: prometheus.ExponentialBucketsRange(0.0005, 2, 8),
37 },
38 []string{"result"},
39 )
40 stats.MustRegister(getLatency)
41
42 return &ROClient{
43 rdb: rdb,
44 timeout: timeout,
45 clk: clk,
46 getLatency: getLatency,
47 }
48 }
49
50
51
52 func (c *ROClient) Ping(ctx context.Context) error {
53 ctx, cancel := context.WithTimeout(ctx, c.timeout)
54 defer cancel()
55
56 err := c.rdb.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
57 return shard.Ping(ctx).Err()
58 })
59 if err != nil {
60 return err
61 }
62 return nil
63 }
64
65
66 type RWClient struct {
67 *ROClient
68 storeResponseLatency *prometheus.HistogramVec
69 }
70
71
72 func NewWritingClient(rdb *redis.Ring, timeout time.Duration, clk clock.Clock, stats prometheus.Registerer) *RWClient {
73 storeResponseLatency := prometheus.NewHistogramVec(
74 prometheus.HistogramOpts{
75 Name: "rocsp_store_response_latency",
76 Help: "Histogram of latencies of rocsp.StoreResponse calls with result labels",
77 },
78 []string{"result"},
79 )
80 stats.MustRegister(storeResponseLatency)
81 return &RWClient{NewReadingClient(rdb, timeout, clk, stats), storeResponseLatency}
82 }
83
84
85
86
87 func (c *RWClient) StoreResponse(ctx context.Context, resp *ocsp.Response) error {
88 start := c.clk.Now()
89 ctx, cancel := context.WithTimeout(ctx, c.timeout)
90 defer cancel()
91
92 serial := core.SerialToString(resp.SerialNumber)
93
94
95 ttl := time.Until(resp.NextUpdate)
96
97 err := c.rdb.Set(ctx, serial, resp.Raw, ttl).Err()
98 if err != nil {
99 state := "failed"
100 if errors.Is(err, context.DeadlineExceeded) {
101 state = "deadlineExceeded"
102 } else if errors.Is(err, context.Canceled) {
103 state = "canceled"
104 }
105 c.storeResponseLatency.With(prometheus.Labels{"result": state}).Observe(time.Since(start).Seconds())
106 return fmt.Errorf("setting response: %w", err)
107 }
108
109 c.storeResponseLatency.With(prometheus.Labels{"result": "success"}).Observe(time.Since(start).Seconds())
110 return nil
111 }
112
113
114
115 func (c *ROClient) GetResponse(ctx context.Context, serial string) ([]byte, error) {
116 start := c.clk.Now()
117 ctx, cancel := context.WithTimeout(ctx, c.timeout)
118 defer cancel()
119
120 resp, err := c.rdb.Get(ctx, serial).Result()
121 if err != nil {
122
123
124 if errors.Is(err, redis.Nil) {
125 c.getLatency.With(prometheus.Labels{"result": "notFound"}).Observe(time.Since(start).Seconds())
126 return nil, ErrRedisNotFound
127 }
128
129 state := "failed"
130 if errors.Is(err, context.DeadlineExceeded) {
131 state = "deadlineExceeded"
132 } else if errors.Is(err, context.Canceled) {
133 state = "canceled"
134 }
135 c.getLatency.With(prometheus.Labels{"result": state}).Observe(time.Since(start).Seconds())
136 return nil, fmt.Errorf("getting response: %w", err)
137 }
138
139 c.getLatency.With(prometheus.Labels{"result": "success"}).Observe(time.Since(start).Seconds())
140 return []byte(resp), nil
141 }
142
143
144
145
146
147 type ScanResponsesResult struct {
148 Serial string
149 Body []byte
150 Err error
151 }
152
153
154
155
156 func (c *ROClient) ScanResponses(ctx context.Context, serialPattern string) <-chan ScanResponsesResult {
157 pattern := fmt.Sprintf("r{%s}", serialPattern)
158 results := make(chan ScanResponsesResult)
159 go func() {
160 defer close(results)
161 err := c.rdb.ForEachShard(ctx, func(ctx context.Context, rdb *redis.Client) error {
162 iter := rdb.Scan(ctx, 0, pattern, 0).Iterator()
163 for iter.Next(ctx) {
164 key := iter.Val()
165 val, err := c.rdb.Get(ctx, key).Result()
166 if err != nil {
167 results <- ScanResponsesResult{Err: fmt.Errorf("getting response: %w", err)}
168 continue
169 }
170 results <- ScanResponsesResult{Serial: key, Body: []byte(val)}
171 }
172 return iter.Err()
173 })
174 if err != nil {
175 results <- ScanResponsesResult{Err: err}
176 return
177 }
178 }()
179 return results
180 }
181
View as plain text