1 package ratelimits
2
3 import (
4 "context"
5 "errors"
6 "net"
7 "time"
8
9 "github.com/jmhodges/clock"
10 "github.com/prometheus/client_golang/prometheus"
11 "github.com/redis/go-redis/v9"
12 )
13
14
15 var _ source = (*RedisSource)(nil)
16
17
18 type RedisSource struct {
19 client *redis.Ring
20 clk clock.Clock
21 latency *prometheus.HistogramVec
22 }
23
24
25
26 func NewRedisSource(client *redis.Ring, clk clock.Clock, stats prometheus.Registerer) *RedisSource {
27 latency := prometheus.NewHistogramVec(
28 prometheus.HistogramOpts{
29 Name: "ratelimits_latency",
30 Help: "Histogram of Redis call latencies labeled by call=[set|get|delete|ping] and result=[success|error]",
31
32 Buckets: prometheus.ExponentialBucketsRange(0.0005, 3, 8),
33 },
34 []string{"call", "result"},
35 )
36 stats.MustRegister(latency)
37
38 return &RedisSource{
39 client: client,
40 clk: clk,
41 latency: latency,
42 }
43 }
44
45
46
47 func resultForError(err error) string {
48 if errors.Is(redis.Nil, err) {
49
50 return "notFound"
51 } else if errors.Is(err, context.DeadlineExceeded) {
52
53 return "deadlineExceeded"
54 } else if errors.Is(err, context.Canceled) {
55
56 return "canceled"
57 }
58 var netErr net.Error
59 if errors.As(err, &netErr) && netErr.Timeout() {
60
61 return "timeout"
62 }
63 var redisErr redis.Error
64 if errors.Is(err, redisErr) {
65
66 return "redisError"
67 }
68 return "failed"
69 }
70
71
72
73
74 func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) error {
75 start := r.clk.Now()
76
77 err := r.client.Set(ctx, bucketKey, tat.UnixNano(), 0).Err()
78 if err != nil {
79 r.latency.With(prometheus.Labels{"call": "set", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
80 return err
81 }
82
83 r.latency.With(prometheus.Labels{"call": "set", "result": "success"}).Observe(time.Since(start).Seconds())
84 return nil
85 }
86
87
88
89
90 func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) {
91 start := r.clk.Now()
92
93 tatNano, err := r.client.Get(ctx, bucketKey).Int64()
94 if err != nil {
95 if errors.Is(err, redis.Nil) {
96
97 r.latency.With(prometheus.Labels{"call": "get", "result": "notFound"}).Observe(time.Since(start).Seconds())
98 return time.Time{}, ErrBucketNotFound
99 }
100 r.latency.With(prometheus.Labels{"call": "get", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
101 return time.Time{}, err
102 }
103
104 r.latency.With(prometheus.Labels{"call": "get", "result": "success"}).Observe(time.Since(start).Seconds())
105 return time.Unix(0, tatNano).UTC(), nil
106 }
107
108
109
110
111 func (r *RedisSource) Delete(ctx context.Context, bucketKey string) error {
112 start := r.clk.Now()
113
114 err := r.client.Del(ctx, bucketKey).Err()
115 if err != nil {
116 r.latency.With(prometheus.Labels{"call": "delete", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
117 return err
118 }
119
120 r.latency.With(prometheus.Labels{"call": "delete", "result": "success"}).Observe(time.Since(start).Seconds())
121 return nil
122 }
123
124
125
126 func (r *RedisSource) Ping(ctx context.Context) error {
127 start := r.clk.Now()
128
129 err := r.client.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
130 return shard.Ping(ctx).Err()
131 })
132 if err != nil {
133 r.latency.With(prometheus.Labels{"call": "ping", "result": resultForError(err)}).Observe(time.Since(start).Seconds())
134 return err
135 }
136 r.latency.With(prometheus.Labels{"call": "ping", "result": "success"}).Observe(time.Since(start).Seconds())
137 return nil
138 }
139
View as plain text