1 package test
2
3 import (
4 "bytes"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "reflect"
9 "strings"
10 "testing"
11 "time"
12
13 "github.com/prometheus/client_golang/prometheus"
14 io_prometheus_client "github.com/prometheus/client_model/go"
15 )
16
17
18 func Assert(t *testing.T, result bool, message string) {
19 t.Helper()
20 if !result {
21 t.Fatal(message)
22 }
23 }
24
25
26
27 func AssertNil(t *testing.T, obj interface{}, message string) {
28 t.Helper()
29 if obj != nil {
30 t.Fatal(message)
31 }
32 }
33
34
35
36
37
38 func AssertNotNil(t *testing.T, obj interface{}, message string) {
39 t.Helper()
40 if obj == nil {
41 t.Fatal(message)
42 }
43 switch reflect.TypeOf(obj).Kind() {
44
45 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
46 if reflect.ValueOf(obj).IsNil() {
47 t.Fatal(message)
48 }
49 }
50 }
51
52
53
54 func AssertBoxedNil(t *testing.T, obj interface{}, message string) {
55 t.Helper()
56 typ := reflect.TypeOf(obj).Kind()
57 switch typ {
58
59 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
60 if !reflect.ValueOf(obj).IsNil() {
61 t.Fatal(message)
62 }
63 default:
64 t.Fatalf("Cannot check type \"%s\". Needs to be of type chan, func, interface, map, pointer, or slice.", typ)
65 }
66 }
67
68
69 func AssertNotError(t *testing.T, err error, message string) {
70 t.Helper()
71 if err != nil {
72 t.Fatalf("%s: %s", message, err)
73 }
74 }
75
76
77 func AssertError(t *testing.T, err error, message string) {
78 t.Helper()
79 if err == nil {
80 t.Fatalf("%s: expected error but received none", message)
81 }
82 }
83
84
85
86 func AssertErrorWraps(t *testing.T, err error, target interface{}) {
87 t.Helper()
88 if !errors.As(err, target) {
89 t.Fatalf("error does not wrap an error of the expected type: %q !> %+T", err.Error(), target)
90 }
91 }
92
93
94 func AssertErrorIs(t *testing.T, err error, target error) {
95 t.Helper()
96
97 if err == nil {
98 t.Fatal("err was unexpectedly nil and should not have been")
99 }
100
101 if !errors.Is(err, target) {
102 t.Fatalf("error does not wrap expected error: %q !> %q", err.Error(), target.Error())
103 }
104 }
105
106
107 func AssertEquals(t *testing.T, one interface{}, two interface{}) {
108 t.Helper()
109 if reflect.TypeOf(one) != reflect.TypeOf(two) {
110 t.Fatalf("cannot test equality of different types: %T != %T", one, two)
111 }
112 if one != two {
113 t.Fatalf("%#v != %#v", one, two)
114 }
115 }
116
117
118 func AssertDeepEquals(t *testing.T, one interface{}, two interface{}) {
119 t.Helper()
120 if !reflect.DeepEqual(one, two) {
121 t.Fatalf("[%#v] !(deep)= [%#v]", one, two)
122 }
123 }
124
125
126
127 func AssertMarshaledEquals(t *testing.T, one interface{}, two interface{}) {
128 t.Helper()
129 oneJSON, err := json.Marshal(one)
130 AssertNotError(t, err, "Could not marshal 1st argument")
131 twoJSON, err := json.Marshal(two)
132 AssertNotError(t, err, "Could not marshal 2nd argument")
133
134 if !bytes.Equal(oneJSON, twoJSON) {
135 t.Fatalf("[%s] !(json)= [%s]", oneJSON, twoJSON)
136 }
137 }
138
139
140
141
142 func AssertUnmarshaledEquals(t *testing.T, got, expected string) {
143 t.Helper()
144 var gotMap, expectedMap map[string]interface{}
145 err := json.Unmarshal([]byte(got), &gotMap)
146 AssertNotError(t, err, "Could not unmarshal 'got'")
147 err = json.Unmarshal([]byte(expected), &expectedMap)
148 AssertNotError(t, err, "Could not unmarshal 'expected'")
149 if len(gotMap) != len(expectedMap) {
150 t.Errorf("Expected had %d keys, got had %d", len(gotMap), len(expectedMap))
151 }
152 for k, v := range expectedMap {
153 if !reflect.DeepEqual(v, gotMap[k]) {
154 t.Errorf("Field %q: Expected \"%v\", got \"%v\"", k, v, gotMap[k])
155 }
156 }
157 }
158
159
160
161 func AssertNotEquals(t *testing.T, one interface{}, two interface{}) {
162 t.Helper()
163 if one == two {
164 t.Fatalf("%#v == %#v", one, two)
165 }
166 }
167
168
169 func AssertByteEquals(t *testing.T, one []byte, two []byte) {
170 t.Helper()
171 if !bytes.Equal(one, two) {
172 t.Fatalf("Byte [%s] != [%s]",
173 base64.StdEncoding.EncodeToString(one),
174 base64.StdEncoding.EncodeToString(two))
175 }
176 }
177
178
179 func AssertContains(t *testing.T, haystack string, needle string) {
180 t.Helper()
181 if !strings.Contains(haystack, needle) {
182 t.Fatalf("String [%s] does not contain [%s]", haystack, needle)
183 }
184 }
185
186
187 func AssertNotContains(t *testing.T, haystack string, needle string) {
188 t.Helper()
189 if strings.Contains(haystack, needle) {
190 t.Fatalf("String [%s] contains [%s]", haystack, needle)
191 }
192 }
193
194
195 func AssertSliceContains[T comparable](t *testing.T, haystack []T, needle T) {
196 t.Helper()
197 for _, item := range haystack {
198 if item == needle {
199 return
200 }
201 }
202 t.Fatalf("Slice %v does not contain %v", haystack, needle)
203 }
204
205
206
207
208
209
210
211
212
213
214 func AssertMetricWithLabelsEquals(t *testing.T, c prometheus.Collector, l prometheus.Labels, expected float64) {
215 t.Helper()
216 ch := make(chan prometheus.Metric)
217 done := make(chan struct{})
218 go func() {
219 c.Collect(ch)
220 close(done)
221 }()
222 var total float64
223 timeout := time.After(time.Second)
224 loop:
225 for {
226 metric:
227 select {
228 case <-timeout:
229 t.Fatal("timed out collecting metrics")
230 case <-done:
231 break loop
232 case m := <-ch:
233 var iom io_prometheus_client.Metric
234 _ = m.Write(&iom)
235 for _, lp := range iom.Label {
236
237
238 val, ok := l[lp.GetName()]
239 if ok && lp.GetValue() != val {
240 break metric
241 }
242 }
243
244
245 total += iom.Counter.GetValue()
246 total += iom.Gauge.GetValue()
247 total += float64(iom.Histogram.GetSampleCount())
248 }
249 }
250 AssertEquals(t, total, expected)
251 }
252
253
254
255
256
257
258 func AssertImplementsGRPCServer(t *testing.T, impl any, unimpl any) {
259
260
261
262 unimplType := reflect.TypeOf(unimpl)
263
264
265
266
267
268
269
270 implType := reflect.TypeOf(impl).Elem()
271
272
273
274
275
276
277
278
279 for i := 0; i < implType.NumMethod(); i++ {
280 method := implType.Method(i)
281 _, ok := unimplType.MethodByName(method.Name)
282 if ok {
283
284
285 t.Errorf("%s does not implement method %s", implType.Name(), method.Name)
286 } else {
287
288
289 t.Errorf("%s.%s has non-pointer receiver", implType.Name(), method.Name)
290 }
291 }
292 }
293
View as plain text