1 package internal
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "reflect"
8 "runtime"
9 "sync"
10 "time"
11
12 "github.com/onsi/gomega/format"
13 "github.com/onsi/gomega/types"
14 )
15
16 var errInterface = reflect.TypeOf((*error)(nil)).Elem()
17 var gomegaType = reflect.TypeOf((*types.Gomega)(nil)).Elem()
18 var contextType = reflect.TypeOf(new(context.Context)).Elem()
19
20 type formattedGomegaError interface {
21 FormattedGomegaError() string
22 }
23
24 type asyncPolledActualError struct {
25 message string
26 }
27
28 func (err *asyncPolledActualError) Error() string {
29 return err.message
30 }
31
32 func (err *asyncPolledActualError) FormattedGomegaError() string {
33 return err.message
34 }
35
36 type contextWithAttachProgressReporter interface {
37 AttachProgressReporter(func() string) func()
38 }
39
40 type asyncGomegaHaltExecutionError struct{}
41
42 func (a asyncGomegaHaltExecutionError) GinkgoRecoverShouldIgnoreThisPanic() {}
43 func (a asyncGomegaHaltExecutionError) Error() string {
44 return `An assertion has failed in a goroutine. You should call
45
46 defer GinkgoRecover()
47
48 at the top of the goroutine that caused this panic. This will allow Ginkgo and Gomega to correctly capture and manage this panic.`
49 }
50
51 type AsyncAssertionType uint
52
53 const (
54 AsyncAssertionTypeEventually AsyncAssertionType = iota
55 AsyncAssertionTypeConsistently
56 )
57
58 func (at AsyncAssertionType) String() string {
59 switch at {
60 case AsyncAssertionTypeEventually:
61 return "Eventually"
62 case AsyncAssertionTypeConsistently:
63 return "Consistently"
64 }
65 return "INVALID ASYNC ASSERTION TYPE"
66 }
67
68 type AsyncAssertion struct {
69 asyncType AsyncAssertionType
70
71 actualIsFunc bool
72 actual interface{}
73 argsToForward []interface{}
74
75 timeoutInterval time.Duration
76 pollingInterval time.Duration
77 mustPassRepeatedly int
78 ctx context.Context
79 offset int
80 g *Gomega
81 }
82
83 func NewAsyncAssertion(asyncType AsyncAssertionType, actualInput interface{}, g *Gomega, timeoutInterval time.Duration, pollingInterval time.Duration, mustPassRepeatedly int, ctx context.Context, offset int) *AsyncAssertion {
84 out := &AsyncAssertion{
85 asyncType: asyncType,
86 timeoutInterval: timeoutInterval,
87 pollingInterval: pollingInterval,
88 mustPassRepeatedly: mustPassRepeatedly,
89 offset: offset,
90 ctx: ctx,
91 g: g,
92 }
93
94 out.actual = actualInput
95 if actualInput != nil && reflect.TypeOf(actualInput).Kind() == reflect.Func {
96 out.actualIsFunc = true
97 }
98
99 return out
100 }
101
102 func (assertion *AsyncAssertion) WithOffset(offset int) types.AsyncAssertion {
103 assertion.offset = offset
104 return assertion
105 }
106
107 func (assertion *AsyncAssertion) WithTimeout(interval time.Duration) types.AsyncAssertion {
108 assertion.timeoutInterval = interval
109 return assertion
110 }
111
112 func (assertion *AsyncAssertion) WithPolling(interval time.Duration) types.AsyncAssertion {
113 assertion.pollingInterval = interval
114 return assertion
115 }
116
117 func (assertion *AsyncAssertion) Within(timeout time.Duration) types.AsyncAssertion {
118 assertion.timeoutInterval = timeout
119 return assertion
120 }
121
122 func (assertion *AsyncAssertion) ProbeEvery(interval time.Duration) types.AsyncAssertion {
123 assertion.pollingInterval = interval
124 return assertion
125 }
126
127 func (assertion *AsyncAssertion) WithContext(ctx context.Context) types.AsyncAssertion {
128 assertion.ctx = ctx
129 return assertion
130 }
131
132 func (assertion *AsyncAssertion) WithArguments(argsToForward ...interface{}) types.AsyncAssertion {
133 assertion.argsToForward = argsToForward
134 return assertion
135 }
136
137 func (assertion *AsyncAssertion) MustPassRepeatedly(count int) types.AsyncAssertion {
138 assertion.mustPassRepeatedly = count
139 return assertion
140 }
141
142 func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
143 assertion.g.THelper()
144 vetOptionalDescription("Asynchronous assertion", optionalDescription...)
145 return assertion.match(matcher, true, optionalDescription...)
146 }
147
148 func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
149 assertion.g.THelper()
150 vetOptionalDescription("Asynchronous assertion", optionalDescription...)
151 return assertion.match(matcher, false, optionalDescription...)
152 }
153
154 func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interface{}) string {
155 switch len(optionalDescription) {
156 case 0:
157 return ""
158 case 1:
159 if describe, ok := optionalDescription[0].(func() string); ok {
160 return describe() + "\n"
161 }
162 }
163 return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n"
164 }
165
166 func (assertion *AsyncAssertion) processReturnValues(values []reflect.Value) (interface{}, error) {
167 if len(values) == 0 {
168 return nil, &asyncPolledActualError{
169 message: fmt.Sprintf("The function passed to %s did not return any values", assertion.asyncType),
170 }
171 }
172
173 actual := values[0].Interface()
174 if _, ok := AsPollingSignalError(actual); ok {
175 return actual, actual.(error)
176 }
177
178 var err error
179 for i, extraValue := range values[1:] {
180 extra := extraValue.Interface()
181 if extra == nil {
182 continue
183 }
184 if _, ok := AsPollingSignalError(extra); ok {
185 return actual, extra.(error)
186 }
187 extraType := reflect.TypeOf(extra)
188 zero := reflect.Zero(extraType).Interface()
189 if reflect.DeepEqual(extra, zero) {
190 continue
191 }
192 if i == len(values)-2 && extraType.Implements(errInterface) {
193 err = extra.(error)
194 }
195 if err == nil {
196 err = &asyncPolledActualError{
197 message: fmt.Sprintf("The function passed to %s had an unexpected non-nil/non-zero return value at index %d:\n%s", assertion.asyncType, i+1, format.Object(extra, 1)),
198 }
199 }
200 }
201
202 return actual, err
203 }
204
205 func (assertion *AsyncAssertion) invalidFunctionError(t reflect.Type) error {
206 return fmt.Errorf(`The function passed to %s had an invalid signature of %s. Functions passed to %s must either:
207
208 (a) have return values or
209 (b) take a Gomega interface as their first argument and use that Gomega instance to make assertions.
210
211 You can learn more at https://onsi.github.io/gomega/#eventually
212 `, assertion.asyncType, t, assertion.asyncType)
213 }
214
215 func (assertion *AsyncAssertion) noConfiguredContextForFunctionError() error {
216 return fmt.Errorf(`The function passed to %s requested a context.Context, but no context has been provided. Please pass one in using %s().WithContext().
217
218 You can learn more at https://onsi.github.io/gomega/#eventually
219 `, assertion.asyncType, assertion.asyncType)
220 }
221
222 func (assertion *AsyncAssertion) argumentMismatchError(t reflect.Type, numProvided int) error {
223 have := "have"
224 if numProvided == 1 {
225 have = "has"
226 }
227 return fmt.Errorf(`The function passed to %s has signature %s takes %d arguments but %d %s been provided. Please use %s().WithArguments() to pass the corect set of arguments.
228
229 You can learn more at https://onsi.github.io/gomega/#eventually
230 `, assertion.asyncType, t, t.NumIn(), numProvided, have, assertion.asyncType)
231 }
232
233 func (assertion *AsyncAssertion) invalidMustPassRepeatedlyError(reason string) error {
234 return fmt.Errorf(`Invalid use of MustPassRepeatedly with %s %s
235
236 You can learn more at https://onsi.github.io/gomega/#eventually
237 `, assertion.asyncType, reason)
238 }
239
240 func (assertion *AsyncAssertion) buildActualPoller() (func() (interface{}, error), error) {
241 if !assertion.actualIsFunc {
242 return func() (interface{}, error) { return assertion.actual, nil }, nil
243 }
244 actualValue := reflect.ValueOf(assertion.actual)
245 actualType := reflect.TypeOf(assertion.actual)
246 numIn, numOut, isVariadic := actualType.NumIn(), actualType.NumOut(), actualType.IsVariadic()
247
248 if numIn == 0 && numOut == 0 {
249 return nil, assertion.invalidFunctionError(actualType)
250 }
251 takesGomega, takesContext := false, false
252 if numIn > 0 {
253 takesGomega, takesContext = actualType.In(0).Implements(gomegaType), actualType.In(0).Implements(contextType)
254 }
255 if takesGomega && numIn > 1 && actualType.In(1).Implements(contextType) {
256 takesContext = true
257 }
258 if takesContext && len(assertion.argsToForward) > 0 && reflect.TypeOf(assertion.argsToForward[0]).Implements(contextType) {
259 takesContext = false
260 }
261 if !takesGomega && numOut == 0 {
262 return nil, assertion.invalidFunctionError(actualType)
263 }
264 if takesContext && assertion.ctx == nil {
265 return nil, assertion.noConfiguredContextForFunctionError()
266 }
267
268 var assertionFailure error
269 inValues := []reflect.Value{}
270 if takesGomega {
271 inValues = append(inValues, reflect.ValueOf(NewGomega(assertion.g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) {
272 skip := 0
273 if len(callerSkip) > 0 {
274 skip = callerSkip[0]
275 }
276 _, file, line, _ := runtime.Caller(skip + 1)
277 assertionFailure = &asyncPolledActualError{
278 message: fmt.Sprintf("The function passed to %s failed at %s:%d with:\n%s", assertion.asyncType, file, line, message),
279 }
280
281 panic(asyncGomegaHaltExecutionError{})
282 })))
283 }
284 if takesContext {
285 inValues = append(inValues, reflect.ValueOf(assertion.ctx))
286 }
287 for _, arg := range assertion.argsToForward {
288 inValues = append(inValues, reflect.ValueOf(arg))
289 }
290
291 if !isVariadic && numIn != len(inValues) {
292 return nil, assertion.argumentMismatchError(actualType, len(inValues))
293 } else if isVariadic && len(inValues) < numIn-1 {
294 return nil, assertion.argumentMismatchError(actualType, len(inValues))
295 }
296
297 if assertion.mustPassRepeatedly != 1 && assertion.asyncType != AsyncAssertionTypeEventually {
298 return nil, assertion.invalidMustPassRepeatedlyError("it can only be used with Eventually")
299 }
300 if assertion.mustPassRepeatedly < 1 {
301 return nil, assertion.invalidMustPassRepeatedlyError("parameter can't be < 1")
302 }
303
304 return func() (actual interface{}, err error) {
305 var values []reflect.Value
306 assertionFailure = nil
307 defer func() {
308 if numOut == 0 && takesGomega {
309 actual = assertionFailure
310 } else {
311 actual, err = assertion.processReturnValues(values)
312 _, isAsyncError := AsPollingSignalError(err)
313 if assertionFailure != nil && !isAsyncError {
314 err = assertionFailure
315 }
316 }
317 if e := recover(); e != nil {
318 if _, isAsyncError := AsPollingSignalError(e); isAsyncError {
319 err = e.(error)
320 } else if assertionFailure == nil {
321 panic(e)
322 }
323 }
324 }()
325 values = actualValue.Call(inValues)
326 return
327 }, nil
328 }
329
330 func (assertion *AsyncAssertion) afterTimeout() <-chan time.Time {
331 if assertion.timeoutInterval >= 0 {
332 return time.After(assertion.timeoutInterval)
333 }
334
335 if assertion.asyncType == AsyncAssertionTypeConsistently {
336 return time.After(assertion.g.DurationBundle.ConsistentlyDuration)
337 } else {
338 if assertion.ctx == nil {
339 return time.After(assertion.g.DurationBundle.EventuallyTimeout)
340 } else {
341 return nil
342 }
343 }
344 }
345
346 func (assertion *AsyncAssertion) afterPolling() <-chan time.Time {
347 if assertion.pollingInterval >= 0 {
348 return time.After(assertion.pollingInterval)
349 }
350 if assertion.asyncType == AsyncAssertionTypeConsistently {
351 return time.After(assertion.g.DurationBundle.ConsistentlyPollingInterval)
352 } else {
353 return time.After(assertion.g.DurationBundle.EventuallyPollingInterval)
354 }
355 }
356
357 func (assertion *AsyncAssertion) matcherSaysStopTrying(matcher types.GomegaMatcher, value interface{}) bool {
358 if assertion.actualIsFunc || types.MatchMayChangeInTheFuture(matcher, value) {
359 return false
360 }
361 return true
362 }
363
364 func (assertion *AsyncAssertion) pollMatcher(matcher types.GomegaMatcher, value interface{}) (matches bool, err error) {
365 defer func() {
366 if e := recover(); e != nil {
367 if _, isAsyncError := AsPollingSignalError(e); isAsyncError {
368 err = e.(error)
369 } else {
370 panic(e)
371 }
372 }
373 }()
374
375 matches, err = matcher.Match(value)
376
377 return
378 }
379
380 func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...interface{}) bool {
381 timer := time.Now()
382 timeout := assertion.afterTimeout()
383 lock := sync.Mutex{}
384
385 var matches, hasLastValidActual bool
386 var actual, lastValidActual interface{}
387 var actualErr, matcherErr error
388 var oracleMatcherSaysStop bool
389
390 assertion.g.THelper()
391
392 pollActual, buildActualPollerErr := assertion.buildActualPoller()
393 if buildActualPollerErr != nil {
394 assertion.g.Fail(buildActualPollerErr.Error(), 2+assertion.offset)
395 return false
396 }
397
398 actual, actualErr = pollActual()
399 if actualErr == nil {
400 lastValidActual = actual
401 hasLastValidActual = true
402 oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual)
403 matches, matcherErr = assertion.pollMatcher(matcher, actual)
404 }
405
406 renderError := func(preamble string, err error) string {
407 message := ""
408 if pollingSignalErr, ok := AsPollingSignalError(err); ok {
409 message = err.Error()
410 for _, attachment := range pollingSignalErr.Attachments {
411 message += fmt.Sprintf("\n%s:\n", attachment.Description)
412 message += format.Object(attachment.Object, 1)
413 }
414 } else {
415 message = preamble + "\n" + format.Object(err, 1)
416 }
417 return message
418 }
419
420 messageGenerator := func() string {
421
422 lock.Lock()
423 defer lock.Unlock()
424 message := ""
425
426 if actualErr == nil {
427 if matcherErr == nil {
428 if desiredMatch != matches {
429 if desiredMatch {
430 message += matcher.FailureMessage(actual)
431 } else {
432 message += matcher.NegatedFailureMessage(actual)
433 }
434 } else {
435 if assertion.asyncType == AsyncAssertionTypeConsistently {
436 message += "There is no failure as the matcher passed to Consistently has not yet failed"
437 } else {
438 message += "There is no failure as the matcher passed to Eventually succeeded on its most recent iteration"
439 }
440 }
441 } else {
442 var fgErr formattedGomegaError
443 if errors.As(actualErr, &fgErr) {
444 message += fgErr.FormattedGomegaError() + "\n"
445 } else {
446 message += renderError(fmt.Sprintf("The matcher passed to %s returned the following error:", assertion.asyncType), matcherErr)
447 }
448 }
449 } else {
450 var fgErr formattedGomegaError
451 if errors.As(actualErr, &fgErr) {
452 message += fgErr.FormattedGomegaError() + "\n"
453 } else {
454 message += renderError(fmt.Sprintf("The function passed to %s returned the following error:", assertion.asyncType), actualErr)
455 }
456 if hasLastValidActual {
457 message += fmt.Sprintf("\nAt one point, however, the function did return successfully.\nYet, %s failed because", assertion.asyncType)
458 _, e := matcher.Match(lastValidActual)
459 if e != nil {
460 message += renderError(" the matcher returned the following error:", e)
461 } else {
462 message += " the matcher was not satisfied:\n"
463 if desiredMatch {
464 message += matcher.FailureMessage(lastValidActual)
465 } else {
466 message += matcher.NegatedFailureMessage(lastValidActual)
467 }
468 }
469 }
470 }
471
472 description := assertion.buildDescription(optionalDescription...)
473 return fmt.Sprintf("%s%s", description, message)
474 }
475
476 fail := func(preamble string) {
477 assertion.g.THelper()
478 assertion.g.Fail(fmt.Sprintf("%s after %.3fs.\n%s", preamble, time.Since(timer).Seconds(), messageGenerator()), 3+assertion.offset)
479 }
480
481 var contextDone <-chan struct{}
482 if assertion.ctx != nil {
483 contextDone = assertion.ctx.Done()
484 if v, ok := assertion.ctx.Value("GINKGO_SPEC_CONTEXT").(contextWithAttachProgressReporter); ok {
485 detach := v.AttachProgressReporter(messageGenerator)
486 defer detach()
487 }
488 }
489
490
491 passedRepeatedlyCount := 0
492 for {
493 var nextPoll <-chan time.Time = nil
494 var isTryAgainAfterError = false
495
496 for _, err := range []error{actualErr, matcherErr} {
497 if pollingSignalErr, ok := AsPollingSignalError(err); ok {
498 if pollingSignalErr.IsStopTrying() {
499 fail("Told to stop trying")
500 return false
501 }
502 if pollingSignalErr.IsTryAgainAfter() {
503 nextPoll = time.After(pollingSignalErr.TryAgainDuration())
504 isTryAgainAfterError = true
505 }
506 }
507 }
508
509 if actualErr == nil && matcherErr == nil && matches == desiredMatch {
510 if assertion.asyncType == AsyncAssertionTypeEventually {
511 passedRepeatedlyCount += 1
512 if passedRepeatedlyCount == assertion.mustPassRepeatedly {
513 return true
514 }
515 }
516 } else if !isTryAgainAfterError {
517 if assertion.asyncType == AsyncAssertionTypeConsistently {
518 fail("Failed")
519 return false
520 }
521
522 passedRepeatedlyCount = 0
523 }
524
525 if oracleMatcherSaysStop {
526 if assertion.asyncType == AsyncAssertionTypeEventually {
527 fail("No future change is possible. Bailing out early")
528 return false
529 } else {
530 return true
531 }
532 }
533
534 if nextPoll == nil {
535 nextPoll = assertion.afterPolling()
536 }
537
538 select {
539 case <-nextPoll:
540 a, e := pollActual()
541 lock.Lock()
542 actual, actualErr = a, e
543 lock.Unlock()
544 if actualErr == nil {
545 lock.Lock()
546 lastValidActual = actual
547 hasLastValidActual = true
548 lock.Unlock()
549 oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual)
550 m, e := assertion.pollMatcher(matcher, actual)
551 lock.Lock()
552 matches, matcherErr = m, e
553 lock.Unlock()
554 }
555 case <-contextDone:
556 err := context.Cause(assertion.ctx)
557 if err != nil && err != context.Canceled {
558 fail(fmt.Sprintf("Context was cancelled (cause: %s)", err))
559 } else {
560 fail("Context was cancelled")
561 }
562 return false
563 case <-timeout:
564 if assertion.asyncType == AsyncAssertionTypeEventually {
565 fail("Timed out")
566 return false
567 } else {
568 if isTryAgainAfterError {
569 fail("Timed out while waiting on TryAgainAfter")
570 return false
571 }
572 return true
573 }
574 }
575 }
576 }
577
View as plain text