1 package grpc
2
3 import (
4 "context"
5 "fmt"
6 "strconv"
7 "strings"
8 "time"
9
10 "github.com/jmhodges/clock"
11 "github.com/prometheus/client_golang/prometheus"
12 "google.golang.org/grpc"
13 "google.golang.org/grpc/codes"
14 "google.golang.org/grpc/credentials"
15 "google.golang.org/grpc/metadata"
16 "google.golang.org/grpc/peer"
17 "google.golang.org/grpc/status"
18
19 "github.com/letsencrypt/boulder/cmd"
20 berrors "github.com/letsencrypt/boulder/errors"
21 )
22
23 const (
24 returnOverhead = 20 * time.Millisecond
25 meaningfulWorkOverhead = 100 * time.Millisecond
26 clientRequestTimeKey = "client-request-time"
27 )
28
29 type serverInterceptor interface {
30 Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error)
31 Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error
32 }
33
34
35
36 type noopServerInterceptor struct{}
37
38
39 func (n *noopServerInterceptor) Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
40 return handler(ctx, req)
41 }
42
43
44 func (n *noopServerInterceptor) Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
45 return handler(srv, ss)
46 }
47
48
49 var _ serverInterceptor = &noopServerInterceptor{}
50
51 type clientInterceptor interface {
52 Unary(ctx context.Context, method string, req interface{}, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error
53 Stream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error)
54 }
55
56
57
58
59 type serverMetadataInterceptor struct {
60 metrics serverMetrics
61 clk clock.Clock
62 }
63
64 func newServerMetadataInterceptor(metrics serverMetrics, clk clock.Clock) serverMetadataInterceptor {
65 return serverMetadataInterceptor{
66 metrics: metrics,
67 clk: clk,
68 }
69 }
70
71
72 func (smi *serverMetadataInterceptor) Unary(
73 ctx context.Context,
74 req interface{},
75 info *grpc.UnaryServerInfo,
76 handler grpc.UnaryHandler) (interface{}, error) {
77 if info == nil {
78 return nil, berrors.InternalServerError("passed nil *grpc.UnaryServerInfo")
79 }
80
81
82
83
84 if md, ok := metadata.FromIncomingContext(ctx); ok && len(md[clientRequestTimeKey]) > 0 {
85 err := smi.observeLatency(md[clientRequestTimeKey][0])
86 if err != nil {
87 return nil, err
88 }
89 }
90
91
92
93
94
95
96
97
98
99 deadline, ok := ctx.Deadline()
100
101 if !ok {
102 deadline = time.Now().Add(100 * time.Second)
103 }
104 deadline = deadline.Add(-returnOverhead)
105 remaining := time.Until(deadline)
106 if remaining < meaningfulWorkOverhead {
107 return nil, status.Errorf(codes.DeadlineExceeded, "not enough time left on clock: %s", remaining)
108 }
109
110 localCtx, cancel := context.WithDeadline(ctx, deadline)
111 defer cancel()
112
113 resp, err := handler(localCtx, req)
114 if err != nil {
115 err = wrapError(localCtx, err)
116 }
117 return resp, err
118 }
119
120
121
122 type interceptedServerStream struct {
123 grpc.ServerStream
124 ctx context.Context
125 }
126
127
128 func (iss interceptedServerStream) Context() context.Context {
129 return iss.ctx
130 }
131
132
133 func (smi *serverMetadataInterceptor) Stream(
134 srv interface{},
135 ss grpc.ServerStream,
136 info *grpc.StreamServerInfo,
137 handler grpc.StreamHandler) error {
138 ctx := ss.Context()
139
140
141
142
143 if md, ok := metadata.FromIncomingContext(ctx); ok && len(md[clientRequestTimeKey]) > 0 {
144 err := smi.observeLatency(md[clientRequestTimeKey][0])
145 if err != nil {
146 return err
147 }
148 }
149
150
151
152
153
154
155
156
157
158 deadline, ok := ctx.Deadline()
159
160 if !ok {
161 deadline = time.Now().Add(100 * time.Second)
162 }
163 deadline = deadline.Add(-returnOverhead)
164 remaining := time.Until(deadline)
165 if remaining < meaningfulWorkOverhead {
166 return status.Errorf(codes.DeadlineExceeded, "not enough time left on clock: %s", remaining)
167 }
168
169
170
171 localCtx, cancel := context.WithDeadline(ctx, deadline)
172 defer cancel()
173
174 err := handler(srv, interceptedServerStream{ss, localCtx})
175 if err != nil {
176 err = wrapError(localCtx, err)
177 }
178 return err
179 }
180
181
182
183
184
185 func splitMethodName(fullMethodName string) (string, string) {
186 fullMethodName = strings.TrimPrefix(fullMethodName, "/")
187 if i := strings.Index(fullMethodName, "/"); i >= 0 {
188 return fullMethodName[:i], fullMethodName[i+1:]
189 }
190 return "unknown", "unknown"
191 }
192
193
194
195
196
197
198 func (smi *serverMetadataInterceptor) observeLatency(clientReqTime string) error {
199
200 reqTimeUnixNanos, err := strconv.ParseInt(clientReqTime, 10, 64)
201 if err != nil {
202 return berrors.InternalServerError("grpc metadata had illegal %s value: %q - %s",
203 clientRequestTimeKey, clientReqTime, err)
204 }
205
206 reqTime := time.Unix(0, reqTimeUnixNanos)
207 elapsed := smi.clk.Since(reqTime)
208
209 smi.metrics.rpcLag.Observe(elapsed.Seconds())
210 return nil
211 }
212
213
214 var _ serverInterceptor = (*serverMetadataInterceptor)(nil)
215
216
217
218
219
220
221
222
223 type clientMetadataInterceptor struct {
224 timeout time.Duration
225 metrics clientMetrics
226 clk clock.Clock
227
228 waitForReady bool
229 }
230
231
232 func (cmi *clientMetadataInterceptor) Unary(
233 ctx context.Context,
234 fullMethod string,
235 req,
236 reply interface{},
237 cc *grpc.ClientConn,
238 invoker grpc.UnaryInvoker,
239 opts ...grpc.CallOption) error {
240
241
242 if cmi.metrics.inFlightRPCs == nil {
243 return berrors.InternalServerError("clientInterceptor has nil inFlightRPCs gauge")
244 }
245
246
247 localCtx, cancel := context.WithTimeout(ctx, cmi.timeout)
248 defer cancel()
249
250
251 nowTS := strconv.FormatInt(cmi.clk.Now().UnixNano(), 10)
252
253
254 reqMD := metadata.New(map[string]string{clientRequestTimeKey: nowTS})
255
256 localCtx = metadata.NewOutgoingContext(localCtx, reqMD)
257
258
259
260 opts = append(opts, grpc.WaitForReady(cmi.waitForReady))
261
262
263 respMD := metadata.New(nil)
264
265
266 opts = append(opts, grpc.Trailer(&respMD))
267
268
269
270
271 service, method := splitMethodName(fullMethod)
272
273 labels := prometheus.Labels{
274 "method": method,
275 "service": service,
276 }
277
278 cmi.metrics.inFlightRPCs.With(labels).Inc()
279
280 defer cmi.metrics.inFlightRPCs.With(labels).Dec()
281
282
283 begin := cmi.clk.Now()
284 err := invoker(localCtx, fullMethod, req, reply, cc, opts...)
285 if err != nil {
286 err = unwrapError(err, respMD)
287 if status.Code(err) == codes.DeadlineExceeded {
288 return deadlineDetails{
289 service: service,
290 method: method,
291 latency: cmi.clk.Since(begin),
292 }
293 }
294 }
295 return err
296 }
297
298
299
300 type interceptedClientStream struct {
301 grpc.ClientStream
302 finish func(error) error
303 }
304
305
306 func (ics interceptedClientStream) Header() (metadata.MD, error) {
307 md, err := ics.ClientStream.Header()
308 if err != nil {
309 err = ics.finish(err)
310 }
311 return md, err
312 }
313
314
315 func (ics interceptedClientStream) SendMsg(m interface{}) error {
316 err := ics.ClientStream.SendMsg(m)
317 if err != nil {
318 err = ics.finish(err)
319 }
320 return err
321 }
322
323
324 func (ics interceptedClientStream) RecvMsg(m interface{}) error {
325 err := ics.ClientStream.RecvMsg(m)
326 if err != nil {
327 err = ics.finish(err)
328 }
329 return err
330 }
331
332
333 func (ics interceptedClientStream) CloseSend() error {
334 err := ics.ClientStream.CloseSend()
335 if err != nil {
336 err = ics.finish(err)
337 }
338 return err
339 }
340
341
342 func (cmi *clientMetadataInterceptor) Stream(
343 ctx context.Context,
344 desc *grpc.StreamDesc,
345 cc *grpc.ClientConn,
346 fullMethod string,
347 streamer grpc.Streamer,
348 opts ...grpc.CallOption) (grpc.ClientStream, error) {
349
350
351 if cmi.metrics.inFlightRPCs == nil {
352 return nil, berrors.InternalServerError("clientInterceptor has nil inFlightRPCs gauge")
353 }
354
355
356
357 localCtx, cancel := context.WithTimeout(ctx, cmi.timeout)
358
359
360 nowTS := strconv.FormatInt(cmi.clk.Now().UnixNano(), 10)
361
362
363 reqMD := metadata.New(map[string]string{clientRequestTimeKey: nowTS})
364
365 localCtx = metadata.NewOutgoingContext(localCtx, reqMD)
366
367
368
369 opts = append(opts, grpc.WaitForReady(cmi.waitForReady))
370
371
372 respMD := metadata.New(nil)
373
374
375 opts = append(opts, grpc.Trailer(&respMD))
376
377
378
379
380 service, method := splitMethodName(fullMethod)
381
382 labels := prometheus.Labels{
383 "method": method,
384 "service": service,
385 }
386
387 cmi.metrics.inFlightRPCs.With(labels).Inc()
388 begin := cmi.clk.Now()
389
390
391
392 finish := func(err error) error {
393 cancel()
394 cmi.metrics.inFlightRPCs.With(labels).Dec()
395 if err != nil {
396 err = unwrapError(err, respMD)
397 if status.Code(err) == codes.DeadlineExceeded {
398 return deadlineDetails{
399 service: service,
400 method: method,
401 latency: cmi.clk.Since(begin),
402 }
403 }
404 }
405 return err
406 }
407
408
409 cs, err := streamer(localCtx, desc, cc, fullMethod, opts...)
410 ics := interceptedClientStream{cs, finish}
411 return ics, err
412 }
413
414 var _ clientInterceptor = (*clientMetadataInterceptor)(nil)
415
416
417
418 type deadlineDetails struct {
419 service string
420 method string
421 latency time.Duration
422 }
423
424 func (dd deadlineDetails) Error() string {
425 return fmt.Sprintf("%s.%s timed out after %d ms",
426 dd.service, dd.method, int64(dd.latency/time.Millisecond))
427 }
428
429
430
431
432 type authInterceptor struct {
433
434
435
436
437 serviceClientNames map[string]map[string]struct{}
438 }
439
440
441
442
443 func newServiceAuthChecker(c *cmd.GRPCServerConfig) *authInterceptor {
444 names := make(map[string]map[string]struct{})
445 for serviceName, service := range c.Services {
446 names[serviceName] = make(map[string]struct{})
447 for _, clientName := range service.ClientNames {
448 names[serviceName][clientName] = struct{}{}
449 }
450 }
451 return &authInterceptor{names}
452 }
453
454
455 func (ac *authInterceptor) Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
456 err := ac.checkContextAuth(ctx, info.FullMethod)
457 if err != nil {
458 return nil, err
459 }
460 return handler(ctx, req)
461 }
462
463
464 func (ac *authInterceptor) Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
465 err := ac.checkContextAuth(ss.Context(), info.FullMethod)
466 if err != nil {
467 return err
468 }
469 return handler(srv, ss)
470 }
471
472
473
474
475
476
477 func (ac *authInterceptor) checkContextAuth(ctx context.Context, fullMethod string) error {
478 serviceName, _ := splitMethodName(fullMethod)
479
480 allowedClientNames, ok := ac.serviceClientNames[serviceName]
481 if !ok || len(allowedClientNames) == 0 {
482 return fmt.Errorf("service %q has no allowed client names", serviceName)
483 }
484
485 p, ok := peer.FromContext(ctx)
486 if !ok {
487 return fmt.Errorf("unable to fetch peer info from grpc context")
488 }
489
490 if p.AuthInfo == nil {
491 return fmt.Errorf("grpc connection appears to be plaintext")
492 }
493
494 tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo)
495 if !ok {
496 return fmt.Errorf("connection is not TLS authed")
497 }
498
499 if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 {
500 return fmt.Errorf("connection auth not verified")
501 }
502
503 cert := tlsAuth.State.VerifiedChains[0][0]
504
505 for _, clientName := range cert.DNSNames {
506 _, ok := allowedClientNames[clientName]
507 if ok {
508 return nil
509 }
510 }
511
512 return fmt.Errorf(
513 "client names %v are not authorized for service %q (%v)",
514 cert.DNSNames, serviceName, allowedClientNames)
515 }
516
517
518 var _ serverInterceptor = (*authInterceptor)(nil)
519
View as plain text