1 /* 2 Copyright 2023 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package waitgroup 18 19 import ( 20 "context" 21 "fmt" 22 "sync" 23 ) 24 25 // RateLimiter abstracts the rate limiter used by RateLimitedSafeWaitGroup. 26 // The implementation must be thread-safe. 27 type RateLimiter interface { 28 Wait(ctx context.Context) error 29 } 30 31 // RateLimiterFactoryFunc is used by the RateLimitedSafeWaitGroup to create a new 32 // instance of a RateLimiter that will be used to rate limit the return rate 33 // of the active number of request(s). 'count' is the number of requests in 34 // flight that are expected to invoke 'Done' on this wait group. 35 type RateLimiterFactoryFunc func(count int) (RateLimiter, context.Context, context.CancelFunc) 36 37 // RateLimitedSafeWaitGroup must not be copied after first use. 38 type RateLimitedSafeWaitGroup struct { 39 wg sync.WaitGroup 40 // Once Wait is initiated, all consecutive Done invocation will be 41 // rate limited using this rate limiter. 42 limiter RateLimiter 43 stopCtx context.Context 44 45 mu sync.Mutex 46 // wait indicate whether Wait is called, if true, 47 // then any Add with positive delta will return error. 48 wait bool 49 // number of request(s) currently using the wait group 50 count int 51 } 52 53 // Add adds delta, which may be negative, similar to sync.WaitGroup. 54 // If Add with a positive delta happens after Wait, it will return error, 55 // which prevent unsafe Add. 56 func (wg *RateLimitedSafeWaitGroup) Add(delta int) error { 57 wg.mu.Lock() 58 defer wg.mu.Unlock() 59 60 if wg.wait && delta > 0 { 61 return fmt.Errorf("add with positive delta after Wait is forbidden") 62 } 63 wg.wg.Add(delta) 64 wg.count += delta 65 return nil 66 } 67 68 // Done decrements the WaitGroup counter, rate limiting is applied only 69 // when the wait group is in waiting mode. 70 func (wg *RateLimitedSafeWaitGroup) Done() { 71 var limiter RateLimiter 72 func() { 73 wg.mu.Lock() 74 defer wg.mu.Unlock() 75 76 wg.count -= 1 77 if wg.wait { 78 // we are using the limiter outside the scope of the lock 79 limiter = wg.limiter 80 } 81 }() 82 83 defer wg.wg.Done() 84 if limiter != nil { 85 limiter.Wait(wg.stopCtx) 86 } 87 } 88 89 // Wait blocks until the WaitGroup counter is zero or a hard limit has elapsed. 90 // It returns the number of active request(s) accounted for at the time Wait 91 // has been invoked, number of request(s) that have drianed (done using the 92 // wait group immediately before Wait returns). 93 // Ideally, the both numbers returned should be equal, to indicate that all 94 // request(s) using the wait group have released their lock. 95 func (wg *RateLimitedSafeWaitGroup) Wait(limiterFactory RateLimiterFactoryFunc) (int, int, error) { 96 if limiterFactory == nil { 97 return 0, 0, fmt.Errorf("rate limiter factory must be specified") 98 } 99 100 var cancel context.CancelFunc 101 var countNow, countAfter int 102 func() { 103 wg.mu.Lock() 104 defer wg.mu.Unlock() 105 106 wg.limiter, wg.stopCtx, cancel = limiterFactory(wg.count) 107 countNow = wg.count 108 wg.wait = true 109 }() 110 111 defer cancel() 112 // there should be a hard stop, in case request(s) are not responsive 113 // enough to invoke Done before the grace period is over. 114 waitDoneCh := make(chan struct{}) 115 go func() { 116 defer close(waitDoneCh) 117 wg.wg.Wait() 118 }() 119 120 var err error 121 select { 122 case <-wg.stopCtx.Done(): 123 err = wg.stopCtx.Err() 124 case <-waitDoneCh: 125 } 126 127 func() { 128 wg.mu.Lock() 129 defer wg.mu.Unlock() 130 131 countAfter = wg.count 132 }() 133 return countNow, countAfter, err 134 } 135