...

Source file src/k8s.io/apimachinery/pkg/util/waitgroup/ratelimited_waitgroup_test.go

Documentation: k8s.io/apimachinery/pkg/util/waitgroup

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package waitgroup
    18  
    19  import (
    20  	"context"
    21  	"strings"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	"golang.org/x/time/rate"
    27  	"k8s.io/apimachinery/pkg/util/wait"
    28  )
    29  
    30  func TestRateLimitedSafeWaitGroup(t *testing.T) {
    31  	// we want to keep track of how many times rate limiter Wait method is
    32  	// being invoked, both before and after the wait group is in waiting mode.
    33  	limiter := &limiterWrapper{}
    34  
    35  	// we expect the context passed by the factory to be used
    36  	var cancelInvoked int
    37  	factory := &factory{
    38  		limiter: limiter,
    39  		grace:   2 * time.Second,
    40  		ctx:     context.Background(),
    41  		cancel: func() {
    42  			cancelInvoked++
    43  		},
    44  	}
    45  	target := &rateLimitedSafeWaitGroupWrapper{
    46  		RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{limiter: limiter},
    47  	}
    48  
    49  	// two set of requests
    50  	//  - n1: this set will finish using this waitgroup before Wait is invoked
    51  	//  - n2: this set will be in flight after Wait is invoked
    52  	n1, n2 := 100, 101
    53  
    54  	// so we know when all requests in n1 are done using the waitgroup
    55  	n1DoneWG := sync.WaitGroup{}
    56  
    57  	// so we know when all requests in n2 have called Add,
    58  	// but not finished with the waitgroup yet.
    59  	// this will allow the test to invoke 'Wait' once all requests
    60  	// in n2 have called `Add`, but none has called `Done` yet.
    61  	n2BeforeWaitWG := sync.WaitGroup{}
    62  	// so we know when all requests in n2 have called Done and
    63  	// are finished using the waitgroup
    64  	n2DoneWG := sync.WaitGroup{}
    65  
    66  	startCh, blockedCh := make(chan struct{}), make(chan struct{})
    67  	n1DoneWG.Add(n1)
    68  	for i := 0; i < n1; i++ {
    69  		go func() {
    70  			defer n1DoneWG.Done()
    71  			<-startCh
    72  
    73  			target.Add(1)
    74  			// let's finish using the waitgroup immediately
    75  			target.Done()
    76  		}()
    77  	}
    78  
    79  	n2BeforeWaitWG.Add(n2)
    80  	n2DoneWG.Add(n2)
    81  	for i := 0; i < n2; i++ {
    82  		go func() {
    83  			func() {
    84  				defer n2BeforeWaitWG.Done()
    85  				<-startCh
    86  
    87  				target.Add(1)
    88  			}()
    89  
    90  			func() {
    91  				defer n2DoneWG.Done()
    92  				// let's wait for the test to instruct the requests in n2
    93  				// that it is time to finish using the waitgroup.
    94  				<-blockedCh
    95  
    96  				target.Done()
    97  			}()
    98  		}()
    99  	}
   100  
   101  	// initially the count should be zero
   102  	if count := target.Count(); count != 0 {
   103  		t.Errorf("expected count to be zero, but got: %d", count)
   104  	}
   105  	// start the test
   106  	close(startCh)
   107  	// wait for the first set of requests (n1) to be done
   108  	n1DoneWG.Wait()
   109  
   110  	// after the first set of requests (n1) are done, the count should be zero
   111  	if invoked := limiter.invoked(); invoked != 0 {
   112  		t.Errorf("expected no call to rate limiter before Wait is called, but got: %d", invoked)
   113  	}
   114  
   115  	// make sure all requetss in the second group (n2) have started using the
   116  	// waitgroup (Add invoked) but no request is done using the waitgroup yet.
   117  	n2BeforeWaitWG.Wait()
   118  
   119  	// count should be n2, since every request in n2 is still using the waitgroup
   120  	if count := target.Count(); count != n2 {
   121  		t.Errorf("expected count to be: %d, but got: %d", n2, count)
   122  	}
   123  
   124  	// time for us to mark the waitgroup as `Waiting`
   125  	waitDoneCh := make(chan waitResult)
   126  	go func() {
   127  		factory.grace = 2 * time.Second
   128  		before, after, err := target.Wait(factory.NewRateLimiter)
   129  		waitDoneCh <- waitResult{before: before, after: after, err: err}
   130  	}()
   131  
   132  	// make sure there is no flake in the test due to this race condition
   133  	var waitingGot bool
   134  	wait.PollImmediate(500*time.Millisecond, wait.ForeverTestTimeout, func() (done bool, err error) {
   135  		if waiting := target.Waiting(); waiting {
   136  			waitingGot = true
   137  			return true, nil
   138  		}
   139  		return false, nil
   140  	})
   141  	// verify that the waitgroup is in 'Waiting' mode
   142  	if !waitingGot {
   143  		t.Errorf("expected to be in waiting")
   144  	}
   145  
   146  	// we should not allow any new request to use this waitgroup any longer
   147  	if err := target.Add(1); err == nil ||
   148  		!strings.Contains(err.Error(), "add with positive delta after Wait is forbidden") {
   149  		t.Errorf("expected Add to return error while in waiting mode: %v", err)
   150  	}
   151  
   152  	// make sure that RateLimitedSafeWaitGroup passes the right
   153  	// request count to the limiter factory.
   154  	if factory.countGot != n2 {
   155  		t.Errorf("expected count passed to factory to be: %d, but got: %d", n2, factory.countGot)
   156  	}
   157  
   158  	// indicate to all requests (each request in n2) that are
   159  	// currently using this waitgroup that they can go ahead
   160  	// and invoke 'Done' to finish using this waitgroup.
   161  	close(blockedCh)
   162  	n2DoneWG.Wait()
   163  
   164  	if invoked := limiter.invoked(); invoked != n2 {
   165  		t.Errorf("expected rate limiter to be called %d times, but got: %d", n2, invoked)
   166  	}
   167  
   168  	waitResult := <-waitDoneCh
   169  	if count := target.Count(); count != 0 {
   170  		t.Errorf("expected count to be zero, but got: %d", count)
   171  	}
   172  	if waitResult.before != n2 {
   173  		t.Errorf("expected count before Wait to be: %d, but got: %d", n2, waitResult.before)
   174  	}
   175  	if waitResult.after != 0 {
   176  		t.Errorf("expected count after Wait to be zero, but got: %d", waitResult.after)
   177  	}
   178  	if cancelInvoked != 1 {
   179  		t.Errorf("expected context cancel to be invoked once, but got: %d", cancelInvoked)
   180  	}
   181  }
   182  
   183  func TestRateLimitedSafeWaitGroupWithHardTimeout(t *testing.T) {
   184  	target := &rateLimitedSafeWaitGroupWrapper{
   185  		RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{},
   186  	}
   187  	n := 10
   188  	wg := sync.WaitGroup{}
   189  	wg.Add(n)
   190  	for i := 0; i < n; i++ {
   191  		go func() {
   192  			defer wg.Done()
   193  			target.Add(1)
   194  		}()
   195  	}
   196  
   197  	wg.Wait()
   198  	if count := target.Count(); count != n {
   199  		t.Errorf("expected count to be: %d, but got: %d", n, count)
   200  	}
   201  
   202  	ctx, cancel := context.WithCancel(context.Background())
   203  	cancel()
   204  	activeAt, activeNow, err := target.Wait(func(count int) (RateLimiter, context.Context, context.CancelFunc) {
   205  		return nil, ctx, cancel
   206  	})
   207  	if activeAt != n {
   208  		t.Errorf("expected active at Wait to be: %d, but got: %d", n, activeAt)
   209  	}
   210  	if activeNow != n {
   211  		t.Errorf("expected active after Wait to be: %d, but got: %d", n, activeNow)
   212  	}
   213  	if err != context.Canceled {
   214  		t.Errorf("expected error: %v, but got: %v", context.Canceled, err)
   215  	}
   216  }
   217  
   218  func TestRateLimitedSafeWaitGroupWithBurstOfOne(t *testing.T) {
   219  	target := &rateLimitedSafeWaitGroupWrapper{
   220  		RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{},
   221  	}
   222  	n := 200
   223  	grace := 5 * time.Second
   224  	wg := sync.WaitGroup{}
   225  	wg.Add(n)
   226  	for i := 0; i < n; i++ {
   227  		go func() {
   228  			defer wg.Done()
   229  			target.Add(1)
   230  		}()
   231  	}
   232  	wg.Wait()
   233  
   234  	waitingCh := make(chan struct{})
   235  	wg.Add(n)
   236  	for i := 0; i < n; i++ {
   237  		go func() {
   238  			defer wg.Done()
   239  
   240  			<-waitingCh
   241  			target.Done()
   242  		}()
   243  	}
   244  	defer wg.Wait()
   245  
   246  	now := time.Now()
   247  	t.Logf("Wait starting, N=%d, grace: %s, at: %s", n, grace, now)
   248  	activeAt, activeNow, err := target.Wait(func(count int) (RateLimiter, context.Context, context.CancelFunc) {
   249  		defer close(waitingCh)
   250  		// no deadline in context, Wait will wait forever, we want to measure
   251  		// how long it takes for the requests to drain.
   252  		return rate.NewLimiter(rate.Limit(n/int(grace.Seconds())), 1), context.Background(), func() {}
   253  	})
   254  	took := time.Since(now)
   255  	t.Logf("Wait finished, count(before): %d, count(after): %d, took: %s, err: %v", activeAt, activeNow, took, err)
   256  
   257  	// in CPU starved environment, the go routines may not finish in time
   258  	if took > 2*grace {
   259  		t.Errorf("expected Wait to take: %s, but it took: %s", grace, took)
   260  	}
   261  }
   262  
   263  type waitResult struct {
   264  	before, after int
   265  	err           error
   266  }
   267  
   268  type rateLimitedSafeWaitGroupWrapper struct {
   269  	*RateLimitedSafeWaitGroup
   270  }
   271  
   272  // used by test only
   273  func (wg *rateLimitedSafeWaitGroupWrapper) Count() int {
   274  	wg.mu.Lock()
   275  	defer wg.mu.Unlock()
   276  
   277  	return wg.count
   278  }
   279  func (wg *rateLimitedSafeWaitGroupWrapper) Waiting() bool {
   280  	wg.mu.Lock()
   281  	defer wg.mu.Unlock()
   282  
   283  	return wg.wait
   284  }
   285  
   286  type limiterWrapper struct {
   287  	delegate RateLimiter
   288  	lock     sync.Mutex
   289  	invokedN int
   290  }
   291  
   292  func (w *limiterWrapper) invoked() int {
   293  	w.lock.Lock()
   294  	defer w.lock.Unlock()
   295  	return w.invokedN
   296  }
   297  func (w *limiterWrapper) Wait(ctx context.Context) error {
   298  	w.lock.Lock()
   299  	w.invokedN++
   300  	w.lock.Unlock()
   301  
   302  	if w.delegate != nil {
   303  		w.delegate.Wait(ctx)
   304  	}
   305  	return nil
   306  }
   307  
   308  type factory struct {
   309  	limiter  *limiterWrapper
   310  	grace    time.Duration
   311  	ctx      context.Context
   312  	cancel   context.CancelFunc
   313  	countGot int
   314  }
   315  
   316  func (f *factory) NewRateLimiter(count int) (RateLimiter, context.Context, context.CancelFunc) {
   317  	f.countGot = count
   318  	f.limiter.delegate = rate.NewLimiter(rate.Limit(count/int(f.grace.Seconds())), 20)
   319  	return f.limiter, f.ctx, f.cancel
   320  }
   321  

View as plain text