1
2
3
4 package grpc_retry
5
6 import (
7 "context"
8 "io"
9 "strconv"
10 "sync"
11 "time"
12
13 "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
14 "golang.org/x/net/trace"
15 "google.golang.org/grpc"
16 "google.golang.org/grpc/codes"
17 "google.golang.org/grpc/metadata"
18 "google.golang.org/grpc/status"
19 )
20
21 const (
22 AttemptMetadataKey = "x-retry-attempty"
23 )
24
25
26
27
28
29 func UnaryClientInterceptor(optFuncs ...CallOption) grpc.UnaryClientInterceptor {
30 intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
31 return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
32 grpcOpts, retryOpts := filterCallOptions(opts)
33 callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
34
35 if callOpts.max == 0 {
36 return invoker(parentCtx, method, req, reply, cc, grpcOpts...)
37 }
38 var lastErr error
39 for attempt := uint(0); attempt < callOpts.max; attempt++ {
40 if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil {
41 return err
42 }
43 callCtx := perCallContext(parentCtx, callOpts, attempt)
44 lastErr = invoker(callCtx, method, req, reply, cc, grpcOpts...)
45
46 if lastErr == nil {
47 return nil
48 }
49 logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr)
50 if isContextError(lastErr) {
51 if parentCtx.Err() != nil {
52 logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err())
53
54 return lastErr
55 } else if callOpts.perCallTimeout != 0 {
56
57
58 logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt)
59 continue
60 }
61 }
62 if !isRetriable(lastErr, callOpts) {
63 return lastErr
64 }
65 }
66 return lastErr
67 }
68 }
69
70
71
72
73
74
75
76
77
78 func StreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientInterceptor {
79 intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
80 return func(parentCtx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
81 grpcOpts, retryOpts := filterCallOptions(opts)
82 callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
83
84 if callOpts.max == 0 {
85 return streamer(parentCtx, desc, cc, method, grpcOpts...)
86 }
87 if desc.ClientStreams {
88 return nil, status.Errorf(codes.Unimplemented, "grpc_retry: cannot retry on ClientStreams, set grpc_retry.Disable()")
89 }
90
91 var lastErr error
92 for attempt := uint(0); attempt < callOpts.max; attempt++ {
93 if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil {
94 return nil, err
95 }
96 callCtx := perCallContext(parentCtx, callOpts, 0)
97
98 var newStreamer grpc.ClientStream
99 newStreamer, lastErr = streamer(callCtx, desc, cc, method, grpcOpts...)
100 if lastErr == nil {
101 retryingStreamer := &serverStreamingRetryingStream{
102 ClientStream: newStreamer,
103 callOpts: callOpts,
104 parentCtx: parentCtx,
105 streamerCall: func(ctx context.Context) (grpc.ClientStream, error) {
106 return streamer(ctx, desc, cc, method, grpcOpts...)
107 },
108 }
109 return retryingStreamer, nil
110 }
111
112 logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr)
113 if isContextError(lastErr) {
114 if parentCtx.Err() != nil {
115 logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err())
116
117 return nil, lastErr
118 } else if callOpts.perCallTimeout != 0 {
119
120
121 logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt)
122 continue
123 }
124 }
125 if !isRetriable(lastErr, callOpts) {
126 return nil, lastErr
127 }
128 }
129 return nil, lastErr
130 }
131 }
132
133
134
135
136 type serverStreamingRetryingStream struct {
137 grpc.ClientStream
138 bufferedSends []interface{}
139 wasClosedSend bool
140 parentCtx context.Context
141 callOpts *options
142 streamerCall func(ctx context.Context) (grpc.ClientStream, error)
143 mu sync.RWMutex
144 }
145
146 func (s *serverStreamingRetryingStream) setStream(clientStream grpc.ClientStream) {
147 s.mu.Lock()
148 s.ClientStream = clientStream
149 s.mu.Unlock()
150 }
151
152 func (s *serverStreamingRetryingStream) getStream() grpc.ClientStream {
153 s.mu.RLock()
154 defer s.mu.RUnlock()
155 return s.ClientStream
156 }
157
158 func (s *serverStreamingRetryingStream) SendMsg(m interface{}) error {
159 s.mu.Lock()
160 s.bufferedSends = append(s.bufferedSends, m)
161 s.mu.Unlock()
162 return s.getStream().SendMsg(m)
163 }
164
165 func (s *serverStreamingRetryingStream) CloseSend() error {
166 s.mu.Lock()
167 s.wasClosedSend = true
168 s.mu.Unlock()
169 return s.getStream().CloseSend()
170 }
171
172 func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) {
173 return s.getStream().Header()
174 }
175
176 func (s *serverStreamingRetryingStream) Trailer() metadata.MD {
177 return s.getStream().Trailer()
178 }
179
180 func (s *serverStreamingRetryingStream) RecvMsg(m interface{}) error {
181 attemptRetry, lastErr := s.receiveMsgAndIndicateRetry(m)
182 if !attemptRetry {
183 return lastErr
184 }
185
186 for attempt := uint(1); attempt < s.callOpts.max; attempt++ {
187 if err := waitRetryBackoff(attempt, s.parentCtx, s.callOpts); err != nil {
188 return err
189 }
190 callCtx := perCallContext(s.parentCtx, s.callOpts, attempt)
191 newStream, err := s.reestablishStreamAndResendBuffer(callCtx)
192 if err != nil {
193
194 if isRetriable(err, s.callOpts) {
195 continue
196 }
197 return err
198 }
199
200 s.setStream(newStream)
201 attemptRetry, lastErr = s.receiveMsgAndIndicateRetry(m)
202
203 if !attemptRetry {
204 return lastErr
205 }
206 }
207 return lastErr
208 }
209
210 func (s *serverStreamingRetryingStream) receiveMsgAndIndicateRetry(m interface{}) (bool, error) {
211 err := s.getStream().RecvMsg(m)
212 if err == nil || err == io.EOF {
213 return false, err
214 }
215 if isContextError(err) {
216 if s.parentCtx.Err() != nil {
217 logTrace(s.parentCtx, "grpc_retry parent context error: %v", s.parentCtx.Err())
218 return false, err
219 } else if s.callOpts.perCallTimeout != 0 {
220
221
222 logTrace(s.parentCtx, "grpc_retry context error from retry call")
223 return true, err
224 }
225 }
226 return isRetriable(err, s.callOpts), err
227 }
228
229 func (s *serverStreamingRetryingStream) reestablishStreamAndResendBuffer(
230 callCtx context.Context,
231 ) (grpc.ClientStream, error) {
232 s.mu.RLock()
233 bufferedSends := s.bufferedSends
234 s.mu.RUnlock()
235 newStream, err := s.streamerCall(callCtx)
236 if err != nil {
237 logTrace(callCtx, "grpc_retry failed redialing new stream: %v", err)
238 return nil, err
239 }
240 for _, msg := range bufferedSends {
241 if err := newStream.SendMsg(msg); err != nil {
242 logTrace(callCtx, "grpc_retry failed resending message: %v", err)
243 return nil, err
244 }
245 }
246 if err := newStream.CloseSend(); err != nil {
247 logTrace(callCtx, "grpc_retry failed CloseSend on new stream %v", err)
248 return nil, err
249 }
250 return newStream, nil
251 }
252
253 func waitRetryBackoff(attempt uint, parentCtx context.Context, callOpts *options) error {
254 var waitTime time.Duration = 0
255 if attempt > 0 {
256 waitTime = callOpts.backoffFunc(parentCtx, attempt)
257 }
258 if waitTime > 0 {
259 logTrace(parentCtx, "grpc_retry attempt: %d, backoff for %v", attempt, waitTime)
260 timer := time.NewTimer(waitTime)
261 select {
262 case <-parentCtx.Done():
263 timer.Stop()
264 return contextErrToGrpcErr(parentCtx.Err())
265 case <-timer.C:
266 }
267 }
268 return nil
269 }
270
271 func isRetriable(err error, callOpts *options) bool {
272 errCode := status.Code(err)
273 if isContextError(err) {
274
275 return false
276 }
277 for _, code := range callOpts.codes {
278 if code == errCode {
279 return true
280 }
281 }
282 return false
283 }
284
285 func isContextError(err error) bool {
286 code := status.Code(err)
287 return code == codes.DeadlineExceeded || code == codes.Canceled
288 }
289
290 func perCallContext(parentCtx context.Context, callOpts *options, attempt uint) context.Context {
291 ctx := parentCtx
292 if callOpts.perCallTimeout != 0 {
293 ctx, _ = context.WithTimeout(ctx, callOpts.perCallTimeout)
294 }
295 if attempt > 0 && callOpts.includeHeader {
296 mdClone := metautils.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, strconv.FormatUint(uint64(attempt), 10))
297 ctx = mdClone.ToOutgoing(ctx)
298 }
299 return ctx
300 }
301
302 func contextErrToGrpcErr(err error) error {
303 switch err {
304 case context.DeadlineExceeded:
305 return status.Error(codes.DeadlineExceeded, err.Error())
306 case context.Canceled:
307 return status.Error(codes.Canceled, err.Error())
308 default:
309 return status.Error(codes.Unknown, err.Error())
310 }
311 }
312
313 func logTrace(ctx context.Context, format string, a ...interface{}) {
314 tr, ok := trace.FromContext(ctx)
315 if !ok {
316 return
317 }
318 tr.LazyPrintf(format, a...)
319 }
320
View as plain text