1
16
17 package waitgroup
18
19 import (
20 "context"
21 "strings"
22 "sync"
23 "testing"
24 "time"
25
26 "golang.org/x/time/rate"
27 "k8s.io/apimachinery/pkg/util/wait"
28 )
29
30 func TestRateLimitedSafeWaitGroup(t *testing.T) {
31
32
33 limiter := &limiterWrapper{}
34
35
36 var cancelInvoked int
37 factory := &factory{
38 limiter: limiter,
39 grace: 2 * time.Second,
40 ctx: context.Background(),
41 cancel: func() {
42 cancelInvoked++
43 },
44 }
45 target := &rateLimitedSafeWaitGroupWrapper{
46 RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{limiter: limiter},
47 }
48
49
50
51
52 n1, n2 := 100, 101
53
54
55 n1DoneWG := sync.WaitGroup{}
56
57
58
59
60
61 n2BeforeWaitWG := sync.WaitGroup{}
62
63
64 n2DoneWG := sync.WaitGroup{}
65
66 startCh, blockedCh := make(chan struct{}), make(chan struct{})
67 n1DoneWG.Add(n1)
68 for i := 0; i < n1; i++ {
69 go func() {
70 defer n1DoneWG.Done()
71 <-startCh
72
73 target.Add(1)
74
75 target.Done()
76 }()
77 }
78
79 n2BeforeWaitWG.Add(n2)
80 n2DoneWG.Add(n2)
81 for i := 0; i < n2; i++ {
82 go func() {
83 func() {
84 defer n2BeforeWaitWG.Done()
85 <-startCh
86
87 target.Add(1)
88 }()
89
90 func() {
91 defer n2DoneWG.Done()
92
93
94 <-blockedCh
95
96 target.Done()
97 }()
98 }()
99 }
100
101
102 if count := target.Count(); count != 0 {
103 t.Errorf("expected count to be zero, but got: %d", count)
104 }
105
106 close(startCh)
107
108 n1DoneWG.Wait()
109
110
111 if invoked := limiter.invoked(); invoked != 0 {
112 t.Errorf("expected no call to rate limiter before Wait is called, but got: %d", invoked)
113 }
114
115
116
117 n2BeforeWaitWG.Wait()
118
119
120 if count := target.Count(); count != n2 {
121 t.Errorf("expected count to be: %d, but got: %d", n2, count)
122 }
123
124
125 waitDoneCh := make(chan waitResult)
126 go func() {
127 factory.grace = 2 * time.Second
128 before, after, err := target.Wait(factory.NewRateLimiter)
129 waitDoneCh <- waitResult{before: before, after: after, err: err}
130 }()
131
132
133 var waitingGot bool
134 wait.PollImmediate(500*time.Millisecond, wait.ForeverTestTimeout, func() (done bool, err error) {
135 if waiting := target.Waiting(); waiting {
136 waitingGot = true
137 return true, nil
138 }
139 return false, nil
140 })
141
142 if !waitingGot {
143 t.Errorf("expected to be in waiting")
144 }
145
146
147 if err := target.Add(1); err == nil ||
148 !strings.Contains(err.Error(), "add with positive delta after Wait is forbidden") {
149 t.Errorf("expected Add to return error while in waiting mode: %v", err)
150 }
151
152
153
154 if factory.countGot != n2 {
155 t.Errorf("expected count passed to factory to be: %d, but got: %d", n2, factory.countGot)
156 }
157
158
159
160
161 close(blockedCh)
162 n2DoneWG.Wait()
163
164 if invoked := limiter.invoked(); invoked != n2 {
165 t.Errorf("expected rate limiter to be called %d times, but got: %d", n2, invoked)
166 }
167
168 waitResult := <-waitDoneCh
169 if count := target.Count(); count != 0 {
170 t.Errorf("expected count to be zero, but got: %d", count)
171 }
172 if waitResult.before != n2 {
173 t.Errorf("expected count before Wait to be: %d, but got: %d", n2, waitResult.before)
174 }
175 if waitResult.after != 0 {
176 t.Errorf("expected count after Wait to be zero, but got: %d", waitResult.after)
177 }
178 if cancelInvoked != 1 {
179 t.Errorf("expected context cancel to be invoked once, but got: %d", cancelInvoked)
180 }
181 }
182
183 func TestRateLimitedSafeWaitGroupWithHardTimeout(t *testing.T) {
184 target := &rateLimitedSafeWaitGroupWrapper{
185 RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{},
186 }
187 n := 10
188 wg := sync.WaitGroup{}
189 wg.Add(n)
190 for i := 0; i < n; i++ {
191 go func() {
192 defer wg.Done()
193 target.Add(1)
194 }()
195 }
196
197 wg.Wait()
198 if count := target.Count(); count != n {
199 t.Errorf("expected count to be: %d, but got: %d", n, count)
200 }
201
202 ctx, cancel := context.WithCancel(context.Background())
203 cancel()
204 activeAt, activeNow, err := target.Wait(func(count int) (RateLimiter, context.Context, context.CancelFunc) {
205 return nil, ctx, cancel
206 })
207 if activeAt != n {
208 t.Errorf("expected active at Wait to be: %d, but got: %d", n, activeAt)
209 }
210 if activeNow != n {
211 t.Errorf("expected active after Wait to be: %d, but got: %d", n, activeNow)
212 }
213 if err != context.Canceled {
214 t.Errorf("expected error: %v, but got: %v", context.Canceled, err)
215 }
216 }
217
218 func TestRateLimitedSafeWaitGroupWithBurstOfOne(t *testing.T) {
219 target := &rateLimitedSafeWaitGroupWrapper{
220 RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{},
221 }
222 n := 200
223 grace := 5 * time.Second
224 wg := sync.WaitGroup{}
225 wg.Add(n)
226 for i := 0; i < n; i++ {
227 go func() {
228 defer wg.Done()
229 target.Add(1)
230 }()
231 }
232 wg.Wait()
233
234 waitingCh := make(chan struct{})
235 wg.Add(n)
236 for i := 0; i < n; i++ {
237 go func() {
238 defer wg.Done()
239
240 <-waitingCh
241 target.Done()
242 }()
243 }
244 defer wg.Wait()
245
246 now := time.Now()
247 t.Logf("Wait starting, N=%d, grace: %s, at: %s", n, grace, now)
248 activeAt, activeNow, err := target.Wait(func(count int) (RateLimiter, context.Context, context.CancelFunc) {
249 defer close(waitingCh)
250
251
252 return rate.NewLimiter(rate.Limit(n/int(grace.Seconds())), 1), context.Background(), func() {}
253 })
254 took := time.Since(now)
255 t.Logf("Wait finished, count(before): %d, count(after): %d, took: %s, err: %v", activeAt, activeNow, took, err)
256
257
258 if took > 2*grace {
259 t.Errorf("expected Wait to take: %s, but it took: %s", grace, took)
260 }
261 }
262
263 type waitResult struct {
264 before, after int
265 err error
266 }
267
268 type rateLimitedSafeWaitGroupWrapper struct {
269 *RateLimitedSafeWaitGroup
270 }
271
272
273 func (wg *rateLimitedSafeWaitGroupWrapper) Count() int {
274 wg.mu.Lock()
275 defer wg.mu.Unlock()
276
277 return wg.count
278 }
279 func (wg *rateLimitedSafeWaitGroupWrapper) Waiting() bool {
280 wg.mu.Lock()
281 defer wg.mu.Unlock()
282
283 return wg.wait
284 }
285
286 type limiterWrapper struct {
287 delegate RateLimiter
288 lock sync.Mutex
289 invokedN int
290 }
291
292 func (w *limiterWrapper) invoked() int {
293 w.lock.Lock()
294 defer w.lock.Unlock()
295 return w.invokedN
296 }
297 func (w *limiterWrapper) Wait(ctx context.Context) error {
298 w.lock.Lock()
299 w.invokedN++
300 w.lock.Unlock()
301
302 if w.delegate != nil {
303 w.delegate.Wait(ctx)
304 }
305 return nil
306 }
307
308 type factory struct {
309 limiter *limiterWrapper
310 grace time.Duration
311 ctx context.Context
312 cancel context.CancelFunc
313 countGot int
314 }
315
316 func (f *factory) NewRateLimiter(count int) (RateLimiter, context.Context, context.CancelFunc) {
317 f.countGot = count
318 f.limiter.delegate = rate.NewLimiter(rate.Limit(count/int(f.grace.Seconds())), 20)
319 return f.limiter, f.ctx, f.cancel
320 }
321
View as plain text