...

Source file src/k8s.io/apimachinery/pkg/util/waitgroup/ratelimited_waitgroup.go

Documentation: k8s.io/apimachinery/pkg/util/waitgroup

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package waitgroup
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"sync"
    23  )
    24  
    25  // RateLimiter abstracts the rate limiter used by RateLimitedSafeWaitGroup.
    26  // The implementation must be thread-safe.
    27  type RateLimiter interface {
    28  	Wait(ctx context.Context) error
    29  }
    30  
    31  // RateLimiterFactoryFunc is used by the RateLimitedSafeWaitGroup to create a new
    32  // instance of a RateLimiter that will be used to rate limit the return rate
    33  // of the active number of request(s). 'count' is the number of requests in
    34  // flight that are expected to invoke 'Done' on this wait group.
    35  type RateLimiterFactoryFunc func(count int) (RateLimiter, context.Context, context.CancelFunc)
    36  
    37  // RateLimitedSafeWaitGroup must not be copied after first use.
    38  type RateLimitedSafeWaitGroup struct {
    39  	wg sync.WaitGroup
    40  	// Once Wait is initiated, all consecutive Done invocation will be
    41  	// rate limited using this rate limiter.
    42  	limiter RateLimiter
    43  	stopCtx context.Context
    44  
    45  	mu sync.Mutex
    46  	// wait indicate whether Wait is called, if true,
    47  	// then any Add with positive delta will return error.
    48  	wait bool
    49  	// number of request(s) currently using the wait group
    50  	count int
    51  }
    52  
    53  // Add adds delta, which may be negative, similar to sync.WaitGroup.
    54  // If Add with a positive delta happens after Wait, it will return error,
    55  // which prevent unsafe Add.
    56  func (wg *RateLimitedSafeWaitGroup) Add(delta int) error {
    57  	wg.mu.Lock()
    58  	defer wg.mu.Unlock()
    59  
    60  	if wg.wait && delta > 0 {
    61  		return fmt.Errorf("add with positive delta after Wait is forbidden")
    62  	}
    63  	wg.wg.Add(delta)
    64  	wg.count += delta
    65  	return nil
    66  }
    67  
    68  // Done decrements the WaitGroup counter, rate limiting is applied only
    69  // when the wait group is in waiting mode.
    70  func (wg *RateLimitedSafeWaitGroup) Done() {
    71  	var limiter RateLimiter
    72  	func() {
    73  		wg.mu.Lock()
    74  		defer wg.mu.Unlock()
    75  
    76  		wg.count -= 1
    77  		if wg.wait {
    78  			// we are using the limiter outside the scope of the lock
    79  			limiter = wg.limiter
    80  		}
    81  	}()
    82  
    83  	defer wg.wg.Done()
    84  	if limiter != nil {
    85  		limiter.Wait(wg.stopCtx)
    86  	}
    87  }
    88  
    89  // Wait blocks until the WaitGroup counter is zero or a hard limit has elapsed.
    90  // It returns the number of active request(s) accounted for at the time Wait
    91  // has been invoked, number of request(s) that have drianed (done using the
    92  // wait group immediately before Wait returns).
    93  // Ideally, the both numbers returned should be equal, to indicate that all
    94  // request(s) using the wait group have released their lock.
    95  func (wg *RateLimitedSafeWaitGroup) Wait(limiterFactory RateLimiterFactoryFunc) (int, int, error) {
    96  	if limiterFactory == nil {
    97  		return 0, 0, fmt.Errorf("rate limiter factory must be specified")
    98  	}
    99  
   100  	var cancel context.CancelFunc
   101  	var countNow, countAfter int
   102  	func() {
   103  		wg.mu.Lock()
   104  		defer wg.mu.Unlock()
   105  
   106  		wg.limiter, wg.stopCtx, cancel = limiterFactory(wg.count)
   107  		countNow = wg.count
   108  		wg.wait = true
   109  	}()
   110  
   111  	defer cancel()
   112  	// there should be a hard stop, in case request(s) are not responsive
   113  	// enough to invoke Done before the grace period is over.
   114  	waitDoneCh := make(chan struct{})
   115  	go func() {
   116  		defer close(waitDoneCh)
   117  		wg.wg.Wait()
   118  	}()
   119  
   120  	var err error
   121  	select {
   122  	case <-wg.stopCtx.Done():
   123  		err = wg.stopCtx.Err()
   124  	case <-waitDoneCh:
   125  	}
   126  
   127  	func() {
   128  		wg.mu.Lock()
   129  		defer wg.mu.Unlock()
   130  
   131  		countAfter = wg.count
   132  	}()
   133  	return countNow, countAfter, err
   134  }
   135  

View as plain text