1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package managedwriter
16
17 import (
18 "context"
19 "errors"
20 "fmt"
21 "testing"
22 "time"
23
24 "golang.org/x/sync/errgroup"
25 )
26
27 func TestFlowControllerCancel(t *testing.T) {
28
29 t.Parallel()
30 wantInsertBytes := 10
31 fc := newFlowController(3, wantInsertBytes)
32 if fc.maxInsertBytes != 10 {
33 t.Fatalf("maxInsertBytes mismatch, got %d want %d", fc.maxInsertBytes, wantInsertBytes)
34 }
35 if err := fc.acquire(context.Background(), 5); err != nil {
36 t.Fatal(err)
37 }
38
39 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
40 defer cancel()
41 if err := fc.acquire(ctx, 6); err != context.DeadlineExceeded {
42 t.Fatalf("got %v, expected DeadlineExceeded", err)
43 }
44
45 go func() {
46 time.Sleep(5 * time.Millisecond)
47 fc.release(5)
48 }()
49 if err := fc.acquire(context.Background(), 6); err != nil {
50 t.Errorf("got %v, expected nil", err)
51 }
52 }
53
54 func TestFlowControllerLargeRequest(t *testing.T) {
55
56 t.Parallel()
57 fc := newFlowController(3, 10)
58 err := fc.acquire(context.Background(), 11)
59 if err != nil {
60 t.Fatal(err)
61 }
62 }
63
64 func TestFlowControllerNoStarve(t *testing.T) {
65
66
67 t.Parallel()
68 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
69 defer cancel()
70 fc := newFlowController(10, 10)
71 first := make(chan int)
72 for i := 0; i < 20; i++ {
73 go func() {
74 for {
75 if err := fc.acquire(ctx, 1); err != nil {
76 if err != context.Canceled {
77 t.Error(err)
78 }
79 return
80 }
81 select {
82 case first <- 1:
83 default:
84 }
85 fc.release(1)
86 }
87 }()
88 }
89 <-first
90 if err := fc.acquire(ctx, 11); err != nil {
91 t.Errorf("got %v, want nil", err)
92 }
93 }
94
95 func TestFlowControllerSaturation(t *testing.T) {
96 t.Parallel()
97 const (
98 maxCount = 6
99 maxSize = 10
100 )
101 for _, test := range []struct {
102 acquireSize int
103 wantCount, wantSize int64
104 }{
105 {
106
107 acquireSize: 1,
108 wantCount: 6,
109 wantSize: 6,
110 },
111 {
112
113
114 acquireSize: 2,
115 wantCount: 5,
116 wantSize: 10,
117 },
118 {
119
120
121 acquireSize: 3,
122 wantCount: 3,
123 wantSize: 9,
124 },
125 } {
126 fc := newFlowController(maxCount, maxSize)
127 success := errors.New("")
128
129 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
130 defer cancel()
131 g, ctx := errgroup.WithContext(ctx)
132 for i := 0; i < 10; i++ {
133 g.Go(func() error {
134 var hitCount, hitSize bool
135
136
137
138 for i := 0; i < 100 || !hitCount || !hitSize; i++ {
139 select {
140 case <-ctx.Done():
141 return ctx.Err()
142 default:
143 }
144 if err := fc.acquire(ctx, test.acquireSize); err != nil {
145 return err
146 }
147 c := int64(fc.count())
148 if c > test.wantCount {
149 return fmt.Errorf("count %d exceeds want %d", c, test.wantCount)
150 }
151 if c == test.wantCount {
152 hitCount = true
153 }
154 s := int64(fc.bytes())
155 if s > test.wantSize {
156 return fmt.Errorf("size %d exceeds want %d", s, test.wantSize)
157 }
158 if s == test.wantSize {
159 hitSize = true
160 }
161 time.Sleep(5 * time.Millisecond)
162 if fc.bytes() < 0 {
163 return errors.New("negative size")
164 }
165 fc.release(test.acquireSize)
166 }
167 return success
168 })
169 }
170 if err := g.Wait(); err != success {
171 t.Errorf("%+v: %v", test, err)
172 continue
173 }
174 }
175 }
176
177 func TestFlowControllerTryAcquire(t *testing.T) {
178 t.Parallel()
179 fc := newFlowController(3, 10)
180
181
182 if !fc.tryAcquire(4) {
183 t.Error("got false, wanted true")
184 }
185
186
187 if fc.tryAcquire(7) {
188 t.Error("got true, wanted false")
189 }
190
191
192 if !fc.tryAcquire(6) {
193 t.Error("got false, wanted true")
194 }
195 }
196
197 func TestFlowControllerUnboundedCount(t *testing.T) {
198 t.Parallel()
199 ctx := context.Background()
200 fc := newFlowController(0, 10)
201
202
203 if err := fc.acquire(ctx, 4); err != nil {
204 t.Errorf("got %v, wanted no error", err)
205 }
206
207
208 if !fc.tryAcquire(4) {
209 t.Error("got false, wanted true")
210 }
211 wantBytes := int64(8)
212 if gotB := int64(fc.bytes()); gotB != wantBytes {
213 t.Fatalf("got bytes %d, want %d", gotB, wantBytes)
214 }
215
216
217 if fc.tryAcquire(3) {
218 t.Error("got true, wanted false")
219 }
220
221 if gotB := int64(fc.bytes()); gotB != wantBytes {
222 t.Fatalf("got bytes %d, want %d", gotB, wantBytes)
223 }
224
225 }
226
227 func TestFlowControllerUnboundedCount2(t *testing.T) {
228 t.Parallel()
229 ctx := context.Background()
230 fc := newFlowController(0, 0)
231
232 if err := fc.acquire(ctx, 4); err != nil {
233 t.Errorf("got %v, wanted no error", err)
234 }
235 wantBytes := int64(0)
236 if gotB := int64(fc.bytes()); gotB != wantBytes {
237 t.Fatalf("got bytes %d, want %d", gotB, wantBytes)
238 }
239 fc.release(1)
240 fc.release(1)
241 fc.release(1)
242 wantCount := int64(-2)
243 if c := int64(fc.count()); c != wantCount {
244 t.Fatalf("got count %d, want %d", c, wantCount)
245 }
246 if gotB := int64(fc.bytes()); gotB != wantBytes {
247 t.Fatalf("got bytes %d, want %d", gotB, wantBytes)
248 }
249 }
250
251 func TestFlowControllerUnboundedBytes(t *testing.T) {
252 t.Parallel()
253 ctx := context.Background()
254 fc := newFlowController(2, 0)
255
256
257 if err := fc.acquire(ctx, 4e9); err != nil {
258 t.Errorf("got %v, wanted no error", err)
259 }
260
261
262 if !fc.tryAcquire(4e9) {
263 t.Error("got false, wanted true")
264 }
265
266
267 if fc.tryAcquire(3) {
268 t.Error("got true, wanted false")
269 }
270 }
271
272 func TestCopyFlowController(t *testing.T) {
273 testcases := []struct {
274 description string
275 in *flowController
276 wantMaxRequests int
277 wantMaxBytes int
278 }{
279 {
280 description: "nil source",
281 wantMaxRequests: 0,
282 wantMaxBytes: 0,
283 },
284 {
285 description: "no limit",
286 in: newFlowController(0, 0),
287 wantMaxRequests: 0,
288 wantMaxBytes: 0,
289 },
290 {
291 description: "bounded",
292 in: newFlowController(10, 1024),
293 wantMaxRequests: 10,
294 wantMaxBytes: 1024,
295 },
296 }
297
298 for _, tc := range testcases {
299 fc := copyFlowController(tc.in)
300 if fc.maxInsertBytes != tc.wantMaxBytes {
301 t.Errorf("%s: max bytes mismatch, got %d want %d ", tc.description, fc.maxInsertBytes, tc.wantMaxBytes)
302 }
303 if fc.maxInsertCount != tc.wantMaxRequests {
304 t.Errorf("%s: max requests mismatch, got %d want %d ", tc.description, fc.maxInsertBytes, tc.wantMaxBytes)
305 }
306 }
307 }
308
View as plain text