...
1
2
3
4 package grpc_recovery
5
6 import (
7 "context"
8
9 "google.golang.org/grpc"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/status"
12 )
13
14
15 type RecoveryHandlerFunc func(p interface{}) (err error)
16
17
18
19 type RecoveryHandlerFuncContext func(ctx context.Context, p interface{}) (err error)
20
21
22 func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
23 o := evaluateOptions(opts)
24 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
25 panicked := true
26
27 defer func() {
28 if r := recover(); r != nil || panicked {
29 err = recoverFrom(ctx, r, o.recoveryHandlerFunc)
30 }
31 }()
32
33 resp, err := handler(ctx, req)
34 panicked = false
35 return resp, err
36 }
37 }
38
39
40 func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
41 o := evaluateOptions(opts)
42 return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
43 panicked := true
44
45 defer func() {
46 if r := recover(); r != nil || panicked {
47 err = recoverFrom(stream.Context(), r, o.recoveryHandlerFunc)
48 }
49 }()
50
51 err = handler(srv, stream)
52 panicked = false
53 return err
54 }
55 }
56
57 func recoverFrom(ctx context.Context, p interface{}, r RecoveryHandlerFuncContext) error {
58 if r == nil {
59 return status.Errorf(codes.Internal, "%v", p)
60 }
61 return r(ctx, p)
62 }
63
View as plain text