1 package f2
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "math/rand"
8 "os"
9 "runtime/debug"
10 "strings"
11 "sync"
12 "testing"
13
14 "slices"
15
16 "edge-infra.dev/pkg/lib/fog"
17 )
18
19 type Framework interface {
20 Setup(...FrameworkFn) Framework
21 Teardown(...FrameworkFn) Framework
22
23 BeforeEachFeature(...FeatureFn) Framework
24 AfterEachFeature(...FeatureFn) Framework
25
26 BeforeEachTest(...FrameworkTestFn) Framework
27 AfterEachTest(...FrameworkTestFn) Framework
28
29
30 Test(*testing.T, ...Feature)
31
32
33 TestInParallel(*testing.T, ...Feature)
34
35
36 Run(*testing.M) int
37
38
39
40
41 WithLabel(k string, v ...string) Framework
42 Component(c ...string) Framework
43 Priviledged(c ...string) Framework
44 WithID(c ...string) Framework
45 Slow() Framework
46 Disruptive() Framework
47 Serial() Framework
48 Flaky() Framework
49 }
50
51 type framework struct {
52 ctx Context
53 actions []action
54 labels map[string]string
55 exts []Extension
56 }
57
58 func New(ctx context.Context, opts ...Option) Framework {
59 options := makeOptions(opts...)
60
61
62 ctx = fog.IntoContext(ctx, fog.New().WithName("f2"))
63
64 f := &framework{
65 actions: []action{
66 {
67 fns: []FrameworkFn{
68 func(ctx Context) (Context, error) {
69
70 ctx.RunID = fmt.Sprintf("%08x", rand.Int31())
71 return ctx, nil
72 },
73 },
74 phase: phaseSetup,
75 },
76 },
77 ctx: Context{Context: ctx},
78 labels: make(map[string]string),
79 exts: options.extensions,
80 }
81
82
83
84 for _, e := range f.exts {
85 if b, ok := e.(FlagBinder); ok {
86 b.BindFlags(Flags)
87 }
88 }
89
90 if err := handleFlags(); err != nil {
91
92 panic(fmt.Errorf("failed to parse flags: %w", err))
93 }
94
95
96
97 for _, e := range f.exts {
98 e.RegisterFns(f)
99 f.ctx = e.IntoContext(f.ctx)
100
101 if l, ok := e.(Labeler); ok {
102 for k, v := range l.Labels() {
103 f.labels[k] = v
104 }
105 }
106 }
107
108 return f
109 }
110
111 func (f *framework) Run(m *testing.M) int {
112 return f.run(m)
113 }
114
115 func (f *framework) run(m TestingMain) (exitCode int) {
116
117 var err error
118 log := fog.FromContext(f.ctx)
119
120
121 defer func() {
122 if rErr := recover(); rErr != nil {
123 exitCode = 1
124 log.V(0).Info("Recovering from panic and running finish actions", "err", rErr, "stacktrace", string(debug.Stack()))
125 }
126 }()
127
128
129 defer func() {
130
131
132
133
134
135
136 for _, fin := range f.getActionsInPhase(phaseTeardown) {
137 if f.ctx, err = fin.run(f.ctx); err != nil {
138 exitCode = 1
139 log.V(0).Error(err, "Cleanup failed", "action", fin.phase)
140 }
141 }
142 }()
143
144 for _, setup := range f.getActionsInPhase(phaseSetup) {
145 f.ctx, err = setup.run(f.ctx)
146 switch {
147 case errors.Is(err, ErrSkip):
148 log.V(1).Info("Skip detected")
149 os.Exit(0)
150 case err != nil:
151 log.V(0).Info("Failure during phase", "phase", setup.phase, "err", err)
152 panic(err)
153 }
154 }
155
156
157 exitCode = m.Run()
158 return
159 }
160
161 func (f *framework) Setup(fns ...FrameworkFn) Framework {
162 f.actions = append(f.actions, action{phase: phaseSetup, fns: fns})
163 return f
164 }
165
166 func (f *framework) Teardown(fns ...FrameworkFn) Framework {
167 f.actions = append(f.actions, action{phase: phaseTeardown, fns: fns})
168 return f
169 }
170
171 func (f *framework) BeforeEachFeature(fns ...FeatureFn) Framework {
172 f.actions = append(f.actions, action{phase: phaseBeforeFeature, featureFns: fns})
173 return f
174 }
175
176 func (f *framework) AfterEachFeature(fns ...FeatureFn) Framework {
177 f.actions = append(f.actions, action{phase: phaseAfterFeature, featureFns: fns})
178 return f
179 }
180
181 func (f *framework) BeforeEachTest(fns ...FrameworkTestFn) Framework {
182 f.actions = append(f.actions, action{phase: phaseBeforeTest, testFns: fns})
183 return f
184 }
185
186 func (f *framework) AfterEachTest(fns ...FrameworkTestFn) Framework {
187 f.actions = append(f.actions, action{phase: phaseAfterTest, testFns: fns})
188 return f
189 }
190
191 func (f *framework) TestInParallel(t *testing.T, features ...Feature) {
192
193
194
195 f.processTests(t, true, features...)
196 }
197
198 func (f *framework) Test(t *testing.T, features ...Feature) {
199 f.processTests(t, false, features...)
200 }
201
202 func (f *framework) processTests(t *testing.T, parallel bool, features ...Feature) {
203 if len(features) == 0 {
204 t.Log("No test features provided, skipping")
205 return
206 }
207
208 f.processTestActions(t, f.getActionsInPhase(phaseBeforeTest))
209
210 var wg sync.WaitGroup
211 for _, feature := range features {
212 fcopy := feature
213 switch parallel {
214 case true:
215 wg.Add(1)
216 go func(w *sync.WaitGroup, feat Feature) {
217 defer w.Done()
218 f.processTestFeature(t, feat)
219 }(&wg, fcopy)
220 case false:
221 f.processTestFeature(t, fcopy)
222 }
223 }
224
225 if parallel {
226 wg.Wait()
227 }
228
229 f.processTestActions(t, f.getActionsInPhase(phaseAfterTest))
230 }
231
232 func (f *framework) processTestActions(t *testing.T, actions []action) {
233 var err error
234 for _, action := range actions {
235 if f.ctx, err = action.runWithT(f.ctx, t); err != nil {
236 t.Fatalf("%s failure: %s", action.phase, err)
237 }
238 }
239 }
240
241 func (f *framework) processFeatureActions(t *testing.T, feat Feature, actions []action) {
242 var err error
243 for _, action := range actions {
244 if f.ctx, err = action.runWithFeature(f.ctx, t, deepCopyFeature(feat)); err != nil {
245 t.Fatalf("%s failure: %s", action.phase, err)
246 }
247 }
248 }
249
250 func (f *framework) processTestFeature(t *testing.T, feat Feature) {
251 f.processFeatureActions(t, feat, f.getActionsInPhase(phaseBeforeFeature))
252
253 f.ctx = f.execFeature(f.ctx, t, feat.Name(), feat)
254
255 f.processFeatureActions(t, feat, f.getActionsInPhase(phaseAfterFeature))
256 }
257
258 func (f *framework) execFeature(ctx Context, t *testing.T, name string, feat Feature) Context {
259 t.Run(name, func(t *testing.T) {
260
261 mergedLabels, skippedLabels := mergeLabels(f.labels, feat.Labels())
262 if len(skippedLabels) > 0 {
263 t.Logf("ignoring the following framework labels because they're duplicated on the feature: %s", strings.Join(skippedLabels[:], ", "))
264 }
265
266 skip, msg := SkipBasedOnLabels(mergedLabels, Labels, SkipLabels)
267 if skip {
268 t.Skipf("skipping: %s", msg)
269 }
270
271 setupSteps := getStepsInPhase(feat, phaseBeforeFeature)
272 ctx = f.execSteps(ctx, t, setupSteps)
273
274 tests := getStepsInPhase(feat, phaseTest)
275 for _, test := range tests {
276 t.Run(test.Name, func(t *testing.T) {
277 ctx = f.execSteps(ctx, t, []Step{test})
278 })
279
280 }
281
282 teardownSteps := getStepsInPhase(feat, phaseAfterFeature)
283 ctx = f.execSteps(ctx, t, teardownSteps)
284 })
285
286 return ctx
287 }
288
289 func (f *framework) execSteps(ctx Context, t *testing.T, steps []Step) Context {
290 for _, s := range steps {
291 ctx = s.Fn(ctx, t)
292 }
293 return ctx
294 }
295
296 func (f *framework) getActionsInPhase(p Phase) []action {
297 if f.actions == nil {
298 return nil
299 }
300
301 result := make([]action, 0)
302 for _, a := range f.actions {
303 if a.phase == p {
304 result = append(result, a)
305 }
306 }
307
308 return result
309 }
310
311
312
313
314
315
316
317
318
319 func (f *framework) WithLabel(key string, values ...string) Framework {
320 if f.labels == nil {
321 f.labels = map[string]string{}
322 }
323 f.labels[key] = commaSepList(values...)
324
325 return f
326 }
327
328
329
330
331 func (f *framework) WithID(feat ...string) Framework {
332 return f.WithLabel("id", feat...)
333 }
334
335
336
337 func (f *framework) Priviledged(p ...string) Framework {
338 return f.WithLabel("priviledged", p...)
339 }
340
341
342
343 func (f *framework) Component(c ...string) Framework {
344 return f.WithLabel("component", c...)
345 }
346
347
348
349 func (f *framework) Slow() Framework {
350 return f.WithLabel("slow", "true")
351 }
352
353
354
355 func (f *framework) Disruptive() Framework {
356 return f.WithLabel("disruptive", "true")
357 }
358
359
360
361 func (f *framework) Serial() Framework {
362 return f.WithLabel("serial", "true")
363 }
364
365
366 func (f *framework) Flaky() Framework {
367 return f.WithLabel("flaky", "true")
368 }
369
370
371
372 func SkipBasedOnLabels(test, labels, skip map[string]string) (bool, string) {
373
374 if len(test) == 0 && len(labels) != 0 {
375 return true, "test without labels skipped due to -labels flag being provided"
376 }
377
378
379 if skip != nil && len(test) != 0 {
380 for k := range test {
381 if v, ok := skip[k]; ok && checkLabelList(test[k], v) {
382 return true, fmt.Sprintf("test with label %s=%s skipped due to -skip-labels flag", k, test[k])
383 }
384 }
385 }
386
387 if len(labels) != 0 {
388 matches := map[string]bool{}
389
390 for k := range labels {
391 matches[k] = false
392 if v, ok := test[k]; ok && checkLabelList(v, labels[k]) {
393 matches[k] = true
394 }
395 }
396 skip := false
397 for _, match := range matches {
398 if !match {
399 skip = true
400 }
401 }
402 if skip {
403 return true, fmt.Sprintf("test without labels skipped due to -labels flag: %v", labels)
404 }
405 }
406
407 return false, ""
408 }
409
410
411 func mergeLabels(framework, feat map[string]string) (map[string]string, []string) {
412
413 if len(feat) == 0 {
414 return framework, nil
415 }
416
417 skipped := []string{}
418 for k, v := range framework {
419
420 if _, ok := feat[k]; ok {
421 skipped = append(skipped, fmt.Sprintf("%s = %s", k, v))
422 continue
423 }
424 feat[k] = v
425 }
426
427 return feat, skipped
428 }
429
430
431 func checkLabelList(testLabels, inputLabels string) bool {
432 test := strings.Split(testLabels, ",")
433 for _, val := range strings.Split(inputLabels, ",") {
434 if slices.Contains(test, val) {
435 return true
436 }
437 }
438 return false
439 }
440
View as plain text