1 package lb
2
3 import (
4 "context"
5 "reflect"
6 "sync"
7 "sync/atomic"
8 "testing"
9 "time"
10
11 "github.com/go-kit/kit/endpoint"
12 "github.com/go-kit/kit/sd"
13 )
14
15 func TestRoundRobin(t *testing.T) {
16 var (
17 counts = []int{0, 0, 0}
18 endpoints = []endpoint.Endpoint{
19 func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
20 func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
21 func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
22 }
23 )
24
25 endpointer := sd.FixedEndpointer(endpoints)
26 balancer := NewRoundRobin(endpointer)
27
28 for i, want := range [][]int{
29 {1, 0, 0},
30 {1, 1, 0},
31 {1, 1, 1},
32 {2, 1, 1},
33 {2, 2, 1},
34 {2, 2, 2},
35 {3, 2, 2},
36 } {
37 endpoint, err := balancer.Endpoint()
38 if err != nil {
39 t.Fatal(err)
40 }
41 endpoint(context.Background(), struct{}{})
42 if have := counts; !reflect.DeepEqual(want, have) {
43 t.Fatalf("%d: want %v, have %v", i, want, have)
44 }
45 }
46 }
47
48 func TestRoundRobinNoEndpoints(t *testing.T) {
49 endpointer := sd.FixedEndpointer{}
50 balancer := NewRoundRobin(endpointer)
51 _, err := balancer.Endpoint()
52 if want, have := ErrNoEndpoints, err; want != have {
53 t.Errorf("want %v, have %v", want, have)
54 }
55 }
56
57 func TestRoundRobinNoRace(t *testing.T) {
58 balancer := NewRoundRobin(sd.FixedEndpointer([]endpoint.Endpoint{
59 endpoint.Nop,
60 endpoint.Nop,
61 endpoint.Nop,
62 endpoint.Nop,
63 endpoint.Nop,
64 }))
65
66 var (
67 n = 100
68 done = make(chan struct{})
69 wg sync.WaitGroup
70 count uint64
71 )
72
73 wg.Add(n)
74
75 for i := 0; i < n; i++ {
76 go func() {
77 defer wg.Done()
78 for {
79 select {
80 case <-done:
81 return
82 default:
83 _, _ = balancer.Endpoint()
84 atomic.AddUint64(&count, 1)
85 }
86 }
87 }()
88 }
89
90 time.Sleep(time.Second)
91 close(done)
92 wg.Wait()
93
94 t.Logf("made %d calls", atomic.LoadUint64(&count))
95 }
96
View as plain text