1 package ratelimits
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "time"
8
9 "github.com/jmhodges/clock"
10 "github.com/prometheus/client_golang/prometheus"
11 )
12
13 const (
14
15
16 Allowed = "allowed"
17
18
19
20 Denied = "denied"
21 )
22
23
24 var ErrInvalidCost = fmt.Errorf("invalid cost, must be > 0")
25
26
27 var ErrInvalidCostForCheck = fmt.Errorf("invalid check cost, must be >= 0")
28
29
30 var ErrInvalidCostOverLimit = fmt.Errorf("invalid cost, must be <= limit.Burst")
31
32
33
34 var errLimitDisabled = errors.New("limit disabled")
35
36
37
38 var disabledLimitDecision = &Decision{true, 0, 0, 0, time.Time{}}
39
40
41
42 type Limiter struct {
43
44 defaults limits
45
46
47 overrides limits
48
49
50 source source
51 clk clock.Clock
52
53 spendLatency *prometheus.HistogramVec
54 overrideUsageGauge *prometheus.GaugeVec
55 }
56
57
58
59
60
61 func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stats prometheus.Registerer) (*Limiter, error) {
62 limiter := &Limiter{source: source, clk: clk}
63
64 var err error
65 limiter.defaults, err = loadAndParseDefaultLimits(defaults)
66 if err != nil {
67 return nil, err
68 }
69
70 limiter.spendLatency = prometheus.NewHistogramVec(prometheus.HistogramOpts{
71 Name: "ratelimits_spend_latency",
72 Help: fmt.Sprintf("Latency of ratelimit checks labeled by limit=[name] and decision=[%s|%s], in seconds", Allowed, Denied),
73
74 Buckets: prometheus.ExponentialBuckets(0.0005, 3, 8),
75 }, []string{"limit", "decision"})
76 stats.MustRegister(limiter.spendLatency)
77
78 if overrides == "" {
79
80 limiter.overrides = make(limits)
81 return limiter, nil
82 }
83
84 limiter.overrides, err = loadAndParseOverrideLimits(overrides)
85 if err != nil {
86 return nil, err
87 }
88
89 limiter.overrideUsageGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
90 Name: "ratelimits_override_usage",
91 Help: "Proportion of override limit used, by limit name and client id.",
92 }, []string{"limit", "client_id"})
93 stats.MustRegister(limiter.overrideUsageGauge)
94
95 return limiter, nil
96 }
97
98 type Decision struct {
99
100
101 Allowed bool
102
103
104
105 Remaining int64
106
107
108
109 RetryIn time.Duration
110
111
112
113 ResetIn time.Duration
114
115
116
117
118 newTAT time.Time
119 }
120
121
122
123
124
125
126
127
128
129
130
131
132 func (l *Limiter) Check(ctx context.Context, name Name, id string, cost int64) (*Decision, error) {
133 if cost < 0 {
134 return nil, ErrInvalidCostForCheck
135 }
136
137 limit, err := l.getLimit(name, id)
138 if err != nil {
139 if errors.Is(err, errLimitDisabled) {
140 return disabledLimitDecision, nil
141 }
142 return nil, err
143 }
144
145 if cost > limit.Burst {
146 return nil, ErrInvalidCostOverLimit
147 }
148
149
150
151 ctx = context.WithoutCancel(ctx)
152 tat, err := l.source.Get(ctx, bucketKey(name, id))
153 if err != nil {
154 if !errors.Is(err, ErrBucketNotFound) {
155 return nil, err
156 }
157
158
159 d, err := l.initialize(ctx, limit, name, id, 0)
160 if err != nil {
161 return nil, err
162 }
163 return maybeSpend(l.clk, limit, d.newTAT, cost), nil
164 }
165 return maybeSpend(l.clk, limit, tat, cost), nil
166 }
167
168
169
170
171
172
173
174
175
176
177
178 func (l *Limiter) Spend(ctx context.Context, name Name, id string, cost int64) (*Decision, error) {
179 if cost <= 0 {
180 return nil, ErrInvalidCost
181 }
182
183 limit, err := l.getLimit(name, id)
184 if err != nil {
185 if errors.Is(err, errLimitDisabled) {
186 return disabledLimitDecision, nil
187 }
188 return nil, err
189 }
190
191 if cost > limit.Burst {
192 return nil, ErrInvalidCostOverLimit
193 }
194
195 start := l.clk.Now()
196 status := Denied
197 defer func() {
198 l.spendLatency.WithLabelValues(name.String(), status).Observe(l.clk.Since(start).Seconds())
199 }()
200
201
202
203 ctx = context.WithoutCancel(ctx)
204 tat, err := l.source.Get(ctx, bucketKey(name, id))
205 if err != nil {
206 if errors.Is(err, ErrBucketNotFound) {
207
208 d, err := l.initialize(ctx, limit, name, id, cost)
209 if err != nil {
210 return nil, err
211 }
212 if d.Allowed {
213 status = Allowed
214 }
215 return d, nil
216 }
217 return nil, err
218 }
219
220 d := maybeSpend(l.clk, limit, tat, cost)
221
222 if limit.isOverride {
223
224
225 utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst)
226 l.overrideUsageGauge.WithLabelValues(name.String(), id).Set(utilization)
227 }
228
229 if !d.Allowed {
230 return d, nil
231 }
232
233 err = l.source.Set(ctx, bucketKey(name, id), d.newTAT)
234 if err != nil {
235 return nil, err
236 }
237 status = Allowed
238 return d, nil
239 }
240
241
242
243
244
245
246
247
248
249
250
251
252 func (l *Limiter) Refund(ctx context.Context, name Name, id string, cost int64) (*Decision, error) {
253 if cost <= 0 {
254 return nil, ErrInvalidCost
255 }
256
257 limit, err := l.getLimit(name, id)
258 if err != nil {
259 if errors.Is(err, errLimitDisabled) {
260 return disabledLimitDecision, nil
261 }
262 return nil, err
263 }
264
265
266
267 ctx = context.WithoutCancel(ctx)
268 tat, err := l.source.Get(ctx, bucketKey(name, id))
269 if err != nil {
270 return nil, err
271 }
272 d := maybeRefund(l.clk, limit, tat, cost)
273 if !d.Allowed {
274
275 return d, nil
276 }
277 return d, l.source.Set(ctx, bucketKey(name, id), d.newTAT)
278
279 }
280
281
282 func (l *Limiter) Reset(ctx context.Context, name Name, id string) error {
283
284
285 ctx = context.WithoutCancel(ctx)
286 return l.source.Delete(ctx, bucketKey(name, id))
287 }
288
289
290
291 func (l *Limiter) initialize(ctx context.Context, rl limit, name Name, id string, cost int64) (*Decision, error) {
292 d := maybeSpend(l.clk, rl, l.clk.Now(), cost)
293
294
295
296 ctx = context.WithoutCancel(ctx)
297 err := l.source.Set(ctx, bucketKey(name, id), d.newTAT)
298 if err != nil {
299 return nil, err
300 }
301 return d, nil
302
303 }
304
305
306
307
308
309 func (l *Limiter) getLimit(name Name, id string) (limit, error) {
310 if !name.isValid() {
311
312
313 return limit{}, fmt.Errorf("specified name enum %q, is invalid", name)
314 }
315 if id != "" {
316
317 ol, ok := l.overrides[bucketKey(name, id)]
318 if ok {
319 return ol, nil
320 }
321 }
322 dl, ok := l.defaults[nameToEnumString(name)]
323 if ok {
324 return dl, nil
325 }
326 return limit{}, errLimitDisabled
327 }
328
View as plain text