1 package middleware
2
3 import (
4 "io/ioutil"
5 "net/http"
6 "net/http/httptest"
7 "strings"
8 "sync"
9 "testing"
10 "time"
11
12 "github.com/go-chi/chi"
13 )
14
15 var testContent = []byte("Hello world!")
16
17 func TestThrottleBacklog(t *testing.T) {
18 r := chi.NewRouter()
19
20 r.Use(ThrottleBacklog(10, 50, time.Second*10))
21
22 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
23 w.WriteHeader(http.StatusOK)
24 time.Sleep(time.Second * 1)
25 w.Write(testContent)
26 })
27
28 server := httptest.NewServer(r)
29 defer server.Close()
30
31 client := http.Client{
32 Timeout: time.Second * 5,
33 }
34
35 var wg sync.WaitGroup
36
37
38
39
40 for i := 0; i < 40; i++ {
41 wg.Add(1)
42 go func(i int) {
43 defer wg.Done()
44
45 res, err := client.Get(server.URL)
46 assertNoError(t, err)
47
48 assertEqual(t, http.StatusOK, res.StatusCode)
49 buf, err := ioutil.ReadAll(res.Body)
50 assertNoError(t, err)
51 assertEqual(t, testContent, buf)
52 }(i)
53 }
54
55 wg.Wait()
56 }
57
58 func TestThrottleClientTimeout(t *testing.T) {
59 r := chi.NewRouter()
60
61 r.Use(ThrottleBacklog(10, 50, time.Second*10))
62
63 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
64 w.WriteHeader(http.StatusOK)
65 time.Sleep(time.Second * 5)
66 w.Write(testContent)
67 })
68
69 server := httptest.NewServer(r)
70 defer server.Close()
71
72 client := http.Client{
73 Timeout: time.Second * 3,
74 }
75
76 var wg sync.WaitGroup
77
78 for i := 0; i < 10; i++ {
79 wg.Add(1)
80 go func(i int) {
81 defer wg.Done()
82 _, err := client.Get(server.URL)
83 assertError(t, err)
84 }(i)
85 }
86
87 wg.Wait()
88 }
89
90 func TestThrottleTriggerGatewayTimeout(t *testing.T) {
91 r := chi.NewRouter()
92
93 r.Use(ThrottleBacklog(50, 100, time.Second*5))
94
95 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
96 w.WriteHeader(http.StatusOK)
97 time.Sleep(time.Second * 10)
98 w.Write(testContent)
99 })
100
101 server := httptest.NewServer(r)
102 defer server.Close()
103
104 client := http.Client{
105 Timeout: time.Second * 60,
106 }
107
108 var wg sync.WaitGroup
109
110
111 for i := 0; i < 50; i++ {
112 wg.Add(1)
113 go func(i int) {
114 defer wg.Done()
115
116 res, err := client.Get(server.URL)
117 assertNoError(t, err)
118 assertEqual(t, http.StatusOK, res.StatusCode)
119
120 }(i)
121 }
122
123 time.Sleep(time.Second * 1)
124
125
126
127 for i := 0; i < 50; i++ {
128 wg.Add(1)
129 go func(i int) {
130 defer wg.Done()
131
132 res, err := client.Get(server.URL)
133 assertNoError(t, err)
134
135 buf, err := ioutil.ReadAll(res.Body)
136 assertNoError(t, err)
137 assertEqual(t, http.StatusServiceUnavailable, res.StatusCode)
138 assertEqual(t, errTimedOut, strings.TrimSpace(string(buf)))
139
140 }(i)
141 }
142
143 wg.Wait()
144 }
145
146 func TestThrottleMaximum(t *testing.T) {
147 r := chi.NewRouter()
148
149 r.Use(ThrottleBacklog(50, 50, time.Second*5))
150
151 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
152 w.WriteHeader(http.StatusOK)
153 time.Sleep(time.Second * 2)
154 w.Write(testContent)
155 })
156
157 server := httptest.NewServer(r)
158 defer server.Close()
159
160 client := http.Client{
161 Timeout: time.Second * 60,
162 }
163
164 var wg sync.WaitGroup
165
166 for i := 0; i < 100; i++ {
167 wg.Add(1)
168 go func(i int) {
169 defer wg.Done()
170
171 res, err := client.Get(server.URL)
172 assertNoError(t, err)
173 assertEqual(t, http.StatusOK, res.StatusCode)
174
175 buf, err := ioutil.ReadAll(res.Body)
176 assertNoError(t, err)
177 assertEqual(t, testContent, buf)
178
179 }(i)
180 }
181
182
183 time.Sleep(time.Second * 1)
184
185
186
187 for i := 0; i < 100; i++ {
188 wg.Add(1)
189 go func(i int) {
190 defer wg.Done()
191
192 res, err := client.Get(server.URL)
193 assertNoError(t, err)
194
195 buf, err := ioutil.ReadAll(res.Body)
196 assertNoError(t, err)
197 assertEqual(t, http.StatusServiceUnavailable, res.StatusCode)
198 assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf)))
199
200 }(i)
201 }
202
203 wg.Wait()
204 }
205
206 func TestThrottleRetryAfter(t *testing.T) {
207 r := chi.NewRouter()
208
209 retryAfterFn := func(ctxDone bool) time.Duration { return time.Hour * 1 }
210 r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 10, RetryAfterFn: retryAfterFn}))
211
212 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
213 w.WriteHeader(http.StatusOK)
214 time.Sleep(time.Second * 3)
215 w.Write(testContent)
216 })
217
218 server := httptest.NewServer(r)
219 defer server.Close()
220
221 client := http.Client{
222 Timeout: time.Second * 60,
223 }
224
225 var wg sync.WaitGroup
226
227 for i := 0; i < 10; i++ {
228 wg.Add(1)
229 go func(i int) {
230 defer wg.Done()
231
232 res, err := client.Get(server.URL)
233 assertNoError(t, err)
234 assertEqual(t, http.StatusOK, res.StatusCode)
235 }(i)
236 }
237
238 time.Sleep(time.Second * 1)
239
240 for i := 0; i < 10; i++ {
241 wg.Add(1)
242 go func(i int) {
243 defer wg.Done()
244
245 res, err := client.Get(server.URL)
246 assertNoError(t, err)
247 assertEqual(t, http.StatusServiceUnavailable, res.StatusCode)
248 assertEqual(t, res.Header.Get("Retry-After"), "3600")
249 }(i)
250 }
251
252 wg.Wait()
253 }
254
View as plain text