...

Source file src/k8s.io/apimachinery/pkg/util/wait/wait_test.go

Documentation: k8s.io/apimachinery/pkg/util/wait

     1  /*
     2  Copyright 2014 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package wait
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"math/rand"
    24  	"sync"
    25  	"sync/atomic"
    26  	"testing"
    27  	"time"
    28  
    29  	"k8s.io/apimachinery/pkg/util/runtime"
    30  	"k8s.io/utils/clock"
    31  	testingclock "k8s.io/utils/clock/testing"
    32  )
    33  
    34  func TestUntil(t *testing.T) {
    35  	ch := make(chan struct{})
    36  	close(ch)
    37  	Until(func() {
    38  		t.Fatal("should not have been invoked")
    39  	}, 0, ch)
    40  
    41  	ch = make(chan struct{})
    42  	called := make(chan struct{})
    43  	go func() {
    44  		Until(func() {
    45  			called <- struct{}{}
    46  		}, 0, ch)
    47  		close(called)
    48  	}()
    49  	<-called
    50  	close(ch)
    51  	<-called
    52  }
    53  
    54  func TestUntilWithContext(t *testing.T) {
    55  	ctx, cancel := context.WithCancel(context.TODO())
    56  	cancel()
    57  	UntilWithContext(ctx, func(context.Context) {
    58  		t.Fatal("should not have been invoked")
    59  	}, 0)
    60  
    61  	ctx, cancel = context.WithCancel(context.TODO())
    62  	called := make(chan struct{})
    63  	go func() {
    64  		UntilWithContext(ctx, func(context.Context) {
    65  			called <- struct{}{}
    66  		}, 0)
    67  		close(called)
    68  	}()
    69  	<-called
    70  	cancel()
    71  	<-called
    72  }
    73  
    74  func TestNonSlidingUntil(t *testing.T) {
    75  	ch := make(chan struct{})
    76  	close(ch)
    77  	NonSlidingUntil(func() {
    78  		t.Fatal("should not have been invoked")
    79  	}, 0, ch)
    80  
    81  	ch = make(chan struct{})
    82  	called := make(chan struct{})
    83  	go func() {
    84  		NonSlidingUntil(func() {
    85  			called <- struct{}{}
    86  		}, 0, ch)
    87  		close(called)
    88  	}()
    89  	<-called
    90  	close(ch)
    91  	<-called
    92  }
    93  
    94  func TestNonSlidingUntilWithContext(t *testing.T) {
    95  	ctx, cancel := context.WithCancel(context.TODO())
    96  	cancel()
    97  	NonSlidingUntilWithContext(ctx, func(context.Context) {
    98  		t.Fatal("should not have been invoked")
    99  	}, 0)
   100  
   101  	ctx, cancel = context.WithCancel(context.TODO())
   102  	called := make(chan struct{})
   103  	go func() {
   104  		NonSlidingUntilWithContext(ctx, func(context.Context) {
   105  			called <- struct{}{}
   106  		}, 0)
   107  		close(called)
   108  	}()
   109  	<-called
   110  	cancel()
   111  	<-called
   112  }
   113  
   114  func TestUntilReturnsImmediately(t *testing.T) {
   115  	now := time.Now()
   116  	ch := make(chan struct{})
   117  	var attempts int
   118  	Until(func() {
   119  		attempts++
   120  		if attempts > 1 {
   121  			t.Fatalf("invoked after close of channel")
   122  		}
   123  		close(ch)
   124  	}, 30*time.Second, ch)
   125  	if now.Add(25 * time.Second).Before(time.Now()) {
   126  		t.Errorf("Until did not return immediately when the stop chan was closed inside the func")
   127  	}
   128  }
   129  
   130  func TestJitterUntil(t *testing.T) {
   131  	ch := make(chan struct{})
   132  	// if a channel is closed JitterUntil never calls function f
   133  	// and returns immediately
   134  	close(ch)
   135  	JitterUntil(func() {
   136  		t.Fatal("should not have been invoked")
   137  	}, 0, 1.0, true, ch)
   138  
   139  	ch = make(chan struct{})
   140  	called := make(chan struct{})
   141  	go func() {
   142  		JitterUntil(func() {
   143  			called <- struct{}{}
   144  		}, 0, 1.0, true, ch)
   145  		close(called)
   146  	}()
   147  	<-called
   148  	close(ch)
   149  	<-called
   150  }
   151  
   152  func TestJitterUntilWithContext(t *testing.T) {
   153  	ctx, cancel := context.WithCancel(context.TODO())
   154  	cancel()
   155  	JitterUntilWithContext(ctx, func(context.Context) {
   156  		t.Fatal("should not have been invoked")
   157  	}, 0, 1.0, true)
   158  
   159  	ctx, cancel = context.WithCancel(context.TODO())
   160  	called := make(chan struct{})
   161  	go func() {
   162  		JitterUntilWithContext(ctx, func(context.Context) {
   163  			called <- struct{}{}
   164  		}, 0, 1.0, true)
   165  		close(called)
   166  	}()
   167  	<-called
   168  	cancel()
   169  	<-called
   170  }
   171  
   172  func TestJitterUntilReturnsImmediately(t *testing.T) {
   173  	now := time.Now()
   174  	ch := make(chan struct{})
   175  	JitterUntil(func() {
   176  		close(ch)
   177  	}, 30*time.Second, 1.0, true, ch)
   178  	if now.Add(25 * time.Second).Before(time.Now()) {
   179  		t.Errorf("JitterUntil did not return immediately when the stop chan was closed inside the func")
   180  	}
   181  }
   182  
   183  func TestJitterUntilRecoversPanic(t *testing.T) {
   184  	// Save and restore crash handlers
   185  	originalReallyCrash := runtime.ReallyCrash
   186  	originalHandlers := runtime.PanicHandlers
   187  	defer func() {
   188  		runtime.ReallyCrash = originalReallyCrash
   189  		runtime.PanicHandlers = originalHandlers
   190  	}()
   191  
   192  	called := 0
   193  	handled := 0
   194  
   195  	// Hook up a custom crash handler to ensure it is called when a jitter function panics
   196  	runtime.ReallyCrash = false
   197  	runtime.PanicHandlers = []func(interface{}){
   198  		func(p interface{}) {
   199  			handled++
   200  		},
   201  	}
   202  
   203  	ch := make(chan struct{})
   204  	JitterUntil(func() {
   205  		called++
   206  		if called > 2 {
   207  			close(ch)
   208  			return
   209  		}
   210  		panic("TestJitterUntilRecoversPanic")
   211  	}, time.Millisecond, 1.0, true, ch)
   212  
   213  	if called != 3 {
   214  		t.Errorf("Expected panic recovers")
   215  	}
   216  }
   217  
   218  func TestJitterUntilNegativeFactor(t *testing.T) {
   219  	now := time.Now()
   220  	ch := make(chan struct{})
   221  	called := make(chan struct{})
   222  	received := make(chan struct{})
   223  	go func() {
   224  		JitterUntil(func() {
   225  			called <- struct{}{}
   226  			<-received
   227  		}, time.Second, -30.0, true, ch)
   228  	}()
   229  	// first loop
   230  	<-called
   231  	received <- struct{}{}
   232  	// second loop
   233  	<-called
   234  	close(ch)
   235  	received <- struct{}{}
   236  
   237  	// it should take at most 2 seconds + some overhead, not 3
   238  	if now.Add(3 * time.Second).Before(time.Now()) {
   239  		t.Errorf("JitterUntil did not returned after predefined period with negative jitter factor when the stop chan was closed inside the func")
   240  	}
   241  }
   242  
   243  func TestExponentialBackoff(t *testing.T) {
   244  	// exits immediately
   245  	i := 0
   246  	err := ExponentialBackoff(Backoff{Factor: 1.0}, func() (bool, error) {
   247  		i++
   248  		return false, nil
   249  	})
   250  	if err != ErrWaitTimeout || i != 0 {
   251  		t.Errorf("unexpected error: %v", err)
   252  	}
   253  
   254  	opts := Backoff{Factor: 1.0, Steps: 3}
   255  
   256  	// waits up to steps
   257  	i = 0
   258  	err = ExponentialBackoff(opts, func() (bool, error) {
   259  		i++
   260  		return false, nil
   261  	})
   262  	if err != ErrWaitTimeout || i != opts.Steps {
   263  		t.Errorf("unexpected error: %v", err)
   264  	}
   265  
   266  	// returns immediately
   267  	i = 0
   268  	err = ExponentialBackoff(opts, func() (bool, error) {
   269  		i++
   270  		return true, nil
   271  	})
   272  	if err != nil || i != 1 {
   273  		t.Errorf("unexpected error: %v", err)
   274  	}
   275  
   276  	// returns immediately on error
   277  	testErr := fmt.Errorf("some other error")
   278  	err = ExponentialBackoff(opts, func() (bool, error) {
   279  		return false, testErr
   280  	})
   281  	if err != testErr {
   282  		t.Errorf("unexpected error: %v", err)
   283  	}
   284  
   285  	// invoked multiple times
   286  	i = 1
   287  	err = ExponentialBackoff(opts, func() (bool, error) {
   288  		if i < opts.Steps {
   289  			i++
   290  			return false, nil
   291  		}
   292  		return true, nil
   293  	})
   294  	if err != nil || i != opts.Steps {
   295  		t.Errorf("unexpected error: %v", err)
   296  	}
   297  }
   298  
   299  func TestPoller(t *testing.T) {
   300  	ctx, cancel := context.WithCancel(context.Background())
   301  	defer cancel()
   302  	w := poller(time.Millisecond, 2*time.Millisecond)
   303  	ch := w(ctx)
   304  	count := 0
   305  DRAIN:
   306  	for {
   307  		select {
   308  		case _, open := <-ch:
   309  			if !open {
   310  				break DRAIN
   311  			}
   312  			count++
   313  		case <-time.After(ForeverTestTimeout):
   314  			t.Errorf("unexpected timeout after poll")
   315  		}
   316  	}
   317  	if count > 3 {
   318  		t.Errorf("expected up to three values, got %d", count)
   319  	}
   320  }
   321  
   322  type fakePoller struct {
   323  	max  int
   324  	used int32 // accessed with atomics
   325  	wg   sync.WaitGroup
   326  }
   327  
   328  func fakeTicker(max int, used *int32, doneFunc func()) waitFunc {
   329  	return func(done <-chan struct{}) <-chan struct{} {
   330  		ch := make(chan struct{})
   331  		go func() {
   332  			defer doneFunc()
   333  			defer close(ch)
   334  			for i := 0; i < max; i++ {
   335  				select {
   336  				case ch <- struct{}{}:
   337  				case <-done:
   338  					return
   339  				}
   340  				if used != nil {
   341  					atomic.AddInt32(used, 1)
   342  				}
   343  			}
   344  		}()
   345  		return ch
   346  	}
   347  }
   348  
   349  func (fp *fakePoller) GetwaitFunc() waitFunc {
   350  	fp.wg.Add(1)
   351  	return fakeTicker(fp.max, &fp.used, fp.wg.Done)
   352  }
   353  
   354  func TestPoll(t *testing.T) {
   355  	invocations := 0
   356  	f := ConditionWithContextFunc(func(ctx context.Context) (bool, error) {
   357  		invocations++
   358  		return true, nil
   359  	})
   360  	fp := fakePoller{max: 1}
   361  
   362  	ctx, cancel := context.WithCancel(context.Background())
   363  	defer cancel()
   364  	if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f); err != nil {
   365  		t.Fatalf("unexpected error %v", err)
   366  	}
   367  	fp.wg.Wait()
   368  	if invocations != 1 {
   369  		t.Errorf("Expected exactly one invocation, got %d", invocations)
   370  	}
   371  	used := atomic.LoadInt32(&fp.used)
   372  	if used != 1 {
   373  		t.Errorf("Expected exactly one tick, got %d", used)
   374  	}
   375  }
   376  
   377  func TestPollError(t *testing.T) {
   378  	expectedError := errors.New("Expected error")
   379  	f := ConditionFunc(func() (bool, error) {
   380  		return false, expectedError
   381  	})
   382  	fp := fakePoller{max: 1}
   383  
   384  	ctx, cancel := context.WithCancel(context.Background())
   385  	defer cancel()
   386  	if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError {
   387  		t.Fatalf("Expected error %v, got none %v", expectedError, err)
   388  	}
   389  	fp.wg.Wait()
   390  	used := atomic.LoadInt32(&fp.used)
   391  	if used != 1 {
   392  		t.Errorf("Expected exactly one tick, got %d", used)
   393  	}
   394  }
   395  
   396  func TestPollImmediate(t *testing.T) {
   397  	invocations := 0
   398  	f := ConditionFunc(func() (bool, error) {
   399  		invocations++
   400  		return true, nil
   401  	})
   402  	fp := fakePoller{max: 0}
   403  
   404  	ctx, cancel := context.WithCancel(context.Background())
   405  	defer cancel()
   406  	if err := poll(ctx, true, fp.GetwaitFunc().WithContext(), f.WithContext()); err != nil {
   407  		t.Fatalf("unexpected error %v", err)
   408  	}
   409  	// We don't need to wait for fp.wg, as pollImmediate shouldn't call waitFunc at all.
   410  	if invocations != 1 {
   411  		t.Errorf("Expected exactly one invocation, got %d", invocations)
   412  	}
   413  	used := atomic.LoadInt32(&fp.used)
   414  	if used != 0 {
   415  		t.Errorf("Expected exactly zero ticks, got %d", used)
   416  	}
   417  }
   418  
   419  func TestPollImmediateError(t *testing.T) {
   420  	expectedError := errors.New("Expected error")
   421  	f := ConditionFunc(func() (bool, error) {
   422  		return false, expectedError
   423  	})
   424  	fp := fakePoller{max: 0}
   425  
   426  	ctx, cancel := context.WithCancel(context.Background())
   427  	defer cancel()
   428  	if err := poll(ctx, true, fp.GetwaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError {
   429  		t.Fatalf("Expected error %v, got none %v", expectedError, err)
   430  	}
   431  	// We don't need to wait for fp.wg, as pollImmediate shouldn't call waitFunc at all.
   432  	used := atomic.LoadInt32(&fp.used)
   433  	if used != 0 {
   434  		t.Errorf("Expected exactly zero ticks, got %d", used)
   435  	}
   436  }
   437  
   438  func TestPollForever(t *testing.T) {
   439  	ch := make(chan struct{})
   440  	errc := make(chan error, 1)
   441  	done := make(chan struct{}, 1)
   442  	complete := make(chan struct{})
   443  	go func() {
   444  		f := ConditionFunc(func() (bool, error) {
   445  			ch <- struct{}{}
   446  			select {
   447  			case <-done:
   448  				return true, nil
   449  			default:
   450  			}
   451  			return false, nil
   452  		})
   453  
   454  		if err := PollInfinite(time.Microsecond, f); err != nil {
   455  			errc <- fmt.Errorf("unexpected error %v", err)
   456  		}
   457  
   458  		close(ch)
   459  		complete <- struct{}{}
   460  	}()
   461  
   462  	// ensure the condition is opened
   463  	<-ch
   464  
   465  	// ensure channel sends events
   466  	for i := 0; i < 10; i++ {
   467  		select {
   468  		case _, open := <-ch:
   469  			if !open {
   470  				if len(errc) != 0 {
   471  					t.Fatalf("did not expect channel to be closed, %v", <-errc)
   472  				}
   473  				t.Fatal("did not expect channel to be closed")
   474  			}
   475  		case <-time.After(ForeverTestTimeout):
   476  			t.Fatalf("channel did not return at least once within the poll interval")
   477  		}
   478  	}
   479  
   480  	// at most one poll notification should be sent once we return from the condition
   481  	done <- struct{}{}
   482  	go func() {
   483  		for i := 0; i < 2; i++ {
   484  			_, open := <-ch
   485  			if !open {
   486  				return
   487  			}
   488  		}
   489  		t.Error("expected closed channel after two iterations")
   490  	}()
   491  	<-complete
   492  
   493  	if len(errc) != 0 {
   494  		t.Fatal(<-errc)
   495  	}
   496  }
   497  
   498  func Test_waitFor(t *testing.T) {
   499  	var invocations int
   500  	testCases := map[string]struct {
   501  		F       ConditionFunc
   502  		Ticks   int
   503  		Invoked int
   504  		Err     bool
   505  	}{
   506  		"invoked once": {
   507  			ConditionFunc(func() (bool, error) {
   508  				invocations++
   509  				return true, nil
   510  			}),
   511  			2,
   512  			1,
   513  			false,
   514  		},
   515  		"invoked and returns a timeout": {
   516  			ConditionFunc(func() (bool, error) {
   517  				invocations++
   518  				return false, nil
   519  			}),
   520  			2,
   521  			3, // the contract of waitFor() says the func is called once more at the end of the wait
   522  			true,
   523  		},
   524  		"returns immediately on error": {
   525  			ConditionFunc(func() (bool, error) {
   526  				invocations++
   527  				return false, errors.New("test")
   528  			}),
   529  			2,
   530  			1,
   531  			true,
   532  		},
   533  	}
   534  	for k, c := range testCases {
   535  		invocations = 0
   536  		ticker := fakeTicker(c.Ticks, nil, func() {})
   537  		err := func() error {
   538  			done := make(chan struct{})
   539  			defer close(done)
   540  			ctx := ContextForChannel(done)
   541  			return waitForWithContext(ctx, ticker.WithContext(), c.F.WithContext())
   542  		}()
   543  		switch {
   544  		case c.Err && err == nil:
   545  			t.Errorf("%s: Expected error, got nil", k)
   546  			continue
   547  		case !c.Err && err != nil:
   548  			t.Errorf("%s: Expected no error, got: %#v", k, err)
   549  			continue
   550  		}
   551  		if invocations != c.Invoked {
   552  			t.Errorf("%s: Expected %d invocations, got %d", k, c.Invoked, invocations)
   553  		}
   554  	}
   555  }
   556  
   557  // Test_waitForWithEarlyClosing_waitFunc tests WaitFor when the waitFunc closes its channel. The WaitFor should
   558  // always return ErrWaitTimeout.
   559  func Test_waitForWithEarlyClosing_waitFunc(t *testing.T) {
   560  	stopCh := make(chan struct{})
   561  	defer close(stopCh)
   562  
   563  	ctx := ContextForChannel(stopCh)
   564  	start := time.Now()
   565  	err := waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} {
   566  		c := make(chan struct{})
   567  		close(c)
   568  		return c
   569  	}, func(_ context.Context) (bool, error) {
   570  		return false, nil
   571  	})
   572  	duration := time.Since(start)
   573  
   574  	// The waitFor should return immediately, so the duration is close to 0s.
   575  	if duration >= ForeverTestTimeout/2 {
   576  		t.Errorf("expected short timeout duration")
   577  	}
   578  	if err != ErrWaitTimeout {
   579  		t.Errorf("expected ErrWaitTimeout from WaitFunc")
   580  	}
   581  }
   582  
   583  // Test_waitForWithClosedChannel tests waitFor when it receives a closed channel. The waitFor should
   584  // always return ErrWaitTimeout.
   585  func Test_waitForWithClosedChannel(t *testing.T) {
   586  	stopCh := make(chan struct{})
   587  	close(stopCh)
   588  	c := make(chan struct{})
   589  	defer close(c)
   590  	ctx := ContextForChannel(stopCh)
   591  
   592  	start := time.Now()
   593  	err := waitForWithContext(ctx, func(_ context.Context) <-chan struct{} {
   594  		return c
   595  	}, func(_ context.Context) (bool, error) {
   596  		return false, nil
   597  	})
   598  	duration := time.Since(start)
   599  	// The waitFor should return immediately, so the duration is close to 0s.
   600  	if duration >= ForeverTestTimeout/2 {
   601  		t.Errorf("expected short timeout duration")
   602  	}
   603  	// The interval of the poller is ForeverTestTimeout, so the waitFor should always return ErrWaitTimeout.
   604  	if err != ErrWaitTimeout {
   605  		t.Errorf("expected ErrWaitTimeout from WaitFunc")
   606  	}
   607  }
   608  
   609  // Test_waitForWithContextCancelsContext verifies that after the condition func returns true,
   610  // waitForWithContext cancels the context it supplies to the WaitWithContextFunc.
   611  func Test_waitForWithContextCancelsContext(t *testing.T) {
   612  	ctx, cancel := context.WithCancel(context.Background())
   613  	defer cancel()
   614  	waitFn := poller(time.Millisecond, ForeverTestTimeout)
   615  
   616  	var ctxPassedToWait context.Context
   617  	waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} {
   618  		ctxPassedToWait = ctx
   619  		return waitFn(ctx)
   620  	}, func(ctx context.Context) (bool, error) {
   621  		time.Sleep(10 * time.Millisecond)
   622  		return true, nil
   623  	})
   624  	// The polling goroutine should be closed after waitForWithContext returning.
   625  	if ctxPassedToWait.Err() != context.Canceled {
   626  		t.Errorf("expected the context passed to waitForWithContext to be closed with: %v, but got: %v", context.Canceled, ctxPassedToWait.Err())
   627  	}
   628  }
   629  
   630  func TestPollUntil(t *testing.T) {
   631  	stopCh := make(chan struct{})
   632  	called := make(chan bool)
   633  	pollDone := make(chan struct{})
   634  
   635  	go func() {
   636  		PollUntil(time.Microsecond, ConditionFunc(func() (bool, error) {
   637  			called <- true
   638  			return false, nil
   639  		}), stopCh)
   640  
   641  		close(pollDone)
   642  	}()
   643  
   644  	// make sure we're called once
   645  	<-called
   646  	// this should trigger a "done"
   647  	close(stopCh)
   648  
   649  	go func() {
   650  		// release the condition func if needed
   651  		for range called {
   652  		}
   653  	}()
   654  
   655  	// make sure we finished the poll
   656  	<-pollDone
   657  	close(called)
   658  }
   659  
   660  func TestBackoff_Step(t *testing.T) {
   661  	tests := []struct {
   662  		initial *Backoff
   663  		want    []time.Duration
   664  	}{
   665  		{initial: nil, want: []time.Duration{0, 0, 0, 0}},
   666  		{initial: &Backoff{Duration: time.Second, Steps: -1}, want: []time.Duration{time.Second, time.Second, time.Second}},
   667  		{initial: &Backoff{Duration: time.Second, Steps: 0}, want: []time.Duration{time.Second, time.Second, time.Second}},
   668  		{initial: &Backoff{Duration: time.Second, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}},
   669  		{initial: &Backoff{Duration: time.Second, Factor: 1.0, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}},
   670  		{initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 3}, want: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second}},
   671  		{initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 3, Cap: 3 * time.Second}, want: []time.Duration{1 * time.Second, 2 * time.Second, 3 * time.Second}},
   672  		{initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 2, Cap: 3 * time.Second, Jitter: 0.5}, want: []time.Duration{2 * time.Second, 3 * time.Second, 3 * time.Second}},
   673  		{initial: &Backoff{Duration: time.Second, Factor: 2, Steps: 6, Jitter: 4}, want: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second, 8 * time.Second, 16 * time.Second, 32 * time.Second}},
   674  	}
   675  	for seed := int64(0); seed < 5; seed++ {
   676  		for _, tt := range tests {
   677  			var initial *Backoff
   678  			if tt.initial != nil {
   679  				copied := *tt.initial
   680  				initial = &copied
   681  			} else {
   682  				initial = nil
   683  			}
   684  			t.Run(fmt.Sprintf("%#v seed=%d", initial, seed), func(t *testing.T) {
   685  				rand.Seed(seed)
   686  				for i := 0; i < len(tt.want); i++ {
   687  					got := initial.Step()
   688  					t.Logf("[%d]=%s", i, got)
   689  					if initial != nil && initial.Jitter > 0 {
   690  						if got == tt.want[i] {
   691  							// this is statistically unlikely to happen by chance
   692  							t.Errorf("Backoff.Step(%d) = %v, no jitter", i, got)
   693  							continue
   694  						}
   695  						diff := float64(tt.want[i]-got) / float64(tt.want[i])
   696  						if diff > initial.Jitter {
   697  							t.Errorf("Backoff.Step(%d) = %v, want %v, outside range", i, got, tt.want)
   698  							continue
   699  						}
   700  					} else {
   701  						if got != tt.want[i] {
   702  							t.Errorf("Backoff.Step(%d) = %v, want %v", i, got, tt.want)
   703  							continue
   704  						}
   705  					}
   706  				}
   707  			})
   708  		}
   709  	}
   710  }
   711  
   712  func TestContextForChannel(t *testing.T) {
   713  	var wg sync.WaitGroup
   714  	parentCh := make(chan struct{})
   715  	done := make(chan struct{})
   716  
   717  	for i := 0; i < 3; i++ {
   718  		wg.Add(1)
   719  		go func() {
   720  			defer wg.Done()
   721  			ctx := ContextForChannel(parentCh)
   722  			<-ctx.Done()
   723  		}()
   724  	}
   725  
   726  	go func() {
   727  		wg.Wait()
   728  		close(done)
   729  	}()
   730  
   731  	// Closing parent channel should cancel all children contexts
   732  	close(parentCh)
   733  
   734  	select {
   735  	case <-done:
   736  	case <-time.After(ForeverTestTimeout):
   737  		t.Errorf("unexpected timeout waiting for parent to cancel child contexts")
   738  	}
   739  }
   740  
   741  func TestExponentialBackoffManagerGetNextBackoff(t *testing.T) {
   742  	fc := testingclock.NewFakeClock(time.Now())
   743  	backoff := NewExponentialBackoffManager(1, 10, 10, 2.0, 0.0, fc)
   744  	durations := []time.Duration{1, 2, 4, 8, 10, 10, 10}
   745  	for i := 0; i < len(durations); i++ {
   746  		generatedBackoff := backoff.(*exponentialBackoffManagerImpl).getNextBackoff()
   747  		if generatedBackoff != durations[i] {
   748  			t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i])
   749  		}
   750  	}
   751  
   752  	fc.Step(11)
   753  	resetDuration := backoff.(*exponentialBackoffManagerImpl).getNextBackoff()
   754  	if resetDuration != 1 {
   755  		t.Errorf("after reset, backoff should be 1, but got %d", resetDuration)
   756  	}
   757  }
   758  
   759  func TestJitteredBackoffManagerGetNextBackoff(t *testing.T) {
   760  	// positive jitter
   761  	backoffMgr := NewJitteredBackoffManager(1, 1, testingclock.NewFakeClock(time.Now()))
   762  	for i := 0; i < 5; i++ {
   763  		backoff := backoffMgr.(*jitteredBackoffManagerImpl).getNextBackoff()
   764  		if backoff < 1 || backoff > 2 {
   765  			t.Errorf("backoff out of range: %d", backoff)
   766  		}
   767  	}
   768  
   769  	// negative jitter, shall be a fixed backoff
   770  	backoffMgr = NewJitteredBackoffManager(1, -1, testingclock.NewFakeClock(time.Now()))
   771  	backoff := backoffMgr.(*jitteredBackoffManagerImpl).getNextBackoff()
   772  	if backoff != 1 {
   773  		t.Errorf("backoff should be 1, but got %d", backoff)
   774  	}
   775  }
   776  
   777  func TestJitterBackoffManagerWithRealClock(t *testing.T) {
   778  	backoffMgr := NewJitteredBackoffManager(1*time.Millisecond, 0, &clock.RealClock{})
   779  	for i := 0; i < 5; i++ {
   780  		start := time.Now()
   781  		<-backoffMgr.Backoff().C()
   782  		passed := time.Since(start)
   783  		if passed < 1*time.Millisecond {
   784  			t.Errorf("backoff should be at least 1ms, but got %s", passed.String())
   785  		}
   786  	}
   787  }
   788  
   789  func TestExponentialBackoffManagerWithRealClock(t *testing.T) {
   790  	// backoff at least 1ms, 2ms, 4ms, 8ms, 10ms, 10ms, 10ms
   791  	durationFactors := []time.Duration{1, 2, 4, 8, 10, 10, 10}
   792  	backoffMgr := NewExponentialBackoffManager(1*time.Millisecond, 10*time.Millisecond, 1*time.Hour, 2.0, 0.0, &clock.RealClock{})
   793  
   794  	for i := range durationFactors {
   795  		start := time.Now()
   796  		<-backoffMgr.Backoff().C()
   797  		passed := time.Since(start)
   798  		if passed < durationFactors[i]*time.Millisecond {
   799  			t.Errorf("backoff should be at least %d ms, but got %s", durationFactors[i], passed.String())
   800  		}
   801  	}
   802  }
   803  
   804  func TestBackoffDelayWithResetExponential(t *testing.T) {
   805  	fc := testingclock.NewFakeClock(time.Now())
   806  	backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 10)
   807  	durations := []time.Duration{1, 2, 4, 8, 10, 10, 10}
   808  	for i := 0; i < len(durations); i++ {
   809  		generatedBackoff := backoff()
   810  		if generatedBackoff != durations[i] {
   811  			t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i])
   812  		}
   813  	}
   814  
   815  	fc.Step(11)
   816  	resetDuration := backoff()
   817  	if resetDuration != 1 {
   818  		t.Errorf("after reset, backoff should be 1, but got %d", resetDuration)
   819  	}
   820  }
   821  
   822  func TestBackoffDelayWithResetEmpty(t *testing.T) {
   823  	fc := testingclock.NewFakeClock(time.Now())
   824  	backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 0)
   825  	// we reset to initial duration because the resetInterval is 0, immediate
   826  	durations := []time.Duration{1, 1, 1, 1, 1, 1, 1}
   827  	for i := 0; i < len(durations); i++ {
   828  		generatedBackoff := backoff()
   829  		if generatedBackoff != durations[i] {
   830  			t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i])
   831  		}
   832  	}
   833  
   834  	fc.Step(11)
   835  	resetDuration := backoff()
   836  	if resetDuration != 1 {
   837  		t.Errorf("after reset, backoff should be 1, but got %d", resetDuration)
   838  	}
   839  }
   840  
   841  func TestBackoffDelayWithResetJitter(t *testing.T) {
   842  	// positive jitter
   843  	backoff := Backoff{Duration: 1, Jitter: 1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0)
   844  	for i := 0; i < 5; i++ {
   845  		value := backoff()
   846  		if value < 1 || value > 2 {
   847  			t.Errorf("backoff out of range: %d", value)
   848  		}
   849  	}
   850  
   851  	// negative jitter, shall be a fixed backoff
   852  	backoff = Backoff{Duration: 1, Jitter: -1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0)
   853  	value := backoff()
   854  	if value != 1 {
   855  		t.Errorf("backoff should be 1, but got %d", value)
   856  	}
   857  }
   858  
   859  func TestBackoffDelayWithResetWithRealClockJitter(t *testing.T) {
   860  	backoff := Backoff{Duration: 1 * time.Millisecond, Jitter: 0}.DelayWithReset(&clock.RealClock{}, 0)
   861  	for i := 0; i < 5; i++ {
   862  		start := time.Now()
   863  		<-RealTimer(backoff()).C()
   864  		passed := time.Since(start)
   865  		if passed < 1*time.Millisecond {
   866  			t.Errorf("backoff should be at least 1ms, but got %s", passed.String())
   867  		}
   868  	}
   869  }
   870  
   871  func TestBackoffDelayWithResetWithRealClockExponential(t *testing.T) {
   872  	// backoff at least 1ms, 2ms, 4ms, 8ms, 10ms, 10ms, 10ms
   873  	durationFactors := []time.Duration{1, 2, 4, 8, 10, 10, 10}
   874  	backoff := Backoff{Duration: 1 * time.Millisecond, Cap: 10 * time.Millisecond, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(&clock.RealClock{}, 1*time.Hour)
   875  
   876  	for i := range durationFactors {
   877  		start := time.Now()
   878  		<-RealTimer(backoff()).C()
   879  		passed := time.Since(start)
   880  		if passed < durationFactors[i]*time.Millisecond {
   881  			t.Errorf("backoff should be at least %d ms, but got %s", durationFactors[i], passed.String())
   882  		}
   883  	}
   884  }
   885  
   886  func defaultContext() (context.Context, context.CancelFunc) {
   887  	return context.WithCancel(context.Background())
   888  }
   889  func cancelledContext() (context.Context, context.CancelFunc) {
   890  	ctx, cancel := context.WithCancel(context.Background())
   891  	cancel()
   892  	return ctx, cancel
   893  }
   894  func deadlinedContext() (context.Context, context.CancelFunc) {
   895  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
   896  	for ctx.Err() != context.DeadlineExceeded {
   897  		time.Sleep(501 * time.Microsecond)
   898  	}
   899  	return ctx, cancel
   900  }
   901  
   902  func TestExponentialBackoffWithContext(t *testing.T) {
   903  	defaultCallback := func(_ int) (bool, error) {
   904  		return false, nil
   905  	}
   906  
   907  	conditionErr := errors.New("condition failed")
   908  
   909  	tests := []struct {
   910  		name               string
   911  		steps              int
   912  		zeroDuration       bool
   913  		context            func() (context.Context, context.CancelFunc)
   914  		callback           func(calls int) (bool, error)
   915  		cancelContextAfter int
   916  		attemptsExpected   int
   917  		errExpected        error
   918  	}{
   919  		{
   920  			name:             "no attempts expected with zero backoff steps",
   921  			steps:            0,
   922  			callback:         defaultCallback,
   923  			attemptsExpected: 0,
   924  			errExpected:      ErrWaitTimeout,
   925  		},
   926  		{
   927  			name:             "condition returns false with single backoff step",
   928  			steps:            1,
   929  			callback:         defaultCallback,
   930  			attemptsExpected: 1,
   931  			errExpected:      ErrWaitTimeout,
   932  		},
   933  		{
   934  			name:  "condition returns true with single backoff step",
   935  			steps: 1,
   936  			callback: func(_ int) (bool, error) {
   937  				return true, nil
   938  			},
   939  			attemptsExpected: 1,
   940  			errExpected:      nil,
   941  		},
   942  		{
   943  			name:             "condition always returns false with multiple backoff steps",
   944  			steps:            5,
   945  			callback:         defaultCallback,
   946  			attemptsExpected: 5,
   947  			errExpected:      ErrWaitTimeout,
   948  		},
   949  		{
   950  			name:  "condition returns true after certain attempts with multiple backoff steps",
   951  			steps: 5,
   952  			callback: func(attempts int) (bool, error) {
   953  				if attempts == 3 {
   954  					return true, nil
   955  				}
   956  				return false, nil
   957  			},
   958  			attemptsExpected: 3,
   959  			errExpected:      nil,
   960  		},
   961  		{
   962  			name:  "condition returns error no further attempts expected",
   963  			steps: 5,
   964  			callback: func(_ int) (bool, error) {
   965  				return true, conditionErr
   966  			},
   967  			attemptsExpected: 1,
   968  			errExpected:      conditionErr,
   969  		},
   970  		{
   971  			name:             "context already canceled no attempts expected",
   972  			steps:            5,
   973  			context:          cancelledContext,
   974  			callback:         defaultCallback,
   975  			attemptsExpected: 0,
   976  			errExpected:      context.Canceled,
   977  		},
   978  		{
   979  			name:             "context at deadline no attempts expected",
   980  			steps:            5,
   981  			context:          deadlinedContext,
   982  			callback:         defaultCallback,
   983  			attemptsExpected: 0,
   984  			errExpected:      context.DeadlineExceeded,
   985  		},
   986  		{
   987  			name:             "no attempts expected with zero backoff steps",
   988  			steps:            0,
   989  			callback:         defaultCallback,
   990  			attemptsExpected: 0,
   991  			errExpected:      ErrWaitTimeout,
   992  		},
   993  		{
   994  			name:             "condition returns false with single backoff step",
   995  			steps:            1,
   996  			callback:         defaultCallback,
   997  			attemptsExpected: 1,
   998  			errExpected:      ErrWaitTimeout,
   999  		},
  1000  		{
  1001  			name:  "condition returns true with single backoff step",
  1002  			steps: 1,
  1003  			callback: func(_ int) (bool, error) {
  1004  				return true, nil
  1005  			},
  1006  			attemptsExpected: 1,
  1007  			errExpected:      nil,
  1008  		},
  1009  		{
  1010  			name:               "condition always returns false with multiple backoff steps but is cancelled at step 4",
  1011  			steps:              5,
  1012  			callback:           defaultCallback,
  1013  			attemptsExpected:   4,
  1014  			cancelContextAfter: 4,
  1015  			errExpected:        context.Canceled,
  1016  		},
  1017  		{
  1018  			name:         "condition returns true after certain attempts with multiple backoff steps and zero duration",
  1019  			steps:        5,
  1020  			zeroDuration: true,
  1021  			callback: func(attempts int) (bool, error) {
  1022  				if attempts == 3 {
  1023  					return true, nil
  1024  				}
  1025  				return false, nil
  1026  			},
  1027  			attemptsExpected: 3,
  1028  			errExpected:      nil,
  1029  		},
  1030  		{
  1031  			name:  "condition returns error no further attempts expected",
  1032  			steps: 5,
  1033  			callback: func(_ int) (bool, error) {
  1034  				return true, conditionErr
  1035  			},
  1036  			attemptsExpected: 1,
  1037  			errExpected:      conditionErr,
  1038  		},
  1039  		{
  1040  			name:             "context already canceled no attempts expected",
  1041  			steps:            5,
  1042  			context:          cancelledContext,
  1043  			callback:         defaultCallback,
  1044  			attemptsExpected: 0,
  1045  			errExpected:      context.Canceled,
  1046  		},
  1047  		{
  1048  			name:             "context at deadline no attempts expected",
  1049  			steps:            5,
  1050  			context:          deadlinedContext,
  1051  			callback:         defaultCallback,
  1052  			attemptsExpected: 0,
  1053  			errExpected:      context.DeadlineExceeded,
  1054  		},
  1055  	}
  1056  
  1057  	for _, test := range tests {
  1058  		t.Run(test.name, func(t *testing.T) {
  1059  			backoff := Backoff{
  1060  				Duration: 1 * time.Microsecond,
  1061  				Factor:   1.0,
  1062  				Steps:    test.steps,
  1063  			}
  1064  			if test.zeroDuration {
  1065  				backoff.Duration = 0
  1066  			}
  1067  
  1068  			contextFn := test.context
  1069  			if contextFn == nil {
  1070  				contextFn = defaultContext
  1071  			}
  1072  			ctx, cancel := contextFn()
  1073  			defer cancel()
  1074  
  1075  			attempts := 0
  1076  			err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) {
  1077  				attempts++
  1078  				defer func() {
  1079  					if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts {
  1080  						cancel()
  1081  					}
  1082  				}()
  1083  				return test.callback(attempts)
  1084  			})
  1085  
  1086  			if test.errExpected != err {
  1087  				t.Errorf("expected error: %v but got: %v", test.errExpected, err)
  1088  			}
  1089  
  1090  			if test.attemptsExpected != attempts {
  1091  				t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts)
  1092  			}
  1093  		})
  1094  	}
  1095  }
  1096  
  1097  func BenchmarkExponentialBackoffWithContext(b *testing.B) {
  1098  	backoff := Backoff{
  1099  		Duration: 0,
  1100  		Factor:   0,
  1101  		Steps:    101,
  1102  	}
  1103  	ctx := context.Background()
  1104  
  1105  	b.ResetTimer()
  1106  	for i := 0; i < b.N; i++ {
  1107  		attempts := 0
  1108  		if err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) {
  1109  			attempts++
  1110  			return attempts >= 100, nil
  1111  		}); err != nil {
  1112  			b.Fatalf("unexpected err: %v", err)
  1113  		}
  1114  	}
  1115  }
  1116  
  1117  func TestPollImmediateUntilWithContext(t *testing.T) {
  1118  	fakeErr := errors.New("my error")
  1119  	tests := []struct {
  1120  		name                         string
  1121  		condition                    func(int) ConditionWithContextFunc
  1122  		context                      func() (context.Context, context.CancelFunc)
  1123  		cancelContextAfterNthAttempt int
  1124  		errExpected                  error
  1125  		attemptsExpected             int
  1126  	}{
  1127  		{
  1128  			name: "condition throws error on immediate attempt, no retry is attempted",
  1129  			condition: func(int) ConditionWithContextFunc {
  1130  				return func(context.Context) (done bool, err error) {
  1131  					return false, fakeErr
  1132  				}
  1133  			},
  1134  			errExpected:      fakeErr,
  1135  			attemptsExpected: 1,
  1136  		},
  1137  		{
  1138  			name: "condition returns done=true on immediate attempt, no retry is attempted",
  1139  			condition: func(int) ConditionWithContextFunc {
  1140  				return func(context.Context) (done bool, err error) {
  1141  					return true, nil
  1142  				}
  1143  			},
  1144  			errExpected:      nil,
  1145  			attemptsExpected: 1,
  1146  		},
  1147  		{
  1148  			name: "condition returns done=false on immediate attempt, context is already cancelled, no retry is attempted",
  1149  			condition: func(int) ConditionWithContextFunc {
  1150  				return func(context.Context) (done bool, err error) {
  1151  					return false, nil
  1152  				}
  1153  			},
  1154  			context:          cancelledContext,
  1155  			errExpected:      ErrWaitTimeout, // this should be context.Canceled but that would break callers that assume all errors are ErrWaitTimeout
  1156  			attemptsExpected: 1,
  1157  		},
  1158  		{
  1159  			name: "condition returns done=false on immediate attempt, context is not cancelled, retry is attempted",
  1160  			condition: func(attempts int) ConditionWithContextFunc {
  1161  				return func(context.Context) (done bool, err error) {
  1162  					// let first 3 attempts fail and the last one succeed
  1163  					if attempts <= 3 {
  1164  						return false, nil
  1165  					}
  1166  					return true, nil
  1167  				}
  1168  			},
  1169  			errExpected:      nil,
  1170  			attemptsExpected: 4,
  1171  		},
  1172  		{
  1173  			name: "condition always returns done=false, context gets cancelled after N attempts",
  1174  			condition: func(attempts int) ConditionWithContextFunc {
  1175  				return func(ctx context.Context) (done bool, err error) {
  1176  					return false, nil
  1177  				}
  1178  			},
  1179  			cancelContextAfterNthAttempt: 4,
  1180  			errExpected:                  ErrWaitTimeout, // this should be context.Canceled, but this method cannot change
  1181  			attemptsExpected:             4,
  1182  		},
  1183  	}
  1184  
  1185  	for _, test := range tests {
  1186  		t.Run(test.name, func(t *testing.T) {
  1187  			contextFn := test.context
  1188  			if contextFn == nil {
  1189  				contextFn = defaultContext
  1190  			}
  1191  			ctx, cancel := contextFn()
  1192  			defer cancel()
  1193  
  1194  			var attempts int
  1195  			conditionWrapper := func(ctx context.Context) (done bool, err error) {
  1196  				attempts++
  1197  				defer func() {
  1198  					if test.cancelContextAfterNthAttempt == attempts {
  1199  						cancel()
  1200  					}
  1201  				}()
  1202  
  1203  				c := test.condition(attempts)
  1204  				return c(ctx)
  1205  			}
  1206  
  1207  			err := PollImmediateUntilWithContext(ctx, time.Millisecond, conditionWrapper)
  1208  			if test.errExpected != err {
  1209  				t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
  1210  			}
  1211  			if test.attemptsExpected != attempts {
  1212  				t.Errorf("Expected ConditionFunc to be invoked: %d times, but got: %d", test.attemptsExpected, attempts)
  1213  			}
  1214  		})
  1215  	}
  1216  }
  1217  
  1218  func Test_waitForWithContext(t *testing.T) {
  1219  	fakeErr := errors.New("fake error")
  1220  	tests := []struct {
  1221  		name             string
  1222  		context          func() (context.Context, context.CancelFunc)
  1223  		condition        ConditionWithContextFunc
  1224  		waitFunc         func() waitFunc
  1225  		attemptsExpected int
  1226  		errExpected      error
  1227  	}{
  1228  		{
  1229  			name:    "condition returns done=true on first attempt, no retry is attempted",
  1230  			context: defaultContext,
  1231  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1232  				return true, nil
  1233  			}),
  1234  			waitFunc:         func() waitFunc { return fakeTicker(2, nil, func() {}) },
  1235  			attemptsExpected: 1,
  1236  			errExpected:      nil,
  1237  		},
  1238  		{
  1239  			name:    "condition always returns done=false, timeout error expected",
  1240  			context: defaultContext,
  1241  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1242  				return false, nil
  1243  			}),
  1244  			waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) },
  1245  			// the contract of waitForWithContext() says the func is called once more at the end of the wait
  1246  			attemptsExpected: 3,
  1247  			errExpected:      ErrWaitTimeout,
  1248  		},
  1249  		{
  1250  			name:    "condition returns an error on first attempt, the error is returned",
  1251  			context: defaultContext,
  1252  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1253  				return false, fakeErr
  1254  			}),
  1255  			waitFunc:         func() waitFunc { return fakeTicker(2, nil, func() {}) },
  1256  			attemptsExpected: 1,
  1257  			errExpected:      fakeErr,
  1258  		},
  1259  		{
  1260  			name:    "context is cancelled, context cancelled error expected",
  1261  			context: cancelledContext,
  1262  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1263  				return false, nil
  1264  			}),
  1265  			waitFunc: func() waitFunc {
  1266  				return func(done <-chan struct{}) <-chan struct{} {
  1267  					ch := make(chan struct{})
  1268  					// never tick on this channel
  1269  					return ch
  1270  				}
  1271  			},
  1272  			attemptsExpected: 0,
  1273  			errExpected:      ErrWaitTimeout,
  1274  		},
  1275  	}
  1276  
  1277  	for _, test := range tests {
  1278  		t.Run(test.name, func(t *testing.T) {
  1279  			var attempts int
  1280  			conditionWrapper := func(ctx context.Context) (done bool, err error) {
  1281  				attempts++
  1282  				return test.condition(ctx)
  1283  			}
  1284  
  1285  			ticker := test.waitFunc()
  1286  			err := func() error {
  1287  				contextFn := test.context
  1288  				if contextFn == nil {
  1289  					contextFn = defaultContext
  1290  				}
  1291  				ctx, cancel := contextFn()
  1292  				defer cancel()
  1293  
  1294  				return waitForWithContext(ctx, ticker.WithContext(), conditionWrapper)
  1295  			}()
  1296  
  1297  			if test.errExpected != err {
  1298  				t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
  1299  			}
  1300  			if test.attemptsExpected != attempts {
  1301  				t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts)
  1302  			}
  1303  		})
  1304  	}
  1305  }
  1306  
  1307  func Test_poll(t *testing.T) {
  1308  	fakeErr := errors.New("fake error")
  1309  	tests := []struct {
  1310  		name               string
  1311  		context            func() (context.Context, context.CancelFunc)
  1312  		immediate          bool
  1313  		waitFunc           func() waitFunc
  1314  		condition          ConditionWithContextFunc
  1315  		cancelContextAfter int
  1316  		attemptsExpected   int
  1317  		errExpected        error
  1318  	}{
  1319  		{
  1320  			name:      "immediate is true, condition returns an error",
  1321  			immediate: true,
  1322  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1323  				return false, fakeErr
  1324  			}),
  1325  			waitFunc:         nil,
  1326  			attemptsExpected: 1,
  1327  			errExpected:      fakeErr,
  1328  		},
  1329  		{
  1330  			name:      "immediate is true, condition returns true",
  1331  			immediate: true,
  1332  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1333  				return true, nil
  1334  			}),
  1335  			waitFunc:         nil,
  1336  			attemptsExpected: 1,
  1337  			errExpected:      nil,
  1338  		},
  1339  		{
  1340  			name:      "immediate is true, context is cancelled, condition return false",
  1341  			immediate: true,
  1342  			context:   cancelledContext,
  1343  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1344  				return false, nil
  1345  			}),
  1346  			waitFunc:         nil,
  1347  			attemptsExpected: 1,
  1348  			errExpected:      ErrWaitTimeout,
  1349  		},
  1350  		{
  1351  			name:      "immediate is false, context is cancelled",
  1352  			immediate: false,
  1353  			context:   cancelledContext,
  1354  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1355  				return false, nil
  1356  			}),
  1357  			waitFunc:         nil,
  1358  			attemptsExpected: 0,
  1359  			errExpected:      ErrWaitTimeout,
  1360  		},
  1361  		{
  1362  			name:      "immediate is false, condition returns an error",
  1363  			immediate: false,
  1364  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1365  				return false, fakeErr
  1366  			}),
  1367  			waitFunc:         func() waitFunc { return fakeTicker(5, nil, func() {}) },
  1368  			attemptsExpected: 1,
  1369  			errExpected:      fakeErr,
  1370  		},
  1371  		{
  1372  			name:      "immediate is false, condition returns true",
  1373  			immediate: false,
  1374  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1375  				return true, nil
  1376  			}),
  1377  			waitFunc:         func() waitFunc { return fakeTicker(5, nil, func() {}) },
  1378  			attemptsExpected: 1,
  1379  			errExpected:      nil,
  1380  		},
  1381  		{
  1382  			name:      "immediate is false, ticker channel is closed, condition returns true",
  1383  			immediate: false,
  1384  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1385  				return true, nil
  1386  			}),
  1387  			waitFunc: func() waitFunc {
  1388  				return func(done <-chan struct{}) <-chan struct{} {
  1389  					ch := make(chan struct{})
  1390  					close(ch)
  1391  					return ch
  1392  				}
  1393  			},
  1394  			attemptsExpected: 1,
  1395  			errExpected:      nil,
  1396  		},
  1397  		{
  1398  			name:      "immediate is false, ticker channel is closed, condition returns error",
  1399  			immediate: false,
  1400  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1401  				return false, fakeErr
  1402  			}),
  1403  			waitFunc: func() waitFunc {
  1404  				return func(done <-chan struct{}) <-chan struct{} {
  1405  					ch := make(chan struct{})
  1406  					close(ch)
  1407  					return ch
  1408  				}
  1409  			},
  1410  			attemptsExpected: 1,
  1411  			errExpected:      fakeErr,
  1412  		},
  1413  		{
  1414  			name:      "immediate is false, ticker channel is closed, condition returns false",
  1415  			immediate: false,
  1416  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1417  				return false, nil
  1418  			}),
  1419  			waitFunc: func() waitFunc {
  1420  				return func(done <-chan struct{}) <-chan struct{} {
  1421  					ch := make(chan struct{})
  1422  					close(ch)
  1423  					return ch
  1424  				}
  1425  			},
  1426  			attemptsExpected: 1,
  1427  			errExpected:      ErrWaitTimeout,
  1428  		},
  1429  		{
  1430  			name:      "condition always returns false, timeout error expected",
  1431  			immediate: false,
  1432  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1433  				return false, nil
  1434  			}),
  1435  			waitFunc: func() waitFunc { return fakeTicker(2, nil, func() {}) },
  1436  			// the contract of waitForWithContext() says the func is called once more at the end of the wait
  1437  			attemptsExpected: 3,
  1438  			errExpected:      ErrWaitTimeout,
  1439  		},
  1440  		{
  1441  			name:      "context is cancelled after N attempts, timeout error expected",
  1442  			immediate: false,
  1443  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1444  				return false, nil
  1445  			}),
  1446  			waitFunc: func() waitFunc {
  1447  				return func(done <-chan struct{}) <-chan struct{} {
  1448  					ch := make(chan struct{})
  1449  					// just tick twice
  1450  					go func() {
  1451  						ch <- struct{}{}
  1452  						ch <- struct{}{}
  1453  					}()
  1454  					return ch
  1455  				}
  1456  			},
  1457  			cancelContextAfter: 2,
  1458  			attemptsExpected:   2,
  1459  			errExpected:        ErrWaitTimeout,
  1460  		},
  1461  		{
  1462  			name:      "context is cancelled after N attempts, context error not expected (legacy behavior)",
  1463  			immediate: false,
  1464  			condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
  1465  				return false, nil
  1466  			}),
  1467  			waitFunc: func() waitFunc {
  1468  				return func(done <-chan struct{}) <-chan struct{} {
  1469  					ch := make(chan struct{})
  1470  					// just tick twice
  1471  					go func() {
  1472  						ch <- struct{}{}
  1473  						ch <- struct{}{}
  1474  					}()
  1475  					return ch
  1476  				}
  1477  			},
  1478  			cancelContextAfter: 2,
  1479  			attemptsExpected:   2,
  1480  			errExpected:        ErrWaitTimeout,
  1481  		},
  1482  	}
  1483  
  1484  	for _, test := range tests {
  1485  		t.Run(test.name, func(t *testing.T) {
  1486  			var attempts int
  1487  			ticker := waitFunc(func(done <-chan struct{}) <-chan struct{} {
  1488  				return nil
  1489  			})
  1490  			if test.waitFunc != nil {
  1491  				ticker = test.waitFunc()
  1492  			}
  1493  			err := func() error {
  1494  				contextFn := test.context
  1495  				if contextFn == nil {
  1496  					contextFn = defaultContext
  1497  				}
  1498  				ctx, cancel := contextFn()
  1499  				defer cancel()
  1500  
  1501  				conditionWrapper := func(ctx context.Context) (done bool, err error) {
  1502  					attempts++
  1503  
  1504  					defer func() {
  1505  						if test.cancelContextAfter == attempts {
  1506  							cancel()
  1507  						}
  1508  					}()
  1509  
  1510  					return test.condition(ctx)
  1511  				}
  1512  
  1513  				return poll(ctx, test.immediate, ticker.WithContext(), conditionWrapper)
  1514  			}()
  1515  
  1516  			if test.errExpected != err {
  1517  				t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
  1518  			}
  1519  			if test.attemptsExpected != attempts {
  1520  				t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts)
  1521  			}
  1522  		})
  1523  	}
  1524  }
  1525  
  1526  func Benchmark_poll(b *testing.B) {
  1527  	ctx := context.Background()
  1528  	b.ResetTimer()
  1529  	for i := 0; i < b.N; i++ {
  1530  		attempts := 0
  1531  		if err := poll(ctx, true, poller(time.Microsecond, 0), func(_ context.Context) (bool, error) {
  1532  			attempts++
  1533  			return attempts >= 100, nil
  1534  		}); err != nil {
  1535  			b.Fatalf("unexpected err: %v", err)
  1536  		}
  1537  	}
  1538  }
  1539  

View as plain text