1
16
17 package wait
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "math/rand"
24 "sync"
25 "sync/atomic"
26 "testing"
27 "time"
28
29 "k8s.io/apimachinery/pkg/util/runtime"
30 "k8s.io/utils/clock"
31 testingclock "k8s.io/utils/clock/testing"
32 )
33
34 func TestUntil(t *testing.T) {
35 ch := make(chan struct{})
36 close(ch)
37 Until(func() {
38 t.Fatal("should not have been invoked")
39 }, 0, ch)
40
41 ch = make(chan struct{})
42 called := make(chan struct{})
43 go func() {
44 Until(func() {
45 called <- struct{}{}
46 }, 0, ch)
47 close(called)
48 }()
49 <-called
50 close(ch)
51 <-called
52 }
53
54 func TestUntilWithContext(t *testing.T) {
55 ctx, cancel := context.WithCancel(context.TODO())
56 cancel()
57 UntilWithContext(ctx, func(context.Context) {
58 t.Fatal("should not have been invoked")
59 }, 0)
60
61 ctx, cancel = context.WithCancel(context.TODO())
62 called := make(chan struct{})
63 go func() {
64 UntilWithContext(ctx, func(context.Context) {
65 called <- struct{}{}
66 }, 0)
67 close(called)
68 }()
69 <-called
70 cancel()
71 <-called
72 }
73
74 func TestNonSlidingUntil(t *testing.T) {
75 ch := make(chan struct{})
76 close(ch)
77 NonSlidingUntil(func() {
78 t.Fatal("should not have been invoked")
79 }, 0, ch)
80
81 ch = make(chan struct{})
82 called := make(chan struct{})
83 go func() {
84 NonSlidingUntil(func() {
85 called <- struct{}{}
86 }, 0, ch)
87 close(called)
88 }()
89 <-called
90 close(ch)
91 <-called
92 }
93
94 func TestNonSlidingUntilWithContext(t *testing.T) {
95 ctx, cancel := context.WithCancel(context.TODO())
96 cancel()
97 NonSlidingUntilWithContext(ctx, func(context.Context) {
98 t.Fatal("should not have been invoked")
99 }, 0)
100
101 ctx, cancel = context.WithCancel(context.TODO())
102 called := make(chan struct{})
103 go func() {
104 NonSlidingUntilWithContext(ctx, func(context.Context) {
105 called <- struct{}{}
106 }, 0)
107 close(called)
108 }()
109 <-called
110 cancel()
111 <-called
112 }
113
114 func TestUntilReturnsImmediately(t *testing.T) {
115 now := time.Now()
116 ch := make(chan struct{})
117 var attempts int
118 Until(func() {
119 attempts++
120 if attempts > 1 {
121 t.Fatalf("invoked after close of channel")
122 }
123 close(ch)
124 }, 30*time.Second, ch)
125 if now.Add(25 * time.Second).Before(time.Now()) {
126 t.Errorf("Until did not return immediately when the stop chan was closed inside the func")
127 }
128 }
129
130 func TestJitterUntil(t *testing.T) {
131 ch := make(chan struct{})
132
133
134 close(ch)
135 JitterUntil(func() {
136 t.Fatal("should not have been invoked")
137 }, 0, 1.0, true, ch)
138
139 ch = make(chan struct{})
140 called := make(chan struct{})
141 go func() {
142 JitterUntil(func() {
143 called <- struct{}{}
144 }, 0, 1.0, true, ch)
145 close(called)
146 }()
147 <-called
148 close(ch)
149 <-called
150 }
151
152 func TestJitterUntilWithContext(t *testing.T) {
153 ctx, cancel := context.WithCancel(context.TODO())
154 cancel()
155 JitterUntilWithContext(ctx, func(context.Context) {
156 t.Fatal("should not have been invoked")
157 }, 0, 1.0, true)
158
159 ctx, cancel = context.WithCancel(context.TODO())
160 called := make(chan struct{})
161 go func() {
162 JitterUntilWithContext(ctx, func(context.Context) {
163 called <- struct{}{}
164 }, 0, 1.0, true)
165 close(called)
166 }()
167 <-called
168 cancel()
169 <-called
170 }
171
172 func TestJitterUntilReturnsImmediately(t *testing.T) {
173 now := time.Now()
174 ch := make(chan struct{})
175 JitterUntil(func() {
176 close(ch)
177 }, 30*time.Second, 1.0, true, ch)
178 if now.Add(25 * time.Second).Before(time.Now()) {
179 t.Errorf("JitterUntil did not return immediately when the stop chan was closed inside the func")
180 }
181 }
182
183 func TestJitterUntilRecoversPanic(t *testing.T) {
184
185 originalReallyCrash := runtime.ReallyCrash
186 originalHandlers := runtime.PanicHandlers
187 defer func() {
188 runtime.ReallyCrash = originalReallyCrash
189 runtime.PanicHandlers = originalHandlers
190 }()
191
192 called := 0
193 handled := 0
194
195
196 runtime.ReallyCrash = false
197 runtime.PanicHandlers = []func(interface{}){
198 func(p interface{}) {
199 handled++
200 },
201 }
202
203 ch := make(chan struct{})
204 JitterUntil(func() {
205 called++
206 if called > 2 {
207 close(ch)
208 return
209 }
210 panic("TestJitterUntilRecoversPanic")
211 }, time.Millisecond, 1.0, true, ch)
212
213 if called != 3 {
214 t.Errorf("Expected panic recovers")
215 }
216 }
217
218 func TestJitterUntilNegativeFactor(t *testing.T) {
219 now := time.Now()
220 ch := make(chan struct{})
221 called := make(chan struct{})
222 received := make(chan struct{})
223 go func() {
224 JitterUntil(func() {
225 called <- struct{}{}
226 <-received
227 }, time.Second, -30.0, true, ch)
228 }()
229
230 <-called
231 received <- struct{}{}
232
233 <-called
234 close(ch)
235 received <- struct{}{}
236
237
238 if now.Add(3 * time.Second).Before(time.Now()) {
239 t.Errorf("JitterUntil did not returned after predefined period with negative jitter factor when the stop chan was closed inside the func")
240 }
241 }
242
243 func TestExponentialBackoff(t *testing.T) {
244
245 i := 0
246 err := ExponentialBackoff(Backoff{Factor: 1.0}, func() (bool, error) {
247 i++
248 return false, nil
249 })
250 if err != ErrWaitTimeout || i != 0 {
251 t.Errorf("unexpected error: %v", err)
252 }
253
254 opts := Backoff{Factor: 1.0, Steps: 3}
255
256
257 i = 0
258 err = ExponentialBackoff(opts, func() (bool, error) {
259 i++
260 return false, nil
261 })
262 if err != ErrWaitTimeout || i != opts.Steps {
263 t.Errorf("unexpected error: %v", err)
264 }
265
266
267 i = 0
268 err = ExponentialBackoff(opts, func() (bool, error) {
269 i++
270 return true, nil
271 })
272 if err != nil || i != 1 {
273 t.Errorf("unexpected error: %v", err)
274 }
275
276
277 testErr := fmt.Errorf("some other error")
278 err = ExponentialBackoff(opts, func() (bool, error) {
279 return false, testErr
280 })
281 if err != testErr {
282 t.Errorf("unexpected error: %v", err)
283 }
284
285
286 i = 1
287 err = ExponentialBackoff(opts, func() (bool, error) {
288 if i < opts.Steps {
289 i++
290 return false, nil
291 }
292 return true, nil
293 })
294 if err != nil || i != opts.Steps {
295 t.Errorf("unexpected error: %v", err)
296 }
297 }
298
299 func TestPoller(t *testing.T) {
300 ctx, cancel := context.WithCancel(context.Background())
301 defer cancel()
302 w := poller(time.Millisecond, 2*time.Millisecond)
303 ch := w(ctx)
304 count := 0
305 DRAIN:
306 for {
307 select {
308 case _, open := <-ch:
309 if !open {
310 break DRAIN
311 }
312 count++
313 case <-time.After(ForeverTestTimeout):
314 t.Errorf("unexpected timeout after poll")
315 }
316 }
317 if count > 3 {
318 t.Errorf("expected up to three values, got %d", count)
319 }
320 }
321
322 type fakePoller struct {
323 max int
324 used int32
325 wg sync.WaitGroup
326 }
327
328 func fakeTicker(max int, used *int32, doneFunc func()) waitFunc {
329 return func(done <-chan struct{}) <-chan struct{} {
330 ch := make(chan struct{})
331 go func() {
332 defer doneFunc()
333 defer close(ch)
334 for i := 0; i < max; i++ {
335 select {
336 case ch <- struct{}{}:
337 case <-done:
338 return
339 }
340 if used != nil {
341 atomic.AddInt32(used, 1)
342 }
343 }
344 }()
345 return ch
346 }
347 }
348
349 func (fp *fakePoller) GetwaitFunc() waitFunc {
350 fp.wg.Add(1)
351 return fakeTicker(fp.max, &fp.used, fp.wg.Done)
352 }
353
354 func TestPoll(t *testing.T) {
355 invocations := 0
356 f := ConditionWithContextFunc(func(ctx context.Context) (bool, error) {
357 invocations++
358 return true, nil
359 })
360 fp := fakePoller{max: 1}
361
362 ctx, cancel := context.WithCancel(context.Background())
363 defer cancel()
364 if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f); err != nil {
365 t.Fatalf("unexpected error %v", err)
366 }
367 fp.wg.Wait()
368 if invocations != 1 {
369 t.Errorf("Expected exactly one invocation, got %d", invocations)
370 }
371 used := atomic.LoadInt32(&fp.used)
372 if used != 1 {
373 t.Errorf("Expected exactly one tick, got %d", used)
374 }
375 }
376
377 func TestPollError(t *testing.T) {
378 expectedError := errors.New("Expected error")
379 f := ConditionFunc(func() (bool, error) {
380 return false, expectedError
381 })
382 fp := fakePoller{max: 1}
383
384 ctx, cancel := context.WithCancel(context.Background())
385 defer cancel()
386 if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError {
387 t.Fatalf("Expected error %v, got none %v", expectedError, err)
388 }
389 fp.wg.Wait()
390 used := atomic.LoadInt32(&fp.used)
391 if used != 1 {
392 t.Errorf("Expected exactly one tick, got %d", used)
393 }
394 }
395
396 func TestPollImmediate(t *testing.T) {
397 invocations := 0
398 f := ConditionFunc(func() (bool, error) {
399 invocations++
400 return true, nil
401 })
402 fp := fakePoller{max: 0}
403
404 ctx, cancel := context.WithCancel(context.Background())
405 defer cancel()
406 if err := poll(ctx, true, fp.GetwaitFunc().WithContext(), f.WithContext()); err != nil {
407 t.Fatalf("unexpected error %v", err)
408 }
409
410 if invocations != 1 {
411 t.Errorf("Expected exactly one invocation, got %d", invocations)
412 }
413 used := atomic.LoadInt32(&fp.used)
414 if used != 0 {
415 t.Errorf("Expected exactly zero ticks, got %d", used)
416 }
417 }
418
419 func TestPollImmediateError(t *testing.T) {
420 expectedError := errors.New("Expected error")
421 f := ConditionFunc(func() (bool, error) {
422 return false, expectedError
423 })
424 fp := fakePoller{max: 0}
425
426 ctx, cancel := context.WithCancel(context.Background())
427 defer cancel()
428 if err := poll(ctx, true, fp.GetwaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError {
429 t.Fatalf("Expected error %v, got none %v", expectedError, err)
430 }
431
432 used := atomic.LoadInt32(&fp.used)
433 if used != 0 {
434 t.Errorf("Expected exactly zero ticks, got %d", used)
435 }
436 }
437
438 func TestPollForever(t *testing.T) {
439 ch := make(chan struct{})
440 errc := make(chan error, 1)
441 done := make(chan struct{}, 1)
442 complete := make(chan struct{})
443 go func() {
444 f := ConditionFunc(func() (bool, error) {
445 ch <- struct{}{}
446 select {
447 case <-done:
448 return true, nil
449 default:
450 }
451 return false, nil
452 })
453
454 if err := PollInfinite(time.Microsecond, f); err != nil {
455 errc <- fmt.Errorf("unexpected error %v", err)
456 }
457
458 close(ch)
459 complete <- struct{}{}
460 }()
461
462
463 <-ch
464
465
466 for i := 0; i < 10; i++ {
467 select {
468 case _, open := <-ch:
469 if !open {
470 if len(errc) != 0 {
471 t.Fatalf("did not expect channel to be closed, %v", <-errc)
472 }
473 t.Fatal("did not expect channel to be closed")
474 }
475 case <-time.After(ForeverTestTimeout):
476 t.Fatalf("channel did not return at least once within the poll interval")
477 }
478 }
479
480
481 done <- struct{}{}
482 go func() {
483 for i := 0; i < 2; i++ {
484 _, open := <-ch
485 if !open {
486 return
487 }
488 }
489 t.Error("expected closed channel after two iterations")
490 }()
491 <-complete
492
493 if len(errc) != 0 {
494 t.Fatal(<-errc)
495 }
496 }
497
498 func Test_waitFor(t *testing.T) {
499 var invocations int
500 testCases := map[string]struct {
501 F ConditionFunc
502 Ticks int
503 Invoked int
504 Err bool
505 }{
506 "invoked once": {
507 ConditionFunc(func() (bool, error) {
508 invocations++
509 return true, nil
510 }),
511 2,
512 1,
513 false,
514 },
515 "invoked and returns a timeout": {
516 ConditionFunc(func() (bool, error) {
517 invocations++
518 return false, nil
519 }),
520 2,
521 3,
522 true,
523 },
524 "returns immediately on error": {
525 ConditionFunc(func() (bool, error) {
526 invocations++
527 return false, errors.New("test")
528 }),
529 2,
530 1,
531 true,
532 },
533 }
534 for k, c := range testCases {
535 invocations = 0
536 ticker := fakeTicker(c.Ticks, nil, func() {})
537 err := func() error {
538 done := make(chan struct{})
539 defer close(done)
540 ctx := ContextForChannel(done)
541 return waitForWithContext(ctx, ticker.WithContext(), c.F.WithContext())
542 }()
543 switch {
544 case c.Err && err == nil:
545 t.Errorf("%s: Expected error, got nil", k)
546 continue
547 case !c.Err && err != nil:
548 t.Errorf("%s: Expected no error, got: %#v", k, err)
549 continue
550 }
551 if invocations != c.Invoked {
552 t.Errorf("%s: Expected %d invocations, got %d", k, c.Invoked, invocations)
553 }
554 }
555 }
556
557
558
559 func Test_waitForWithEarlyClosing_waitFunc(t *testing.T) {
560 stopCh := make(chan struct{})
561 defer close(stopCh)
562
563 ctx := ContextForChannel(stopCh)
564 start := time.Now()
565 err := waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} {
566 c := make(chan struct{})
567 close(c)
568 return c
569 }, func(_ context.Context) (bool, error) {
570 return false, nil
571 })
572 duration := time.Since(start)
573
574
575 if duration >= ForeverTestTimeout/2 {
576 t.Errorf("expected short timeout duration")
577 }
578 if err != ErrWaitTimeout {
579 t.Errorf("expected ErrWaitTimeout from WaitFunc")
580 }
581 }
582
583
584
585 func Test_waitForWithClosedChannel(t *testing.T) {
586 stopCh := make(chan struct{})
587 close(stopCh)
588 c := make(chan struct{})
589 defer close(c)
590 ctx := ContextForChannel(stopCh)
591
592 start := time.Now()
593 err := waitForWithContext(ctx, func(_ context.Context) <-chan struct{} {
594 return c
595 }, func(_ context.Context) (bool, error) {
596 return false, nil
597 })
598 duration := time.Since(start)
599
600 if duration >= ForeverTestTimeout/2 {
601 t.Errorf("expected short timeout duration")
602 }
603
604 if err != ErrWaitTimeout {
605 t.Errorf("expected ErrWaitTimeout from WaitFunc")
606 }
607 }
608
609
610
611 func Test_waitForWithContextCancelsContext(t *testing.T) {
612 ctx, cancel := context.WithCancel(context.Background())
613 defer cancel()
614 waitFn := poller(time.Millisecond, ForeverTestTimeout)
615
616 var ctxPassedToWait context.Context
617 waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} {
618 ctxPassedToWait = ctx
619 return waitFn(ctx)
620 }, func(ctx context.Context) (bool, error) {
621 time.Sleep(10 * time.Millisecond)
622 return true, nil
623 })
624
625 if ctxPassedToWait.Err() != context.Canceled {
626 t.Errorf("expected the context passed to waitForWithContext to be closed with: %v, but got: %v", context.Canceled, ctxPassedToWait.Err())
627 }
628 }
629
630 func TestPollUntil(t *testing.T) {
631 stopCh := make(chan struct{})
632 called := make(chan bool)
633 pollDone := make(chan struct{})
634
635 go func() {
636 PollUntil(time.Microsecond, ConditionFunc(func() (bool, error) {
637 called <- true
638 return false, nil
639 }), stopCh)
640
641 close(pollDone)
642 }()
643
644
645 <-called
646
647 close(stopCh)
648
649 go func() {
650
651 for range called {
652 }
653 }()
654
655
656 <-pollDone
657 close(called)
658 }
659
660 func TestBackoff_Step(t *testing.T) {
661 tests := []struct {
662 initial *Backoff
663 want []time.Duration
664 }{
665 {initial: nil, want: []time.Duration{0, 0, 0, 0}},
666 {initial: &Backoff{Duration: time.Second, Steps: -1}, want: []time.Duration{time.Second, time.Second, time.Second}},
667 {initial: &Backoff{Duration: time.Second, Steps: 0}, want: []time.Duration{time.Second, time.Second, time.Second}},
668 {initial: &Backoff{Duration: time.Second, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}},
669 {initial: &Backoff{Duration: time.Second, Factor: 1.0, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}},
670 {initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 3}, want: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second}},
671 {initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 3, Cap: 3 * time.Second}, want: []time.Duration{1 * time.Second, 2 * time.Second, 3 * time.Second}},
672 {initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 2, Cap: 3 * time.Second, Jitter: 0.5}, want: []time.Duration{2 * time.Second, 3 * time.Second, 3 * time.Second}},
673 {initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 6, Jitter: 4}, want: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second, 8 * time.Second, 16 * time.Second, 32 * time.Second}},
674 }
675 for seed := int64(0); seed < 5; seed++ {
676 for _, tt := range tests {
677 var initial *Backoff
678 if tt.initial != nil {
679 copied := *tt.initial
680 initial = &copied
681 } else {
682 initial = nil
683 }
684 t.Run(fmt.Sprintf("%#v seed=%d", initial, seed), func(t *testing.T) {
685 rand.Seed(seed)
686 for i := 0; i < len(tt.want); i++ {
687 got := initial.Step()
688 t.Logf("[%d]=%s", i, got)
689 if initial != nil && initial.Jitter > 0 {
690 if got == tt.want[i] {
691
692 t.Errorf("Backoff.Step(%d) = %v, no jitter", i, got)
693 continue
694 }
695 diff := float64(tt.want[i]-got) / float64(tt.want[i])
696 if diff > initial.Jitter {
697 t.Errorf("Backoff.Step(%d) = %v, want %v, outside range", i, got, tt.want)
698 continue
699 }
700 } else {
701 if got != tt.want[i] {
702 t.Errorf("Backoff.Step(%d) = %v, want %v", i, got, tt.want)
703 continue
704 }
705 }
706 }
707 })
708 }
709 }
710 }
711
712 func TestContextForChannel(t *testing.T) {
713 var wg sync.WaitGroup
714 parentCh := make(chan struct{})
715 done := make(chan struct{})
716
717 for i := 0; i < 3; i++ {
718 wg.Add(1)
719 go func() {
720 defer wg.Done()
721 ctx := ContextForChannel(parentCh)
722 <-ctx.Done()
723 }()
724 }
725
726 go func() {
727 wg.Wait()
728 close(done)
729 }()
730
731
732 close(parentCh)
733
734 select {
735 case <-done:
736 case <-time.After(ForeverTestTimeout):
737 t.Errorf("unexpected timeout waiting for parent to cancel child contexts")
738 }
739 }
740
741 func TestExponentialBackoffManagerGetNextBackoff(t *testing.T) {
742 fc := testingclock.NewFakeClock(time.Now())
743 backoff := NewExponentialBackoffManager(1, 10, 10, 2.0, 0.0, fc)
744 durations := []time.Duration{1, 2, 4, 8, 10, 10, 10}
745 for i := 0; i < len(durations); i++ {
746 generatedBackoff := backoff.(*exponentialBackoffManagerImpl).getNextBackoff()
747 if generatedBackoff != durations[i] {
748 t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i])
749 }
750 }
751
752 fc.Step(11)
753 resetDuration := backoff.(*exponentialBackoffManagerImpl).getNextBackoff()
754 if resetDuration != 1 {
755 t.Errorf("after reset, backoff should be 1, but got %d", resetDuration)
756 }
757 }
758
759 func TestJitteredBackoffManagerGetNextBackoff(t *testing.T) {
760
761 backoffMgr := NewJitteredBackoffManager(1, 1, testingclock.NewFakeClock(time.Now()))
762 for i := 0; i < 5; i++ {
763 backoff := backoffMgr.(*jitteredBackoffManagerImpl).getNextBackoff()
764 if backoff < 1 || backoff > 2 {
765 t.Errorf("backoff out of range: %d", backoff)
766 }
767 }
768
769
770 backoffMgr = NewJitteredBackoffManager(1, -1, testingclock.NewFakeClock(time.Now()))
771 backoff := backoffMgr.(*jitteredBackoffManagerImpl).getNextBackoff()
772 if backoff != 1 {
773 t.Errorf("backoff should be 1, but got %d", backoff)
774 }
775 }
776
777 func TestJitterBackoffManagerWithRealClock(t *testing.T) {
778 backoffMgr := NewJitteredBackoffManager(1*time.Millisecond, 0, &clock.RealClock{})
779 for i := 0; i < 5; i++ {
780 start := time.Now()
781 <-backoffMgr.Backoff().C()
782 passed := time.Since(start)
783 if passed < 1*time.Millisecond {
784 t.Errorf("backoff should be at least 1ms, but got %s", passed.String())
785 }
786 }
787 }
788
789 func TestExponentialBackoffManagerWithRealClock(t *testing.T) {
790
791 durationFactors := []time.Duration{1, 2, 4, 8, 10, 10, 10}
792 backoffMgr := NewExponentialBackoffManager(1*time.Millisecond, 10*time.Millisecond, 1*time.Hour, 2.0, 0.0, &clock.RealClock{})
793
794 for i := range durationFactors {
795 start := time.Now()
796 <-backoffMgr.Backoff().C()
797 passed := time.Since(start)
798 if passed < durationFactors[i]*time.Millisecond {
799 t.Errorf("backoff should be at least %d ms, but got %s", durationFactors[i], passed.String())
800 }
801 }
802 }
803
804 func TestBackoffDelayWithResetExponential(t *testing.T) {
805 fc := testingclock.NewFakeClock(time.Now())
806 backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 10)
807 durations := []time.Duration{1, 2, 4, 8, 10, 10, 10}
808 for i := 0; i < len(durations); i++ {
809 generatedBackoff := backoff()
810 if generatedBackoff != durations[i] {
811 t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i])
812 }
813 }
814
815 fc.Step(11)
816 resetDuration := backoff()
817 if resetDuration != 1 {
818 t.Errorf("after reset, backoff should be 1, but got %d", resetDuration)
819 }
820 }
821
822 func TestBackoffDelayWithResetEmpty(t *testing.T) {
823 fc := testingclock.NewFakeClock(time.Now())
824 backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 0)
825
826 durations := []time.Duration{1, 1, 1, 1, 1, 1, 1}
827 for i := 0; i < len(durations); i++ {
828 generatedBackoff := backoff()
829 if generatedBackoff != durations[i] {
830 t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i])
831 }
832 }
833
834 fc.Step(11)
835 resetDuration := backoff()
836 if resetDuration != 1 {
837 t.Errorf("after reset, backoff should be 1, but got %d", resetDuration)
838 }
839 }
840
841 func TestBackoffDelayWithResetJitter(t *testing.T) {
842
843 backoff := Backoff{Duration: 1, Jitter: 1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0)
844 for i := 0; i < 5; i++ {
845 value := backoff()
846 if value < 1 || value > 2 {
847 t.Errorf("backoff out of range: %d", value)
848 }
849 }
850
851
852 backoff = Backoff{Duration: 1, Jitter: -1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0)
853 value := backoff()
854 if value != 1 {
855 t.Errorf("backoff should be 1, but got %d", value)
856 }
857 }
858
859 func TestBackoffDelayWithResetWithRealClockJitter(t *testing.T) {
860 backoff := Backoff{Duration: 1 * time.Millisecond, Jitter: 0}.DelayWithReset(&clock.RealClock{}, 0)
861 for i := 0; i < 5; i++ {
862 start := time.Now()
863 <-RealTimer(backoff()).C()
864 passed := time.Since(start)
865 if passed < 1*time.Millisecond {
866 t.Errorf("backoff should be at least 1ms, but got %s", passed.String())
867 }
868 }
869 }
870
871 func TestBackoffDelayWithResetWithRealClockExponential(t *testing.T) {
872
873 durationFactors := []time.Duration{1, 2, 4, 8, 10, 10, 10}
874 backoff := Backoff{Duration: 1 * time.Millisecond, Cap: 10 * time.Millisecond, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(&clock.RealClock{}, 1*time.Hour)
875
876 for i := range durationFactors {
877 start := time.Now()
878 <-RealTimer(backoff()).C()
879 passed := time.Since(start)
880 if passed < durationFactors[i]*time.Millisecond {
881 t.Errorf("backoff should be at least %d ms, but got %s", durationFactors[i], passed.String())
882 }
883 }
884 }
885
886 func defaultContext() (context.Context, context.CancelFunc) {
887 return context.WithCancel(context.Background())
888 }
889 func cancelledContext() (context.Context, context.CancelFunc) {
890 ctx, cancel := context.WithCancel(context.Background())
891 cancel()
892 return ctx, cancel
893 }
894 func deadlinedContext() (context.Context, context.CancelFunc) {
895 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
896 for ctx.Err() != context.DeadlineExceeded {
897 time.Sleep(501 * time.Microsecond)
898 }
899 return ctx, cancel
900 }
901
902 func TestExponentialBackoffWithContext(t *testing.T) {
903 defaultCallback := func(_ int) (bool, error) {
904 return false, nil
905 }
906
907 conditionErr := errors.New("condition failed")
908
909 tests := []struct {
910 name string
911 steps int
912 zeroDuration bool
913 context func() (context.Context, context.CancelFunc)
914 callback func(calls int) (bool, error)
915 cancelContextAfter int
916 attemptsExpected int
917 errExpected error
918 }{
919 {
920 name: "no attempts expected with zero backoff steps",
921 steps: 0,
922 callback: defaultCallback,
923 attemptsExpected: 0,
924 errExpected: ErrWaitTimeout,
925 },
926 {
927 name: "condition returns false with single backoff step",
928 steps: 1,
929 callback: defaultCallback,
930 attemptsExpected: 1,
931 errExpected: ErrWaitTimeout,
932 },
933 {
934 name: "condition returns true with single backoff step",
935 steps: 1,
936 callback: func(_ int) (bool, error) {
937 return true, nil
938 },
939 attemptsExpected: 1,
940 errExpected: nil,
941 },
942 {
943 name: "condition always returns false with multiple backoff steps",
944 steps: 5,
945 callback: defaultCallback,
946 attemptsExpected: 5,
947 errExpected: ErrWaitTimeout,
948 },
949 {
950 name: "condition returns true after certain attempts with multiple backoff steps",
951 steps: 5,
952 callback: func(attempts int) (bool, error) {
953 if attempts == 3 {
954 return true, nil
955 }
956 return false, nil
957 },
958 attemptsExpected: 3,
959 errExpected: nil,
960 },
961 {
962 name: "condition returns error no further attempts expected",
963 steps: 5,
964 callback: func(_ int) (bool, error) {
965 return true, conditionErr
966 },
967 attemptsExpected: 1,
968 errExpected: conditionErr,
969 },
970 {
971 name: "context already canceled no attempts expected",
972 steps: 5,
973 context: cancelledContext,
974 callback: defaultCallback,
975 attemptsExpected: 0,
976 errExpected: context.Canceled,
977 },
978 {
979 name: "context at deadline no attempts expected",
980 steps: 5,
981 context: deadlinedContext,
982 callback: defaultCallback,
983 attemptsExpected: 0,
984 errExpected: context.DeadlineExceeded,
985 },
986 {
987 name: "no attempts expected with zero backoff steps",
988 steps: 0,
989 callback: defaultCallback,
990 attemptsExpected: 0,
991 errExpected: ErrWaitTimeout,
992 },
993 {
994 name: "condition returns false with single backoff step",
995 steps: 1,
996 callback: defaultCallback,
997 attemptsExpected: 1,
998 errExpected: ErrWaitTimeout,
999 },
1000 {
1001 name: "condition returns true with single backoff step",
1002 steps: 1,
1003 callback: func(_ int) (bool, error) {
1004 return true, nil
1005 },
1006 attemptsExpected: 1,
1007 errExpected: nil,
1008 },
1009 {
1010 name: "condition always returns false with multiple backoff steps but is cancelled at step 4",
1011 steps: 5,
1012 callback: defaultCallback,
1013 attemptsExpected: 4,
1014 cancelContextAfter: 4,
1015 errExpected: context.Canceled,
1016 },
1017 {
1018 name: "condition returns true after certain attempts with multiple backoff steps and zero duration",
1019 steps: 5,
1020 zeroDuration: true,
1021 callback: func(attempts int) (bool, error) {
1022 if attempts == 3 {
1023 return true, nil
1024 }
1025 return false, nil
1026 },
1027 attemptsExpected: 3,
1028 errExpected: nil,
1029 },
1030 {
1031 name: "condition returns error no further attempts expected",
1032 steps: 5,
1033 callback: func(_ int) (bool, error) {
1034 return true, conditionErr
1035 },
1036 attemptsExpected: 1,
1037 errExpected: conditionErr,
1038 },
1039 {
1040 name: "context already canceled no attempts expected",
1041 steps: 5,
1042 context: cancelledContext,
1043 callback: defaultCallback,
1044 attemptsExpected: 0,
1045 errExpected: context.Canceled,
1046 },
1047 {
1048 name: "context at deadline no attempts expected",
1049 steps: 5,
1050 context: deadlinedContext,
1051 callback: defaultCallback,
1052 attemptsExpected: 0,
1053 errExpected: context.DeadlineExceeded,
1054 },
1055 }
1056
1057 for _, test := range tests {
1058 t.Run(test.name, func(t *testing.T) {
1059 backoff := Backoff{
1060 Duration: 1 * time.Microsecond,
1061 Factor: 1.0,
1062 Steps: test.steps,
1063 }
1064 if test.zeroDuration {
1065 backoff.Duration = 0
1066 }
1067
1068 contextFn := test.context
1069 if contextFn == nil {
1070 contextFn = defaultContext
1071 }
1072 ctx, cancel := contextFn()
1073 defer cancel()
1074
1075 attempts := 0
1076 err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) {
1077 attempts++
1078 defer func() {
1079 if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts {
1080 cancel()
1081 }
1082 }()
1083 return test.callback(attempts)
1084 })
1085
1086 if test.errExpected != err {
1087 t.Errorf("expected error: %v but got: %v", test.errExpected, err)
1088 }
1089
1090 if test.attemptsExpected != attempts {
1091 t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts)
1092 }
1093 })
1094 }
1095 }
1096
1097 func BenchmarkExponentialBackoffWithContext(b *testing.B) {
1098 backoff := Backoff{
1099 Duration: 0,
1100 Factor: 0,
1101 Steps: 101,
1102 }
1103 ctx := context.Background()
1104
1105 b.ResetTimer()
1106 for i := 0; i < b.N; i++ {
1107 attempts := 0
1108 if err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) {
1109 attempts++
1110 return attempts >= 100, nil
1111 }); err != nil {
1112 b.Fatalf("unexpected err: %v", err)
1113 }
1114 }
1115 }
1116
1117 func TestPollImmediateUntilWithContext(t *testing.T) {
1118 fakeErr := errors.New("my error")
1119 tests := []struct {
1120 name string
1121 condition func(int) ConditionWithContextFunc
1122 context func() (context.Context, context.CancelFunc)
1123 cancelContextAfterNthAttempt int
1124 errExpected error
1125 attemptsExpected int
1126 }{
1127 {
1128 name: "condition throws error on immediate attempt, no retry is attempted",
1129 condition: func(int) ConditionWithContextFunc {
1130 return func(context.Context) (done bool, err error) {
1131 return false, fakeErr
1132 }
1133 },
1134 errExpected: fakeErr,
1135 attemptsExpected: 1,
1136 },
1137 {
1138 name: "condition returns done=true on immediate attempt, no retry is attempted",
1139 condition: func(int) ConditionWithContextFunc {
1140 return func(context.Context) (done bool, err error) {
1141 return true, nil
1142 }
1143 },
1144 errExpected: nil,
1145 attemptsExpected: 1,
1146 },
1147 {
1148 name: "condition returns done=false on immediate attempt, context is already cancelled, no retry is attempted",
1149 condition: func(int) ConditionWithContextFunc {
1150 return func(context.Context) (done bool, err error) {
1151 return false, nil
1152 }
1153 },
1154 context: cancelledContext,
1155 errExpected: ErrWaitTimeout,
1156 attemptsExpected: 1,
1157 },
1158 {
1159 name: "condition returns done=false on immediate attempt, context is not cancelled, retry is attempted",
1160 condition: func(attempts int) ConditionWithContextFunc {
1161 return func(context.Context) (done bool, err error) {
1162
1163 if attempts <= 3 {
1164 return false, nil
1165 }
1166 return true, nil
1167 }
1168 },
1169 errExpected: nil,
1170 attemptsExpected: 4,
1171 },
1172 {
1173 name: "condition always returns done=false, context gets cancelled after N attempts",
1174 condition: func(attempts int) ConditionWithContextFunc {
1175 return func(ctx context.Context) (done bool, err error) {
1176 return false, nil
1177 }
1178 },
1179 cancelContextAfterNthAttempt: 4,
1180 errExpected: ErrWaitTimeout,
1181 attemptsExpected: 4,
1182 },
1183 }
1184
1185 for _, test := range tests {
1186 t.Run(test.name, func(t *testing.T) {
1187 contextFn := test.context
1188 if contextFn == nil {
1189 contextFn = defaultContext
1190 }
1191 ctx, cancel := contextFn()
1192 defer cancel()
1193
1194 var attempts int
1195 conditionWrapper := func(ctx context.Context) (done bool, err error) {
1196 attempts++
1197 defer func() {
1198 if test.cancelContextAfterNthAttempt == attempts {
1199 cancel()
1200 }
1201 }()
1202
1203 c := test.condition(attempts)
1204 return c(ctx)
1205 }
1206
1207 err := PollImmediateUntilWithContext(ctx, time.Millisecond, conditionWrapper)
1208 if test.errExpected != err {
1209 t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
1210 }
1211 if test.attemptsExpected != attempts {
1212 t.Errorf("Expected ConditionFunc to be invoked: %d times, but got: %d", test.attemptsExpected, attempts)
1213 }
1214 })
1215 }
1216 }
1217
1218 func Test_waitForWithContext(t *testing.T) {
1219 fakeErr := errors.New("fake error")
1220 tests := []struct {
1221 name string
1222 context func() (context.Context, context.CancelFunc)
1223 condition ConditionWithContextFunc
1224 waitFunc func() waitFunc
1225 attemptsExpected int
1226 errExpected error
1227 }{
1228 {
1229 name: "condition returns done=true on first attempt, no retry is attempted",
1230 context: defaultContext,
1231 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1232 return true, nil
1233 }),
1234 waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) },
1235 attemptsExpected: 1,
1236 errExpected: nil,
1237 },
1238 {
1239 name: "condition always returns done=false, timeout error expected",
1240 context: defaultContext,
1241 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1242 return false, nil
1243 }),
1244 waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) },
1245
1246 attemptsExpected: 3,
1247 errExpected: ErrWaitTimeout,
1248 },
1249 {
1250 name: "condition returns an error on first attempt, the error is returned",
1251 context: defaultContext,
1252 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1253 return false, fakeErr
1254 }),
1255 waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) },
1256 attemptsExpected: 1,
1257 errExpected: fakeErr,
1258 },
1259 {
1260 name: "context is cancelled, context cancelled error expected",
1261 context: cancelledContext,
1262 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1263 return false, nil
1264 }),
1265 waitFunc: func() waitFunc {
1266 return func(done <-chan struct{}) <-chan struct{} {
1267 ch := make(chan struct{})
1268
1269 return ch
1270 }
1271 },
1272 attemptsExpected: 0,
1273 errExpected: ErrWaitTimeout,
1274 },
1275 }
1276
1277 for _, test := range tests {
1278 t.Run(test.name, func(t *testing.T) {
1279 var attempts int
1280 conditionWrapper := func(ctx context.Context) (done bool, err error) {
1281 attempts++
1282 return test.condition(ctx)
1283 }
1284
1285 ticker := test.waitFunc()
1286 err := func() error {
1287 contextFn := test.context
1288 if contextFn == nil {
1289 contextFn = defaultContext
1290 }
1291 ctx, cancel := contextFn()
1292 defer cancel()
1293
1294 return waitForWithContext(ctx, ticker.WithContext(), conditionWrapper)
1295 }()
1296
1297 if test.errExpected != err {
1298 t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
1299 }
1300 if test.attemptsExpected != attempts {
1301 t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts)
1302 }
1303 })
1304 }
1305 }
1306
1307 func Test_poll(t *testing.T) {
1308 fakeErr := errors.New("fake error")
1309 tests := []struct {
1310 name string
1311 context func() (context.Context, context.CancelFunc)
1312 immediate bool
1313 waitFunc func() waitFunc
1314 condition ConditionWithContextFunc
1315 cancelContextAfter int
1316 attemptsExpected int
1317 errExpected error
1318 }{
1319 {
1320 name: "immediate is true, condition returns an error",
1321 immediate: true,
1322 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1323 return false, fakeErr
1324 }),
1325 waitFunc: nil,
1326 attemptsExpected: 1,
1327 errExpected: fakeErr,
1328 },
1329 {
1330 name: "immediate is true, condition returns true",
1331 immediate: true,
1332 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1333 return true, nil
1334 }),
1335 waitFunc: nil,
1336 attemptsExpected: 1,
1337 errExpected: nil,
1338 },
1339 {
1340 name: "immediate is true, context is cancelled, condition return false",
1341 immediate: true,
1342 context: cancelledContext,
1343 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1344 return false, nil
1345 }),
1346 waitFunc: nil,
1347 attemptsExpected: 1,
1348 errExpected: ErrWaitTimeout,
1349 },
1350 {
1351 name: "immediate is false, context is cancelled",
1352 immediate: false,
1353 context: cancelledContext,
1354 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1355 return false, nil
1356 }),
1357 waitFunc: nil,
1358 attemptsExpected: 0,
1359 errExpected: ErrWaitTimeout,
1360 },
1361 {
1362 name: "immediate is false, condition returns an error",
1363 immediate: false,
1364 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1365 return false, fakeErr
1366 }),
1367 waitFunc: func() waitFunc { return fakeTicker(5, nil, func() {}) },
1368 attemptsExpected: 1,
1369 errExpected: fakeErr,
1370 },
1371 {
1372 name: "immediate is false, condition returns true",
1373 immediate: false,
1374 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1375 return true, nil
1376 }),
1377 waitFunc: func() waitFunc { return fakeTicker(5, nil, func() {}) },
1378 attemptsExpected: 1,
1379 errExpected: nil,
1380 },
1381 {
1382 name: "immediate is false, ticker channel is closed, condition returns true",
1383 immediate: false,
1384 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1385 return true, nil
1386 }),
1387 waitFunc: func() waitFunc {
1388 return func(done <-chan struct{}) <-chan struct{} {
1389 ch := make(chan struct{})
1390 close(ch)
1391 return ch
1392 }
1393 },
1394 attemptsExpected: 1,
1395 errExpected: nil,
1396 },
1397 {
1398 name: "immediate is false, ticker channel is closed, condition returns error",
1399 immediate: false,
1400 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1401 return false, fakeErr
1402 }),
1403 waitFunc: func() waitFunc {
1404 return func(done <-chan struct{}) <-chan struct{} {
1405 ch := make(chan struct{})
1406 close(ch)
1407 return ch
1408 }
1409 },
1410 attemptsExpected: 1,
1411 errExpected: fakeErr,
1412 },
1413 {
1414 name: "immediate is false, ticker channel is closed, condition returns false",
1415 immediate: false,
1416 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1417 return false, nil
1418 }),
1419 waitFunc: func() waitFunc {
1420 return func(done <-chan struct{}) <-chan struct{} {
1421 ch := make(chan struct{})
1422 close(ch)
1423 return ch
1424 }
1425 },
1426 attemptsExpected: 1,
1427 errExpected: ErrWaitTimeout,
1428 },
1429 {
1430 name: "condition always returns false, timeout error expected",
1431 immediate: false,
1432 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1433 return false, nil
1434 }),
1435 waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) },
1436
1437 attemptsExpected: 3,
1438 errExpected: ErrWaitTimeout,
1439 },
1440 {
1441 name: "context is cancelled after N attempts, timeout error expected",
1442 immediate: false,
1443 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1444 return false, nil
1445 }),
1446 waitFunc: func() waitFunc {
1447 return func(done <-chan struct{}) <-chan struct{} {
1448 ch := make(chan struct{})
1449
1450 go func() {
1451 ch <- struct{}{}
1452 ch <- struct{}{}
1453 }()
1454 return ch
1455 }
1456 },
1457 cancelContextAfter: 2,
1458 attemptsExpected: 2,
1459 errExpected: ErrWaitTimeout,
1460 },
1461 {
1462 name: "context is cancelled after N attempts, context error not expected (legacy behavior)",
1463 immediate: false,
1464 condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
1465 return false, nil
1466 }),
1467 waitFunc: func() waitFunc {
1468 return func(done <-chan struct{}) <-chan struct{} {
1469 ch := make(chan struct{})
1470
1471 go func() {
1472 ch <- struct{}{}
1473 ch <- struct{}{}
1474 }()
1475 return ch
1476 }
1477 },
1478 cancelContextAfter: 2,
1479 attemptsExpected: 2,
1480 errExpected: ErrWaitTimeout,
1481 },
1482 }
1483
1484 for _, test := range tests {
1485 t.Run(test.name, func(t *testing.T) {
1486 var attempts int
1487 ticker := waitFunc(func(done <-chan struct{}) <-chan struct{} {
1488 return nil
1489 })
1490 if test.waitFunc != nil {
1491 ticker = test.waitFunc()
1492 }
1493 err := func() error {
1494 contextFn := test.context
1495 if contextFn == nil {
1496 contextFn = defaultContext
1497 }
1498 ctx, cancel := contextFn()
1499 defer cancel()
1500
1501 conditionWrapper := func(ctx context.Context) (done bool, err error) {
1502 attempts++
1503
1504 defer func() {
1505 if test.cancelContextAfter == attempts {
1506 cancel()
1507 }
1508 }()
1509
1510 return test.condition(ctx)
1511 }
1512
1513 return poll(ctx, test.immediate, ticker.WithContext(), conditionWrapper)
1514 }()
1515
1516 if test.errExpected != err {
1517 t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
1518 }
1519 if test.attemptsExpected != attempts {
1520 t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts)
1521 }
1522 })
1523 }
1524 }
1525
1526 func Benchmark_poll(b *testing.B) {
1527 ctx := context.Background()
1528 b.ResetTimer()
1529 for i := 0; i < b.N; i++ {
1530 attempts := 0
1531 if err := poll(ctx, true, poller(time.Microsecond, 0), func(_ context.Context) (bool, error) {
1532 attempts++
1533 return attempts >= 100, nil
1534 }); err != nil {
1535 b.Fatalf("unexpected err: %v", err)
1536 }
1537 }
1538 }
1539
View as plain text