...

Source file src/github.com/onsi/ginkgo/v2/internal/node.go

Documentation: github.com/onsi/ginkgo/v2/internal

     1  package internal
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"sort"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/onsi/ginkgo/v2/types"
    12  )
    13  
    14  var _global_node_id_counter = uint(0)
    15  var _global_id_mutex = &sync.Mutex{}
    16  
    17  func UniqueNodeID() uint {
    18  	// There's a reace in the internal integration tests if we don't make
    19  	// accessing _global_node_id_counter safe across goroutines.
    20  	_global_id_mutex.Lock()
    21  	defer _global_id_mutex.Unlock()
    22  	_global_node_id_counter += 1
    23  	return _global_node_id_counter
    24  }
    25  
    26  type Node struct {
    27  	ID       uint
    28  	NodeType types.NodeType
    29  
    30  	Text         string
    31  	Body         func(SpecContext)
    32  	CodeLocation types.CodeLocation
    33  	NestingLevel int
    34  	HasContext   bool
    35  
    36  	SynchronizedBeforeSuiteProc1Body              func(SpecContext) []byte
    37  	SynchronizedBeforeSuiteProc1BodyHasContext    bool
    38  	SynchronizedBeforeSuiteAllProcsBody           func(SpecContext, []byte)
    39  	SynchronizedBeforeSuiteAllProcsBodyHasContext bool
    40  
    41  	SynchronizedAfterSuiteAllProcsBody           func(SpecContext)
    42  	SynchronizedAfterSuiteAllProcsBodyHasContext bool
    43  	SynchronizedAfterSuiteProc1Body              func(SpecContext)
    44  	SynchronizedAfterSuiteProc1BodyHasContext    bool
    45  
    46  	ReportEachBody  func(SpecContext, types.SpecReport)
    47  	ReportSuiteBody func(SpecContext, types.Report)
    48  
    49  	MarkedFocus             bool
    50  	MarkedPending           bool
    51  	MarkedSerial            bool
    52  	MarkedOrdered           bool
    53  	MarkedContinueOnFailure bool
    54  	MarkedOncePerOrdered    bool
    55  	FlakeAttempts           int
    56  	MustPassRepeatedly      int
    57  	Labels                  Labels
    58  	PollProgressAfter       time.Duration
    59  	PollProgressInterval    time.Duration
    60  	NodeTimeout             time.Duration
    61  	SpecTimeout             time.Duration
    62  	GracePeriod             time.Duration
    63  
    64  	NodeIDWhereCleanupWasGenerated uint
    65  }
    66  
    67  // Decoration Types
    68  type focusType bool
    69  type pendingType bool
    70  type serialType bool
    71  type orderedType bool
    72  type continueOnFailureType bool
    73  type honorsOrderedType bool
    74  type suppressProgressReporting bool
    75  
    76  const Focus = focusType(true)
    77  const Pending = pendingType(true)
    78  const Serial = serialType(true)
    79  const Ordered = orderedType(true)
    80  const ContinueOnFailure = continueOnFailureType(true)
    81  const OncePerOrdered = honorsOrderedType(true)
    82  const SuppressProgressReporting = suppressProgressReporting(true)
    83  
    84  type FlakeAttempts uint
    85  type MustPassRepeatedly uint
    86  type Offset uint
    87  type Done chan<- interface{} // Deprecated Done Channel for asynchronous testing
    88  type Labels []string
    89  type PollProgressInterval time.Duration
    90  type PollProgressAfter time.Duration
    91  type NodeTimeout time.Duration
    92  type SpecTimeout time.Duration
    93  type GracePeriod time.Duration
    94  
    95  func (l Labels) MatchesLabelFilter(query string) bool {
    96  	return types.MustParseLabelFilter(query)(l)
    97  }
    98  
    99  func UnionOfLabels(labels ...Labels) Labels {
   100  	out := Labels{}
   101  	seen := map[string]bool{}
   102  	for _, labelSet := range labels {
   103  		for _, label := range labelSet {
   104  			if !seen[label] {
   105  				seen[label] = true
   106  				out = append(out, label)
   107  			}
   108  		}
   109  	}
   110  	return out
   111  }
   112  
   113  func PartitionDecorations(args ...interface{}) ([]interface{}, []interface{}) {
   114  	decorations := []interface{}{}
   115  	remainingArgs := []interface{}{}
   116  	for _, arg := range args {
   117  		if isDecoration(arg) {
   118  			decorations = append(decorations, arg)
   119  		} else {
   120  			remainingArgs = append(remainingArgs, arg)
   121  		}
   122  	}
   123  	return decorations, remainingArgs
   124  }
   125  
   126  func isDecoration(arg interface{}) bool {
   127  	switch t := reflect.TypeOf(arg); {
   128  	case t == nil:
   129  		return false
   130  	case t == reflect.TypeOf(Offset(0)):
   131  		return true
   132  	case t == reflect.TypeOf(types.CodeLocation{}):
   133  		return true
   134  	case t == reflect.TypeOf(Focus):
   135  		return true
   136  	case t == reflect.TypeOf(Pending):
   137  		return true
   138  	case t == reflect.TypeOf(Serial):
   139  		return true
   140  	case t == reflect.TypeOf(Ordered):
   141  		return true
   142  	case t == reflect.TypeOf(ContinueOnFailure):
   143  		return true
   144  	case t == reflect.TypeOf(OncePerOrdered):
   145  		return true
   146  	case t == reflect.TypeOf(SuppressProgressReporting):
   147  		return true
   148  	case t == reflect.TypeOf(FlakeAttempts(0)):
   149  		return true
   150  	case t == reflect.TypeOf(MustPassRepeatedly(0)):
   151  		return true
   152  	case t == reflect.TypeOf(Labels{}):
   153  		return true
   154  	case t == reflect.TypeOf(PollProgressInterval(0)):
   155  		return true
   156  	case t == reflect.TypeOf(PollProgressAfter(0)):
   157  		return true
   158  	case t == reflect.TypeOf(NodeTimeout(0)):
   159  		return true
   160  	case t == reflect.TypeOf(SpecTimeout(0)):
   161  		return true
   162  	case t == reflect.TypeOf(GracePeriod(0)):
   163  		return true
   164  	case t.Kind() == reflect.Slice && isSliceOfDecorations(arg):
   165  		return true
   166  	default:
   167  		return false
   168  	}
   169  }
   170  
   171  func isSliceOfDecorations(slice interface{}) bool {
   172  	vSlice := reflect.ValueOf(slice)
   173  	if vSlice.Len() == 0 {
   174  		return false
   175  	}
   176  	for i := 0; i < vSlice.Len(); i++ {
   177  		if !isDecoration(vSlice.Index(i).Interface()) {
   178  			return false
   179  		}
   180  	}
   181  	return true
   182  }
   183  
   184  var contextType = reflect.TypeOf(new(context.Context)).Elem()
   185  var specContextType = reflect.TypeOf(new(SpecContext)).Elem()
   186  
   187  func NewNode(deprecationTracker *types.DeprecationTracker, nodeType types.NodeType, text string, args ...interface{}) (Node, []error) {
   188  	baseOffset := 2
   189  	node := Node{
   190  		ID:                   UniqueNodeID(),
   191  		NodeType:             nodeType,
   192  		Text:                 text,
   193  		Labels:               Labels{},
   194  		CodeLocation:         types.NewCodeLocation(baseOffset),
   195  		NestingLevel:         -1,
   196  		PollProgressAfter:    -1,
   197  		PollProgressInterval: -1,
   198  		GracePeriod:          -1,
   199  	}
   200  
   201  	errors := []error{}
   202  	appendError := func(err error) {
   203  		if err != nil {
   204  			errors = append(errors, err)
   205  		}
   206  	}
   207  
   208  	args = unrollInterfaceSlice(args)
   209  
   210  	remainingArgs := []interface{}{}
   211  	// First get the CodeLocation up-to-date
   212  	for _, arg := range args {
   213  		switch v := arg.(type) {
   214  		case Offset:
   215  			node.CodeLocation = types.NewCodeLocation(baseOffset + int(v))
   216  		case types.CodeLocation:
   217  			node.CodeLocation = v
   218  		default:
   219  			remainingArgs = append(remainingArgs, arg)
   220  		}
   221  	}
   222  
   223  	labelsSeen := map[string]bool{}
   224  	trackedFunctionError := false
   225  	args = remainingArgs
   226  	remainingArgs = []interface{}{}
   227  	// now process the rest of the args
   228  	for _, arg := range args {
   229  		switch t := reflect.TypeOf(arg); {
   230  		case t == reflect.TypeOf(float64(0)):
   231  			break // ignore deprecated timeouts
   232  		case t == reflect.TypeOf(Focus):
   233  			node.MarkedFocus = bool(arg.(focusType))
   234  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   235  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Focus"))
   236  			}
   237  		case t == reflect.TypeOf(Pending):
   238  			node.MarkedPending = bool(arg.(pendingType))
   239  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   240  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Pending"))
   241  			}
   242  		case t == reflect.TypeOf(Serial):
   243  			node.MarkedSerial = bool(arg.(serialType))
   244  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   245  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Serial"))
   246  			}
   247  		case t == reflect.TypeOf(Ordered):
   248  			node.MarkedOrdered = bool(arg.(orderedType))
   249  			if !nodeType.Is(types.NodeTypeContainer) {
   250  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Ordered"))
   251  			}
   252  		case t == reflect.TypeOf(ContinueOnFailure):
   253  			node.MarkedContinueOnFailure = bool(arg.(continueOnFailureType))
   254  			if !nodeType.Is(types.NodeTypeContainer) {
   255  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "ContinueOnFailure"))
   256  			}
   257  		case t == reflect.TypeOf(OncePerOrdered):
   258  			node.MarkedOncePerOrdered = bool(arg.(honorsOrderedType))
   259  			if !nodeType.Is(types.NodeTypeBeforeEach | types.NodeTypeJustBeforeEach | types.NodeTypeAfterEach | types.NodeTypeJustAfterEach) {
   260  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "OncePerOrdered"))
   261  			}
   262  		case t == reflect.TypeOf(SuppressProgressReporting):
   263  			deprecationTracker.TrackDeprecation(types.Deprecations.SuppressProgressReporting())
   264  		case t == reflect.TypeOf(FlakeAttempts(0)):
   265  			node.FlakeAttempts = int(arg.(FlakeAttempts))
   266  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   267  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "FlakeAttempts"))
   268  			}
   269  		case t == reflect.TypeOf(MustPassRepeatedly(0)):
   270  			node.MustPassRepeatedly = int(arg.(MustPassRepeatedly))
   271  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   272  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "MustPassRepeatedly"))
   273  			}
   274  		case t == reflect.TypeOf(PollProgressAfter(0)):
   275  			node.PollProgressAfter = time.Duration(arg.(PollProgressAfter))
   276  			if nodeType.Is(types.NodeTypeContainer) {
   277  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "PollProgressAfter"))
   278  			}
   279  		case t == reflect.TypeOf(PollProgressInterval(0)):
   280  			node.PollProgressInterval = time.Duration(arg.(PollProgressInterval))
   281  			if nodeType.Is(types.NodeTypeContainer) {
   282  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "PollProgressInterval"))
   283  			}
   284  		case t == reflect.TypeOf(NodeTimeout(0)):
   285  			node.NodeTimeout = time.Duration(arg.(NodeTimeout))
   286  			if nodeType.Is(types.NodeTypeContainer) {
   287  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "NodeTimeout"))
   288  			}
   289  		case t == reflect.TypeOf(SpecTimeout(0)):
   290  			node.SpecTimeout = time.Duration(arg.(SpecTimeout))
   291  			if !nodeType.Is(types.NodeTypeIt) {
   292  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "SpecTimeout"))
   293  			}
   294  		case t == reflect.TypeOf(GracePeriod(0)):
   295  			node.GracePeriod = time.Duration(arg.(GracePeriod))
   296  			if nodeType.Is(types.NodeTypeContainer) {
   297  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "GracePeriod"))
   298  			}
   299  		case t == reflect.TypeOf(Labels{}):
   300  			if !nodeType.Is(types.NodeTypesForContainerAndIt) {
   301  				appendError(types.GinkgoErrors.InvalidDecoratorForNodeType(node.CodeLocation, nodeType, "Label"))
   302  			}
   303  			for _, label := range arg.(Labels) {
   304  				if !labelsSeen[label] {
   305  					labelsSeen[label] = true
   306  					label, err := types.ValidateAndCleanupLabel(label, node.CodeLocation)
   307  					node.Labels = append(node.Labels, label)
   308  					appendError(err)
   309  				}
   310  			}
   311  		case t.Kind() == reflect.Func:
   312  			if nodeType.Is(types.NodeTypeContainer) {
   313  				if node.Body != nil {
   314  					appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
   315  					trackedFunctionError = true
   316  					break
   317  				}
   318  				if t.NumOut() > 0 || t.NumIn() > 0 {
   319  					appendError(types.GinkgoErrors.InvalidBodyTypeForContainer(t, node.CodeLocation, nodeType))
   320  					trackedFunctionError = true
   321  					break
   322  				}
   323  				body := arg.(func())
   324  				node.Body = func(SpecContext) { body() }
   325  			} else if nodeType.Is(types.NodeTypeReportBeforeEach | types.NodeTypeReportAfterEach) {
   326  				if node.ReportEachBody == nil {
   327  					if fn, ok := arg.(func(types.SpecReport)); ok {
   328  						node.ReportEachBody = func(_ SpecContext, r types.SpecReport) { fn(r) }
   329  					} else {
   330  						node.ReportEachBody = arg.(func(SpecContext, types.SpecReport))
   331  						node.HasContext = true
   332  					}
   333  				} else {
   334  					appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
   335  					trackedFunctionError = true
   336  					break
   337  				}
   338  			} else if nodeType.Is(types.NodeTypeReportBeforeSuite | types.NodeTypeReportAfterSuite) {
   339  				if node.ReportSuiteBody == nil {
   340  					if fn, ok := arg.(func(types.Report)); ok {
   341  						node.ReportSuiteBody = func(_ SpecContext, r types.Report) { fn(r) }
   342  					} else {
   343  						node.ReportSuiteBody = arg.(func(SpecContext, types.Report))
   344  						node.HasContext = true
   345  					}
   346  				} else {
   347  					appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
   348  					trackedFunctionError = true
   349  					break
   350  				}
   351  			} else if nodeType.Is(types.NodeTypeSynchronizedBeforeSuite) {
   352  				if node.SynchronizedBeforeSuiteProc1Body != nil && node.SynchronizedBeforeSuiteAllProcsBody != nil {
   353  					appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
   354  					trackedFunctionError = true
   355  					break
   356  				}
   357  				if node.SynchronizedBeforeSuiteProc1Body == nil {
   358  					body, hasContext := extractSynchronizedBeforeSuiteProc1Body(arg)
   359  					if body == nil {
   360  						appendError(types.GinkgoErrors.InvalidBodyTypeForSynchronizedBeforeSuiteProc1(t, node.CodeLocation))
   361  						trackedFunctionError = true
   362  					}
   363  					node.SynchronizedBeforeSuiteProc1Body, node.SynchronizedBeforeSuiteProc1BodyHasContext = body, hasContext
   364  				} else if node.SynchronizedBeforeSuiteAllProcsBody == nil {
   365  					body, hasContext := extractSynchronizedBeforeSuiteAllProcsBody(arg)
   366  					if body == nil {
   367  						appendError(types.GinkgoErrors.InvalidBodyTypeForSynchronizedBeforeSuiteAllProcs(t, node.CodeLocation))
   368  						trackedFunctionError = true
   369  					}
   370  					node.SynchronizedBeforeSuiteAllProcsBody, node.SynchronizedBeforeSuiteAllProcsBodyHasContext = body, hasContext
   371  				}
   372  			} else if nodeType.Is(types.NodeTypeSynchronizedAfterSuite) {
   373  				if node.SynchronizedAfterSuiteAllProcsBody != nil && node.SynchronizedAfterSuiteProc1Body != nil {
   374  					appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
   375  					trackedFunctionError = true
   376  					break
   377  				}
   378  				body, hasContext := extractBodyFunction(deprecationTracker, node.CodeLocation, arg)
   379  				if body == nil {
   380  					appendError(types.GinkgoErrors.InvalidBodyType(t, node.CodeLocation, nodeType))
   381  					trackedFunctionError = true
   382  					break
   383  				}
   384  				if node.SynchronizedAfterSuiteAllProcsBody == nil {
   385  					node.SynchronizedAfterSuiteAllProcsBody, node.SynchronizedAfterSuiteAllProcsBodyHasContext = body, hasContext
   386  				} else if node.SynchronizedAfterSuiteProc1Body == nil {
   387  					node.SynchronizedAfterSuiteProc1Body, node.SynchronizedAfterSuiteProc1BodyHasContext = body, hasContext
   388  				}
   389  			} else {
   390  				if node.Body != nil {
   391  					appendError(types.GinkgoErrors.MultipleBodyFunctions(node.CodeLocation, nodeType))
   392  					trackedFunctionError = true
   393  					break
   394  				}
   395  				node.Body, node.HasContext = extractBodyFunction(deprecationTracker, node.CodeLocation, arg)
   396  				if node.Body == nil {
   397  					appendError(types.GinkgoErrors.InvalidBodyType(t, node.CodeLocation, nodeType))
   398  					trackedFunctionError = true
   399  					break
   400  				}
   401  			}
   402  		default:
   403  			remainingArgs = append(remainingArgs, arg)
   404  		}
   405  	}
   406  
   407  	// validations
   408  	if node.MarkedPending && node.MarkedFocus {
   409  		appendError(types.GinkgoErrors.InvalidDeclarationOfFocusedAndPending(node.CodeLocation, nodeType))
   410  	}
   411  
   412  	if node.MarkedContinueOnFailure && !node.MarkedOrdered {
   413  		appendError(types.GinkgoErrors.InvalidContinueOnFailureDecoration(node.CodeLocation))
   414  	}
   415  
   416  	hasContext := node.HasContext || node.SynchronizedAfterSuiteProc1BodyHasContext || node.SynchronizedAfterSuiteAllProcsBodyHasContext || node.SynchronizedBeforeSuiteProc1BodyHasContext || node.SynchronizedBeforeSuiteAllProcsBodyHasContext
   417  
   418  	if !hasContext && (node.NodeTimeout > 0 || node.SpecTimeout > 0 || node.GracePeriod > 0) && len(errors) == 0 {
   419  		appendError(types.GinkgoErrors.InvalidTimeoutOrGracePeriodForNonContextNode(node.CodeLocation, nodeType))
   420  	}
   421  
   422  	if !node.NodeType.Is(types.NodeTypeReportBeforeEach|types.NodeTypeReportAfterEach|types.NodeTypeSynchronizedBeforeSuite|types.NodeTypeSynchronizedAfterSuite|types.NodeTypeReportBeforeSuite|types.NodeTypeReportAfterSuite) && node.Body == nil && !node.MarkedPending && !trackedFunctionError {
   423  		appendError(types.GinkgoErrors.MissingBodyFunction(node.CodeLocation, nodeType))
   424  	}
   425  
   426  	if node.NodeType.Is(types.NodeTypeSynchronizedBeforeSuite) && !trackedFunctionError && (node.SynchronizedBeforeSuiteProc1Body == nil || node.SynchronizedBeforeSuiteAllProcsBody == nil) {
   427  		appendError(types.GinkgoErrors.MissingBodyFunction(node.CodeLocation, nodeType))
   428  	}
   429  
   430  	if node.NodeType.Is(types.NodeTypeSynchronizedAfterSuite) && !trackedFunctionError && (node.SynchronizedAfterSuiteProc1Body == nil || node.SynchronizedAfterSuiteAllProcsBody == nil) {
   431  		appendError(types.GinkgoErrors.MissingBodyFunction(node.CodeLocation, nodeType))
   432  	}
   433  
   434  	for _, arg := range remainingArgs {
   435  		appendError(types.GinkgoErrors.UnknownDecorator(node.CodeLocation, nodeType, arg))
   436  	}
   437  
   438  	if node.FlakeAttempts > 0 && node.MustPassRepeatedly > 0 {
   439  		appendError(types.GinkgoErrors.InvalidDeclarationOfFlakeAttemptsAndMustPassRepeatedly(node.CodeLocation, nodeType))
   440  	}
   441  
   442  	if len(errors) > 0 {
   443  		return Node{}, errors
   444  	}
   445  
   446  	return node, errors
   447  }
   448  
   449  var doneType = reflect.TypeOf(make(Done))
   450  
   451  func extractBodyFunction(deprecationTracker *types.DeprecationTracker, cl types.CodeLocation, arg interface{}) (func(SpecContext), bool) {
   452  	t := reflect.TypeOf(arg)
   453  	if t.NumOut() > 0 || t.NumIn() > 1 {
   454  		return nil, false
   455  	}
   456  	if t.NumIn() == 1 {
   457  		if t.In(0) == doneType {
   458  			deprecationTracker.TrackDeprecation(types.Deprecations.Async(), cl)
   459  			deprecatedAsyncBody := arg.(func(Done))
   460  			return func(SpecContext) { deprecatedAsyncBody(make(Done)) }, false
   461  		} else if t.In(0).Implements(specContextType) {
   462  			return arg.(func(SpecContext)), true
   463  		} else if t.In(0).Implements(contextType) {
   464  			body := arg.(func(context.Context))
   465  			return func(c SpecContext) { body(c) }, true
   466  		}
   467  
   468  		return nil, false
   469  	}
   470  
   471  	body := arg.(func())
   472  	return func(SpecContext) { body() }, false
   473  }
   474  
   475  var byteType = reflect.TypeOf([]byte{})
   476  
   477  func extractSynchronizedBeforeSuiteProc1Body(arg interface{}) (func(SpecContext) []byte, bool) {
   478  	t := reflect.TypeOf(arg)
   479  	v := reflect.ValueOf(arg)
   480  
   481  	if t.NumOut() > 1 || t.NumIn() > 1 {
   482  		return nil, false
   483  	} else if t.NumOut() == 1 && t.Out(0) != byteType {
   484  		return nil, false
   485  	} else if t.NumIn() == 1 && !t.In(0).Implements(contextType) {
   486  		return nil, false
   487  	}
   488  	hasContext := t.NumIn() == 1
   489  
   490  	return func(c SpecContext) []byte {
   491  		var out []reflect.Value
   492  		if hasContext {
   493  			out = v.Call([]reflect.Value{reflect.ValueOf(c)})
   494  		} else {
   495  			out = v.Call([]reflect.Value{})
   496  		}
   497  		if len(out) == 1 {
   498  			return (out[0].Interface()).([]byte)
   499  		} else {
   500  			return []byte{}
   501  		}
   502  	}, hasContext
   503  }
   504  
   505  func extractSynchronizedBeforeSuiteAllProcsBody(arg interface{}) (func(SpecContext, []byte), bool) {
   506  	t := reflect.TypeOf(arg)
   507  	v := reflect.ValueOf(arg)
   508  	hasContext, hasByte := false, false
   509  
   510  	if t.NumOut() > 0 || t.NumIn() > 2 {
   511  		return nil, false
   512  	} else if t.NumIn() == 2 && t.In(0).Implements(contextType) && t.In(1) == byteType {
   513  		hasContext, hasByte = true, true
   514  	} else if t.NumIn() == 1 && t.In(0).Implements(contextType) {
   515  		hasContext = true
   516  	} else if t.NumIn() == 1 && t.In(0) == byteType {
   517  		hasByte = true
   518  	} else if t.NumIn() != 0 {
   519  		return nil, false
   520  	}
   521  
   522  	return func(c SpecContext, b []byte) {
   523  		in := []reflect.Value{}
   524  		if hasContext {
   525  			in = append(in, reflect.ValueOf(c))
   526  		}
   527  		if hasByte {
   528  			in = append(in, reflect.ValueOf(b))
   529  		}
   530  		v.Call(in)
   531  	}, hasContext
   532  }
   533  
   534  var errInterface = reflect.TypeOf((*error)(nil)).Elem()
   535  
   536  func NewCleanupNode(deprecationTracker *types.DeprecationTracker, fail func(string, types.CodeLocation), args ...interface{}) (Node, []error) {
   537  	decorations, remainingArgs := PartitionDecorations(args...)
   538  	baseOffset := 2
   539  	cl := types.NewCodeLocation(baseOffset)
   540  	finalArgs := []interface{}{}
   541  	for _, arg := range decorations {
   542  		switch t := reflect.TypeOf(arg); {
   543  		case t == reflect.TypeOf(Offset(0)):
   544  			cl = types.NewCodeLocation(baseOffset + int(arg.(Offset)))
   545  		case t == reflect.TypeOf(types.CodeLocation{}):
   546  			cl = arg.(types.CodeLocation)
   547  		default:
   548  			finalArgs = append(finalArgs, arg)
   549  		}
   550  	}
   551  	finalArgs = append(finalArgs, cl)
   552  
   553  	if len(remainingArgs) == 0 {
   554  		return Node{}, []error{types.GinkgoErrors.DeferCleanupInvalidFunction(cl)}
   555  	}
   556  
   557  	callback := reflect.ValueOf(remainingArgs[0])
   558  	if !(callback.Kind() == reflect.Func) {
   559  		return Node{}, []error{types.GinkgoErrors.DeferCleanupInvalidFunction(cl)}
   560  	}
   561  
   562  	callArgs := []reflect.Value{}
   563  	for _, arg := range remainingArgs[1:] {
   564  		callArgs = append(callArgs, reflect.ValueOf(arg))
   565  	}
   566  
   567  	hasContext := false
   568  	t := callback.Type()
   569  	if t.NumIn() > 0 {
   570  		if t.In(0).Implements(specContextType) {
   571  			hasContext = true
   572  		} else if t.In(0).Implements(contextType) && (len(callArgs) == 0 || !callArgs[0].Type().Implements(contextType)) {
   573  			hasContext = true
   574  		}
   575  	}
   576  
   577  	handleFailure := func(out []reflect.Value) {
   578  		if len(out) == 0 {
   579  			return
   580  		}
   581  		last := out[len(out)-1]
   582  		if last.Type().Implements(errInterface) && !last.IsNil() {
   583  			fail(fmt.Sprintf("DeferCleanup callback returned error: %v", last), cl)
   584  		}
   585  	}
   586  
   587  	if hasContext {
   588  		finalArgs = append(finalArgs, func(c SpecContext) {
   589  			out := callback.Call(append([]reflect.Value{reflect.ValueOf(c)}, callArgs...))
   590  			handleFailure(out)
   591  		})
   592  	} else {
   593  		finalArgs = append(finalArgs, func() {
   594  			out := callback.Call(callArgs)
   595  			handleFailure(out)
   596  		})
   597  	}
   598  
   599  	return NewNode(deprecationTracker, types.NodeTypeCleanupInvalid, "", finalArgs...)
   600  }
   601  
   602  func (n Node) IsZero() bool {
   603  	return n.ID == 0
   604  }
   605  
   606  /* Nodes */
   607  type Nodes []Node
   608  
   609  func (n Nodes) Clone() Nodes {
   610  	nodes := make(Nodes, len(n))
   611  	copy(nodes, n)
   612  	return nodes
   613  }
   614  
   615  func (n Nodes) CopyAppend(nodes ...Node) Nodes {
   616  	numN := len(n)
   617  	out := make(Nodes, numN+len(nodes))
   618  	copy(out, n)
   619  	for j, node := range nodes {
   620  		out[numN+j] = node
   621  	}
   622  	return out
   623  }
   624  
   625  func (n Nodes) SplitAround(pivot Node) (Nodes, Nodes) {
   626  	pivotIdx := len(n)
   627  	for i := range n {
   628  		if n[i].ID == pivot.ID {
   629  			pivotIdx = i
   630  			break
   631  		}
   632  	}
   633  	left := n[:pivotIdx]
   634  	right := Nodes{}
   635  	if pivotIdx+1 < len(n) {
   636  		right = n[pivotIdx+1:]
   637  	}
   638  
   639  	return left, right
   640  }
   641  
   642  func (n Nodes) FirstNodeWithType(nodeTypes types.NodeType) Node {
   643  	for i := range n {
   644  		if n[i].NodeType.Is(nodeTypes) {
   645  			return n[i]
   646  		}
   647  	}
   648  	return Node{}
   649  }
   650  
   651  func (n Nodes) WithType(nodeTypes types.NodeType) Nodes {
   652  	count := 0
   653  	for i := range n {
   654  		if n[i].NodeType.Is(nodeTypes) {
   655  			count++
   656  		}
   657  	}
   658  
   659  	out, j := make(Nodes, count), 0
   660  	for i := range n {
   661  		if n[i].NodeType.Is(nodeTypes) {
   662  			out[j] = n[i]
   663  			j++
   664  		}
   665  	}
   666  	return out
   667  }
   668  
   669  func (n Nodes) WithoutType(nodeTypes types.NodeType) Nodes {
   670  	count := 0
   671  	for i := range n {
   672  		if !n[i].NodeType.Is(nodeTypes) {
   673  			count++
   674  		}
   675  	}
   676  
   677  	out, j := make(Nodes, count), 0
   678  	for i := range n {
   679  		if !n[i].NodeType.Is(nodeTypes) {
   680  			out[j] = n[i]
   681  			j++
   682  		}
   683  	}
   684  	return out
   685  }
   686  
   687  func (n Nodes) WithoutNode(nodeToExclude Node) Nodes {
   688  	idxToExclude := len(n)
   689  	for i := range n {
   690  		if n[i].ID == nodeToExclude.ID {
   691  			idxToExclude = i
   692  			break
   693  		}
   694  	}
   695  	if idxToExclude == len(n) {
   696  		return n
   697  	}
   698  	out, j := make(Nodes, len(n)-1), 0
   699  	for i := range n {
   700  		if i == idxToExclude {
   701  			continue
   702  		}
   703  		out[j] = n[i]
   704  		j++
   705  	}
   706  	return out
   707  }
   708  
   709  func (n Nodes) Filter(filter func(Node) bool) Nodes {
   710  	trufa, count := make([]bool, len(n)), 0
   711  	for i := range n {
   712  		if filter(n[i]) {
   713  			trufa[i] = true
   714  			count += 1
   715  		}
   716  	}
   717  	out, j := make(Nodes, count), 0
   718  	for i := range n {
   719  		if trufa[i] {
   720  			out[j] = n[i]
   721  			j++
   722  		}
   723  	}
   724  	return out
   725  }
   726  
   727  func (n Nodes) FirstSatisfying(filter func(Node) bool) Node {
   728  	for i := range n {
   729  		if filter(n[i]) {
   730  			return n[i]
   731  		}
   732  	}
   733  	return Node{}
   734  }
   735  
   736  func (n Nodes) WithinNestingLevel(deepestNestingLevel int) Nodes {
   737  	count := 0
   738  	for i := range n {
   739  		if n[i].NestingLevel <= deepestNestingLevel {
   740  			count++
   741  		}
   742  	}
   743  	out, j := make(Nodes, count), 0
   744  	for i := range n {
   745  		if n[i].NestingLevel <= deepestNestingLevel {
   746  			out[j] = n[i]
   747  			j++
   748  		}
   749  	}
   750  	return out
   751  }
   752  
   753  func (n Nodes) SortedByDescendingNestingLevel() Nodes {
   754  	out := make(Nodes, len(n))
   755  	copy(out, n)
   756  	sort.SliceStable(out, func(i int, j int) bool {
   757  		return out[i].NestingLevel > out[j].NestingLevel
   758  	})
   759  
   760  	return out
   761  }
   762  
   763  func (n Nodes) SortedByAscendingNestingLevel() Nodes {
   764  	out := make(Nodes, len(n))
   765  	copy(out, n)
   766  	sort.SliceStable(out, func(i int, j int) bool {
   767  		return out[i].NestingLevel < out[j].NestingLevel
   768  	})
   769  
   770  	return out
   771  }
   772  
   773  func (n Nodes) FirstWithNestingLevel(level int) Node {
   774  	for i := range n {
   775  		if n[i].NestingLevel == level {
   776  			return n[i]
   777  		}
   778  	}
   779  	return Node{}
   780  }
   781  
   782  func (n Nodes) Reverse() Nodes {
   783  	out := make(Nodes, len(n))
   784  	for i := range n {
   785  		out[len(n)-1-i] = n[i]
   786  	}
   787  	return out
   788  }
   789  
   790  func (n Nodes) Texts() []string {
   791  	out := make([]string, len(n))
   792  	for i := range n {
   793  		out[i] = n[i].Text
   794  	}
   795  	return out
   796  }
   797  
   798  func (n Nodes) Labels() [][]string {
   799  	out := make([][]string, len(n))
   800  	for i := range n {
   801  		if n[i].Labels == nil {
   802  			out[i] = []string{}
   803  		} else {
   804  			out[i] = []string(n[i].Labels)
   805  		}
   806  	}
   807  	return out
   808  }
   809  
   810  func (n Nodes) UnionOfLabels() []string {
   811  	out := []string{}
   812  	seen := map[string]bool{}
   813  	for i := range n {
   814  		for _, label := range n[i].Labels {
   815  			if !seen[label] {
   816  				seen[label] = true
   817  				out = append(out, label)
   818  			}
   819  		}
   820  	}
   821  	return out
   822  }
   823  
   824  func (n Nodes) CodeLocations() []types.CodeLocation {
   825  	out := make([]types.CodeLocation, len(n))
   826  	for i := range n {
   827  		out[i] = n[i].CodeLocation
   828  	}
   829  	return out
   830  }
   831  
   832  func (n Nodes) BestTextFor(node Node) string {
   833  	if node.Text != "" {
   834  		return node.Text
   835  	}
   836  	parentNestingLevel := node.NestingLevel - 1
   837  	for i := range n {
   838  		if n[i].Text != "" && n[i].NestingLevel == parentNestingLevel {
   839  			return n[i].Text
   840  		}
   841  	}
   842  
   843  	return ""
   844  }
   845  
   846  func (n Nodes) ContainsNodeID(id uint) bool {
   847  	for i := range n {
   848  		if n[i].ID == id {
   849  			return true
   850  		}
   851  	}
   852  	return false
   853  }
   854  
   855  func (n Nodes) HasNodeMarkedPending() bool {
   856  	for i := range n {
   857  		if n[i].MarkedPending {
   858  			return true
   859  		}
   860  	}
   861  	return false
   862  }
   863  
   864  func (n Nodes) HasNodeMarkedFocus() bool {
   865  	for i := range n {
   866  		if n[i].MarkedFocus {
   867  			return true
   868  		}
   869  	}
   870  	return false
   871  }
   872  
   873  func (n Nodes) HasNodeMarkedSerial() bool {
   874  	for i := range n {
   875  		if n[i].MarkedSerial {
   876  			return true
   877  		}
   878  	}
   879  	return false
   880  }
   881  
   882  func (n Nodes) FirstNodeMarkedOrdered() Node {
   883  	for i := range n {
   884  		if n[i].MarkedOrdered {
   885  			return n[i]
   886  		}
   887  	}
   888  	return Node{}
   889  }
   890  
   891  func (n Nodes) IndexOfFirstNodeMarkedOrdered() int {
   892  	for i := range n {
   893  		if n[i].MarkedOrdered {
   894  			return i
   895  		}
   896  	}
   897  	return -1
   898  }
   899  
   900  func (n Nodes) GetMaxFlakeAttempts() int {
   901  	maxFlakeAttempts := 0
   902  	for i := range n {
   903  		if n[i].FlakeAttempts > 0 {
   904  			maxFlakeAttempts = n[i].FlakeAttempts
   905  		}
   906  	}
   907  	return maxFlakeAttempts
   908  }
   909  
   910  func (n Nodes) GetMaxMustPassRepeatedly() int {
   911  	maxMustPassRepeatedly := 0
   912  	for i := range n {
   913  		if n[i].MustPassRepeatedly > 0 {
   914  			maxMustPassRepeatedly = n[i].MustPassRepeatedly
   915  		}
   916  	}
   917  	return maxMustPassRepeatedly
   918  }
   919  
   920  func unrollInterfaceSlice(args interface{}) []interface{} {
   921  	v := reflect.ValueOf(args)
   922  	if v.Kind() != reflect.Slice {
   923  		return []interface{}{args}
   924  	}
   925  	out := []interface{}{}
   926  	for i := 0; i < v.Len(); i++ {
   927  		el := reflect.ValueOf(v.Index(i).Interface())
   928  		if el.Kind() == reflect.Slice && el.Type() != reflect.TypeOf(Labels{}) {
   929  			out = append(out, unrollInterfaceSlice(el.Interface())...)
   930  		} else {
   931  			out = append(out, v.Index(i).Interface())
   932  		}
   933  	}
   934  	return out
   935  }
   936  

View as plain text