...
1 package circuitbreaker_test
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "path/filepath"
8 "runtime"
9 "testing"
10 "time"
11
12 "github.com/go-kit/kit/endpoint"
13 )
14
15 func testFailingEndpoint(
16 t *testing.T,
17 breaker endpoint.Middleware,
18 primeWith int,
19 shouldPass func(int) bool,
20 requestDelay time.Duration,
21 openCircuitError string,
22 ) {
23 _, file, line, _ := runtime.Caller(1)
24 caller := fmt.Sprintf("%s:%d", filepath.Base(file), line)
25
26
27 m := mock{}
28 var e endpoint.Endpoint
29 e = m.endpoint
30 e = breaker(e)
31
32
33 for i := 0; i < primeWith; i++ {
34 if _, err := e(context.Background(), struct{}{}); err != nil {
35 t.Fatalf("%s: during priming, got error: %v", caller, err)
36 }
37 time.Sleep(requestDelay)
38 }
39
40
41 m.err = errors.New("tragedy+disaster")
42 m.through = 0
43
44
45 for i := 0; shouldPass(i); i++ {
46 if _, err := e(context.Background(), struct{}{}); err != m.err {
47 t.Fatalf("%s: want %v, have %v", caller, m.err, err)
48 }
49 time.Sleep(requestDelay)
50 }
51 through := m.through
52
53
54 for i := 0; i < 10; i++ {
55 if _, err := e(context.Background(), struct{}{}); err.Error() != openCircuitError {
56 t.Fatalf("%s: want %q, have %q", caller, openCircuitError, err.Error())
57 }
58 time.Sleep(requestDelay)
59 }
60
61
62 if want, have := through, m.through; want != have {
63 t.Errorf("%s: want %d, have %d", caller, want, have)
64 }
65 }
66
67 type mock struct {
68 through int
69 err error
70 }
71
72 func (m *mock) endpoint(context.Context, interface{}) (interface{}, error) {
73 m.through++
74 return struct{}{}, m.err
75 }
76
View as plain text