1
18
19 package xdsclient
20
21 import (
22 "sync"
23 "sync/atomic"
24 "testing"
25 )
26
27 const testService = "test-service-name"
28
29 type counterTest struct {
30 name string
31 maxRequests uint32
32 numRequests uint32
33 expectedSuccesses uint32
34 expectedErrors uint32
35 }
36
37 var tests = []counterTest{
38 {
39 name: "does-not-exceed-max-requests",
40 maxRequests: 1024,
41 numRequests: 1024,
42 expectedSuccesses: 1024,
43 expectedErrors: 0,
44 },
45 {
46 name: "exceeds-max-requests",
47 maxRequests: 32,
48 numRequests: 64,
49 expectedSuccesses: 32,
50 expectedErrors: 32,
51 },
52 }
53
54 func resetClusterRequestsCounter() {
55 src = &clusterRequestsCounter{
56 clusters: make(map[clusterNameAndServiceName]*ClusterRequestsCounter),
57 }
58 }
59
60 func testCounter(t *testing.T, test counterTest) {
61 requestsStarted := make(chan struct{})
62 requestsSent := sync.WaitGroup{}
63 requestsSent.Add(int(test.numRequests))
64 requestsDone := sync.WaitGroup{}
65 requestsDone.Add(int(test.numRequests))
66 var lastError atomic.Value
67 var successes, errors uint32
68 for i := 0; i < int(test.numRequests); i++ {
69 go func() {
70 counter := GetClusterRequestsCounter(test.name, testService)
71 defer requestsDone.Done()
72 err := counter.StartRequest(test.maxRequests)
73 if err == nil {
74 atomic.AddUint32(&successes, 1)
75 } else {
76 atomic.AddUint32(&errors, 1)
77 lastError.Store(err)
78 }
79 requestsSent.Done()
80 if err == nil {
81 <-requestsStarted
82 counter.EndRequest()
83 }
84 }()
85 }
86 requestsSent.Wait()
87 close(requestsStarted)
88 requestsDone.Wait()
89 loadedError := lastError.Load()
90 if test.expectedErrors > 0 && loadedError == nil {
91 t.Error("no error when error expected")
92 }
93 if test.expectedErrors == 0 && loadedError != nil {
94 t.Errorf("error starting request: %v", loadedError.(error))
95 }
96
97
98
99
100 if successes < test.expectedSuccesses || errors > test.expectedErrors {
101 t.Errorf("unexpected number of (successes, errors), expected (%v, %v), encountered (%v, %v)", test.expectedSuccesses, test.expectedErrors, successes, errors)
102 }
103 }
104
105 func (s) TestRequestsCounter(t *testing.T) {
106 defer resetClusterRequestsCounter()
107 for _, test := range tests {
108 t.Run(test.name, func(t *testing.T) {
109 testCounter(t, test)
110 })
111 }
112 }
113
114 func (s) TestGetClusterRequestsCounter(t *testing.T) {
115 defer resetClusterRequestsCounter()
116 for _, test := range tests {
117 counterA := GetClusterRequestsCounter(test.name, testService)
118 counterB := GetClusterRequestsCounter(test.name, testService)
119 if counterA != counterB {
120 t.Errorf("counter %v %v != counter %v %v", counterA, *counterA, counterB, *counterB)
121 }
122 }
123 }
124
125 func startRequests(t *testing.T, n uint32, max uint32, counter *ClusterRequestsCounter) {
126 for i := uint32(0); i < n; i++ {
127 if err := counter.StartRequest(max); err != nil {
128 t.Fatalf("error starting initial request: %v", err)
129 }
130 }
131 }
132
133 func (s) TestSetMaxRequestsIncreased(t *testing.T) {
134 defer resetClusterRequestsCounter()
135 const clusterName string = "set-max-requests-increased"
136 var initialMax uint32 = 16
137
138 counter := GetClusterRequestsCounter(clusterName, testService)
139 startRequests(t, initialMax, initialMax, counter)
140 if err := counter.StartRequest(initialMax); err == nil {
141 t.Fatal("unexpected success on start request after max met")
142 }
143
144 newMax := initialMax + 1
145 if err := counter.StartRequest(newMax); err != nil {
146 t.Fatalf("unexpected error on start request after max increased: %v", err)
147 }
148 }
149
150 func (s) TestSetMaxRequestsDecreased(t *testing.T) {
151 defer resetClusterRequestsCounter()
152 const clusterName string = "set-max-requests-decreased"
153 var initialMax uint32 = 16
154
155 counter := GetClusterRequestsCounter(clusterName, testService)
156 startRequests(t, initialMax-1, initialMax, counter)
157
158 newMax := initialMax - 1
159 if err := counter.StartRequest(newMax); err == nil {
160 t.Fatalf("unexpected success on start request after max decreased: %v", err)
161 }
162 }
163
View as plain text