1
18
19 package testutils
20
21 import (
22 "context"
23 "errors"
24 "fmt"
25 "testing"
26
27 "google.golang.org/grpc/balancer"
28 "google.golang.org/grpc/connectivity"
29 "google.golang.org/grpc/internal/grpcsync"
30 "google.golang.org/grpc/resolver"
31 )
32
33
34 type TestSubConn struct {
35 tcc *BalancerClientConn
36 id string
37 ConnectCh chan struct{}
38 stateListener func(balancer.SubConnState)
39 connectCalled *grpcsync.Event
40 }
41
42
43
44
45 func NewTestSubConn(id string) *TestSubConn {
46 return &TestSubConn{
47 ConnectCh: make(chan struct{}, 1),
48 connectCalled: grpcsync.NewEvent(),
49 id: id,
50 }
51 }
52
53
54 func (tsc *TestSubConn) UpdateAddresses([]resolver.Address) {}
55
56
57 func (tsc *TestSubConn) Connect() {
58 tsc.connectCalled.Fire()
59 select {
60 case tsc.ConnectCh <- struct{}{}:
61 default:
62 }
63 }
64
65
66 func (tsc *TestSubConn) GetOrBuildProducer(balancer.ProducerBuilder) (balancer.Producer, func()) {
67 return nil, nil
68 }
69
70
71 func (tsc *TestSubConn) UpdateState(state balancer.SubConnState) {
72 <-tsc.connectCalled.Done()
73 if tsc.stateListener != nil {
74 tsc.stateListener(state)
75 return
76 }
77 }
78
79
80
81 func (tsc *TestSubConn) Shutdown() {
82 tsc.tcc.logger.Logf("SubConn %s: Shutdown", tsc)
83 select {
84 case tsc.tcc.ShutdownSubConnCh <- tsc:
85 default:
86 }
87 }
88
89
90 func (tsc *TestSubConn) String() string {
91 return tsc.id
92 }
93
94
95 type BalancerClientConn struct {
96 logger Logger
97
98 NewSubConnAddrsCh chan []resolver.Address
99 NewSubConnCh chan *TestSubConn
100 ShutdownSubConnCh chan *TestSubConn
101 UpdateAddressesAddrsCh chan []resolver.Address
102
103 NewPickerCh chan balancer.Picker
104 NewStateCh chan connectivity.State
105 ResolveNowCh chan resolver.ResolveNowOptions
106
107 subConnIdx int
108 }
109
110
111 func NewBalancerClientConn(t *testing.T) *BalancerClientConn {
112 return &BalancerClientConn{
113 logger: t,
114
115 NewSubConnAddrsCh: make(chan []resolver.Address, 10),
116 NewSubConnCh: make(chan *TestSubConn, 10),
117 ShutdownSubConnCh: make(chan *TestSubConn, 10),
118 UpdateAddressesAddrsCh: make(chan []resolver.Address, 1),
119
120 NewPickerCh: make(chan balancer.Picker, 1),
121 NewStateCh: make(chan connectivity.State, 1),
122 ResolveNowCh: make(chan resolver.ResolveNowOptions, 1),
123 }
124 }
125
126
127 func (tcc *BalancerClientConn) NewSubConn(a []resolver.Address, o balancer.NewSubConnOptions) (balancer.SubConn, error) {
128 sc := &TestSubConn{
129 tcc: tcc,
130 id: fmt.Sprintf("sc%d", tcc.subConnIdx),
131 ConnectCh: make(chan struct{}, 1),
132 stateListener: o.StateListener,
133 connectCalled: grpcsync.NewEvent(),
134 }
135 tcc.subConnIdx++
136 tcc.logger.Logf("testClientConn: NewSubConn(%v, %+v) => %s", a, o, sc)
137 select {
138 case tcc.NewSubConnAddrsCh <- a:
139 default:
140 }
141
142 select {
143 case tcc.NewSubConnCh <- sc:
144 default:
145 }
146
147 return sc, nil
148 }
149
150
151
152 func (tcc *BalancerClientConn) RemoveSubConn(sc balancer.SubConn) {
153 tcc.logger.Errorf("RemoveSubConn(%v) called unexpectedly", sc)
154 }
155
156
157 func (tcc *BalancerClientConn) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
158 tcc.logger.Logf("testutils.BalancerClientConn: UpdateAddresses(%v, %+v)", sc, addrs)
159 select {
160 case tcc.UpdateAddressesAddrsCh <- addrs:
161 default:
162 }
163 }
164
165
166 func (tcc *BalancerClientConn) UpdateState(bs balancer.State) {
167 tcc.logger.Logf("testutils.BalancerClientConn: UpdateState(%v)", bs)
168 select {
169 case <-tcc.NewStateCh:
170 default:
171 }
172 tcc.NewStateCh <- bs.ConnectivityState
173
174 select {
175 case <-tcc.NewPickerCh:
176 default:
177 }
178 tcc.NewPickerCh <- bs.Picker
179 }
180
181
182 func (tcc *BalancerClientConn) ResolveNow(o resolver.ResolveNowOptions) {
183 select {
184 case <-tcc.ResolveNowCh:
185 default:
186 }
187 tcc.ResolveNowCh <- o
188 }
189
190
191 func (tcc *BalancerClientConn) Target() string {
192 panic("not implemented")
193 }
194
195
196
197
198 func (tcc *BalancerClientConn) WaitForErrPicker(ctx context.Context) error {
199 select {
200 case <-ctx.Done():
201 return errors.New("timeout when waiting for an error picker")
202 case picker := <-tcc.NewPickerCh:
203 if _, perr := picker.Pick(balancer.PickInfo{}); perr == nil {
204 return fmt.Errorf("balancer returned a picker which is not an error picker")
205 }
206 }
207 return nil
208 }
209
210
211
212
213
214 func (tcc *BalancerClientConn) WaitForPickerWithErr(ctx context.Context, want error) error {
215 lastErr := errors.New("received no picker")
216 for {
217 select {
218 case <-ctx.Done():
219 return fmt.Errorf("timeout when waiting for an error picker with %v; last picker error: %v", want, lastErr)
220 case picker := <-tcc.NewPickerCh:
221 if _, lastErr = picker.Pick(balancer.PickInfo{}); lastErr != nil && lastErr.Error() == want.Error() {
222 return nil
223 }
224 }
225 }
226 }
227
228
229
230
231 func (tcc *BalancerClientConn) WaitForConnectivityState(ctx context.Context, want connectivity.State) error {
232 var lastState connectivity.State = -1
233 for {
234 select {
235 case <-ctx.Done():
236 return fmt.Errorf("timeout when waiting for state to be %s; last state: %s", want, lastState)
237 case s := <-tcc.NewStateCh:
238 if s == want {
239 return nil
240 }
241 lastState = s
242 }
243 }
244 }
245
246
247
248
249
250
251 func (tcc *BalancerClientConn) WaitForRoundRobinPicker(ctx context.Context, want ...balancer.SubConn) error {
252 lastErr := errors.New("received no picker")
253 for {
254 select {
255 case <-ctx.Done():
256 return fmt.Errorf("timeout when waiting for round robin picker with %v; last error: %v", want, lastErr)
257 case p := <-tcc.NewPickerCh:
258 s := connectivity.Ready
259 select {
260 case s = <-tcc.NewStateCh:
261 default:
262 }
263 if s != connectivity.Ready {
264 lastErr = fmt.Errorf("received state %v instead of ready", s)
265 break
266 }
267 var pickerErr error
268 if err := IsRoundRobin(want, func() balancer.SubConn {
269 sc, err := p.Pick(balancer.PickInfo{})
270 if err != nil {
271 pickerErr = err
272 } else if sc.Done != nil {
273 sc.Done(balancer.DoneInfo{})
274 }
275 return sc.SubConn
276 }); pickerErr != nil {
277 lastErr = pickerErr
278 continue
279 } else if err != nil {
280 lastErr = err
281 continue
282 }
283 return nil
284 }
285 }
286 }
287
288
289
290 func (tcc *BalancerClientConn) WaitForPicker(ctx context.Context, f func(balancer.Picker) error) error {
291 lastErr := errors.New("received no picker")
292 for {
293 select {
294 case <-ctx.Done():
295 return fmt.Errorf("timeout when waiting for picker; last error: %v", lastErr)
296 case p := <-tcc.NewPickerCh:
297 if err := f(p); err != nil {
298 lastErr = err
299 continue
300 }
301 return nil
302 }
303 }
304 }
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326 func IsRoundRobin(want []balancer.SubConn, f func() balancer.SubConn) error {
327 wantSet := make(map[balancer.SubConn]int)
328 for _, sc := range want {
329 wantSet[sc]++
330 }
331
332
333
334
335
336
337 gotSliceFirstIteration := make([]balancer.SubConn, 0, len(want))
338 for range want {
339 got := f()
340 gotSliceFirstIteration = append(gotSliceFirstIteration, got)
341 wantSet[got]--
342 if wantSet[got] < 0 {
343 return fmt.Errorf("non-roundrobin want: %v, result: %v", want, gotSliceFirstIteration)
344 }
345 }
346
347
348 var gotSliceSecondIteration []balancer.SubConn
349 for i := 0; i < 2; i++ {
350 for _, w := range gotSliceFirstIteration {
351 g := f()
352 gotSliceSecondIteration = append(gotSliceSecondIteration, g)
353 if w != g {
354 return fmt.Errorf("non-roundrobin, first iter: %v, second iter: %v", gotSliceFirstIteration, gotSliceSecondIteration)
355 }
356 }
357 }
358
359 return nil
360 }
361
362
363
364
365 func SubConnFromPicker(p balancer.Picker) func() balancer.SubConn {
366 return func() balancer.SubConn {
367 scst, _ := p.Pick(balancer.PickInfo{})
368 return scst.SubConn
369 }
370 }
371
372
373 var ErrTestConstPicker = fmt.Errorf("const picker error")
374
375
376 type TestConstPicker struct {
377 Err error
378 SC balancer.SubConn
379 }
380
381
382 func (tcp *TestConstPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
383 if tcp.Err != nil {
384 return balancer.PickResult{}, tcp.Err
385 }
386 return balancer.PickResult{SubConn: tcp.SC}, nil
387 }
388
View as plain text