1 package lb_test
2
3 import (
4 "context"
5 "errors"
6 "testing"
7 "time"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/sd"
11 "github.com/go-kit/kit/sd/lb"
12 )
13
14 func TestRetryMaxTotalFail(t *testing.T) {
15 var (
16 endpoints = sd.FixedEndpointer{}
17 rr = lb.NewRoundRobin(endpoints)
18 retry = lb.Retry(999, time.Second, rr)
19 ctx = context.Background()
20 )
21 if _, err := retry(ctx, struct{}{}); err == nil {
22 t.Errorf("expected error, got none")
23 }
24 }
25
26 func TestRetryMaxPartialFail(t *testing.T) {
27 var (
28 endpoints = []endpoint.Endpoint{
29 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
30 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
31 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
32 }
33 endpointer = sd.FixedEndpointer{
34 0: endpoints[0],
35 1: endpoints[1],
36 2: endpoints[2],
37 }
38 retries = len(endpoints) - 1
39 rr = lb.NewRoundRobin(endpointer)
40 ctx = context.Background()
41 )
42 if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err == nil {
43 t.Errorf("expected error two, got none")
44 }
45 }
46
47 func TestRetryMaxSuccess(t *testing.T) {
48 var (
49 endpoints = []endpoint.Endpoint{
50 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
51 func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
52 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
53 }
54 endpointer = sd.FixedEndpointer{
55 0: endpoints[0],
56 1: endpoints[1],
57 2: endpoints[2],
58 }
59 retries = len(endpoints)
60 rr = lb.NewRoundRobin(endpointer)
61 ctx = context.Background()
62 )
63 if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err != nil {
64 t.Error(err)
65 }
66 }
67
68 func TestRetryTimeout(t *testing.T) {
69 var (
70 step = make(chan struct{})
71 e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
72 timeout = time.Millisecond
73 retry = lb.Retry(999, timeout, lb.NewRoundRobin(sd.FixedEndpointer{0: e}))
74 errs = make(chan error, 1)
75 invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
76 )
77
78 go func() { step <- struct{}{} }()
79 invoke()
80 if err := <-errs; err != nil {
81 t.Error(err)
82 }
83
84 go func() { time.Sleep(10 * timeout); step <- struct{}{} }()
85 invoke()
86 if err := <-errs; err != context.DeadlineExceeded {
87 t.Errorf("wanted %v, got none", context.DeadlineExceeded)
88 }
89 }
90
91 func TestAbortEarlyCustomMessage(t *testing.T) {
92 var (
93 myErr = errors.New("aborting early")
94 cb = func(int, error) (bool, error) { return false, myErr }
95 endpoints = sd.FixedEndpointer{}
96 rr = lb.NewRoundRobin(endpoints)
97 retry = lb.RetryWithCallback(time.Second, rr, cb)
98 ctx = context.Background()
99 )
100 _, err := retry(ctx, struct{}{})
101 if want, have := myErr, err.(lb.RetryError).Final; want != have {
102 t.Errorf("want %v, have %v", want, have)
103 }
104 }
105
106 func TestErrorPassedUnchangedToCallback(t *testing.T) {
107 var (
108 myErr = errors.New("my custom error")
109 cb = func(_ int, err error) (bool, error) {
110 if want, have := myErr, err; want != have {
111 t.Errorf("want %v, have %v", want, have)
112 }
113 return false, nil
114 }
115 endpoint = func(ctx context.Context, request interface{}) (interface{}, error) {
116 return nil, myErr
117 }
118 endpoints = sd.FixedEndpointer{endpoint}
119 rr = lb.NewRoundRobin(endpoints)
120 retry = lb.RetryWithCallback(time.Second, rr, cb)
121 ctx = context.Background()
122 )
123 _, err := retry(ctx, struct{}{})
124 if want, have := myErr, err.(lb.RetryError).Final; want != have {
125 t.Errorf("want %v, have %v", want, have)
126 }
127 }
128
129 func TestHandleNilCallback(t *testing.T) {
130 var (
131 endpointer = sd.FixedEndpointer{
132 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
133 }
134 rr = lb.NewRoundRobin(endpointer)
135 ctx = context.Background()
136 )
137 retry := lb.RetryWithCallback(time.Second, rr, nil)
138 if _, err := retry(ctx, struct{}{}); err != nil {
139 t.Error(err)
140 }
141 }
142
View as plain text