1 package backoff
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log"
9 "testing"
10 "time"
11 )
12
13 type testTimer struct {
14 timer *time.Timer
15 }
16
17 func (t *testTimer) Start(duration time.Duration) {
18 t.timer = time.NewTimer(0)
19 }
20
21 func (t *testTimer) Stop() {
22 if t.timer != nil {
23 t.timer.Stop()
24 }
25 }
26
27 func (t *testTimer) C() <-chan time.Time {
28 return t.timer.C
29 }
30
31 func TestRetry(t *testing.T) {
32 const successOn = 3
33 var i = 0
34
35
36 f := func() error {
37 i++
38 log.Printf("function is called %d. time\n", i)
39
40 if i == successOn {
41 log.Println("OK")
42 return nil
43 }
44
45 log.Println("error")
46 return errors.New("error")
47 }
48
49 err := RetryNotifyWithTimer(f, NewExponentialBackOff(), nil, &testTimer{})
50 if err != nil {
51 t.Errorf("unexpected error: %s", err.Error())
52 }
53 if i != successOn {
54 t.Errorf("invalid number of retries: %d", i)
55 }
56 }
57
58 func TestRetryWithData(t *testing.T) {
59 const successOn = 3
60 var i = 0
61
62
63 f := func() (int, error) {
64 i++
65 log.Printf("function is called %d. time\n", i)
66
67 if i == successOn {
68 log.Println("OK")
69 return 42, nil
70 }
71
72 log.Println("error")
73 return 1, errors.New("error")
74 }
75
76 res, err := RetryNotifyWithTimerAndData(f, NewExponentialBackOff(), nil, &testTimer{})
77 if err != nil {
78 t.Errorf("unexpected error: %s", err.Error())
79 }
80 if i != successOn {
81 t.Errorf("invalid number of retries: %d", i)
82 }
83 if res != 42 {
84 t.Errorf("invalid data in response: %d, expected 42", res)
85 }
86 }
87
88 func TestRetryContext(t *testing.T) {
89 var cancelOn = 3
90 var i = 0
91
92 ctx, cancel := context.WithCancel(context.Background())
93 defer cancel()
94
95
96 f := func() error {
97 i++
98 log.Printf("function is called %d. time\n", i)
99
100
101
102 if i == cancelOn {
103 cancel()
104 }
105
106 log.Println("error")
107 return fmt.Errorf("error (%d)", i)
108 }
109
110 err := RetryNotifyWithTimer(f, WithContext(NewConstantBackOff(time.Millisecond), ctx), nil, &testTimer{})
111 if err == nil {
112 t.Errorf("error is unexpectedly nil")
113 }
114 if !errors.Is(err, context.Canceled) {
115 t.Errorf("unexpected error: %s", err.Error())
116 }
117 if i != cancelOn {
118 t.Errorf("invalid number of retries: %d", i)
119 }
120 }
121
122 func TestRetryPermanent(t *testing.T) {
123 ensureRetries := func(test string, shouldRetry bool, f func() (int, error), expectRes int) {
124 numRetries := -1
125 maxRetries := 1
126
127 res, _ := RetryNotifyWithTimerAndData(
128 func() (int, error) {
129 numRetries++
130 if numRetries >= maxRetries {
131 return -1, Permanent(errors.New("forced"))
132 }
133 return f()
134 },
135 NewExponentialBackOff(),
136 nil,
137 &testTimer{},
138 )
139
140 if shouldRetry && numRetries == 0 {
141 t.Errorf("Test: '%s', backoff should have retried", test)
142 }
143
144 if !shouldRetry && numRetries > 0 {
145 t.Errorf("Test: '%s', backoff should not have retried", test)
146 }
147
148 if res != expectRes {
149 t.Errorf("Test: '%s', got res %d but expected %d", test, res, expectRes)
150 }
151 }
152
153 for _, testCase := range []struct {
154 name string
155 f func() (int, error)
156 shouldRetry bool
157 res int
158 }{
159 {
160 "nil test",
161 func() (int, error) {
162 return 1, nil
163 },
164 false,
165 1,
166 },
167 {
168 "io.EOF",
169 func() (int, error) {
170 return 2, io.EOF
171 },
172 true,
173 -1,
174 },
175 {
176 "Permanent(io.EOF)",
177 func() (int, error) {
178 return 3, Permanent(io.EOF)
179 },
180 false,
181 3,
182 },
183 {
184 "Wrapped: Permanent(io.EOF)",
185 func() (int, error) {
186 return 4, fmt.Errorf("Wrapped error: %w", Permanent(io.EOF))
187 },
188 false,
189 4,
190 },
191 } {
192 ensureRetries(testCase.name, testCase.shouldRetry, testCase.f, testCase.res)
193 }
194 }
195
196 func TestPermanent(t *testing.T) {
197 want := errors.New("foo")
198 other := errors.New("bar")
199 var err error = Permanent(want)
200
201 got := errors.Unwrap(err)
202 if got != want {
203 t.Errorf("got %v, want %v", got, want)
204 }
205
206 if is := errors.Is(err, want); !is {
207 t.Errorf("err: %v is not %v", err, want)
208 }
209
210 if is := errors.Is(err, other); is {
211 t.Errorf("err: %v is %v", err, other)
212 }
213
214 wrapped := fmt.Errorf("wrapped: %w", err)
215 var permanent *PermanentError
216 if !errors.As(wrapped, &permanent) {
217 t.Errorf("errors.As(%v, %v)", wrapped, permanent)
218 }
219
220 err = Permanent(nil)
221 if err != nil {
222 t.Errorf("got %v, want nil", err)
223 }
224 }
225
View as plain text