1 package ratelimits
2
3 import (
4 "context"
5 "math/rand"
6 "net"
7 "testing"
8 "time"
9
10 "github.com/jmhodges/clock"
11 "github.com/letsencrypt/boulder/metrics"
12 "github.com/letsencrypt/boulder/test"
13 "github.com/prometheus/client_golang/prometheus"
14 )
15
16
17
18 const tenZeroZeroTwo = "10.0.0.2"
19
20
21
22
23 func newTestLimiter(t *testing.T, s source, clk clock.FakeClock) *Limiter {
24 l, err := NewLimiter(clk, s, "testdata/working_default.yml", "testdata/working_override.yml", metrics.NoopRegisterer)
25 test.AssertNotError(t, err, "should not error")
26 return l
27 }
28
29 func setup(t *testing.T) (context.Context, map[string]*Limiter, clock.FakeClock, string) {
30 testCtx := context.Background()
31 clk := clock.NewFake()
32
33
34
35 randIP := make(net.IP, 4)
36 for i := 0; i < 4; i++ {
37 randIP[i] = byte(rand.Intn(256))
38 }
39
40
41 return testCtx, map[string]*Limiter{
42 "inmem": newInmemTestLimiter(t, clk),
43 "redis": newRedisTestLimiter(t, clk),
44 }, clk, randIP.String()
45 }
46
47 func Test_Limiter_WithBadLimitsPath(t *testing.T) {
48 t.Parallel()
49 _, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/does-not-exist.yml", "", metrics.NoopRegisterer)
50 test.AssertError(t, err, "should error")
51
52 _, err = NewLimiter(clock.NewFake(), newInmem(), "testdata/defaults.yml", "testdata/does-not-exist.yml", metrics.NoopRegisterer)
53 test.AssertError(t, err, "should error")
54 }
55
56 func Test_Limiter_getLimitNoExist(t *testing.T) {
57 t.Parallel()
58 l, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/working_default.yml", "", metrics.NoopRegisterer)
59 test.AssertNotError(t, err, "should not error")
60 _, err = l.getLimit(Name(9999), "")
61 test.AssertError(t, err, "should error")
62
63 }
64
65 func Test_Limiter_CheckWithLimitNoExist(t *testing.T) {
66 t.Parallel()
67 testCtx, limiters, _, testIP := setup(t)
68 for name, l := range limiters {
69 t.Run(name, func(t *testing.T) {
70 _, err := l.Check(testCtx, Name(9999), testIP, 1)
71 test.AssertError(t, err, "should error")
72 })
73 }
74 }
75
76 func Test_Limiter_CheckWithLimitOverrides(t *testing.T) {
77 t.Parallel()
78 testCtx, limiters, clk, _ := setup(t)
79 for name, l := range limiters {
80 t.Run(name, func(t *testing.T) {
81
82
83 test.AssertMetricWithLabelsEquals(t, l.overrideUsageGauge, prometheus.Labels{
84 "limit": NewRegistrationsPerIPAddress.String(), "client_id": tenZeroZeroTwo}, 0)
85
86
87
88 _, err := l.Check(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
89 test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
90
91
92
93 _, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
94 test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
95
96
97 d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 40)
98 test.AssertNotError(t, err, "should not error")
99 test.Assert(t, d.Allowed, "should be allowed")
100
101
102 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
103 test.AssertNotError(t, err, "should not error")
104 test.Assert(t, !d.Allowed, "should not be allowed")
105 test.AssertEquals(t, d.Remaining, int64(0))
106 test.AssertEquals(t, d.ResetIn, time.Second)
107
108
109
110 test.AssertMetricWithLabelsEquals(t, l.overrideUsageGauge, prometheus.Labels{
111 "limit_name": NewRegistrationsPerIPAddress.String(), "client_id": tenZeroZeroTwo}, 1.0)
112
113
114
115 test.AssertEquals(t, d.RetryIn, time.Millisecond*25)
116
117
118 clk.Add(d.RetryIn)
119
120
121 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
122 test.AssertNotError(t, err, "should not error")
123 test.Assert(t, d.Allowed, "should be allowed")
124 test.AssertEquals(t, d.Remaining, int64(0))
125 test.AssertEquals(t, d.ResetIn, time.Second)
126
127
128 clk.Add(d.ResetIn)
129
130
131 for i := 0; i < 40; i++ {
132 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
133 test.AssertNotError(t, err, "should not error")
134 test.Assert(t, d.Allowed, "should be allowed")
135 test.AssertEquals(t, d.Remaining, int64(39-i))
136 }
137
138
139 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
140 test.AssertNotError(t, err, "should not error")
141 test.Assert(t, !d.Allowed, "should not be allowed")
142 test.AssertEquals(t, d.Remaining, int64(0))
143 test.AssertEquals(t, d.ResetIn, time.Second)
144
145
146 err = l.Reset(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo)
147 test.AssertNotError(t, err, "should not error")
148 })
149 }
150 }
151
152 func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) {
153 t.Parallel()
154 testCtx, limiters, _, testIP := setup(t)
155 for name, l := range limiters {
156 t.Run(name, func(t *testing.T) {
157
158
159 d, err := l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
160 test.AssertNotError(t, err, "should not error")
161 test.Assert(t, d.Allowed, "should be allowed")
162 test.AssertEquals(t, d.Remaining, int64(19))
163
164
165 test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
166 test.AssertEquals(t, d.RetryIn, time.Duration(0))
167
168
169
170 d, err = l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 0)
171 test.AssertNotError(t, err, "should not error")
172 test.Assert(t, d.Allowed, "should be allowed")
173 test.AssertEquals(t, d.Remaining, int64(20))
174 test.AssertEquals(t, d.ResetIn, time.Duration(0))
175 test.AssertEquals(t, d.RetryIn, time.Duration(0))
176
177
178 err = l.Reset(testCtx, NewRegistrationsPerIPAddress, testIP)
179 test.AssertNotError(t, err, "should not error")
180
181
182
183
184 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
185 test.AssertNotError(t, err, "should not error")
186 test.Assert(t, d.Allowed, "should be allowed")
187 test.AssertEquals(t, d.Remaining, int64(19))
188
189
190 test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
191 test.AssertEquals(t, d.RetryIn, time.Duration(0))
192
193
194
195 d, err = l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 0)
196 test.AssertNotError(t, err, "should not error")
197 test.Assert(t, d.Allowed, "should be allowed")
198 test.AssertEquals(t, d.Remaining, int64(19))
199
200
201 test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
202 test.AssertEquals(t, d.RetryIn, time.Duration(0))
203 })
204 }
205 }
206
207 func Test_Limiter_RefundAndSpendCostErr(t *testing.T) {
208 t.Parallel()
209 testCtx, limiters, _, testIP := setup(t)
210 for name, l := range limiters {
211 t.Run(name, func(t *testing.T) {
212
213 _, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 0)
214 test.AssertErrorIs(t, err, ErrInvalidCost)
215
216
217 _, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, -1)
218 test.AssertErrorIs(t, err, ErrInvalidCost)
219
220
221 _, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 0)
222 test.AssertErrorIs(t, err, ErrInvalidCost)
223
224
225 _, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, -1)
226 test.AssertErrorIs(t, err, ErrInvalidCost)
227 })
228 }
229 }
230
231 func Test_Limiter_CheckWithBadCost(t *testing.T) {
232 t.Parallel()
233 testCtx, limiters, _, testIP := setup(t)
234 for name, l := range limiters {
235 t.Run(name, func(t *testing.T) {
236 _, err := l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, -1)
237 test.AssertErrorIs(t, err, ErrInvalidCostForCheck)
238 })
239 }
240 }
241
242 func Test_Limiter_DefaultLimits(t *testing.T) {
243 t.Parallel()
244 testCtx, limiters, clk, testIP := setup(t)
245 for name, l := range limiters {
246 t.Run(name, func(t *testing.T) {
247
248
249 _, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 21)
250 test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
251
252
253 d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20)
254 test.AssertNotError(t, err, "should not error")
255 test.Assert(t, d.Allowed, "should be allowed")
256 test.AssertEquals(t, d.Remaining, int64(0))
257 test.AssertEquals(t, d.ResetIn, time.Second)
258
259
260 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
261 test.AssertNotError(t, err, "should not error")
262 test.Assert(t, !d.Allowed, "should not be allowed")
263 test.AssertEquals(t, d.Remaining, int64(0))
264 test.AssertEquals(t, d.ResetIn, time.Second)
265
266
267
268 test.AssertEquals(t, d.RetryIn, time.Millisecond*50)
269
270
271 clk.Add(d.RetryIn)
272
273
274 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
275 test.AssertNotError(t, err, "should not error")
276 test.Assert(t, d.Allowed, "should be allowed")
277 test.AssertEquals(t, d.Remaining, int64(0))
278 test.AssertEquals(t, d.ResetIn, time.Second)
279
280
281 clk.Add(d.ResetIn)
282
283
284 for i := 0; i < 20; i++ {
285 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
286 test.AssertNotError(t, err, "should not error")
287 test.Assert(t, d.Allowed, "should be allowed")
288 test.AssertEquals(t, d.Remaining, int64(19-i))
289 }
290
291
292 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
293 test.AssertNotError(t, err, "should not error")
294 test.Assert(t, !d.Allowed, "should not be allowed")
295 test.AssertEquals(t, d.Remaining, int64(0))
296 test.AssertEquals(t, d.ResetIn, time.Second)
297 })
298 }
299 }
300
301 func Test_Limiter_RefundAndReset(t *testing.T) {
302 t.Parallel()
303 testCtx, limiters, clk, testIP := setup(t)
304 for name, l := range limiters {
305 t.Run(name, func(t *testing.T) {
306
307 d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20)
308 test.AssertNotError(t, err, "should not error")
309 test.Assert(t, d.Allowed, "should be allowed")
310 test.AssertEquals(t, d.Remaining, int64(0))
311 test.AssertEquals(t, d.ResetIn, time.Second)
312
313
314 d, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 10)
315 test.AssertNotError(t, err, "should not error")
316 test.AssertEquals(t, d.Remaining, int64(10))
317
318
319 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 10)
320 test.AssertNotError(t, err, "should not error")
321 test.Assert(t, d.Allowed, "should be allowed")
322 test.AssertEquals(t, d.Remaining, int64(0))
323 test.AssertEquals(t, d.ResetIn, time.Second)
324
325 err = l.Reset(testCtx, NewRegistrationsPerIPAddress, testIP)
326 test.AssertNotError(t, err, "should not error")
327
328
329 d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20)
330 test.AssertNotError(t, err, "should not error")
331 test.Assert(t, d.Allowed, "should be allowed")
332 test.AssertEquals(t, d.Remaining, int64(0))
333 test.AssertEquals(t, d.ResetIn, time.Second)
334
335
336 clk.Add(d.ResetIn)
337
338
339 d, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
340 test.AssertNotError(t, err, "should not error")
341 test.Assert(t, !d.Allowed, "should not be allowed")
342 test.AssertEquals(t, d.Remaining, int64(20))
343 })
344 }
345 }
346
View as plain text