1
29
30
31
32 package responder
33
34 import (
35 "context"
36 "crypto"
37 "crypto/sha256"
38 "encoding/base64"
39 "encoding/json"
40 "errors"
41 "fmt"
42 "io"
43 "math/rand"
44 "net/http"
45 "net/url"
46 "time"
47
48 "github.com/jmhodges/clock"
49 "github.com/prometheus/client_golang/prometheus"
50 "golang.org/x/crypto/ocsp"
51
52 "github.com/letsencrypt/boulder/core"
53 blog "github.com/letsencrypt/boulder/log"
54 )
55
56
57
58 var ErrNotFound = errors.New("request OCSP Response not found")
59
60
61
62
63 var errOCSPResponseExpired = errors.New("OCSP response is expired")
64
65 var responseTypeToString = map[ocsp.ResponseStatus]string{
66 ocsp.Success: "Success",
67 ocsp.Malformed: "Malformed",
68 ocsp.InternalError: "InternalError",
69 ocsp.TryLater: "TryLater",
70 ocsp.SignatureRequired: "SignatureRequired",
71 ocsp.Unauthorized: "Unauthorized",
72 }
73
74
75 type Responder struct {
76 Source Source
77 timeout time.Duration
78 responseTypes *prometheus.CounterVec
79 responseAges prometheus.Histogram
80 requestSizes prometheus.Histogram
81 sampleRate int
82 clk clock.Clock
83 log blog.Logger
84 }
85
86
87 func NewResponder(source Source, timeout time.Duration, stats prometheus.Registerer, logger blog.Logger, sampleRate int) *Responder {
88 requestSizes := prometheus.NewHistogram(
89 prometheus.HistogramOpts{
90 Name: "ocsp_request_sizes",
91 Help: "Size of OCSP requests",
92 Buckets: []float64{1, 100, 200, 400, 800, 1200, 2000, 5000, 10000},
93 },
94 )
95 stats.MustRegister(requestSizes)
96
97
98 buckets := make([]float64, 14)
99 for i := range buckets {
100 buckets[i] = 43200 * float64(i)
101 }
102 responseAges := prometheus.NewHistogram(prometheus.HistogramOpts{
103 Name: "ocsp_response_ages",
104 Help: "How old are the OCSP responses when we serve them. Must stay well below 84 hours.",
105 Buckets: buckets,
106 })
107 stats.MustRegister(responseAges)
108
109 responseTypes := prometheus.NewCounterVec(
110 prometheus.CounterOpts{
111 Name: "ocsp_responses",
112 Help: "Number of OCSP responses returned by type",
113 },
114 []string{"type"},
115 )
116 stats.MustRegister(responseTypes)
117
118 return &Responder{
119 Source: source,
120 timeout: timeout,
121 responseTypes: responseTypes,
122 responseAges: responseAges,
123 requestSizes: requestSizes,
124 clk: clock.New(),
125 log: logger,
126 sampleRate: sampleRate,
127 }
128 }
129
130 type logEvent struct {
131 IP string `json:"ip,omitempty"`
132 UA string `json:"ua,omitempty"`
133 Method string `json:"method,omitempty"`
134 Path string `json:"path,omitempty"`
135 Body string `json:"body,omitempty"`
136 Received time.Time `json:"received,omitempty"`
137 Took time.Duration `json:"took,omitempty"`
138 Headers http.Header `json:"headers,omitempty"`
139
140 Serial string `json:"serial,omitempty"`
141 IssuerKeyHash string `json:"issuerKeyHash,omitempty"`
142 IssuerNameHash string `json:"issuerNameHash,omitempty"`
143 HashAlg string `json:"hashAlg,omitempty"`
144 }
145
146
147
148 var hashToString = map[crypto.Hash]string{
149 crypto.SHA1: "SHA1",
150 crypto.SHA256: "SHA256",
151 crypto.SHA384: "SHA384",
152 crypto.SHA512: "SHA512",
153 }
154
155 func SampledError(log blog.Logger, sampleRate int, format string, a ...interface{}) {
156 if sampleRate > 0 && rand.Intn(sampleRate) == 0 {
157 log.Errf(format, a...)
158 }
159 }
160
161 func (rs Responder) sampledError(format string, a ...interface{}) {
162 SampledError(rs.log, rs.sampleRate, format, a...)
163 }
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183 func (rs Responder) ServeHTTP(response http.ResponseWriter, request *http.Request) {
184
185
186
187 ctx := context.WithoutCancel(request.Context())
188 request = request.WithContext(ctx)
189
190 if rs.timeout != 0 {
191 var cancel func()
192 ctx, cancel = context.WithTimeout(ctx, rs.timeout)
193 defer cancel()
194 }
195
196 le := logEvent{
197 IP: request.RemoteAddr,
198 UA: request.UserAgent(),
199 Method: request.Method,
200 Path: request.URL.Path,
201 Received: time.Now(),
202 }
203
204 defer func() {
205 le.Headers = response.Header()
206 le.Took = time.Since(le.Received)
207 jb, err := json.Marshal(le)
208 if err != nil {
209
210
211 rs.log.Debugf("failed to marshal log event object: %s", err)
212 return
213 }
214 rs.log.Debugf("Received request: %s", string(jb))
215 }()
216
217
218
219
220 response.Header().Add("Cache-Control", "max-age=0, no-cache")
221
222 var requestBody []byte
223 var err error
224 switch request.Method {
225 case "GET":
226 base64Request, err := url.QueryUnescape(request.URL.Path)
227 if err != nil {
228 rs.log.Debugf("Error decoding URL: %s", request.URL.Path)
229 rs.responseTypes.With(prometheus.Labels{"type": responseTypeToString[ocsp.Malformed]}).Inc()
230 response.WriteHeader(http.StatusBadRequest)
231 return
232 }
233
234
235
236
237 base64RequestBytes := []byte(base64Request)
238 for i := range base64RequestBytes {
239 if base64RequestBytes[i] == ' ' {
240 base64RequestBytes[i] = '+'
241 }
242 }
243
244
245
246
247 if len(base64RequestBytes) > 0 && base64RequestBytes[0] == '/' {
248 base64RequestBytes = base64RequestBytes[1:]
249 }
250 requestBody, err = base64.StdEncoding.DecodeString(string(base64RequestBytes))
251 if err != nil {
252 rs.log.Debugf("Error decoding base64 from URL: %s", string(base64RequestBytes))
253 response.WriteHeader(http.StatusBadRequest)
254 rs.responseTypes.With(prometheus.Labels{"type": responseTypeToString[ocsp.Malformed]}).Inc()
255 return
256 }
257 case "POST":
258 requestBody, err = io.ReadAll(http.MaxBytesReader(nil, request.Body, 10000))
259 if err != nil {
260 rs.log.Errf("Problem reading body of POST: %s", err)
261 response.WriteHeader(http.StatusBadRequest)
262 rs.responseTypes.With(prometheus.Labels{"type": responseTypeToString[ocsp.Malformed]}).Inc()
263 return
264 }
265 rs.requestSizes.Observe(float64(len(requestBody)))
266 default:
267 response.WriteHeader(http.StatusMethodNotAllowed)
268 return
269 }
270 b64Body := base64.StdEncoding.EncodeToString(requestBody)
271 rs.log.Debugf("Received OCSP request: %s", b64Body)
272 if request.Method == http.MethodPost {
273 le.Body = b64Body
274 }
275
276
277
278
279 response.Header().Add("Content-Type", "application/ocsp-response")
280
281
282
283
284
285 ocspRequest, err := ocsp.ParseRequest(requestBody)
286 if err != nil {
287 rs.log.Debugf("Error decoding request body: %s", b64Body)
288 response.WriteHeader(http.StatusBadRequest)
289 response.Write(ocsp.MalformedRequestErrorResponse)
290 rs.responseTypes.With(prometheus.Labels{"type": responseTypeToString[ocsp.Malformed]}).Inc()
291 return
292 }
293 le.Serial = fmt.Sprintf("%x", ocspRequest.SerialNumber.Bytes())
294 le.IssuerKeyHash = fmt.Sprintf("%x", ocspRequest.IssuerKeyHash)
295 le.IssuerNameHash = fmt.Sprintf("%x", ocspRequest.IssuerNameHash)
296 le.HashAlg = hashToString[ocspRequest.HashAlgorithm]
297
298
299 ocspResponse, err := rs.Source.Response(ctx, ocspRequest)
300 if err != nil {
301 if errors.Is(err, ErrNotFound) {
302 response.Write(ocsp.UnauthorizedErrorResponse)
303 rs.responseTypes.With(prometheus.Labels{"type": responseTypeToString[ocsp.Unauthorized]}).Inc()
304 return
305 } else if errors.Is(err, errOCSPResponseExpired) {
306 rs.sampledError("Requested ocsp response is expired: serial %x, request body %s",
307 ocspRequest.SerialNumber, b64Body)
308
309 response.WriteHeader(533)
310 response.Write(ocsp.InternalErrorErrorResponse)
311 rs.responseTypes.With(prometheus.Labels{"type": responseTypeToString[ocsp.Unauthorized]}).Inc()
312 return
313 }
314 rs.sampledError("Error retrieving response for request: serial %x, request body %s, error: %s",
315 ocspRequest.SerialNumber, b64Body, err)
316 response.WriteHeader(http.StatusInternalServerError)
317 response.Write(ocsp.InternalErrorErrorResponse)
318 rs.responseTypes.With(prometheus.Labels{"type": responseTypeToString[ocsp.InternalError]}).Inc()
319 return
320 }
321
322
323 response.Header().Add("Last-Modified", ocspResponse.ThisUpdate.Format(time.RFC1123))
324 response.Header().Add("Expires", ocspResponse.NextUpdate.Format(time.RFC1123))
325 now := rs.clk.Now()
326 var maxAge int
327 if now.Before(ocspResponse.NextUpdate) {
328 maxAge = int(ocspResponse.NextUpdate.Sub(now) / time.Second)
329 } else {
330
331
332 maxAge = 0
333 }
334 response.Header().Set(
335 "Cache-Control",
336 fmt.Sprintf(
337 "max-age=%d, public, no-transform, must-revalidate",
338 maxAge,
339 ),
340 )
341 responseHash := sha256.Sum256(ocspResponse.Raw)
342 response.Header().Add("ETag", fmt.Sprintf("\"%X\"", responseHash))
343
344 serialString := core.SerialToString(ocspResponse.SerialNumber)
345 if len(serialString) > 2 {
346
347
348
349 response.Header().Add("Edge-Cache-Tag", serialString[len(serialString)-2:])
350 }
351
352
353
354
355 if etag := request.Header.Get("If-None-Match"); etag != "" {
356 if etag == fmt.Sprintf("\"%X\"", responseHash) {
357 response.WriteHeader(http.StatusNotModified)
358 return
359 }
360 }
361 response.WriteHeader(http.StatusOK)
362 response.Write(ocspResponse.Raw)
363 rs.responseAges.Observe(rs.clk.Now().Sub(ocspResponse.ThisUpdate).Seconds())
364 rs.responseTypes.With(prometheus.Labels{"type": responseTypeToString[ocsp.Success]}).Inc()
365 }
366
View as plain text