...

Source file src/github.com/onsi/gomega/internal/async_assertion.go

Documentation: github.com/onsi/gomega/internal

     1  package internal
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"runtime"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/onsi/gomega/format"
    13  	"github.com/onsi/gomega/types"
    14  )
    15  
    16  var errInterface = reflect.TypeOf((*error)(nil)).Elem()
    17  var gomegaType = reflect.TypeOf((*types.Gomega)(nil)).Elem()
    18  var contextType = reflect.TypeOf(new(context.Context)).Elem()
    19  
    20  type formattedGomegaError interface {
    21  	FormattedGomegaError() string
    22  }
    23  
    24  type asyncPolledActualError struct {
    25  	message string
    26  }
    27  
    28  func (err *asyncPolledActualError) Error() string {
    29  	return err.message
    30  }
    31  
    32  func (err *asyncPolledActualError) FormattedGomegaError() string {
    33  	return err.message
    34  }
    35  
    36  type contextWithAttachProgressReporter interface {
    37  	AttachProgressReporter(func() string) func()
    38  }
    39  
    40  type asyncGomegaHaltExecutionError struct{}
    41  
    42  func (a asyncGomegaHaltExecutionError) GinkgoRecoverShouldIgnoreThisPanic() {}
    43  func (a asyncGomegaHaltExecutionError) Error() string {
    44  	return `An assertion has failed in a goroutine.  You should call 
    45  
    46      defer GinkgoRecover()
    47  
    48  at the top of the goroutine that caused this panic.  This will allow Ginkgo and Gomega to correctly capture and manage this panic.`
    49  }
    50  
    51  type AsyncAssertionType uint
    52  
    53  const (
    54  	AsyncAssertionTypeEventually AsyncAssertionType = iota
    55  	AsyncAssertionTypeConsistently
    56  )
    57  
    58  func (at AsyncAssertionType) String() string {
    59  	switch at {
    60  	case AsyncAssertionTypeEventually:
    61  		return "Eventually"
    62  	case AsyncAssertionTypeConsistently:
    63  		return "Consistently"
    64  	}
    65  	return "INVALID ASYNC ASSERTION TYPE"
    66  }
    67  
    68  type AsyncAssertion struct {
    69  	asyncType AsyncAssertionType
    70  
    71  	actualIsFunc  bool
    72  	actual        interface{}
    73  	argsToForward []interface{}
    74  
    75  	timeoutInterval    time.Duration
    76  	pollingInterval    time.Duration
    77  	mustPassRepeatedly int
    78  	ctx                context.Context
    79  	offset             int
    80  	g                  *Gomega
    81  }
    82  
    83  func NewAsyncAssertion(asyncType AsyncAssertionType, actualInput interface{}, g *Gomega, timeoutInterval time.Duration, pollingInterval time.Duration, mustPassRepeatedly int, ctx context.Context, offset int) *AsyncAssertion {
    84  	out := &AsyncAssertion{
    85  		asyncType:          asyncType,
    86  		timeoutInterval:    timeoutInterval,
    87  		pollingInterval:    pollingInterval,
    88  		mustPassRepeatedly: mustPassRepeatedly,
    89  		offset:             offset,
    90  		ctx:                ctx,
    91  		g:                  g,
    92  	}
    93  
    94  	out.actual = actualInput
    95  	if actualInput != nil && reflect.TypeOf(actualInput).Kind() == reflect.Func {
    96  		out.actualIsFunc = true
    97  	}
    98  
    99  	return out
   100  }
   101  
   102  func (assertion *AsyncAssertion) WithOffset(offset int) types.AsyncAssertion {
   103  	assertion.offset = offset
   104  	return assertion
   105  }
   106  
   107  func (assertion *AsyncAssertion) WithTimeout(interval time.Duration) types.AsyncAssertion {
   108  	assertion.timeoutInterval = interval
   109  	return assertion
   110  }
   111  
   112  func (assertion *AsyncAssertion) WithPolling(interval time.Duration) types.AsyncAssertion {
   113  	assertion.pollingInterval = interval
   114  	return assertion
   115  }
   116  
   117  func (assertion *AsyncAssertion) Within(timeout time.Duration) types.AsyncAssertion {
   118  	assertion.timeoutInterval = timeout
   119  	return assertion
   120  }
   121  
   122  func (assertion *AsyncAssertion) ProbeEvery(interval time.Duration) types.AsyncAssertion {
   123  	assertion.pollingInterval = interval
   124  	return assertion
   125  }
   126  
   127  func (assertion *AsyncAssertion) WithContext(ctx context.Context) types.AsyncAssertion {
   128  	assertion.ctx = ctx
   129  	return assertion
   130  }
   131  
   132  func (assertion *AsyncAssertion) WithArguments(argsToForward ...interface{}) types.AsyncAssertion {
   133  	assertion.argsToForward = argsToForward
   134  	return assertion
   135  }
   136  
   137  func (assertion *AsyncAssertion) MustPassRepeatedly(count int) types.AsyncAssertion {
   138  	assertion.mustPassRepeatedly = count
   139  	return assertion
   140  }
   141  
   142  func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
   143  	assertion.g.THelper()
   144  	vetOptionalDescription("Asynchronous assertion", optionalDescription...)
   145  	return assertion.match(matcher, true, optionalDescription...)
   146  }
   147  
   148  func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
   149  	assertion.g.THelper()
   150  	vetOptionalDescription("Asynchronous assertion", optionalDescription...)
   151  	return assertion.match(matcher, false, optionalDescription...)
   152  }
   153  
   154  func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interface{}) string {
   155  	switch len(optionalDescription) {
   156  	case 0:
   157  		return ""
   158  	case 1:
   159  		if describe, ok := optionalDescription[0].(func() string); ok {
   160  			return describe() + "\n"
   161  		}
   162  	}
   163  	return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n"
   164  }
   165  
   166  func (assertion *AsyncAssertion) processReturnValues(values []reflect.Value) (interface{}, error) {
   167  	if len(values) == 0 {
   168  		return nil, &asyncPolledActualError{
   169  			message: fmt.Sprintf("The function passed to %s did not return any values", assertion.asyncType),
   170  		}
   171  	}
   172  
   173  	actual := values[0].Interface()
   174  	if _, ok := AsPollingSignalError(actual); ok {
   175  		return actual, actual.(error)
   176  	}
   177  
   178  	var err error
   179  	for i, extraValue := range values[1:] {
   180  		extra := extraValue.Interface()
   181  		if extra == nil {
   182  			continue
   183  		}
   184  		if _, ok := AsPollingSignalError(extra); ok {
   185  			return actual, extra.(error)
   186  		}
   187  		extraType := reflect.TypeOf(extra)
   188  		zero := reflect.Zero(extraType).Interface()
   189  		if reflect.DeepEqual(extra, zero) {
   190  			continue
   191  		}
   192  		if i == len(values)-2 && extraType.Implements(errInterface) {
   193  			err = extra.(error)
   194  		}
   195  		if err == nil {
   196  			err = &asyncPolledActualError{
   197  				message: fmt.Sprintf("The function passed to %s had an unexpected non-nil/non-zero return value at index %d:\n%s", assertion.asyncType, i+1, format.Object(extra, 1)),
   198  			}
   199  		}
   200  	}
   201  
   202  	return actual, err
   203  }
   204  
   205  func (assertion *AsyncAssertion) invalidFunctionError(t reflect.Type) error {
   206  	return fmt.Errorf(`The function passed to %s had an invalid signature of %s.  Functions passed to %s must either:
   207  
   208  	(a) have return values or
   209  	(b) take a Gomega interface as their first argument and use that Gomega instance to make assertions.
   210  
   211  You can learn more at https://onsi.github.io/gomega/#eventually
   212  `, assertion.asyncType, t, assertion.asyncType)
   213  }
   214  
   215  func (assertion *AsyncAssertion) noConfiguredContextForFunctionError() error {
   216  	return fmt.Errorf(`The function passed to %s requested a context.Context, but no context has been provided.  Please pass one in using %s().WithContext().
   217  
   218  You can learn more at https://onsi.github.io/gomega/#eventually
   219  `, assertion.asyncType, assertion.asyncType)
   220  }
   221  
   222  func (assertion *AsyncAssertion) argumentMismatchError(t reflect.Type, numProvided int) error {
   223  	have := "have"
   224  	if numProvided == 1 {
   225  		have = "has"
   226  	}
   227  	return fmt.Errorf(`The function passed to %s has signature %s takes %d arguments but %d %s been provided.  Please use %s().WithArguments() to pass the corect set of arguments.
   228  
   229  You can learn more at https://onsi.github.io/gomega/#eventually
   230  `, assertion.asyncType, t, t.NumIn(), numProvided, have, assertion.asyncType)
   231  }
   232  
   233  func (assertion *AsyncAssertion) invalidMustPassRepeatedlyError(reason string) error {
   234  	return fmt.Errorf(`Invalid use of MustPassRepeatedly with %s %s
   235  
   236  You can learn more at https://onsi.github.io/gomega/#eventually
   237  `, assertion.asyncType, reason)
   238  }
   239  
   240  func (assertion *AsyncAssertion) buildActualPoller() (func() (interface{}, error), error) {
   241  	if !assertion.actualIsFunc {
   242  		return func() (interface{}, error) { return assertion.actual, nil }, nil
   243  	}
   244  	actualValue := reflect.ValueOf(assertion.actual)
   245  	actualType := reflect.TypeOf(assertion.actual)
   246  	numIn, numOut, isVariadic := actualType.NumIn(), actualType.NumOut(), actualType.IsVariadic()
   247  
   248  	if numIn == 0 && numOut == 0 {
   249  		return nil, assertion.invalidFunctionError(actualType)
   250  	}
   251  	takesGomega, takesContext := false, false
   252  	if numIn > 0 {
   253  		takesGomega, takesContext = actualType.In(0).Implements(gomegaType), actualType.In(0).Implements(contextType)
   254  	}
   255  	if takesGomega && numIn > 1 && actualType.In(1).Implements(contextType) {
   256  		takesContext = true
   257  	}
   258  	if takesContext && len(assertion.argsToForward) > 0 && reflect.TypeOf(assertion.argsToForward[0]).Implements(contextType) {
   259  		takesContext = false
   260  	}
   261  	if !takesGomega && numOut == 0 {
   262  		return nil, assertion.invalidFunctionError(actualType)
   263  	}
   264  	if takesContext && assertion.ctx == nil {
   265  		return nil, assertion.noConfiguredContextForFunctionError()
   266  	}
   267  
   268  	var assertionFailure error
   269  	inValues := []reflect.Value{}
   270  	if takesGomega {
   271  		inValues = append(inValues, reflect.ValueOf(NewGomega(assertion.g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) {
   272  			skip := 0
   273  			if len(callerSkip) > 0 {
   274  				skip = callerSkip[0]
   275  			}
   276  			_, file, line, _ := runtime.Caller(skip + 1)
   277  			assertionFailure = &asyncPolledActualError{
   278  				message: fmt.Sprintf("The function passed to %s failed at %s:%d with:\n%s", assertion.asyncType, file, line, message),
   279  			}
   280  			// we throw an asyncGomegaHaltExecutionError so that defer GinkgoRecover() can catch this error if the user makes an assertion in a goroutine
   281  			panic(asyncGomegaHaltExecutionError{})
   282  		})))
   283  	}
   284  	if takesContext {
   285  		inValues = append(inValues, reflect.ValueOf(assertion.ctx))
   286  	}
   287  	for _, arg := range assertion.argsToForward {
   288  		inValues = append(inValues, reflect.ValueOf(arg))
   289  	}
   290  
   291  	if !isVariadic && numIn != len(inValues) {
   292  		return nil, assertion.argumentMismatchError(actualType, len(inValues))
   293  	} else if isVariadic && len(inValues) < numIn-1 {
   294  		return nil, assertion.argumentMismatchError(actualType, len(inValues))
   295  	}
   296  
   297  	if assertion.mustPassRepeatedly != 1 && assertion.asyncType != AsyncAssertionTypeEventually {
   298  		return nil, assertion.invalidMustPassRepeatedlyError("it can only be used with Eventually")
   299  	}
   300  	if assertion.mustPassRepeatedly < 1 {
   301  		return nil, assertion.invalidMustPassRepeatedlyError("parameter can't be < 1")
   302  	}
   303  
   304  	return func() (actual interface{}, err error) {
   305  		var values []reflect.Value
   306  		assertionFailure = nil
   307  		defer func() {
   308  			if numOut == 0 && takesGomega {
   309  				actual = assertionFailure
   310  			} else {
   311  				actual, err = assertion.processReturnValues(values)
   312  				_, isAsyncError := AsPollingSignalError(err)
   313  				if assertionFailure != nil && !isAsyncError {
   314  					err = assertionFailure
   315  				}
   316  			}
   317  			if e := recover(); e != nil {
   318  				if _, isAsyncError := AsPollingSignalError(e); isAsyncError {
   319  					err = e.(error)
   320  				} else if assertionFailure == nil {
   321  					panic(e)
   322  				}
   323  			}
   324  		}()
   325  		values = actualValue.Call(inValues)
   326  		return
   327  	}, nil
   328  }
   329  
   330  func (assertion *AsyncAssertion) afterTimeout() <-chan time.Time {
   331  	if assertion.timeoutInterval >= 0 {
   332  		return time.After(assertion.timeoutInterval)
   333  	}
   334  
   335  	if assertion.asyncType == AsyncAssertionTypeConsistently {
   336  		return time.After(assertion.g.DurationBundle.ConsistentlyDuration)
   337  	} else {
   338  		if assertion.ctx == nil {
   339  			return time.After(assertion.g.DurationBundle.EventuallyTimeout)
   340  		} else {
   341  			return nil
   342  		}
   343  	}
   344  }
   345  
   346  func (assertion *AsyncAssertion) afterPolling() <-chan time.Time {
   347  	if assertion.pollingInterval >= 0 {
   348  		return time.After(assertion.pollingInterval)
   349  	}
   350  	if assertion.asyncType == AsyncAssertionTypeConsistently {
   351  		return time.After(assertion.g.DurationBundle.ConsistentlyPollingInterval)
   352  	} else {
   353  		return time.After(assertion.g.DurationBundle.EventuallyPollingInterval)
   354  	}
   355  }
   356  
   357  func (assertion *AsyncAssertion) matcherSaysStopTrying(matcher types.GomegaMatcher, value interface{}) bool {
   358  	if assertion.actualIsFunc || types.MatchMayChangeInTheFuture(matcher, value) {
   359  		return false
   360  	}
   361  	return true
   362  }
   363  
   364  func (assertion *AsyncAssertion) pollMatcher(matcher types.GomegaMatcher, value interface{}) (matches bool, err error) {
   365  	defer func() {
   366  		if e := recover(); e != nil {
   367  			if _, isAsyncError := AsPollingSignalError(e); isAsyncError {
   368  				err = e.(error)
   369  			} else {
   370  				panic(e)
   371  			}
   372  		}
   373  	}()
   374  
   375  	matches, err = matcher.Match(value)
   376  
   377  	return
   378  }
   379  
   380  func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...interface{}) bool {
   381  	timer := time.Now()
   382  	timeout := assertion.afterTimeout()
   383  	lock := sync.Mutex{}
   384  
   385  	var matches, hasLastValidActual bool
   386  	var actual, lastValidActual interface{}
   387  	var actualErr, matcherErr error
   388  	var oracleMatcherSaysStop bool
   389  
   390  	assertion.g.THelper()
   391  
   392  	pollActual, buildActualPollerErr := assertion.buildActualPoller()
   393  	if buildActualPollerErr != nil {
   394  		assertion.g.Fail(buildActualPollerErr.Error(), 2+assertion.offset)
   395  		return false
   396  	}
   397  
   398  	actual, actualErr = pollActual()
   399  	if actualErr == nil {
   400  		lastValidActual = actual
   401  		hasLastValidActual = true
   402  		oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual)
   403  		matches, matcherErr = assertion.pollMatcher(matcher, actual)
   404  	}
   405  
   406  	renderError := func(preamble string, err error) string {
   407  		message := ""
   408  		if pollingSignalErr, ok := AsPollingSignalError(err); ok {
   409  			message = err.Error()
   410  			for _, attachment := range pollingSignalErr.Attachments {
   411  				message += fmt.Sprintf("\n%s:\n", attachment.Description)
   412  				message += format.Object(attachment.Object, 1)
   413  			}
   414  		} else {
   415  			message = preamble + "\n" + format.Object(err, 1)
   416  		}
   417  		return message
   418  	}
   419  
   420  	messageGenerator := func() string {
   421  		// can be called out of band by Ginkgo if the user requests a progress report
   422  		lock.Lock()
   423  		defer lock.Unlock()
   424  		message := ""
   425  
   426  		if actualErr == nil {
   427  			if matcherErr == nil {
   428  				if desiredMatch != matches {
   429  					if desiredMatch {
   430  						message += matcher.FailureMessage(actual)
   431  					} else {
   432  						message += matcher.NegatedFailureMessage(actual)
   433  					}
   434  				} else {
   435  					if assertion.asyncType == AsyncAssertionTypeConsistently {
   436  						message += "There is no failure as the matcher passed to Consistently has not yet failed"
   437  					} else {
   438  						message += "There is no failure as the matcher passed to Eventually succeeded on its most recent iteration"
   439  					}
   440  				}
   441  			} else {
   442  				var fgErr formattedGomegaError
   443  				if errors.As(actualErr, &fgErr) {
   444  					message += fgErr.FormattedGomegaError() + "\n"
   445  				} else {
   446  					message += renderError(fmt.Sprintf("The matcher passed to %s returned the following error:", assertion.asyncType), matcherErr)
   447  				}
   448  			}
   449  		} else {
   450  			var fgErr formattedGomegaError
   451  			if errors.As(actualErr, &fgErr) {
   452  				message += fgErr.FormattedGomegaError() + "\n"
   453  			} else {
   454  				message += renderError(fmt.Sprintf("The function passed to %s returned the following error:", assertion.asyncType), actualErr)
   455  			}
   456  			if hasLastValidActual {
   457  				message += fmt.Sprintf("\nAt one point, however, the function did return successfully.\nYet, %s failed because", assertion.asyncType)
   458  				_, e := matcher.Match(lastValidActual)
   459  				if e != nil {
   460  					message += renderError(" the matcher returned the following error:", e)
   461  				} else {
   462  					message += " the matcher was not satisfied:\n"
   463  					if desiredMatch {
   464  						message += matcher.FailureMessage(lastValidActual)
   465  					} else {
   466  						message += matcher.NegatedFailureMessage(lastValidActual)
   467  					}
   468  				}
   469  			}
   470  		}
   471  
   472  		description := assertion.buildDescription(optionalDescription...)
   473  		return fmt.Sprintf("%s%s", description, message)
   474  	}
   475  
   476  	fail := func(preamble string) {
   477  		assertion.g.THelper()
   478  		assertion.g.Fail(fmt.Sprintf("%s after %.3fs.\n%s", preamble, time.Since(timer).Seconds(), messageGenerator()), 3+assertion.offset)
   479  	}
   480  
   481  	var contextDone <-chan struct{}
   482  	if assertion.ctx != nil {
   483  		contextDone = assertion.ctx.Done()
   484  		if v, ok := assertion.ctx.Value("GINKGO_SPEC_CONTEXT").(contextWithAttachProgressReporter); ok {
   485  			detach := v.AttachProgressReporter(messageGenerator)
   486  			defer detach()
   487  		}
   488  	}
   489  
   490  	// Used to count the number of times in a row a step passed
   491  	passedRepeatedlyCount := 0
   492  	for {
   493  		var nextPoll <-chan time.Time = nil
   494  		var isTryAgainAfterError = false
   495  
   496  		for _, err := range []error{actualErr, matcherErr} {
   497  			if pollingSignalErr, ok := AsPollingSignalError(err); ok {
   498  				if pollingSignalErr.IsStopTrying() {
   499  					fail("Told to stop trying")
   500  					return false
   501  				}
   502  				if pollingSignalErr.IsTryAgainAfter() {
   503  					nextPoll = time.After(pollingSignalErr.TryAgainDuration())
   504  					isTryAgainAfterError = true
   505  				}
   506  			}
   507  		}
   508  
   509  		if actualErr == nil && matcherErr == nil && matches == desiredMatch {
   510  			if assertion.asyncType == AsyncAssertionTypeEventually {
   511  				passedRepeatedlyCount += 1
   512  				if passedRepeatedlyCount == assertion.mustPassRepeatedly {
   513  					return true
   514  				}
   515  			}
   516  		} else if !isTryAgainAfterError {
   517  			if assertion.asyncType == AsyncAssertionTypeConsistently {
   518  				fail("Failed")
   519  				return false
   520  			}
   521  			// Reset the consecutive pass count
   522  			passedRepeatedlyCount = 0
   523  		}
   524  
   525  		if oracleMatcherSaysStop {
   526  			if assertion.asyncType == AsyncAssertionTypeEventually {
   527  				fail("No future change is possible.  Bailing out early")
   528  				return false
   529  			} else {
   530  				return true
   531  			}
   532  		}
   533  
   534  		if nextPoll == nil {
   535  			nextPoll = assertion.afterPolling()
   536  		}
   537  
   538  		select {
   539  		case <-nextPoll:
   540  			a, e := pollActual()
   541  			lock.Lock()
   542  			actual, actualErr = a, e
   543  			lock.Unlock()
   544  			if actualErr == nil {
   545  				lock.Lock()
   546  				lastValidActual = actual
   547  				hasLastValidActual = true
   548  				lock.Unlock()
   549  				oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual)
   550  				m, e := assertion.pollMatcher(matcher, actual)
   551  				lock.Lock()
   552  				matches, matcherErr = m, e
   553  				lock.Unlock()
   554  			}
   555  		case <-contextDone:
   556  			err := context.Cause(assertion.ctx)
   557  			if err != nil && err != context.Canceled {
   558  				fail(fmt.Sprintf("Context was cancelled (cause: %s)", err))
   559  			} else {
   560  				fail("Context was cancelled")
   561  			}
   562  			return false
   563  		case <-timeout:
   564  			if assertion.asyncType == AsyncAssertionTypeEventually {
   565  				fail("Timed out")
   566  				return false
   567  			} else {
   568  				if isTryAgainAfterError {
   569  					fail("Timed out while waiting on TryAgainAfter")
   570  					return false
   571  				}
   572  				return true
   573  			}
   574  		}
   575  	}
   576  }
   577  

View as plain text