1 package internal
2
3 import (
4 "context"
5 "fmt"
6 "reflect"
7 "sort"
8 "sync"
9 "time"
10
11 "github.com/onsi/ginkgo/v2/types"
12 )
13
14 var _global_node_id_counter = uint(0)
15 var _global_id_mutex = &sync.Mutex{}
16
17 func UniqueNodeID() uint {
18
19
20 _global_id_mutex.Lock()
21 defer _global_id_mutex.Unlock()
22 _global_node_id_counter += 1
23 return _global_node_id_counter
24 }
25
26 type Node struct {
27 ID uint
28 NodeType types.NodeType
29
30 Text string
31 Body func(SpecContext)
32 CodeLocation types.CodeLocation
33 NestingLevel int
34 HasContext bool
35
36 SynchronizedBeforeSuiteProc1Body func(SpecContext) []byte
37 SynchronizedBeforeSuiteProc1BodyHasContext bool
38 SynchronizedBeforeSuiteAllProcsBody func(SpecContext, []byte)
39 SynchronizedBeforeSuiteAllProcsBodyHasContext bool
40
41 SynchronizedAfterSuiteAllProcsBody func(SpecContext)
42 SynchronizedAfterSuiteAllProcsBodyHasContext bool
43 SynchronizedAfterSuiteProc1Body func(SpecContext)
44 SynchronizedAfterSuiteProc1BodyHasContext bool
45
46 ReportEachBody func(SpecContext, types.SpecReport)
47 ReportSuiteBody func(SpecContext, types.Report)
48
49 MarkedFocus bool
50 MarkedPending bool
51 MarkedSerial bool
52 MarkedOrdered bool
53 MarkedContinueOnFailure bool
54 MarkedOncePerOrdered bool
55 FlakeAttempts int
56 MustPassRepeatedly int
57 Labels Labels
58 PollProgressAfter time.Duration
59 PollProgressInterval time.Duration
60 NodeTimeout time.Duration
61 SpecTimeout time.Duration
62 GracePeriod time.Duration
63
64 NodeIDWhereCleanupWasGenerated uint
65 }
66
67
68 type focusType bool
69 type pendingType bool
70 type serialType bool
71 type orderedType bool
72 type continueOnFailureType bool
73 type honorsOrderedType bool
74 type suppressProgressReporting bool
75
76 const Focus = focusType(true)
77 const Pending = pendingType(true)
78 const Serial = serialType(true)
79 const Ordered = orderedType(true)
80 const ContinueOnFailure = continueOnFailureType(true)
81 const OncePerOrdered = honorsOrderedType(true)
82 const SuppressProgressReporting = suppressProgressReporting(true)
83
84 type FlakeAttempts uint
85 type MustPassRepeatedly uint
86 type Offset uint
87 type Done chan<- interface{}
88 type Labels []string
89 type PollProgressInterval time.Duration
90 type PollProgressAfter time.Duration
91 type NodeTimeout time.Duration
92 type SpecTimeout time.Duration
93 type GracePeriod time.Duration
94
95 func (l Labels) MatchesLabelFilter(query string) bool {
96 return types.MustParseLabelFilter(query)(l)
97 }
98
99 func UnionOfLabels(labels ...Labels) Labels {
100 out := Labels{}
101 seen := map[string]bool{}
102 for _, labelSet := range labels {
103 for _, label := range labelSet {
104 if !seen[label] {
105 seen[label] = true
106 out = append(out, label)
107 }
108 }
109 }
110 return out
111 }
112
113 func PartitionDecorations(args ...interface{}) ([]interface{}, []interface{}) {
114 decorations := []interface{}{}
115 remainingArgs := []interface{}{}
116 for _, arg := range args {
117 if isDecoration(arg) {
118 decorations = append(decorations, arg)
119 } else {
120 remainingArgs = append(remainingArgs, arg)
121 }
122 }
123 return decorations, remainingArgs
124 }
125
126 func isDecoration(arg interface{}) bool {
127 switch t := reflect.TypeOf(arg); {
128 case t == nil:
129 return false
130 case t == reflect.TypeOf(Offset(0)):
131 return true
132 case t == reflect.TypeOf(types.CodeLocation{}):
133 return true
134 case t == reflect.TypeOf(Focus):
135 return true
136 case t == reflect.TypeOf(Pending):
137 return true
138 case t == reflect.TypeOf(Serial):
139 return true
140 case t == reflect.TypeOf(Ordered):
141 return true
142 case t == reflect.TypeOf(ContinueOnFailure):
143 return true
144 case t == reflect.TypeOf(OncePerOrdered):
145 return true
146 case t == reflect.TypeOf(SuppressProgressReporting):
147 return true
148 case t == reflect.TypeOf(FlakeAttempts(0)):
149 return true
150 case t == reflect.TypeOf(MustPassRepeatedly(0)):
151 return true
152 case t == reflect.TypeOf(Labels{}):
153 return true
154 case t == reflect.TypeOf(PollProgressInterval(0)):
155 return true
156 case t == reflect.TypeOf(PollProgressAfter(0)):
157 return true
158 case t == reflect.TypeOf(NodeTimeout(0)):
159 return true
160 case t == reflect.TypeOf(SpecTimeout(0)):
161 return true
162 case t == reflect.TypeOf(GracePeriod(0)):
163 return true
164 case t.Kind() == reflect.Slice && isSliceOfDecorations(arg):
165 return true
166 default:
167 return false
168 }
169 }
170
171 func isSliceOfDecorations(slice interface{}) bool {
172 vSlice := reflect.ValueOf(slice)
173 if vSlice.Len() == 0 {
174 return false
175 }
176 for i := 0; i < vSlice.Len(); i++ {
177 if !isDecoration(vSlice.Index(i).Interface()) {
178 return false
179 }
180 }
181 return true
182 }
183
184 var contextType = reflect.TypeOf(new(context.Context)).Elem()
185 var specContextType = reflect.TypeOf(new(SpecContext)).Elem()
186
187 func NewNode(deprecationTracker *types.DeprecationTracker, nodeType types.NodeType, text string, args ...interface{}) (Node, []error) {
188 baseOffset := 2
189 node := Node{
190 ID: UniqueNodeID(),
191 NodeType: nodeType,
192 Text: text,
193 Labels: Labels{},
194 CodeLocation: types.NewCodeLocation(baseOffset),
195 NestingLevel: -1,
196 PollProgressAfter: -1,
197 PollProgressInterval: -1,
198 GracePeriod: -1,
199 }
200
201 errors := []error{}
202 appendError := func(err error) {
203 if err != nil {
204 errors = append(errors, err)
205 }
206 }
207
208 args = unrollInterfaceSlice(args)
209
210 remainingArgs := []interface{}{}
211
212 for _, arg := range args {
213 switch v := arg.(type) {
214 case Offset:
215 node.CodeLocation = types.NewCodeLocation(baseOffset + int(v))
216 case types.CodeLocation:
217 node.CodeLocation = v
218 default:
219 remainingArgs = append(remainingArgs, arg)
220 }
221 }
222
223 labelsSeen := map[string]bool{}
224 trackedFunctionError := false
225 args = remainingArgs
226 remainingArgs = []interface{}{}
227
228 for _, arg := range args {
229 switch t := reflect.TypeOf(arg); {
230 case t == reflect.TypeOf(float64(0)):
231 break
232 case t == reflect.TypeOf(Focus):
233 node.MarkedFocus = bool(arg.(focusType))
234 if !nodeType.Is(types.NodeTypesForContainerAndIt) {
235 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Focus"))
236 }
237 case t == reflect.TypeOf(Pending):
238 node.MarkedPending = bool(arg.(pendingType))
239 if !nodeType.Is(types.NodeTypesForContainerAndIt) {
240 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Pending"))
241 }
242 case t == reflect.TypeOf(Serial):
243 node.MarkedSerial = bool(arg.(serialType))
244 if !nodeType.Is(types.NodeTypesForContainerAndIt) {
245 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Serial"))
246 }
247 case t == reflect.TypeOf(Ordered):
248 node.MarkedOrdered = bool(arg.(orderedType))
249 if !nodeType.Is(types.NodeTypeContainer) {
250 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Ordered"))
251 }
252 case t == reflect.TypeOf(ContinueOnFailure):
253 node.MarkedContinueOnFailure = bool(arg.(continueOnFailureType))
254 if !nodeType.Is(types.NodeTypeContainer) {
255 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "ContinueOnFailure"))
256 }
257 case t == reflect.TypeOf(OncePerOrdered):
258 node.MarkedOncePerOrdered = bool(arg.(honorsOrderedType))
259 if !nodeType.Is(types.NodeTypeBeforeEach | types.NodeTypeJustBeforeEach | types.NodeTypeAfterEach | types.NodeTypeJustAfterEach) {
260 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "OncePerOrdered"))
261 }
262 case t == reflect.TypeOf(SuppressProgressReporting):
263 deprecationTracker.TrackDeprecation(types.Deprecations.SuppressProgressReporting())
264 case t == reflect.TypeOf(FlakeAttempts(0)):
265 node.FlakeAttempts = int(arg.(FlakeAttempts))
266 if !nodeType.Is(types.NodeTypesForContainerAndIt) {
267 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "FlakeAttempts"))
268 }
269 case t == reflect.TypeOf(MustPassRepeatedly(0)):
270 node.MustPassRepeatedly = int(arg.(MustPassRepeatedly))
271 if !nodeType.Is(types.NodeTypesForContainerAndIt) {
272 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "MustPassRepeatedly"))
273 }
274 case t == reflect.TypeOf(PollProgressAfter(0)):
275 node.PollProgressAfter = time.Duration(arg.(PollProgressAfter))
276 if nodeType.Is(types.NodeTypeContainer) {
277 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "PollProgressAfter"))
278 }
279 case t == reflect.TypeOf(PollProgressInterval(0)):
280 node.PollProgressInterval = time.Duration(arg.(PollProgressInterval))
281 if nodeType.Is(types.NodeTypeContainer) {
282 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "PollProgressInterval"))
283 }
284 case t == reflect.TypeOf(NodeTimeout(0)):
285 node.NodeTimeout = time.Duration(arg.(NodeTimeout))
286 if nodeType.Is(types.NodeTypeContainer) {
287 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "NodeTimeout"))
288 }
289 case t == reflect.TypeOf(SpecTimeout(0)):
290 node.SpecTimeout = time.Duration(arg.(SpecTimeout))
291 if !nodeType.Is(types.NodeTypeIt) {
292 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "SpecTimeout"))
293 }
294 case t == reflect.TypeOf(GracePeriod(0)):
295 node.GracePeriod = time.Duration(arg.(GracePeriod))
296 if nodeType.Is(types.NodeTypeContainer) {
297 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "GracePeriod"))
298 }
299 case t == reflect.TypeOf(Labels{}):
300 if !nodeType.Is(types.NodeTypesForContainerAndIt) {
301 appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Label"))
302 }
303 for _, label := range arg.(Labels) {
304 if !labelsSeen[label] {
305 labelsSeen[label] = true
306 label, err := types.ValidateAndCleanupLabel(label, node.CodeLocation)
307 node.Labels = append(node.Labels, label)
308 appendError(err)
309 }
310 }
311 case t.Kind() == reflect.Func:
312 if nodeType.Is(types.NodeTypeContainer) {
313 if node.Body != nil {
314 appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
315 trackedFunctionError = true
316 break
317 }
318 if t.NumOut() > 0 || t.NumIn() > 0 {
319 appendError(types.GinkgoErrors.InvalidBodyTypeForContainer(t, node.CodeLocation, nodeType))
320 trackedFunctionError = true
321 break
322 }
323 body := arg.(func())
324 node.Body = func(SpecContext) { body() }
325 } else if nodeType.Is(types.NodeTypeReportBeforeEach | types.NodeTypeReportAfterEach) {
326 if node.ReportEachBody == nil {
327 if fn, ok := arg.(func(types.SpecReport)); ok {
328 node.ReportEachBody = func(_ SpecContext, r types.SpecReport) { fn(r) }
329 } else {
330 node.ReportEachBody = arg.(func(SpecContext, types.SpecReport))
331 node.HasContext = true
332 }
333 } else {
334 appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
335 trackedFunctionError = true
336 break
337 }
338 } else if nodeType.Is(types.NodeTypeReportBeforeSuite | types.NodeTypeReportAfterSuite) {
339 if node.ReportSuiteBody == nil {
340 if fn, ok := arg.(func(types.Report)); ok {
341 node.ReportSuiteBody = func(_ SpecContext, r types.Report) { fn(r) }
342 } else {
343 node.ReportSuiteBody = arg.(func(SpecContext, types.Report))
344 node.HasContext = true
345 }
346 } else {
347 appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
348 trackedFunctionError = true
349 break
350 }
351 } else if nodeType.Is(types.NodeTypeSynchronizedBeforeSuite) {
352 if node.SynchronizedBeforeSuiteProc1Body != nil && node.SynchronizedBeforeSuiteAllProcsBody != nil {
353 appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
354 trackedFunctionError = true
355 break
356 }
357 if node.SynchronizedBeforeSuiteProc1Body == nil {
358 body, hasContext := extractSynchronizedBeforeSuiteProc1Body(arg)
359 if body == nil {
360 appendError(types.GinkgoErrors.InvalidBodyTypeForSynchronizedBeforeSuiteProc1(t, node.CodeLocation))
361 trackedFunctionError = true
362 }
363 node.SynchronizedBeforeSuiteProc1Body, node.SynchronizedBeforeSuiteProc1BodyHasContext = body, hasContext
364 } else if node.SynchronizedBeforeSuiteAllProcsBody == nil {
365 body, hasContext := extractSynchronizedBeforeSuiteAllProcsBody(arg)
366 if body == nil {
367 appendError(types.GinkgoErrors.InvalidBodyTypeForSynchronizedBeforeSuiteAllProcs(t, node.CodeLocation))
368 trackedFunctionError = true
369 }
370 node.SynchronizedBeforeSuiteAllProcsBody, node.SynchronizedBeforeSuiteAllProcsBodyHasContext = body, hasContext
371 }
372 } else if nodeType.Is(types.NodeTypeSynchronizedAfterSuite) {
373 if node.SynchronizedAfterSuiteAllProcsBody != nil && node.SynchronizedAfterSuiteProc1Body != nil {
374 appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
375 trackedFunctionError = true
376 break
377 }
378 body, hasContext := extractBodyFunction(deprecationTracker, node.CodeLocation, arg)
379 if body == nil {
380 appendError(types.GinkgoErrors.InvalidBodyType(t, node.CodeLocation, nodeType))
381 trackedFunctionError = true
382 break
383 }
384 if node.SynchronizedAfterSuiteAllProcsBody == nil {
385 node.SynchronizedAfterSuiteAllProcsBody, node.SynchronizedAfterSuiteAllProcsBodyHasContext = body, hasContext
386 } else if node.SynchronizedAfterSuiteProc1Body == nil {
387 node.SynchronizedAfterSuiteProc1Body, node.SynchronizedAfterSuiteProc1BodyHasContext = body, hasContext
388 }
389 } else {
390 if node.Body != nil {
391 appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
392 trackedFunctionError = true
393 break
394 }
395 node.Body, node.HasContext = extractBodyFunction(deprecationTracker, node.CodeLocation, arg)
396 if node.Body == nil {
397 appendError(types.GinkgoErrors.InvalidBodyType(t, node.CodeLocation, nodeType))
398 trackedFunctionError = true
399 break
400 }
401 }
402 default:
403 remainingArgs = append(remainingArgs, arg)
404 }
405 }
406
407
408 if node.MarkedPending && node.MarkedFocus {
409 appendError(types.GinkgoErrors.InvalidDeclarationOfFocusedAndPending(node.CodeLocation, nodeType))
410 }
411
412 if node.MarkedContinueOnFailure && !node.MarkedOrdered {
413 appendError(types.GinkgoErrors.InvalidContinueOnFailureDecoration(node.CodeLocation))
414 }
415
416 hasContext := node.HasContext || node.SynchronizedAfterSuiteProc1BodyHasContext || node.SynchronizedAfterSuiteAllProcsBodyHasContext || node.SynchronizedBeforeSuiteProc1BodyHasContext || node.SynchronizedBeforeSuiteAllProcsBodyHasContext
417
418 if !hasContext && (node.NodeTimeout > 0 || node.SpecTimeout > 0 || node.GracePeriod > 0) && len(errors) == 0 {
419 appendError(types.GinkgoErrors.InvalidTimeoutOrGracePeriodForNonContextNode(node.CodeLocation, nodeType))
420 }
421
422 if !node.NodeType.Is(types.NodeTypeReportBeforeEach|types.NodeTypeReportAfterEach|types.NodeTypeSynchronizedBeforeSuite|types.NodeTypeSynchronizedAfterSuite|types.NodeTypeReportBeforeSuite|types.NodeTypeReportAfterSuite) && node.Body == nil && !node.MarkedPending && !trackedFunctionError {
423 appendError(types.GinkgoErrors.MissingBodyFunction(node.CodeLocation, nodeType))
424 }
425
426 if node.NodeType.Is(types.NodeTypeSynchronizedBeforeSuite) && !trackedFunctionError && (node.SynchronizedBeforeSuiteProc1Body == nil || node.SynchronizedBeforeSuiteAllProcsBody == nil) {
427 appendError(types.GinkgoErrors.MissingBodyFunction(node.CodeLocation, nodeType))
428 }
429
430 if node.NodeType.Is(types.NodeTypeSynchronizedAfterSuite) && !trackedFunctionError && (node.SynchronizedAfterSuiteProc1Body == nil || node.SynchronizedAfterSuiteAllProcsBody == nil) {
431 appendError(types.GinkgoErrors.MissingBodyFunction(node.CodeLocation, nodeType))
432 }
433
434 for _, arg := range remainingArgs {
435 appendError(types.GinkgoErrors.UnknownDecorator(node.CodeLocation, nodeType, arg))
436 }
437
438 if node.FlakeAttempts > 0 && node.MustPassRepeatedly > 0 {
439 appendError(types.GinkgoErrors.InvalidDeclarationOfFlakeAttemptsAndMustPassRepeatedly(node.CodeLocation, nodeType))
440 }
441
442 if len(errors) > 0 {
443 return Node{}, errors
444 }
445
446 return node, errors
447 }
448
449 var doneType = reflect.TypeOf(make(Done))
450
451 func extractBodyFunction(deprecationTracker *types.DeprecationTracker, cl types.CodeLocation, arg interface{}) (func(SpecContext), bool) {
452 t := reflect.TypeOf(arg)
453 if t.NumOut() > 0 || t.NumIn() > 1 {
454 return nil, false
455 }
456 if t.NumIn() == 1 {
457 if t.In(0) == doneType {
458 deprecationTracker.TrackDeprecation(types.Deprecations.Async(), cl)
459 deprecatedAsyncBody := arg.(func(Done))
460 return func(SpecContext) { deprecatedAsyncBody(make(Done)) }, false
461 } else if t.In(0).Implements(specContextType) {
462 return arg.(func(SpecContext)), true
463 } else if t.In(0).Implements(contextType) {
464 body := arg.(func(context.Context))
465 return func(c SpecContext) { body(c) }, true
466 }
467
468 return nil, false
469 }
470
471 body := arg.(func())
472 return func(SpecContext) { body() }, false
473 }
474
475 var byteType = reflect.TypeOf([]byte{})
476
477 func extractSynchronizedBeforeSuiteProc1Body(arg interface{}) (func(SpecContext) []byte, bool) {
478 t := reflect.TypeOf(arg)
479 v := reflect.ValueOf(arg)
480
481 if t.NumOut() > 1 || t.NumIn() > 1 {
482 return nil, false
483 } else if t.NumOut() == 1 && t.Out(0) != byteType {
484 return nil, false
485 } else if t.NumIn() == 1 && !t.In(0).Implements(contextType) {
486 return nil, false
487 }
488 hasContext := t.NumIn() == 1
489
490 return func(c SpecContext) []byte {
491 var out []reflect.Value
492 if hasContext {
493 out = v.Call([]reflect.Value{reflect.ValueOf(c)})
494 } else {
495 out = v.Call([]reflect.Value{})
496 }
497 if len(out) == 1 {
498 return (out[0].Interface()).([]byte)
499 } else {
500 return []byte{}
501 }
502 }, hasContext
503 }
504
505 func extractSynchronizedBeforeSuiteAllProcsBody(arg interface{}) (func(SpecContext, []byte), bool) {
506 t := reflect.TypeOf(arg)
507 v := reflect.ValueOf(arg)
508 hasContext, hasByte := false, false
509
510 if t.NumOut() > 0 || t.NumIn() > 2 {
511 return nil, false
512 } else if t.NumIn() == 2 && t.In(0).Implements(contextType) && t.In(1) == byteType {
513 hasContext, hasByte = true, true
514 } else if t.NumIn() == 1 && t.In(0).Implements(contextType) {
515 hasContext = true
516 } else if t.NumIn() == 1 && t.In(0) == byteType {
517 hasByte = true
518 } else if t.NumIn() != 0 {
519 return nil, false
520 }
521
522 return func(c SpecContext, b []byte) {
523 in := []reflect.Value{}
524 if hasContext {
525 in = append(in, reflect.ValueOf(c))
526 }
527 if hasByte {
528 in = append(in, reflect.ValueOf(b))
529 }
530 v.Call(in)
531 }, hasContext
532 }
533
534 var errInterface = reflect.TypeOf((*error)(nil)).Elem()
535
536 func NewCleanupNode(deprecationTracker *types.DeprecationTracker, fail func(string, types.CodeLocation), args ...interface{}) (Node, []error) {
537 decorations, remainingArgs := PartitionDecorations(args...)
538 baseOffset := 2
539 cl := types.NewCodeLocation(baseOffset)
540 finalArgs := []interface{}{}
541 for _, arg := range decorations {
542 switch t := reflect.TypeOf(arg); {
543 case t == reflect.TypeOf(Offset(0)):
544 cl = types.NewCodeLocation(baseOffset + int(arg.(Offset)))
545 case t == reflect.TypeOf(types.CodeLocation{}):
546 cl = arg.(types.CodeLocation)
547 default:
548 finalArgs = append(finalArgs, arg)
549 }
550 }
551 finalArgs = append(finalArgs, cl)
552
553 if len(remainingArgs) == 0 {
554 return Node{}, []error{types.GinkgoErrors.DeferCleanupInvalidFunction(cl)}
555 }
556
557 callback := reflect.ValueOf(remainingArgs[0])
558 if !(callback.Kind() == reflect.Func) {
559 return Node{}, []error{types.GinkgoErrors.DeferCleanupInvalidFunction(cl)}
560 }
561
562 callArgs := []reflect.Value{}
563 for _, arg := range remainingArgs[1:] {
564 callArgs = append(callArgs, reflect.ValueOf(arg))
565 }
566
567 hasContext := false
568 t := callback.Type()
569 if t.NumIn() > 0 {
570 if t.In(0).Implements(specContextType) {
571 hasContext = true
572 } else if t.In(0).Implements(contextType) && (len(callArgs) == 0 || !callArgs[0].Type().Implements(contextType)) {
573 hasContext = true
574 }
575 }
576
577 handleFailure := func(out []reflect.Value) {
578 if len(out) == 0 {
579 return
580 }
581 last := out[len(out)-1]
582 if last.Type().Implements(errInterface) && !last.IsNil() {
583 fail(fmt.Sprintf("DeferCleanup callback returned error: %v", last), cl)
584 }
585 }
586
587 if hasContext {
588 finalArgs = append(finalArgs, func(c SpecContext) {
589 out := callback.Call(append([]reflect.Value{reflect.ValueOf(c)}, callArgs...))
590 handleFailure(out)
591 })
592 } else {
593 finalArgs = append(finalArgs, func() {
594 out := callback.Call(callArgs)
595 handleFailure(out)
596 })
597 }
598
599 return NewNode(deprecationTracker, types.NodeTypeCleanupInvalid, "", finalArgs...)
600 }
601
602 func (n Node) IsZero() bool {
603 return n.ID == 0
604 }
605
606
607 type Nodes []Node
608
609 func (n Nodes) Clone() Nodes {
610 nodes := make(Nodes, len(n))
611 copy(nodes, n)
612 return nodes
613 }
614
615 func (n Nodes) CopyAppend(nodes ...Node) Nodes {
616 numN := len(n)
617 out := make(Nodes, numN+len(nodes))
618 copy(out, n)
619 for j, node := range nodes {
620 out[numN+j] = node
621 }
622 return out
623 }
624
625 func (n Nodes) SplitAround(pivot Node) (Nodes, Nodes) {
626 pivotIdx := len(n)
627 for i := range n {
628 if n[i].ID == pivot.ID {
629 pivotIdx = i
630 break
631 }
632 }
633 left := n[:pivotIdx]
634 right := Nodes{}
635 if pivotIdx+1 < len(n) {
636 right = n[pivotIdx+1:]
637 }
638
639 return left, right
640 }
641
642 func (n Nodes) FirstNodeWithType(nodeTypes types.NodeType) Node {
643 for i := range n {
644 if n[i].NodeType.Is(nodeTypes) {
645 return n[i]
646 }
647 }
648 return Node{}
649 }
650
651 func (n Nodes) WithType(nodeTypes types.NodeType) Nodes {
652 count := 0
653 for i := range n {
654 if n[i].NodeType.Is(nodeTypes) {
655 count++
656 }
657 }
658
659 out, j := make(Nodes, count), 0
660 for i := range n {
661 if n[i].NodeType.Is(nodeTypes) {
662 out[j] = n[i]
663 j++
664 }
665 }
666 return out
667 }
668
669 func (n Nodes) WithoutType(nodeTypes types.NodeType) Nodes {
670 count := 0
671 for i := range n {
672 if !n[i].NodeType.Is(nodeTypes) {
673 count++
674 }
675 }
676
677 out, j := make(Nodes, count), 0
678 for i := range n {
679 if !n[i].NodeType.Is(nodeTypes) {
680 out[j] = n[i]
681 j++
682 }
683 }
684 return out
685 }
686
687 func (n Nodes) WithoutNode(nodeToExclude Node) Nodes {
688 idxToExclude := len(n)
689 for i := range n {
690 if n[i].ID == nodeToExclude.ID {
691 idxToExclude = i
692 break
693 }
694 }
695 if idxToExclude == len(n) {
696 return n
697 }
698 out, j := make(Nodes, len(n)-1), 0
699 for i := range n {
700 if i == idxToExclude {
701 continue
702 }
703 out[j] = n[i]
704 j++
705 }
706 return out
707 }
708
709 func (n Nodes) Filter(filter func(Node) bool) Nodes {
710 trufa, count := make([]bool, len(n)), 0
711 for i := range n {
712 if filter(n[i]) {
713 trufa[i] = true
714 count += 1
715 }
716 }
717 out, j := make(Nodes, count), 0
718 for i := range n {
719 if trufa[i] {
720 out[j] = n[i]
721 j++
722 }
723 }
724 return out
725 }
726
727 func (n Nodes) FirstSatisfying(filter func(Node) bool) Node {
728 for i := range n {
729 if filter(n[i]) {
730 return n[i]
731 }
732 }
733 return Node{}
734 }
735
736 func (n Nodes) WithinNestingLevel(deepestNestingLevel int) Nodes {
737 count := 0
738 for i := range n {
739 if n[i].NestingLevel <= deepestNestingLevel {
740 count++
741 }
742 }
743 out, j := make(Nodes, count), 0
744 for i := range n {
745 if n[i].NestingLevel <= deepestNestingLevel {
746 out[j] = n[i]
747 j++
748 }
749 }
750 return out
751 }
752
753 func (n Nodes) SortedByDescendingNestingLevel() Nodes {
754 out := make(Nodes, len(n))
755 copy(out, n)
756 sort.SliceStable(out, func(i int, j int) bool {
757 return out[i].NestingLevel > out[j].NestingLevel
758 })
759
760 return out
761 }
762
763 func (n Nodes) SortedByAscendingNestingLevel() Nodes {
764 out := make(Nodes, len(n))
765 copy(out, n)
766 sort.SliceStable(out, func(i int, j int) bool {
767 return out[i].NestingLevel < out[j].NestingLevel
768 })
769
770 return out
771 }
772
773 func (n Nodes) FirstWithNestingLevel(level int) Node {
774 for i := range n {
775 if n[i].NestingLevel == level {
776 return n[i]
777 }
778 }
779 return Node{}
780 }
781
782 func (n Nodes) Reverse() Nodes {
783 out := make(Nodes, len(n))
784 for i := range n {
785 out[len(n)-1-i] = n[i]
786 }
787 return out
788 }
789
790 func (n Nodes) Texts() []string {
791 out := make([]string, len(n))
792 for i := range n {
793 out[i] = n[i].Text
794 }
795 return out
796 }
797
798 func (n Nodes) Labels() [][]string {
799 out := make([][]string, len(n))
800 for i := range n {
801 if n[i].Labels == nil {
802 out[i] = []string{}
803 } else {
804 out[i] = []string(n[i].Labels)
805 }
806 }
807 return out
808 }
809
810 func (n Nodes) UnionOfLabels() []string {
811 out := []string{}
812 seen := map[string]bool{}
813 for i := range n {
814 for _, label := range n[i].Labels {
815 if !seen[label] {
816 seen[label] = true
817 out = append(out, label)
818 }
819 }
820 }
821 return out
822 }
823
824 func (n Nodes) CodeLocations() []types.CodeLocation {
825 out := make([]types.CodeLocation, len(n))
826 for i := range n {
827 out[i] = n[i].CodeLocation
828 }
829 return out
830 }
831
832 func (n Nodes) BestTextFor(node Node) string {
833 if node.Text != "" {
834 return node.Text
835 }
836 parentNestingLevel := node.NestingLevel - 1
837 for i := range n {
838 if n[i].Text != "" && n[i].NestingLevel == parentNestingLevel {
839 return n[i].Text
840 }
841 }
842
843 return ""
844 }
845
846 func (n Nodes) ContainsNodeID(id uint) bool {
847 for i := range n {
848 if n[i].ID == id {
849 return true
850 }
851 }
852 return false
853 }
854
855 func (n Nodes) HasNodeMarkedPending() bool {
856 for i := range n {
857 if n[i].MarkedPending {
858 return true
859 }
860 }
861 return false
862 }
863
864 func (n Nodes) HasNodeMarkedFocus() bool {
865 for i := range n {
866 if n[i].MarkedFocus {
867 return true
868 }
869 }
870 return false
871 }
872
873 func (n Nodes) HasNodeMarkedSerial() bool {
874 for i := range n {
875 if n[i].MarkedSerial {
876 return true
877 }
878 }
879 return false
880 }
881
882 func (n Nodes) FirstNodeMarkedOrdered() Node {
883 for i := range n {
884 if n[i].MarkedOrdered {
885 return n[i]
886 }
887 }
888 return Node{}
889 }
890
891 func (n Nodes) IndexOfFirstNodeMarkedOrdered() int {
892 for i := range n {
893 if n[i].MarkedOrdered {
894 return i
895 }
896 }
897 return -1
898 }
899
900 func (n Nodes) GetMaxFlakeAttempts() int {
901 maxFlakeAttempts := 0
902 for i := range n {
903 if n[i].FlakeAttempts > 0 {
904 maxFlakeAttempts = n[i].FlakeAttempts
905 }
906 }
907 return maxFlakeAttempts
908 }
909
910 func (n Nodes) GetMaxMustPassRepeatedly() int {
911 maxMustPassRepeatedly := 0
912 for i := range n {
913 if n[i].MustPassRepeatedly > 0 {
914 maxMustPassRepeatedly = n[i].MustPassRepeatedly
915 }
916 }
917 return maxMustPassRepeatedly
918 }
919
920 func unrollInterfaceSlice(args interface{}) []interface{} {
921 v := reflect.ValueOf(args)
922 if v.Kind() != reflect.Slice {
923 return []interface{}{args}
924 }
925 out := []interface{}{}
926 for i := 0; i < v.Len(); i++ {
927 el := reflect.ValueOf(v.Index(i).Interface())
928 if el.Kind() == reflect.Slice && el.Type() != reflect.TypeOf(Labels{}) {
929 out = append(out, unrollInterfaceSlice(el.Interface())...)
930 } else {
931 out = append(out, v.Index(i).Interface())
932 }
933 }
934 return out
935 }
936
View as plain text