1
2
3
4 package grpc_recovery_test
5
6 import (
7 "context"
8 "testing"
9
10 grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
11 grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
12 grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing"
13 pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16 "github.com/stretchr/testify/suite"
17 "google.golang.org/grpc"
18 "google.golang.org/grpc/codes"
19 "google.golang.org/grpc/status"
20 )
21
22 var (
23 goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
24 panicPing = &pb_testproto.PingRequest{Value: "panic", SleepTimeMs: 9999}
25 nilPanicPing = &pb_testproto.PingRequest{Value: "nilpanic", SleepTimeMs: 9999}
26 )
27
28 type recoveryAssertService struct {
29 pb_testproto.TestServiceServer
30 }
31
32 func (s *recoveryAssertService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) {
33 if ping.Value == "panic" {
34 panic("very bad thing happened")
35 }
36 if ping.Value == "nilpanic" {
37 panic(nil)
38 }
39 return s.TestServiceServer.Ping(ctx, ping)
40 }
41
42 func (s *recoveryAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
43 if ping.Value == "panic" {
44 panic("very bad thing happened")
45 }
46 if ping.Value == "nilpanic" {
47 panic(nil)
48 }
49 return s.TestServiceServer.PingList(ping, stream)
50 }
51
52 func TestRecoverySuite(t *testing.T) {
53 s := &RecoverySuite{
54 InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
55 TestService: &recoveryAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}},
56 ServerOpts: []grpc.ServerOption{
57 grpc_middleware.WithStreamServerChain(
58 grpc_recovery.StreamServerInterceptor()),
59 grpc_middleware.WithUnaryServerChain(
60 grpc_recovery.UnaryServerInterceptor()),
61 },
62 },
63 }
64 suite.Run(t, s)
65 }
66
67 type RecoverySuite struct {
68 *grpc_testing.InterceptorTestSuite
69 }
70
71 func (s *RecoverySuite) TestUnary_SuccessfulRequest() {
72 _, err := s.Client.Ping(s.SimpleCtx(), goodPing)
73 require.NoError(s.T(), err, "no error must occur")
74 }
75
76 func (s *RecoverySuite) TestUnary_PanickingRequest() {
77 _, err := s.Client.Ping(s.SimpleCtx(), panicPing)
78 require.Error(s.T(), err, "there must be an error")
79 assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
80 assert.Equal(s.T(), "very bad thing happened", status.Convert(err).Message(), "must error with message")
81 }
82
83 func (s *RecoverySuite) TestUnary_NilPanickingRequest() {
84 _, err := s.Client.Ping(s.SimpleCtx(), nilPanicPing)
85 require.Error(s.T(), err, "there must be an error")
86 assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
87 assert.Equal(s.T(), "<nil>", status.Convert(err).Message(), "must error with <nil>")
88 }
89
90 func (s *RecoverySuite) TestStream_SuccessfulReceive() {
91 stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
92 require.NoError(s.T(), err, "should not fail on establishing the stream")
93 pong, err := stream.Recv()
94 require.NoError(s.T(), err, "no error must occur")
95 require.NotNil(s.T(), pong, "pong must not be nil")
96 }
97
98 func (s *RecoverySuite) TestStream_PanickingReceive() {
99 stream, err := s.Client.PingList(s.SimpleCtx(), panicPing)
100 require.NoError(s.T(), err, "should not fail on establishing the stream")
101 _, err = stream.Recv()
102 require.Error(s.T(), err, "there must be an error")
103 assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
104 assert.Equal(s.T(), "very bad thing happened", status.Convert(err).Message(), "must error with message")
105 }
106
107 func (s *RecoverySuite) TestStream_NilPanickingReceive() {
108 stream, err := s.Client.PingList(s.SimpleCtx(), nilPanicPing)
109 require.NoError(s.T(), err, "should not fail on establishing the stream")
110 _, err = stream.Recv()
111 require.Error(s.T(), err, "there must be an error")
112 assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
113 assert.Equal(s.T(), "<nil>", status.Convert(err).Message(), "must error with <nil>")
114 }
115
116 func TestRecoveryOverrideSuite(t *testing.T) {
117 opts := []grpc_recovery.Option{
118 grpc_recovery.WithRecoveryHandler(func(p interface{}) (err error) {
119 return status.Errorf(codes.Unknown, "panic triggered: %v", p)
120 }),
121 }
122 s := &RecoveryOverrideSuite{
123 InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
124 TestService: &recoveryAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}},
125 ServerOpts: []grpc.ServerOption{
126 grpc_middleware.WithStreamServerChain(
127 grpc_recovery.StreamServerInterceptor(opts...)),
128 grpc_middleware.WithUnaryServerChain(
129 grpc_recovery.UnaryServerInterceptor(opts...)),
130 },
131 },
132 }
133 suite.Run(t, s)
134 }
135
136 type RecoveryOverrideSuite struct {
137 *grpc_testing.InterceptorTestSuite
138 }
139
140 func (s *RecoveryOverrideSuite) TestUnary_SuccessfulRequest() {
141 _, err := s.Client.Ping(s.SimpleCtx(), goodPing)
142 require.NoError(s.T(), err, "no error must occur")
143 }
144
145 func (s *RecoveryOverrideSuite) TestUnary_PanickingRequest() {
146 _, err := s.Client.Ping(s.SimpleCtx(), panicPing)
147 require.Error(s.T(), err, "there must be an error")
148 assert.Equal(s.T(), codes.Unknown, status.Code(err), "must error with unknown")
149 assert.Equal(s.T(), "panic triggered: very bad thing happened", status.Convert(err).Message(), "must error with message")
150 }
151
152 func (s *RecoveryOverrideSuite) TestStream_SuccessfulReceive() {
153 stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
154 require.NoError(s.T(), err, "should not fail on establishing the stream")
155 pong, err := stream.Recv()
156 require.NoError(s.T(), err, "no error must occur")
157 require.NotNil(s.T(), pong, "pong must not be nil")
158 }
159
160 func (s *RecoveryOverrideSuite) TestStream_PanickingReceive() {
161 stream, err := s.Client.PingList(s.SimpleCtx(), panicPing)
162 require.NoError(s.T(), err, "should not fail on establishing the stream")
163 _, err = stream.Recv()
164 require.Error(s.T(), err, "there must be an error")
165 assert.Equal(s.T(), codes.Unknown, status.Code(err), "must error with unknown")
166 assert.Equal(s.T(), "panic triggered: very bad thing happened", status.Convert(err).Message(), "must error with message")
167 }
168
View as plain text