1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package retry
19
20 import (
21 "context"
22 "errors"
23 "math"
24 "sync"
25 "testing"
26 "time"
27
28 "github.com/cenkalti/backoff/v4"
29 "github.com/stretchr/testify/assert"
30 )
31
32 func TestWait(t *testing.T) {
33 tests := []struct {
34 ctx context.Context
35 delay time.Duration
36 expected error
37 }{
38 {
39 ctx: context.Background(),
40 delay: time.Duration(0),
41 },
42 {
43 ctx: context.Background(),
44 delay: time.Duration(1),
45 },
46 {
47 ctx: context.Background(),
48 delay: time.Duration(-1),
49 },
50 {
51 ctx: func() context.Context {
52 ctx, cancel := context.WithCancel(context.Background())
53 cancel()
54 return ctx
55 }(),
56
57 delay: 1 * time.Hour,
58 expected: context.Canceled,
59 },
60 }
61
62 for _, test := range tests {
63 err := wait(test.ctx, test.delay)
64 if test.expected == nil {
65 assert.NoError(t, err)
66 } else {
67 assert.ErrorIs(t, err, test.expected)
68 }
69 }
70 }
71
72 func TestNonRetryableError(t *testing.T) {
73 ev := func(error) (bool, time.Duration) { return false, 0 }
74
75 reqFunc := Config{
76 Enabled: true,
77 InitialInterval: 1 * time.Nanosecond,
78 MaxInterval: 1 * time.Nanosecond,
79
80 MaxElapsedTime: 0,
81 }.RequestFunc(ev)
82 ctx := context.Background()
83 assert.NoError(t, reqFunc(ctx, func(context.Context) error {
84 return nil
85 }))
86 assert.ErrorIs(t, reqFunc(ctx, func(context.Context) error {
87 return assert.AnError
88 }), assert.AnError)
89 }
90
91 func TestThrottledRetry(t *testing.T) {
92
93 throttleDelay, backoffDelay := time.Second, time.Nanosecond
94
95 ev := func(error) (bool, time.Duration) {
96
97 return true, throttleDelay
98 }
99
100 reqFunc := Config{
101 Enabled: true,
102 InitialInterval: backoffDelay,
103 MaxInterval: backoffDelay,
104
105 MaxElapsedTime: 0,
106 }.RequestFunc(ev)
107
108 origWait := waitFunc
109 var done bool
110 waitFunc = func(_ context.Context, delay time.Duration) error {
111 assert.Equal(t, throttleDelay, delay, "retry not throttled")
112
113 if done {
114 return assert.AnError
115 }
116 done = true
117 return nil
118 }
119 defer func() { waitFunc = origWait }()
120
121 ctx := context.Background()
122 assert.ErrorIs(t, reqFunc(ctx, func(context.Context) error {
123 return errors.New("not this error")
124 }), assert.AnError)
125 }
126
127 func TestBackoffRetry(t *testing.T) {
128 ev := func(error) (bool, time.Duration) { return true, 0 }
129
130 delay := time.Nanosecond
131 reqFunc := Config{
132 Enabled: true,
133 InitialInterval: delay,
134 MaxInterval: delay,
135
136 MaxElapsedTime: 0,
137 }.RequestFunc(ev)
138
139 origWait := waitFunc
140 var done bool
141 waitFunc = func(_ context.Context, d time.Duration) error {
142 delta := math.Ceil(float64(delay) * backoff.DefaultRandomizationFactor)
143 assert.InDelta(t, delay, d, delta, "retry not backoffed")
144
145 if done {
146 return assert.AnError
147 }
148 done = true
149 return nil
150 }
151 t.Cleanup(func() { waitFunc = origWait })
152
153 ctx := context.Background()
154 assert.ErrorIs(t, reqFunc(ctx, func(context.Context) error {
155 return errors.New("not this error")
156 }), assert.AnError)
157 }
158
159 func TestBackoffRetryCanceledContext(t *testing.T) {
160 ev := func(error) (bool, time.Duration) { return true, 0 }
161
162 delay := time.Millisecond
163 reqFunc := Config{
164 Enabled: true,
165 InitialInterval: delay,
166 MaxInterval: delay,
167
168 MaxElapsedTime: 10 * time.Millisecond,
169 }.RequestFunc(ev)
170
171 ctx, cancel := context.WithCancel(context.Background())
172 count := 0
173 cancel()
174 err := reqFunc(ctx, func(context.Context) error {
175 count++
176 return assert.AnError
177 })
178
179 assert.ErrorIs(t, err, context.Canceled)
180 assert.Contains(t, err.Error(), assert.AnError.Error())
181 assert.Equal(t, 1, count)
182 }
183
184 func TestThrottledRetryGreaterThanMaxElapsedTime(t *testing.T) {
185
186 tDelay, bDelay := time.Hour, time.Nanosecond
187 ev := func(error) (bool, time.Duration) { return true, tDelay }
188 reqFunc := Config{
189 Enabled: true,
190 InitialInterval: bDelay,
191 MaxInterval: bDelay,
192 MaxElapsedTime: tDelay - (time.Nanosecond),
193 }.RequestFunc(ev)
194
195 ctx := context.Background()
196 assert.Contains(t, reqFunc(ctx, func(context.Context) error {
197 return assert.AnError
198 }).Error(), "max retry time would elapse: ")
199 }
200
201 func TestMaxElapsedTime(t *testing.T) {
202 ev := func(error) (bool, time.Duration) { return true, 0 }
203 delay := time.Nanosecond
204 reqFunc := Config{
205 Enabled: true,
206
207 InitialInterval: 2 * delay,
208 MaxElapsedTime: delay,
209 }.RequestFunc(ev)
210
211 ctx := context.Background()
212 assert.Contains(t, reqFunc(ctx, func(context.Context) error {
213 return assert.AnError
214 }).Error(), "max retry time elapsed: ")
215 }
216
217 func TestRetryNotEnabled(t *testing.T) {
218 ev := func(error) (bool, time.Duration) {
219 t.Error("evaluated retry when not enabled")
220 return false, 0
221 }
222
223 reqFunc := Config{}.RequestFunc(ev)
224 ctx := context.Background()
225 assert.NoError(t, reqFunc(ctx, func(context.Context) error {
226 return nil
227 }))
228 assert.ErrorIs(t, reqFunc(ctx, func(context.Context) error {
229 return assert.AnError
230 }), assert.AnError)
231 }
232
233 func TestRetryConcurrentSafe(t *testing.T) {
234 ev := func(error) (bool, time.Duration) { return true, 0 }
235 reqFunc := Config{
236 Enabled: true,
237 }.RequestFunc(ev)
238
239 var wg sync.WaitGroup
240 ctx := context.Background()
241
242 for i := 1; i < 5; i++ {
243 wg.Add(1)
244
245 go func() {
246 defer wg.Done()
247
248 var done bool
249 assert.NoError(t, reqFunc(ctx, func(context.Context) error {
250 if !done {
251 done = true
252 return assert.AnError
253 }
254
255 return nil
256 }))
257 }()
258 }
259
260 wg.Wait()
261 }
262
View as plain text