...
1 package middleware
2
3 import (
4 "net/http"
5 "strconv"
6 "time"
7 )
8
9 const (
10 errCapacityExceeded = "Server capacity exceeded."
11 errTimedOut = "Timed out while waiting for a pending request to complete."
12 errContextCanceled = "Context was canceled."
13 )
14
15 var (
16 defaultBacklogTimeout = time.Second * 60
17 )
18
19
20 type ThrottleOpts struct {
21 Limit int
22 BacklogLimit int
23 BacklogTimeout time.Duration
24 RetryAfterFn func(ctxDone bool) time.Duration
25 }
26
27
28
29
30
31 func Throttle(limit int) func(http.Handler) http.Handler {
32 return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout})
33 }
34
35
36
37
38 func ThrottleBacklog(limit int, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler {
39 return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogLimit: backlogLimit, BacklogTimeout: backlogTimeout})
40 }
41
42
43 func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
44 if opts.Limit < 1 {
45 panic("chi/middleware: Throttle expects limit > 0")
46 }
47
48 if opts.BacklogLimit < 0 {
49 panic("chi/middleware: Throttle expects backlogLimit to be positive")
50 }
51
52 t := throttler{
53 tokens: make(chan token, opts.Limit),
54 backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit),
55 backlogTimeout: opts.BacklogTimeout,
56 retryAfterFn: opts.RetryAfterFn,
57 }
58
59
60 for i := 0; i < opts.Limit+opts.BacklogLimit; i++ {
61 if i < opts.Limit {
62 t.tokens <- token{}
63 }
64 t.backlogTokens <- token{}
65 }
66
67 return func(next http.Handler) http.Handler {
68 fn := func(w http.ResponseWriter, r *http.Request) {
69 ctx := r.Context()
70
71 select {
72
73 case <-ctx.Done():
74 t.setRetryAfterHeaderIfNeeded(w, true)
75 http.Error(w, errContextCanceled, http.StatusServiceUnavailable)
76 return
77
78 case btok := <-t.backlogTokens:
79 timer := time.NewTimer(t.backlogTimeout)
80
81 defer func() {
82 t.backlogTokens <- btok
83 }()
84
85 select {
86 case <-timer.C:
87 t.setRetryAfterHeaderIfNeeded(w, false)
88 http.Error(w, errTimedOut, http.StatusServiceUnavailable)
89 return
90 case <-ctx.Done():
91 timer.Stop()
92 t.setRetryAfterHeaderIfNeeded(w, true)
93 http.Error(w, errContextCanceled, http.StatusServiceUnavailable)
94 return
95 case tok := <-t.tokens:
96 defer func() {
97 timer.Stop()
98 t.tokens <- tok
99 }()
100 next.ServeHTTP(w, r)
101 }
102 return
103
104 default:
105 t.setRetryAfterHeaderIfNeeded(w, false)
106 http.Error(w, errCapacityExceeded, http.StatusServiceUnavailable)
107 return
108 }
109 }
110
111 return http.HandlerFunc(fn)
112 }
113 }
114
115
116 type token struct{}
117
118
119 type throttler struct {
120 tokens chan token
121 backlogTokens chan token
122 backlogTimeout time.Duration
123 retryAfterFn func(ctxDone bool) time.Duration
124 }
125
126
127 func (t throttler) setRetryAfterHeaderIfNeeded(w http.ResponseWriter, ctxDone bool) {
128 if t.retryAfterFn == nil {
129 return
130 }
131 w.Header().Set("Retry-After", strconv.Itoa(int(t.retryAfterFn(ctxDone).Seconds())))
132 }
133
View as plain text