...

Source file src/github.com/go-chi/chi/middleware/throttle.go

Documentation: github.com/go-chi/chi/middleware

     1  package middleware
     2  
     3  import (
     4  	"net/http"
     5  	"strconv"
     6  	"time"
     7  )
     8  
     9  const (
    10  	errCapacityExceeded = "Server capacity exceeded."
    11  	errTimedOut         = "Timed out while waiting for a pending request to complete."
    12  	errContextCanceled  = "Context was canceled."
    13  )
    14  
    15  var (
    16  	defaultBacklogTimeout = time.Second * 60
    17  )
    18  
    19  // ThrottleOpts represents a set of throttling options.
    20  type ThrottleOpts struct {
    21  	Limit          int
    22  	BacklogLimit   int
    23  	BacklogTimeout time.Duration
    24  	RetryAfterFn   func(ctxDone bool) time.Duration
    25  }
    26  
    27  // Throttle is a middleware that limits number of currently processed requests
    28  // at a time across all users. Note: Throttle is not a rate-limiter per user,
    29  // instead it just puts a ceiling on the number of currentl in-flight requests
    30  // being processed from the point from where the Throttle middleware is mounted.
    31  func Throttle(limit int) func(http.Handler) http.Handler {
    32  	return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout})
    33  }
    34  
    35  // ThrottleBacklog is a middleware that limits number of currently processed
    36  // requests at a time and provides a backlog for holding a finite number of
    37  // pending requests.
    38  func ThrottleBacklog(limit int, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler {
    39  	return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogLimit: backlogLimit, BacklogTimeout: backlogTimeout})
    40  }
    41  
    42  // ThrottleWithOpts is a middleware that limits number of currently processed requests using passed ThrottleOpts.
    43  func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
    44  	if opts.Limit < 1 {
    45  		panic("chi/middleware: Throttle expects limit > 0")
    46  	}
    47  
    48  	if opts.BacklogLimit < 0 {
    49  		panic("chi/middleware: Throttle expects backlogLimit to be positive")
    50  	}
    51  
    52  	t := throttler{
    53  		tokens:         make(chan token, opts.Limit),
    54  		backlogTokens:  make(chan token, opts.Limit+opts.BacklogLimit),
    55  		backlogTimeout: opts.BacklogTimeout,
    56  		retryAfterFn:   opts.RetryAfterFn,
    57  	}
    58  
    59  	// Filling tokens.
    60  	for i := 0; i < opts.Limit+opts.BacklogLimit; i++ {
    61  		if i < opts.Limit {
    62  			t.tokens <- token{}
    63  		}
    64  		t.backlogTokens <- token{}
    65  	}
    66  
    67  	return func(next http.Handler) http.Handler {
    68  		fn := func(w http.ResponseWriter, r *http.Request) {
    69  			ctx := r.Context()
    70  
    71  			select {
    72  
    73  			case <-ctx.Done():
    74  				t.setRetryAfterHeaderIfNeeded(w, true)
    75  				http.Error(w, errContextCanceled, http.StatusServiceUnavailable)
    76  				return
    77  
    78  			case btok := <-t.backlogTokens:
    79  				timer := time.NewTimer(t.backlogTimeout)
    80  
    81  				defer func() {
    82  					t.backlogTokens <- btok
    83  				}()
    84  
    85  				select {
    86  				case <-timer.C:
    87  					t.setRetryAfterHeaderIfNeeded(w, false)
    88  					http.Error(w, errTimedOut, http.StatusServiceUnavailable)
    89  					return
    90  				case <-ctx.Done():
    91  					timer.Stop()
    92  					t.setRetryAfterHeaderIfNeeded(w, true)
    93  					http.Error(w, errContextCanceled, http.StatusServiceUnavailable)
    94  					return
    95  				case tok := <-t.tokens:
    96  					defer func() {
    97  						timer.Stop()
    98  						t.tokens <- tok
    99  					}()
   100  					next.ServeHTTP(w, r)
   101  				}
   102  				return
   103  
   104  			default:
   105  				t.setRetryAfterHeaderIfNeeded(w, false)
   106  				http.Error(w, errCapacityExceeded, http.StatusServiceUnavailable)
   107  				return
   108  			}
   109  		}
   110  
   111  		return http.HandlerFunc(fn)
   112  	}
   113  }
   114  
   115  // token represents a request that is being processed.
   116  type token struct{}
   117  
   118  // throttler limits number of currently processed requests at a time.
   119  type throttler struct {
   120  	tokens         chan token
   121  	backlogTokens  chan token
   122  	backlogTimeout time.Duration
   123  	retryAfterFn   func(ctxDone bool) time.Duration
   124  }
   125  
   126  // setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized.
   127  func (t throttler) setRetryAfterHeaderIfNeeded(w http.ResponseWriter, ctxDone bool) {
   128  	if t.retryAfterFn == nil {
   129  		return
   130  	}
   131  	w.Header().Set("Retry-After", strconv.Itoa(int(t.retryAfterFn(ctxDone).Seconds())))
   132  }
   133  

View as plain text