/* Copyright 2014 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package wait import ( "context" "errors" "fmt" "math/rand" "sync" "sync/atomic" "testing" "time" "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/utils/clock" testingclock "k8s.io/utils/clock/testing" ) func TestUntil(t *testing.T) { ch := make(chan struct{}) close(ch) Until(func() { t.Fatal("should not have been invoked") }, 0, ch) ch = make(chan struct{}) called := make(chan struct{}) go func() { Until(func() { called <- struct{}{} }, 0, ch) close(called) }() <-called close(ch) <-called } func TestUntilWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) cancel() UntilWithContext(ctx, func(context.Context) { t.Fatal("should not have been invoked") }, 0) ctx, cancel = context.WithCancel(context.TODO()) called := make(chan struct{}) go func() { UntilWithContext(ctx, func(context.Context) { called <- struct{}{} }, 0) close(called) }() <-called cancel() <-called } func TestNonSlidingUntil(t *testing.T) { ch := make(chan struct{}) close(ch) NonSlidingUntil(func() { t.Fatal("should not have been invoked") }, 0, ch) ch = make(chan struct{}) called := make(chan struct{}) go func() { NonSlidingUntil(func() { called <- struct{}{} }, 0, ch) close(called) }() <-called close(ch) <-called } func TestNonSlidingUntilWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) cancel() NonSlidingUntilWithContext(ctx, func(context.Context) { t.Fatal("should not have been invoked") }, 0) ctx, cancel = context.WithCancel(context.TODO()) called := make(chan struct{}) go func() { NonSlidingUntilWithContext(ctx, func(context.Context) { called <- struct{}{} }, 0) close(called) }() <-called cancel() <-called } func TestUntilReturnsImmediately(t *testing.T) { now := time.Now() ch := make(chan struct{}) var attempts int Until(func() { attempts++ if attempts > 1 { t.Fatalf("invoked after close of channel") } close(ch) }, 30*time.Second, ch) if now.Add(25 * time.Second).Before(time.Now()) { t.Errorf("Until did not return immediately when the stop chan was closed inside the func") } } func TestJitterUntil(t *testing.T) { ch := make(chan struct{}) // if a channel is closed JitterUntil never calls function f // and returns immediately close(ch) JitterUntil(func() { t.Fatal("should not have been invoked") }, 0, 1.0, true, ch) ch = make(chan struct{}) called := make(chan struct{}) go func() { JitterUntil(func() { called <- struct{}{} }, 0, 1.0, true, ch) close(called) }() <-called close(ch) <-called } func TestJitterUntilWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) cancel() JitterUntilWithContext(ctx, func(context.Context) { t.Fatal("should not have been invoked") }, 0, 1.0, true) ctx, cancel = context.WithCancel(context.TODO()) called := make(chan struct{}) go func() { JitterUntilWithContext(ctx, func(context.Context) { called <- struct{}{} }, 0, 1.0, true) close(called) }() <-called cancel() <-called } func TestJitterUntilReturnsImmediately(t *testing.T) { now := time.Now() ch := make(chan struct{}) JitterUntil(func() { close(ch) }, 30*time.Second, 1.0, true, ch) if now.Add(25 * time.Second).Before(time.Now()) { t.Errorf("JitterUntil did not return immediately when the stop chan was closed inside the func") } } func TestJitterUntilRecoversPanic(t *testing.T) { // Save and restore crash handlers originalReallyCrash := runtime.ReallyCrash originalHandlers := runtime.PanicHandlers defer func() { runtime.ReallyCrash = originalReallyCrash runtime.PanicHandlers = originalHandlers }() called := 0 handled := 0 // Hook up a custom crash handler to ensure it is called when a jitter function panics runtime.ReallyCrash = false runtime.PanicHandlers = []func(interface{}){ func(p interface{}) { handled++ }, } ch := make(chan struct{}) JitterUntil(func() { called++ if called > 2 { close(ch) return } panic("TestJitterUntilRecoversPanic") }, time.Millisecond, 1.0, true, ch) if called != 3 { t.Errorf("Expected panic recovers") } } func TestJitterUntilNegativeFactor(t *testing.T) { now := time.Now() ch := make(chan struct{}) called := make(chan struct{}) received := make(chan struct{}) go func() { JitterUntil(func() { called <- struct{}{} <-received }, time.Second, -30.0, true, ch) }() // first loop <-called received <- struct{}{} // second loop <-called close(ch) received <- struct{}{} // it should take at most 2 seconds + some overhead, not 3 if now.Add(3 * time.Second).Before(time.Now()) { t.Errorf("JitterUntil did not returned after predefined period with negative jitter factor when the stop chan was closed inside the func") } } func TestExponentialBackoff(t *testing.T) { // exits immediately i := 0 err := ExponentialBackoff(Backoff{Factor: 1.0}, func() (bool, error) { i++ return false, nil }) if err != ErrWaitTimeout || i != 0 { t.Errorf("unexpected error: %v", err) } opts := Backoff{Factor: 1.0, Steps: 3} // waits up to steps i = 0 err = ExponentialBackoff(opts, func() (bool, error) { i++ return false, nil }) if err != ErrWaitTimeout || i != opts.Steps { t.Errorf("unexpected error: %v", err) } // returns immediately i = 0 err = ExponentialBackoff(opts, func() (bool, error) { i++ return true, nil }) if err != nil || i != 1 { t.Errorf("unexpected error: %v", err) } // returns immediately on error testErr := fmt.Errorf("some other error") err = ExponentialBackoff(opts, func() (bool, error) { return false, testErr }) if err != testErr { t.Errorf("unexpected error: %v", err) } // invoked multiple times i = 1 err = ExponentialBackoff(opts, func() (bool, error) { if i < opts.Steps { i++ return false, nil } return true, nil }) if err != nil || i != opts.Steps { t.Errorf("unexpected error: %v", err) } } func TestPoller(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() w := poller(time.Millisecond, 2*time.Millisecond) ch := w(ctx) count := 0 DRAIN: for { select { case _, open := <-ch: if !open { break DRAIN } count++ case <-time.After(ForeverTestTimeout): t.Errorf("unexpected timeout after poll") } } if count > 3 { t.Errorf("expected up to three values, got %d", count) } } type fakePoller struct { max int used int32 // accessed with atomics wg sync.WaitGroup } func fakeTicker(max int, used *int32, doneFunc func()) waitFunc { return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) go func() { defer doneFunc() defer close(ch) for i := 0; i < max; i++ { select { case ch <- struct{}{}: case <-done: return } if used != nil { atomic.AddInt32(used, 1) } } }() return ch } } func (fp *fakePoller) GetwaitFunc() waitFunc { fp.wg.Add(1) return fakeTicker(fp.max, &fp.used, fp.wg.Done) } func TestPoll(t *testing.T) { invocations := 0 f := ConditionWithContextFunc(func(ctx context.Context) (bool, error) { invocations++ return true, nil }) fp := fakePoller{max: 1} ctx, cancel := context.WithCancel(context.Background()) defer cancel() if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f); err != nil { t.Fatalf("unexpected error %v", err) } fp.wg.Wait() if invocations != 1 { t.Errorf("Expected exactly one invocation, got %d", invocations) } used := atomic.LoadInt32(&fp.used) if used != 1 { t.Errorf("Expected exactly one tick, got %d", used) } } func TestPollError(t *testing.T) { expectedError := errors.New("Expected error") f := ConditionFunc(func() (bool, error) { return false, expectedError }) fp := fakePoller{max: 1} ctx, cancel := context.WithCancel(context.Background()) defer cancel() if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError { t.Fatalf("Expected error %v, got none %v", expectedError, err) } fp.wg.Wait() used := atomic.LoadInt32(&fp.used) if used != 1 { t.Errorf("Expected exactly one tick, got %d", used) } } func TestPollImmediate(t *testing.T) { invocations := 0 f := ConditionFunc(func() (bool, error) { invocations++ return true, nil }) fp := fakePoller{max: 0} ctx, cancel := context.WithCancel(context.Background()) defer cancel() if err := poll(ctx, true, fp.GetwaitFunc().WithContext(), f.WithContext()); err != nil { t.Fatalf("unexpected error %v", err) } // We don't need to wait for fp.wg, as pollImmediate shouldn't call waitFunc at all. if invocations != 1 { t.Errorf("Expected exactly one invocation, got %d", invocations) } used := atomic.LoadInt32(&fp.used) if used != 0 { t.Errorf("Expected exactly zero ticks, got %d", used) } } func TestPollImmediateError(t *testing.T) { expectedError := errors.New("Expected error") f := ConditionFunc(func() (bool, error) { return false, expectedError }) fp := fakePoller{max: 0} ctx, cancel := context.WithCancel(context.Background()) defer cancel() if err := poll(ctx, true, fp.GetwaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError { t.Fatalf("Expected error %v, got none %v", expectedError, err) } // We don't need to wait for fp.wg, as pollImmediate shouldn't call waitFunc at all. used := atomic.LoadInt32(&fp.used) if used != 0 { t.Errorf("Expected exactly zero ticks, got %d", used) } } func TestPollForever(t *testing.T) { ch := make(chan struct{}) errc := make(chan error, 1) done := make(chan struct{}, 1) complete := make(chan struct{}) go func() { f := ConditionFunc(func() (bool, error) { ch <- struct{}{} select { case <-done: return true, nil default: } return false, nil }) if err := PollInfinite(time.Microsecond, f); err != nil { errc <- fmt.Errorf("unexpected error %v", err) } close(ch) complete <- struct{}{} }() // ensure the condition is opened <-ch // ensure channel sends events for i := 0; i < 10; i++ { select { case _, open := <-ch: if !open { if len(errc) != 0 { t.Fatalf("did not expect channel to be closed, %v", <-errc) } t.Fatal("did not expect channel to be closed") } case <-time.After(ForeverTestTimeout): t.Fatalf("channel did not return at least once within the poll interval") } } // at most one poll notification should be sent once we return from the condition done <- struct{}{} go func() { for i := 0; i < 2; i++ { _, open := <-ch if !open { return } } t.Error("expected closed channel after two iterations") }() <-complete if len(errc) != 0 { t.Fatal(<-errc) } } func Test_waitFor(t *testing.T) { var invocations int testCases := map[string]struct { F ConditionFunc Ticks int Invoked int Err bool }{ "invoked once": { ConditionFunc(func() (bool, error) { invocations++ return true, nil }), 2, 1, false, }, "invoked and returns a timeout": { ConditionFunc(func() (bool, error) { invocations++ return false, nil }), 2, 3, // the contract of waitFor() says the func is called once more at the end of the wait true, }, "returns immediately on error": { ConditionFunc(func() (bool, error) { invocations++ return false, errors.New("test") }), 2, 1, true, }, } for k, c := range testCases { invocations = 0 ticker := fakeTicker(c.Ticks, nil, func() {}) err := func() error { done := make(chan struct{}) defer close(done) ctx := ContextForChannel(done) return waitForWithContext(ctx, ticker.WithContext(), c.F.WithContext()) }() switch { case c.Err && err == nil: t.Errorf("%s: Expected error, got nil", k) continue case !c.Err && err != nil: t.Errorf("%s: Expected no error, got: %#v", k, err) continue } if invocations != c.Invoked { t.Errorf("%s: Expected %d invocations, got %d", k, c.Invoked, invocations) } } } // Test_waitForWithEarlyClosing_waitFunc tests WaitFor when the waitFunc closes its channel. The WaitFor should // always return ErrWaitTimeout. func Test_waitForWithEarlyClosing_waitFunc(t *testing.T) { stopCh := make(chan struct{}) defer close(stopCh) ctx := ContextForChannel(stopCh) start := time.Now() err := waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} { c := make(chan struct{}) close(c) return c }, func(_ context.Context) (bool, error) { return false, nil }) duration := time.Since(start) // The waitFor should return immediately, so the duration is close to 0s. if duration >= ForeverTestTimeout/2 { t.Errorf("expected short timeout duration") } if err != ErrWaitTimeout { t.Errorf("expected ErrWaitTimeout from WaitFunc") } } // Test_waitForWithClosedChannel tests waitFor when it receives a closed channel. The waitFor should // always return ErrWaitTimeout. func Test_waitForWithClosedChannel(t *testing.T) { stopCh := make(chan struct{}) close(stopCh) c := make(chan struct{}) defer close(c) ctx := ContextForChannel(stopCh) start := time.Now() err := waitForWithContext(ctx, func(_ context.Context) <-chan struct{} { return c }, func(_ context.Context) (bool, error) { return false, nil }) duration := time.Since(start) // The waitFor should return immediately, so the duration is close to 0s. if duration >= ForeverTestTimeout/2 { t.Errorf("expected short timeout duration") } // The interval of the poller is ForeverTestTimeout, so the waitFor should always return ErrWaitTimeout. if err != ErrWaitTimeout { t.Errorf("expected ErrWaitTimeout from WaitFunc") } } // Test_waitForWithContextCancelsContext verifies that after the condition func returns true, // waitForWithContext cancels the context it supplies to the WaitWithContextFunc. func Test_waitForWithContextCancelsContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() waitFn := poller(time.Millisecond, ForeverTestTimeout) var ctxPassedToWait context.Context waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} { ctxPassedToWait = ctx return waitFn(ctx) }, func(ctx context.Context) (bool, error) { time.Sleep(10 * time.Millisecond) return true, nil }) // The polling goroutine should be closed after waitForWithContext returning. if ctxPassedToWait.Err() != context.Canceled { t.Errorf("expected the context passed to waitForWithContext to be closed with: %v, but got: %v", context.Canceled, ctxPassedToWait.Err()) } } func TestPollUntil(t *testing.T) { stopCh := make(chan struct{}) called := make(chan bool) pollDone := make(chan struct{}) go func() { PollUntil(time.Microsecond, ConditionFunc(func() (bool, error) { called <- true return false, nil }), stopCh) close(pollDone) }() // make sure we're called once <-called // this should trigger a "done" close(stopCh) go func() { // release the condition func if needed for range called { } }() // make sure we finished the poll <-pollDone close(called) } func TestBackoff_Step(t *testing.T) { tests := []struct { initial *Backoff want []time.Duration }{ {initial: nil, want: []time.Duration{0, 0, 0, 0}}, {initial: &Backoff{Duration: time.Second, Steps: -1}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Steps: 0}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Factor: 1.0, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 3}, want: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second}}, {initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 3, Cap: 3 * time.Second}, want: []time.Duration{1 * time.Second, 2 * time.Second, 3 * time.Second}}, {initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 2, Cap: 3 * time.Second, Jitter: 0.5}, want: []time.Duration{2 * time.Second, 3 * time.Second, 3 * time.Second}}, {initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 6, Jitter: 4}, want: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second, 8 * time.Second, 16 * time.Second, 32 * time.Second}}, } for seed := int64(0); seed < 5; seed++ { for _, tt := range tests { var initial *Backoff if tt.initial != nil { copied := *tt.initial initial = &copied } else { initial = nil } t.Run(fmt.Sprintf("%#v seed=%d", initial, seed), func(t *testing.T) { rand.Seed(seed) for i := 0; i < len(tt.want); i++ { got := initial.Step() t.Logf("[%d]=%s", i, got) if initial != nil && initial.Jitter > 0 { if got == tt.want[i] { // this is statistically unlikely to happen by chance t.Errorf("Backoff.Step(%d) = %v, no jitter", i, got) continue } diff := float64(tt.want[i]-got) / float64(tt.want[i]) if diff > initial.Jitter { t.Errorf("Backoff.Step(%d) = %v, want %v, outside range", i, got, tt.want) continue } } else { if got != tt.want[i] { t.Errorf("Backoff.Step(%d) = %v, want %v", i, got, tt.want) continue } } } }) } } } func TestContextForChannel(t *testing.T) { var wg sync.WaitGroup parentCh := make(chan struct{}) done := make(chan struct{}) for i := 0; i < 3; i++ { wg.Add(1) go func() { defer wg.Done() ctx := ContextForChannel(parentCh) <-ctx.Done() }() } go func() { wg.Wait() close(done) }() // Closing parent channel should cancel all children contexts close(parentCh) select { case <-done: case <-time.After(ForeverTestTimeout): t.Errorf("unexpected timeout waiting for parent to cancel child contexts") } } func TestExponentialBackoffManagerGetNextBackoff(t *testing.T) { fc := testingclock.NewFakeClock(time.Now()) backoff := NewExponentialBackoffManager(1, 10, 10, 2.0, 0.0, fc) durations := []time.Duration{1, 2, 4, 8, 10, 10, 10} for i := 0; i < len(durations); i++ { generatedBackoff := backoff.(*exponentialBackoffManagerImpl).getNextBackoff() if generatedBackoff != durations[i] { t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i]) } } fc.Step(11) resetDuration := backoff.(*exponentialBackoffManagerImpl).getNextBackoff() if resetDuration != 1 { t.Errorf("after reset, backoff should be 1, but got %d", resetDuration) } } func TestJitteredBackoffManagerGetNextBackoff(t *testing.T) { // positive jitter backoffMgr := NewJitteredBackoffManager(1, 1, testingclock.NewFakeClock(time.Now())) for i := 0; i < 5; i++ { backoff := backoffMgr.(*jitteredBackoffManagerImpl).getNextBackoff() if backoff < 1 || backoff > 2 { t.Errorf("backoff out of range: %d", backoff) } } // negative jitter, shall be a fixed backoff backoffMgr = NewJitteredBackoffManager(1, -1, testingclock.NewFakeClock(time.Now())) backoff := backoffMgr.(*jitteredBackoffManagerImpl).getNextBackoff() if backoff != 1 { t.Errorf("backoff should be 1, but got %d", backoff) } } func TestJitterBackoffManagerWithRealClock(t *testing.T) { backoffMgr := NewJitteredBackoffManager(1*time.Millisecond, 0, &clock.RealClock{}) for i := 0; i < 5; i++ { start := time.Now() <-backoffMgr.Backoff().C() passed := time.Since(start) if passed < 1*time.Millisecond { t.Errorf("backoff should be at least 1ms, but got %s", passed.String()) } } } func TestExponentialBackoffManagerWithRealClock(t *testing.T) { // backoff at least 1ms, 2ms, 4ms, 8ms, 10ms, 10ms, 10ms durationFactors := []time.Duration{1, 2, 4, 8, 10, 10, 10} backoffMgr := NewExponentialBackoffManager(1*time.Millisecond, 10*time.Millisecond, 1*time.Hour, 2.0, 0.0, &clock.RealClock{}) for i := range durationFactors { start := time.Now() <-backoffMgr.Backoff().C() passed := time.Since(start) if passed < durationFactors[i]*time.Millisecond { t.Errorf("backoff should be at least %d ms, but got %s", durationFactors[i], passed.String()) } } } func TestBackoffDelayWithResetExponential(t *testing.T) { fc := testingclock.NewFakeClock(time.Now()) backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 10) durations := []time.Duration{1, 2, 4, 8, 10, 10, 10} for i := 0; i < len(durations); i++ { generatedBackoff := backoff() if generatedBackoff != durations[i] { t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i]) } } fc.Step(11) resetDuration := backoff() if resetDuration != 1 { t.Errorf("after reset, backoff should be 1, but got %d", resetDuration) } } func TestBackoffDelayWithResetEmpty(t *testing.T) { fc := testingclock.NewFakeClock(time.Now()) backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 0) // we reset to initial duration because the resetInterval is 0, immediate durations := []time.Duration{1, 1, 1, 1, 1, 1, 1} for i := 0; i < len(durations); i++ { generatedBackoff := backoff() if generatedBackoff != durations[i] { t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i]) } } fc.Step(11) resetDuration := backoff() if resetDuration != 1 { t.Errorf("after reset, backoff should be 1, but got %d", resetDuration) } } func TestBackoffDelayWithResetJitter(t *testing.T) { // positive jitter backoff := Backoff{Duration: 1, Jitter: 1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0) for i := 0; i < 5; i++ { value := backoff() if value < 1 || value > 2 { t.Errorf("backoff out of range: %d", value) } } // negative jitter, shall be a fixed backoff backoff = Backoff{Duration: 1, Jitter: -1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0) value := backoff() if value != 1 { t.Errorf("backoff should be 1, but got %d", value) } } func TestBackoffDelayWithResetWithRealClockJitter(t *testing.T) { backoff := Backoff{Duration: 1 * time.Millisecond, Jitter: 0}.DelayWithReset(&clock.RealClock{}, 0) for i := 0; i < 5; i++ { start := time.Now() <-RealTimer(backoff()).C() passed := time.Since(start) if passed < 1*time.Millisecond { t.Errorf("backoff should be at least 1ms, but got %s", passed.String()) } } } func TestBackoffDelayWithResetWithRealClockExponential(t *testing.T) { // backoff at least 1ms, 2ms, 4ms, 8ms, 10ms, 10ms, 10ms durationFactors := []time.Duration{1, 2, 4, 8, 10, 10, 10} backoff := Backoff{Duration: 1 * time.Millisecond, Cap: 10 * time.Millisecond, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(&clock.RealClock{}, 1*time.Hour) for i := range durationFactors { start := time.Now() <-RealTimer(backoff()).C() passed := time.Since(start) if passed < durationFactors[i]*time.Millisecond { t.Errorf("backoff should be at least %d ms, but got %s", durationFactors[i], passed.String()) } } } func defaultContext() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) } func cancelledContext() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) cancel() return ctx, cancel } func deadlinedContext() (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) for ctx.Err() != context.DeadlineExceeded { time.Sleep(501 * time.Microsecond) } return ctx, cancel } func TestExponentialBackoffWithContext(t *testing.T) { defaultCallback := func(_ int) (bool, error) { return false, nil } conditionErr := errors.New("condition failed") tests := []struct { name string steps int zeroDuration bool context func() (context.Context, context.CancelFunc) callback func(calls int) (bool, error) cancelContextAfter int attemptsExpected int errExpected error }{ { name: "no attempts expected with zero backoff steps", steps: 0, callback: defaultCallback, attemptsExpected: 0, errExpected: ErrWaitTimeout, }, { name: "condition returns false with single backoff step", steps: 1, callback: defaultCallback, attemptsExpected: 1, errExpected: ErrWaitTimeout, }, { name: "condition returns true with single backoff step", steps: 1, callback: func(_ int) (bool, error) { return true, nil }, attemptsExpected: 1, errExpected: nil, }, { name: "condition always returns false with multiple backoff steps", steps: 5, callback: defaultCallback, attemptsExpected: 5, errExpected: ErrWaitTimeout, }, { name: "condition returns true after certain attempts with multiple backoff steps", steps: 5, callback: func(attempts int) (bool, error) { if attempts == 3 { return true, nil } return false, nil }, attemptsExpected: 3, errExpected: nil, }, { name: "condition returns error no further attempts expected", steps: 5, callback: func(_ int) (bool, error) { return true, conditionErr }, attemptsExpected: 1, errExpected: conditionErr, }, { name: "context already canceled no attempts expected", steps: 5, context: cancelledContext, callback: defaultCallback, attemptsExpected: 0, errExpected: context.Canceled, }, { name: "context at deadline no attempts expected", steps: 5, context: deadlinedContext, callback: defaultCallback, attemptsExpected: 0, errExpected: context.DeadlineExceeded, }, { name: "no attempts expected with zero backoff steps", steps: 0, callback: defaultCallback, attemptsExpected: 0, errExpected: ErrWaitTimeout, }, { name: "condition returns false with single backoff step", steps: 1, callback: defaultCallback, attemptsExpected: 1, errExpected: ErrWaitTimeout, }, { name: "condition returns true with single backoff step", steps: 1, callback: func(_ int) (bool, error) { return true, nil }, attemptsExpected: 1, errExpected: nil, }, { name: "condition always returns false with multiple backoff steps but is cancelled at step 4", steps: 5, callback: defaultCallback, attemptsExpected: 4, cancelContextAfter: 4, errExpected: context.Canceled, }, { name: "condition returns true after certain attempts with multiple backoff steps and zero duration", steps: 5, zeroDuration: true, callback: func(attempts int) (bool, error) { if attempts == 3 { return true, nil } return false, nil }, attemptsExpected: 3, errExpected: nil, }, { name: "condition returns error no further attempts expected", steps: 5, callback: func(_ int) (bool, error) { return true, conditionErr }, attemptsExpected: 1, errExpected: conditionErr, }, { name: "context already canceled no attempts expected", steps: 5, context: cancelledContext, callback: defaultCallback, attemptsExpected: 0, errExpected: context.Canceled, }, { name: "context at deadline no attempts expected", steps: 5, context: deadlinedContext, callback: defaultCallback, attemptsExpected: 0, errExpected: context.DeadlineExceeded, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { backoff := Backoff{ Duration: 1 * time.Microsecond, Factor: 1.0, Steps: test.steps, } if test.zeroDuration { backoff.Duration = 0 } contextFn := test.context if contextFn == nil { contextFn = defaultContext } ctx, cancel := contextFn() defer cancel() attempts := 0 err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) { attempts++ defer func() { if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { cancel() } }() return test.callback(attempts) }) if test.errExpected != err { t.Errorf("expected error: %v but got: %v", test.errExpected, err) } if test.attemptsExpected != attempts { t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts) } }) } } func BenchmarkExponentialBackoffWithContext(b *testing.B) { backoff := Backoff{ Duration: 0, Factor: 0, Steps: 101, } ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { attempts := 0 if err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) { attempts++ return attempts >= 100, nil }); err != nil { b.Fatalf("unexpected err: %v", err) } } } func TestPollImmediateUntilWithContext(t *testing.T) { fakeErr := errors.New("my error") tests := []struct { name string condition func(int) ConditionWithContextFunc context func() (context.Context, context.CancelFunc) cancelContextAfterNthAttempt int errExpected error attemptsExpected int }{ { name: "condition throws error on immediate attempt, no retry is attempted", condition: func(int) ConditionWithContextFunc { return func(context.Context) (done bool, err error) { return false, fakeErr } }, errExpected: fakeErr, attemptsExpected: 1, }, { name: "condition returns done=true on immediate attempt, no retry is attempted", condition: func(int) ConditionWithContextFunc { return func(context.Context) (done bool, err error) { return true, nil } }, errExpected: nil, attemptsExpected: 1, }, { name: "condition returns done=false on immediate attempt, context is already cancelled, no retry is attempted", condition: func(int) ConditionWithContextFunc { return func(context.Context) (done bool, err error) { return false, nil } }, context: cancelledContext, errExpected: ErrWaitTimeout, // this should be context.Canceled but that would break callers that assume all errors are ErrWaitTimeout attemptsExpected: 1, }, { name: "condition returns done=false on immediate attempt, context is not cancelled, retry is attempted", condition: func(attempts int) ConditionWithContextFunc { return func(context.Context) (done bool, err error) { // let first 3 attempts fail and the last one succeed if attempts <= 3 { return false, nil } return true, nil } }, errExpected: nil, attemptsExpected: 4, }, { name: "condition always returns done=false, context gets cancelled after N attempts", condition: func(attempts int) ConditionWithContextFunc { return func(ctx context.Context) (done bool, err error) { return false, nil } }, cancelContextAfterNthAttempt: 4, errExpected: ErrWaitTimeout, // this should be context.Canceled, but this method cannot change attemptsExpected: 4, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { contextFn := test.context if contextFn == nil { contextFn = defaultContext } ctx, cancel := contextFn() defer cancel() var attempts int conditionWrapper := func(ctx context.Context) (done bool, err error) { attempts++ defer func() { if test.cancelContextAfterNthAttempt == attempts { cancel() } }() c := test.condition(attempts) return c(ctx) } err := PollImmediateUntilWithContext(ctx, time.Millisecond, conditionWrapper) if test.errExpected != err { t.Errorf("Expected error: %v, but got: %v", test.errExpected, err) } if test.attemptsExpected != attempts { t.Errorf("Expected ConditionFunc to be invoked: %d times, but got: %d", test.attemptsExpected, attempts) } }) } } func Test_waitForWithContext(t *testing.T) { fakeErr := errors.New("fake error") tests := []struct { name string context func() (context.Context, context.CancelFunc) condition ConditionWithContextFunc waitFunc func() waitFunc attemptsExpected int errExpected error }{ { name: "condition returns done=true on first attempt, no retry is attempted", context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) }, attemptsExpected: 1, errExpected: nil, }, { name: "condition always returns done=false, timeout error expected", context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) }, // the contract of waitForWithContext() says the func is called once more at the end of the wait attemptsExpected: 3, errExpected: ErrWaitTimeout, }, { name: "condition returns an error on first attempt, the error is returned", context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) }, attemptsExpected: 1, errExpected: fakeErr, }, { name: "context is cancelled, context cancelled error expected", context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), waitFunc: func() waitFunc { return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) // never tick on this channel return ch } }, attemptsExpected: 0, errExpected: ErrWaitTimeout, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var attempts int conditionWrapper := func(ctx context.Context) (done bool, err error) { attempts++ return test.condition(ctx) } ticker := test.waitFunc() err := func() error { contextFn := test.context if contextFn == nil { contextFn = defaultContext } ctx, cancel := contextFn() defer cancel() return waitForWithContext(ctx, ticker.WithContext(), conditionWrapper) }() if test.errExpected != err { t.Errorf("Expected error: %v, but got: %v", test.errExpected, err) } if test.attemptsExpected != attempts { t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts) } }) } } func Test_poll(t *testing.T) { fakeErr := errors.New("fake error") tests := []struct { name string context func() (context.Context, context.CancelFunc) immediate bool waitFunc func() waitFunc condition ConditionWithContextFunc cancelContextAfter int attemptsExpected int errExpected error }{ { name: "immediate is true, condition returns an error", immediate: true, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), waitFunc: nil, attemptsExpected: 1, errExpected: fakeErr, }, { name: "immediate is true, condition returns true", immediate: true, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), waitFunc: nil, attemptsExpected: 1, errExpected: nil, }, { name: "immediate is true, context is cancelled, condition return false", immediate: true, context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), waitFunc: nil, attemptsExpected: 1, errExpected: ErrWaitTimeout, }, { name: "immediate is false, context is cancelled", immediate: false, context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), waitFunc: nil, attemptsExpected: 0, errExpected: ErrWaitTimeout, }, { name: "immediate is false, condition returns an error", immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), waitFunc: func() waitFunc { return fakeTicker(5, nil, func() {}) }, attemptsExpected: 1, errExpected: fakeErr, }, { name: "immediate is false, condition returns true", immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), waitFunc: func() waitFunc { return fakeTicker(5, nil, func() {}) }, attemptsExpected: 1, errExpected: nil, }, { name: "immediate is false, ticker channel is closed, condition returns true", immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), waitFunc: func() waitFunc { return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) close(ch) return ch } }, attemptsExpected: 1, errExpected: nil, }, { name: "immediate is false, ticker channel is closed, condition returns error", immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), waitFunc: func() waitFunc { return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) close(ch) return ch } }, attemptsExpected: 1, errExpected: fakeErr, }, { name: "immediate is false, ticker channel is closed, condition returns false", immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), waitFunc: func() waitFunc { return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) close(ch) return ch } }, attemptsExpected: 1, errExpected: ErrWaitTimeout, }, { name: "condition always returns false, timeout error expected", immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) }, // the contract of waitForWithContext() says the func is called once more at the end of the wait attemptsExpected: 3, errExpected: ErrWaitTimeout, }, { name: "context is cancelled after N attempts, timeout error expected", immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), waitFunc: func() waitFunc { return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) // just tick twice go func() { ch <- struct{}{} ch <- struct{}{} }() return ch } }, cancelContextAfter: 2, attemptsExpected: 2, errExpected: ErrWaitTimeout, }, { name: "context is cancelled after N attempts, context error not expected (legacy behavior)", immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), waitFunc: func() waitFunc { return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) // just tick twice go func() { ch <- struct{}{} ch <- struct{}{} }() return ch } }, cancelContextAfter: 2, attemptsExpected: 2, errExpected: ErrWaitTimeout, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var attempts int ticker := waitFunc(func(done <-chan struct{}) <-chan struct{} { return nil }) if test.waitFunc != nil { ticker = test.waitFunc() } err := func() error { contextFn := test.context if contextFn == nil { contextFn = defaultContext } ctx, cancel := contextFn() defer cancel() conditionWrapper := func(ctx context.Context) (done bool, err error) { attempts++ defer func() { if test.cancelContextAfter == attempts { cancel() } }() return test.condition(ctx) } return poll(ctx, test.immediate, ticker.WithContext(), conditionWrapper) }() if test.errExpected != err { t.Errorf("Expected error: %v, but got: %v", test.errExpected, err) } if test.attemptsExpected != attempts { t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts) } }) } } func Benchmark_poll(b *testing.B) { ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { attempts := 0 if err := poll(ctx, true, poller(time.Microsecond, 0), func(_ context.Context) (bool, error) { attempts++ return attempts >= 100, nil }); err != nil { b.Fatalf("unexpected err: %v", err) } } }