1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package gax
31
32 import (
33 "context"
34 "errors"
35 "testing"
36 "time"
37
38 "github.com/google/go-cmp/cmp"
39 "github.com/google/go-cmp/cmp/cmpopts"
40 "github.com/googleapis/gax-go/v2/apierror"
41 "google.golang.org/genproto/googleapis/rpc/errdetails"
42 "google.golang.org/grpc/codes"
43 "google.golang.org/grpc/status"
44 )
45
46 var canceledContext context.Context
47
48 func init() {
49 ctx, cancel := context.WithCancel(context.Background())
50 cancel()
51 canceledContext = ctx
52 }
53
54
55 type recordSleeper int
56
57 func (s *recordSleeper) sleep(ctx context.Context, _ time.Duration) error {
58 *s++
59 return ctx.Err()
60 }
61
62 type boolRetryer bool
63
64 func (r boolRetryer) Retry(err error) (time.Duration, bool) { return 0, bool(r) }
65
66 func TestInvokeSuccess(t *testing.T) {
67 apiCall := func(context.Context, CallSettings) error { return nil }
68 var sp recordSleeper
69 err := invoke(context.Background(), apiCall, CallSettings{}, sp.sleep)
70
71 if err != nil {
72 t.Errorf("found error %s, want nil", err)
73 }
74 if sp != 0 {
75 t.Errorf("slept %d times, should not have slept since the call succeeded", int(sp))
76 }
77 }
78
79 func TestInvokeCertificateError(t *testing.T) {
80 stat := status.New(codes.Unavailable, "x509: certificate signed by unknown authority")
81 apiErr := stat.Err()
82 apiCall := func(context.Context, CallSettings) error { return apiErr }
83 var sp recordSleeper
84 err := invoke(context.Background(), apiCall, CallSettings{}, sp.sleep)
85 if diff := cmp.Diff(err, apiErr, cmpopts.EquateErrors()); diff != "" {
86 t.Errorf("got(-), want(+): \n%s", diff)
87 }
88 }
89
90 func TestInvokeAPIError(t *testing.T) {
91 qf := &errdetails.QuotaFailure{
92 Violations: []*errdetails.QuotaFailure_Violation{{Subject: "Foo", Description: "Bar"}},
93 }
94 stat, _ := status.New(codes.ResourceExhausted, "Per user quota has been exhausted").WithDetails(qf)
95 apiErr, _ := apierror.FromError(stat.Err())
96 apiCall := func(context.Context, CallSettings) error { return stat.Err() }
97 var sp recordSleeper
98 err := invoke(context.Background(), apiCall, CallSettings{}, sp.sleep)
99 if diff := cmp.Diff(err.Error(), apiErr.Error()); diff != "" {
100 t.Errorf("got(-), want(+): \n%s", diff)
101 }
102 if sp != 0 {
103 t.Errorf("slept %d times, should not have slept since the call succeeded", int(sp))
104 }
105 }
106
107 func TestInvokeCtxError(t *testing.T) {
108 ctxErr := context.DeadlineExceeded
109 apiCall := func(context.Context, CallSettings) error { return ctxErr }
110 var sp recordSleeper
111 err := invoke(context.Background(), apiCall, CallSettings{}, sp.sleep)
112 if err != ctxErr {
113 t.Errorf("found error %s, want %s", err, ctxErr)
114 }
115 if sp != 0 {
116 t.Errorf("slept %d times, should not have slept since the call succeeded", int(sp))
117 }
118
119 }
120
121 func TestInvokeNoRetry(t *testing.T) {
122 apiErr := errors.New("foo error")
123 apiCall := func(context.Context, CallSettings) error { return apiErr }
124 var sp recordSleeper
125 err := invoke(context.Background(), apiCall, CallSettings{}, sp.sleep)
126
127 if err != apiErr {
128 t.Errorf("found error %s, want %s", err, apiErr)
129 }
130 if sp != 0 {
131 t.Errorf("slept %d times, should not have slept since retry is not specified", int(sp))
132 }
133 }
134
135 func TestInvokeNilRetry(t *testing.T) {
136 apiErr := errors.New("foo error")
137 apiCall := func(context.Context, CallSettings) error { return apiErr }
138 var settings CallSettings
139 WithRetry(func() Retryer { return nil }).Resolve(&settings)
140 var sp recordSleeper
141 err := invoke(context.Background(), apiCall, settings, sp.sleep)
142
143 if err != apiErr {
144 t.Errorf("found error %s, want %s", err, apiErr)
145 }
146 if sp != 0 {
147 t.Errorf("slept %d times, should not have slept since retry is not specified", int(sp))
148 }
149 }
150
151 func TestInvokeNeverRetry(t *testing.T) {
152 apiErr := errors.New("foo error")
153 apiCall := func(context.Context, CallSettings) error { return apiErr }
154 var settings CallSettings
155 WithRetry(func() Retryer { return boolRetryer(false) }).Resolve(&settings)
156 var sp recordSleeper
157 err := invoke(context.Background(), apiCall, settings, sp.sleep)
158
159 if err != apiErr {
160 t.Errorf("found error %s, want %s", err, apiErr)
161 }
162 if sp != 0 {
163 t.Errorf("slept %d times, should not have slept since retry is not specified", int(sp))
164 }
165 }
166
167 func TestInvokeRetry(t *testing.T) {
168 const target = 3
169
170 retryNum := 0
171 apiErr := errors.New("foo error")
172 apiCall := func(context.Context, CallSettings) error {
173 retryNum++
174 if retryNum < target {
175 return apiErr
176 }
177 return nil
178 }
179 var settings CallSettings
180 WithRetry(func() Retryer { return boolRetryer(true) }).Resolve(&settings)
181 var sp recordSleeper
182 err := invoke(context.Background(), apiCall, settings, sp.sleep)
183
184 if err != nil {
185 t.Errorf("found error %s, want nil, call should have succeeded after %d tries", err, target)
186 }
187 if sp != target-1 {
188 t.Errorf("retried %d times, want %d", int(sp), int(target-1))
189 }
190 }
191
192 func TestInvokeRetryTimeout(t *testing.T) {
193 apiErr := errors.New("foo error")
194 apiCall := func(context.Context, CallSettings) error { return apiErr }
195 var settings CallSettings
196 WithRetry(func() Retryer { return boolRetryer(true) }).Resolve(&settings)
197 var sp recordSleeper
198
199 err := invoke(canceledContext, apiCall, settings, sp.sleep)
200
201 if err != context.Canceled {
202 t.Errorf("found error %s, want %s", err, context.Canceled)
203 }
204 }
205
206 func TestInvokeWithTimeout(t *testing.T) {
207
208
209
210
211 sleepingCall := func(sleep time.Duration) APICall {
212 return func(ctx context.Context, _ CallSettings) error {
213 time.Sleep(sleep)
214 return ctx.Err()
215 }
216 }
217
218 bg := context.Background()
219 preset, pcc := context.WithTimeout(bg, 10*time.Millisecond)
220 defer pcc()
221
222 for _, tst := range []struct {
223 name string
224 timeout time.Duration
225 sleep time.Duration
226 ctx context.Context
227 want error
228 }{
229 {
230 name: "success",
231 timeout: 10 * time.Millisecond,
232 sleep: 1 * time.Millisecond,
233 ctx: bg,
234 want: nil,
235 },
236 {
237 name: "respect_context_deadline",
238 timeout: 1 * time.Millisecond,
239 sleep: 3 * time.Millisecond,
240 ctx: preset,
241 want: nil,
242 },
243 {
244 name: "with_timeout_deadline_exceeded",
245 timeout: 1 * time.Millisecond,
246 sleep: 3 * time.Millisecond,
247 ctx: bg,
248 want: context.DeadlineExceeded,
249 },
250 } {
251 t.Run(tst.name, func(t *testing.T) {
252
253
254 var sp recordSleeper
255 var settings CallSettings
256
257 WithTimeout(tst.timeout).Resolve(&settings)
258
259 err := invoke(tst.ctx, sleepingCall(tst.sleep), settings, sp.sleep)
260
261 if err != tst.want {
262 t.Errorf("found error %v, want %v", err, tst.want)
263 }
264 })
265 }
266 }
267
View as plain text