...

Source file src/github.com/grpc-ecosystem/go-grpc-middleware/retry/retry.go

Documentation: github.com/grpc-ecosystem/go-grpc-middleware/retry

     1  // Copyright 2016 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_retry
     5  
     6  import (
     7  	"context"
     8  	"io"
     9  	"strconv"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
    14  	"golang.org/x/net/trace"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/codes"
    17  	"google.golang.org/grpc/metadata"
    18  	"google.golang.org/grpc/status"
    19  )
    20  
    21  const (
    22  	AttemptMetadataKey = "x-retry-attempty"
    23  )
    24  
    25  // UnaryClientInterceptor returns a new retrying unary client interceptor.
    26  //
    27  // The default configuration of the interceptor is to not retry *at all*. This behaviour can be
    28  // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
    29  func UnaryClientInterceptor(optFuncs ...CallOption) grpc.UnaryClientInterceptor {
    30  	intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
    31  	return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    32  		grpcOpts, retryOpts := filterCallOptions(opts)
    33  		callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
    34  		// short circuit for simplicity, and avoiding allocations.
    35  		if callOpts.max == 0 {
    36  			return invoker(parentCtx, method, req, reply, cc, grpcOpts...)
    37  		}
    38  		var lastErr error
    39  		for attempt := uint(0); attempt < callOpts.max; attempt++ {
    40  			if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil {
    41  				return err
    42  			}
    43  			callCtx := perCallContext(parentCtx, callOpts, attempt)
    44  			lastErr = invoker(callCtx, method, req, reply, cc, grpcOpts...)
    45  			// TODO(mwitkow): Maybe dial and transport errors should be retriable?
    46  			if lastErr == nil {
    47  				return nil
    48  			}
    49  			logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr)
    50  			if isContextError(lastErr) {
    51  				if parentCtx.Err() != nil {
    52  					logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err())
    53  					// its the parent context deadline or cancellation.
    54  					return lastErr
    55  				} else if callOpts.perCallTimeout != 0 {
    56  					// We have set a perCallTimeout in the retry middleware, which would result in a context error if
    57  					// the deadline was exceeded, in which case try again.
    58  					logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt)
    59  					continue
    60  				}
    61  			}
    62  			if !isRetriable(lastErr, callOpts) {
    63  				return lastErr
    64  			}
    65  		}
    66  		return lastErr
    67  	}
    68  }
    69  
    70  // StreamClientInterceptor returns a new retrying stream client interceptor for server side streaming calls.
    71  //
    72  // The default configuration of the interceptor is to not retry *at all*. This behaviour can be
    73  // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
    74  //
    75  // Retry logic is available *only for ServerStreams*, i.e. 1:n streams, as the internal logic needs
    76  // to buffer the messages sent by the client. If retry is enabled on any other streams (ClientStreams,
    77  // BidiStreams), the retry interceptor will fail the call.
    78  func StreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientInterceptor {
    79  	intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
    80  	return func(parentCtx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    81  		grpcOpts, retryOpts := filterCallOptions(opts)
    82  		callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
    83  		// short circuit for simplicity, and avoiding allocations.
    84  		if callOpts.max == 0 {
    85  			return streamer(parentCtx, desc, cc, method, grpcOpts...)
    86  		}
    87  		if desc.ClientStreams {
    88  			return nil, status.Errorf(codes.Unimplemented, "grpc_retry: cannot retry on ClientStreams, set grpc_retry.Disable()")
    89  		}
    90  
    91  		var lastErr error
    92  		for attempt := uint(0); attempt < callOpts.max; attempt++ {
    93  			if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil {
    94  				return nil, err
    95  			}
    96  			callCtx := perCallContext(parentCtx, callOpts, 0)
    97  
    98  			var newStreamer grpc.ClientStream
    99  			newStreamer, lastErr = streamer(callCtx, desc, cc, method, grpcOpts...)
   100  			if lastErr == nil {
   101  				retryingStreamer := &serverStreamingRetryingStream{
   102  					ClientStream: newStreamer,
   103  					callOpts:     callOpts,
   104  					parentCtx:    parentCtx,
   105  					streamerCall: func(ctx context.Context) (grpc.ClientStream, error) {
   106  						return streamer(ctx, desc, cc, method, grpcOpts...)
   107  					},
   108  				}
   109  				return retryingStreamer, nil
   110  			}
   111  
   112  			logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr)
   113  			if isContextError(lastErr) {
   114  				if parentCtx.Err() != nil {
   115  					logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err())
   116  					// its the parent context deadline or cancellation.
   117  					return nil, lastErr
   118  				} else if callOpts.perCallTimeout != 0 {
   119  					// We have set a perCallTimeout in the retry middleware, which would result in a context error if
   120  					// the deadline was exceeded, in which case try again.
   121  					logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt)
   122  					continue
   123  				}
   124  			}
   125  			if !isRetriable(lastErr, callOpts) {
   126  				return nil, lastErr
   127  			}
   128  		}
   129  		return nil, lastErr
   130  	}
   131  }
   132  
   133  // type serverStreamingRetryingStream is the implementation of grpc.ClientStream that acts as a
   134  // proxy to the underlying call. If any of the RecvMsg() calls fail, it will try to reestablish
   135  // a new ClientStream according to the retry policy.
   136  type serverStreamingRetryingStream struct {
   137  	grpc.ClientStream
   138  	bufferedSends []interface{} // single message that the client can sen
   139  	wasClosedSend bool          // indicates that CloseSend was closed
   140  	parentCtx     context.Context
   141  	callOpts      *options
   142  	streamerCall  func(ctx context.Context) (grpc.ClientStream, error)
   143  	mu            sync.RWMutex
   144  }
   145  
   146  func (s *serverStreamingRetryingStream) setStream(clientStream grpc.ClientStream) {
   147  	s.mu.Lock()
   148  	s.ClientStream = clientStream
   149  	s.mu.Unlock()
   150  }
   151  
   152  func (s *serverStreamingRetryingStream) getStream() grpc.ClientStream {
   153  	s.mu.RLock()
   154  	defer s.mu.RUnlock()
   155  	return s.ClientStream
   156  }
   157  
   158  func (s *serverStreamingRetryingStream) SendMsg(m interface{}) error {
   159  	s.mu.Lock()
   160  	s.bufferedSends = append(s.bufferedSends, m)
   161  	s.mu.Unlock()
   162  	return s.getStream().SendMsg(m)
   163  }
   164  
   165  func (s *serverStreamingRetryingStream) CloseSend() error {
   166  	s.mu.Lock()
   167  	s.wasClosedSend = true
   168  	s.mu.Unlock()
   169  	return s.getStream().CloseSend()
   170  }
   171  
   172  func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) {
   173  	return s.getStream().Header()
   174  }
   175  
   176  func (s *serverStreamingRetryingStream) Trailer() metadata.MD {
   177  	return s.getStream().Trailer()
   178  }
   179  
   180  func (s *serverStreamingRetryingStream) RecvMsg(m interface{}) error {
   181  	attemptRetry, lastErr := s.receiveMsgAndIndicateRetry(m)
   182  	if !attemptRetry {
   183  		return lastErr // success or hard failure
   184  	}
   185  	// We start off from attempt 1, because zeroth was already made on normal SendMsg().
   186  	for attempt := uint(1); attempt < s.callOpts.max; attempt++ {
   187  		if err := waitRetryBackoff(attempt, s.parentCtx, s.callOpts); err != nil {
   188  			return err
   189  		}
   190  		callCtx := perCallContext(s.parentCtx, s.callOpts, attempt)
   191  		newStream, err := s.reestablishStreamAndResendBuffer(callCtx)
   192  		if err != nil {
   193  			// Retry dial and transport errors of establishing stream as grpc doesn't retry.
   194  			if isRetriable(err, s.callOpts) {
   195  				continue
   196  			}
   197  			return err
   198  		}
   199  
   200  		s.setStream(newStream)
   201  		attemptRetry, lastErr = s.receiveMsgAndIndicateRetry(m)
   202  		//fmt.Printf("Received message and indicate: %v  %v\n", attemptRetry, lastErr)
   203  		if !attemptRetry {
   204  			return lastErr
   205  		}
   206  	}
   207  	return lastErr
   208  }
   209  
   210  func (s *serverStreamingRetryingStream) receiveMsgAndIndicateRetry(m interface{}) (bool, error) {
   211  	err := s.getStream().RecvMsg(m)
   212  	if err == nil || err == io.EOF {
   213  		return false, err
   214  	}
   215  	if isContextError(err) {
   216  		if s.parentCtx.Err() != nil {
   217  			logTrace(s.parentCtx, "grpc_retry parent context error: %v", s.parentCtx.Err())
   218  			return false, err
   219  		} else if s.callOpts.perCallTimeout != 0 {
   220  			// We have set a perCallTimeout in the retry middleware, which would result in a context error if
   221  			// the deadline was exceeded, in which case try again.
   222  			logTrace(s.parentCtx, "grpc_retry context error from retry call")
   223  			return true, err
   224  		}
   225  	}
   226  	return isRetriable(err, s.callOpts), err
   227  }
   228  
   229  func (s *serverStreamingRetryingStream) reestablishStreamAndResendBuffer(
   230  	callCtx context.Context,
   231  ) (grpc.ClientStream, error) {
   232  	s.mu.RLock()
   233  	bufferedSends := s.bufferedSends
   234  	s.mu.RUnlock()
   235  	newStream, err := s.streamerCall(callCtx)
   236  	if err != nil {
   237  		logTrace(callCtx, "grpc_retry failed redialing new stream: %v", err)
   238  		return nil, err
   239  	}
   240  	for _, msg := range bufferedSends {
   241  		if err := newStream.SendMsg(msg); err != nil {
   242  			logTrace(callCtx, "grpc_retry failed resending message: %v", err)
   243  			return nil, err
   244  		}
   245  	}
   246  	if err := newStream.CloseSend(); err != nil {
   247  		logTrace(callCtx, "grpc_retry failed CloseSend on new stream %v", err)
   248  		return nil, err
   249  	}
   250  	return newStream, nil
   251  }
   252  
   253  func waitRetryBackoff(attempt uint, parentCtx context.Context, callOpts *options) error {
   254  	var waitTime time.Duration = 0
   255  	if attempt > 0 {
   256  		waitTime = callOpts.backoffFunc(parentCtx, attempt)
   257  	}
   258  	if waitTime > 0 {
   259  		logTrace(parentCtx, "grpc_retry attempt: %d, backoff for %v", attempt, waitTime)
   260  		timer := time.NewTimer(waitTime)
   261  		select {
   262  		case <-parentCtx.Done():
   263  			timer.Stop()
   264  			return contextErrToGrpcErr(parentCtx.Err())
   265  		case <-timer.C:
   266  		}
   267  	}
   268  	return nil
   269  }
   270  
   271  func isRetriable(err error, callOpts *options) bool {
   272  	errCode := status.Code(err)
   273  	if isContextError(err) {
   274  		// context errors are not retriable based on user settings.
   275  		return false
   276  	}
   277  	for _, code := range callOpts.codes {
   278  		if code == errCode {
   279  			return true
   280  		}
   281  	}
   282  	return false
   283  }
   284  
   285  func isContextError(err error) bool {
   286  	code := status.Code(err)
   287  	return code == codes.DeadlineExceeded || code == codes.Canceled
   288  }
   289  
   290  func perCallContext(parentCtx context.Context, callOpts *options, attempt uint) context.Context {
   291  	ctx := parentCtx
   292  	if callOpts.perCallTimeout != 0 {
   293  		ctx, _ = context.WithTimeout(ctx, callOpts.perCallTimeout)
   294  	}
   295  	if attempt > 0 && callOpts.includeHeader {
   296  		mdClone := metautils.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, strconv.FormatUint(uint64(attempt), 10))
   297  		ctx = mdClone.ToOutgoing(ctx)
   298  	}
   299  	return ctx
   300  }
   301  
   302  func contextErrToGrpcErr(err error) error {
   303  	switch err {
   304  	case context.DeadlineExceeded:
   305  		return status.Error(codes.DeadlineExceeded, err.Error())
   306  	case context.Canceled:
   307  		return status.Error(codes.Canceled, err.Error())
   308  	default:
   309  		return status.Error(codes.Unknown, err.Error())
   310  	}
   311  }
   312  
   313  func logTrace(ctx context.Context, format string, a ...interface{}) {
   314  	tr, ok := trace.FromContext(ctx)
   315  	if !ok {
   316  		return
   317  	}
   318  	tr.LazyPrintf(format, a...)
   319  }
   320  

View as plain text