1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25 package retryablehttp
26
27 import (
28 "bytes"
29 "context"
30 "crypto/x509"
31 "fmt"
32 "io"
33 "io/ioutil"
34 "log"
35 "math"
36 "math/rand"
37 "net/http"
38 "net/url"
39 "os"
40 "regexp"
41 "strconv"
42 "strings"
43 "sync"
44 "time"
45
46 cleanhttp "github.com/hashicorp/go-cleanhttp"
47 )
48
49 var (
50
51 defaultRetryWaitMin = 1 * time.Second
52 defaultRetryWaitMax = 30 * time.Second
53 defaultRetryMax = 4
54
55
56 defaultLogger = log.New(os.Stderr, "", log.LstdFlags)
57
58
59
60 defaultClient = NewClient()
61
62
63
64 respReadLimit = int64(4096)
65
66
67
68
69 redirectsErrorRe = regexp.MustCompile(`stopped after \d+ redirects\z`)
70
71
72
73
74 schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)
75
76
77
78
79 notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`)
80 )
81
82
83 type ReaderFunc func() (io.Reader, error)
84
85
86
87
88
89
90
91
92
93
94
95 type ResponseHandlerFunc func(*http.Response) error
96
97
98
99 type LenReader interface {
100 Len() int
101 }
102
103
104 type Request struct {
105
106
107 body ReaderFunc
108
109 responseHandler ResponseHandlerFunc
110
111
112
113 *http.Request
114 }
115
116
117
118 func (r *Request) WithContext(ctx context.Context) *Request {
119 return &Request{
120 body: r.body,
121 responseHandler: r.responseHandler,
122 Request: r.Request.WithContext(ctx),
123 }
124 }
125
126
127 func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) {
128 r.responseHandler = fn
129 }
130
131
132
133
134
135
136
137 func (r *Request) BodyBytes() ([]byte, error) {
138 if r.body == nil {
139 return nil, nil
140 }
141 body, err := r.body()
142 if err != nil {
143 return nil, err
144 }
145 buf := new(bytes.Buffer)
146 _, err = buf.ReadFrom(body)
147 if err != nil {
148 return nil, err
149 }
150 return buf.Bytes(), nil
151 }
152
153
154
155
156 func (r *Request) SetBody(rawBody interface{}) error {
157 bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody)
158 if err != nil {
159 return err
160 }
161 r.body = bodyReader
162 r.ContentLength = contentLength
163 if bodyReader != nil {
164 r.GetBody = func() (io.ReadCloser, error) {
165 body, err := bodyReader()
166 if err != nil {
167 return nil, err
168 }
169 if rc, ok := body.(io.ReadCloser); ok {
170 return rc, nil
171 }
172 return io.NopCloser(body), nil
173 }
174 } else {
175 r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil }
176 }
177 return nil
178 }
179
180
181
182
183
184
185
186 func (r *Request) WriteTo(w io.Writer) (int64, error) {
187 body, err := r.body()
188 if err != nil {
189 return 0, err
190 }
191 if c, ok := body.(io.Closer); ok {
192 defer c.Close()
193 }
194 return io.Copy(w, body)
195 }
196
197 func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, error) {
198 var bodyReader ReaderFunc
199 var contentLength int64
200
201 switch body := rawBody.(type) {
202
203 case ReaderFunc:
204 bodyReader = body
205 tmp, err := body()
206 if err != nil {
207 return nil, 0, err
208 }
209 if lr, ok := tmp.(LenReader); ok {
210 contentLength = int64(lr.Len())
211 }
212 if c, ok := tmp.(io.Closer); ok {
213 c.Close()
214 }
215
216 case func() (io.Reader, error):
217 bodyReader = body
218 tmp, err := body()
219 if err != nil {
220 return nil, 0, err
221 }
222 if lr, ok := tmp.(LenReader); ok {
223 contentLength = int64(lr.Len())
224 }
225 if c, ok := tmp.(io.Closer); ok {
226 c.Close()
227 }
228
229
230
231 case []byte:
232 buf := body
233 bodyReader = func() (io.Reader, error) {
234 return bytes.NewReader(buf), nil
235 }
236 contentLength = int64(len(buf))
237
238
239
240 case *bytes.Buffer:
241 buf := body
242 bodyReader = func() (io.Reader, error) {
243 return bytes.NewReader(buf.Bytes()), nil
244 }
245 contentLength = int64(buf.Len())
246
247
248
249
250 case *bytes.Reader:
251 buf, err := ioutil.ReadAll(body)
252 if err != nil {
253 return nil, 0, err
254 }
255 bodyReader = func() (io.Reader, error) {
256 return bytes.NewReader(buf), nil
257 }
258 contentLength = int64(len(buf))
259
260
261 case io.ReadSeeker:
262 raw := body
263 bodyReader = func() (io.Reader, error) {
264 _, err := raw.Seek(0, 0)
265 return ioutil.NopCloser(raw), err
266 }
267 if lr, ok := raw.(LenReader); ok {
268 contentLength = int64(lr.Len())
269 }
270
271
272 case io.Reader:
273 buf, err := ioutil.ReadAll(body)
274 if err != nil {
275 return nil, 0, err
276 }
277 if len(buf) == 0 {
278 bodyReader = func() (io.Reader, error) {
279 return http.NoBody, nil
280 }
281 contentLength = 0
282 } else {
283 bodyReader = func() (io.Reader, error) {
284 return bytes.NewReader(buf), nil
285 }
286 contentLength = int64(len(buf))
287 }
288
289
290 case nil:
291
292
293 default:
294 return nil, 0, fmt.Errorf("cannot handle type %T", rawBody)
295 }
296 return bodyReader, contentLength, nil
297 }
298
299
300 func FromRequest(r *http.Request) (*Request, error) {
301 bodyReader, _, err := getBodyReaderAndContentLength(r.Body)
302 if err != nil {
303 return nil, err
304 }
305
306 return &Request{body: bodyReader, Request: r}, nil
307 }
308
309
310 func NewRequest(method, url string, rawBody interface{}) (*Request, error) {
311 return NewRequestWithContext(context.Background(), method, url, rawBody)
312 }
313
314
315
316
317
318 func NewRequestWithContext(ctx context.Context, method, url string, rawBody interface{}) (*Request, error) {
319 httpReq, err := http.NewRequestWithContext(ctx, method, url, nil)
320 if err != nil {
321 return nil, err
322 }
323
324 req := &Request{
325 Request: httpReq,
326 }
327 if err := req.SetBody(rawBody); err != nil {
328 return nil, err
329 }
330
331 return req, nil
332 }
333
334
335
336 type Logger interface {
337 Printf(string, ...interface{})
338 }
339
340
341
342
343
344
345 type LeveledLogger interface {
346 Error(msg string, keysAndValues ...interface{})
347 Info(msg string, keysAndValues ...interface{})
348 Debug(msg string, keysAndValues ...interface{})
349 Warn(msg string, keysAndValues ...interface{})
350 }
351
352
353
354 type hookLogger struct {
355 LeveledLogger
356 }
357
358 func (h hookLogger) Printf(s string, args ...interface{}) {
359 h.Info(fmt.Sprintf(s, args...))
360 }
361
362
363
364
365
366 type RequestLogHook func(Logger, *http.Request, int)
367
368
369
370
371
372
373 type ResponseLogHook func(Logger, *http.Response)
374
375
376
377
378
379
380
381
382
383 type CheckRetry func(ctx context.Context, resp *http.Response, err error) (bool, error)
384
385
386
387
388 type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration
389
390
391
392
393
394 type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error)
395
396
397
398 type Client struct {
399 HTTPClient *http.Client
400 Logger interface{}
401
402 RetryWaitMin time.Duration
403 RetryWaitMax time.Duration
404 RetryMax int
405
406
407
408 RequestLogHook RequestLogHook
409
410
411
412 ResponseLogHook ResponseLogHook
413
414
415
416 CheckRetry CheckRetry
417
418
419 Backoff Backoff
420
421
422 ErrorHandler ErrorHandler
423
424 loggerInit sync.Once
425 clientInit sync.Once
426 }
427
428
429 func NewClient() *Client {
430 return &Client{
431 HTTPClient: cleanhttp.DefaultPooledClient(),
432 Logger: defaultLogger,
433 RetryWaitMin: defaultRetryWaitMin,
434 RetryWaitMax: defaultRetryWaitMax,
435 RetryMax: defaultRetryMax,
436 CheckRetry: DefaultRetryPolicy,
437 Backoff: DefaultBackoff,
438 }
439 }
440
441 func (c *Client) logger() interface{} {
442 c.loggerInit.Do(func() {
443 if c.Logger == nil {
444 return
445 }
446
447 switch c.Logger.(type) {
448 case Logger, LeveledLogger:
449
450 default:
451
452 panic(fmt.Sprintf("invalid logger type passed, must be Logger or LeveledLogger, was %T", c.Logger))
453 }
454 })
455
456 return c.Logger
457 }
458
459
460
461 func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
462
463 if ctx.Err() != nil {
464 return false, ctx.Err()
465 }
466
467
468 shouldRetry, _ := baseRetryPolicy(resp, err)
469 return shouldRetry, nil
470 }
471
472
473
474
475 func ErrorPropagatedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
476
477 if ctx.Err() != nil {
478 return false, ctx.Err()
479 }
480
481 return baseRetryPolicy(resp, err)
482 }
483
484 func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
485 if err != nil {
486 if v, ok := err.(*url.Error); ok {
487
488 if redirectsErrorRe.MatchString(v.Error()) {
489 return false, v
490 }
491
492
493 if schemeErrorRe.MatchString(v.Error()) {
494 return false, v
495 }
496
497
498 if notTrustedErrorRe.MatchString(v.Error()) {
499 return false, v
500 }
501 if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
502 return false, v
503 }
504 }
505
506
507 return true, nil
508 }
509
510
511
512
513 if resp.StatusCode == http.StatusTooManyRequests {
514 return true, nil
515 }
516
517
518
519
520
521 if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != http.StatusNotImplemented) {
522 return true, fmt.Errorf("unexpected HTTP status %s", resp.Status)
523 }
524
525 return false, nil
526 }
527
528
529
530
531
532
533
534
535 func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
536 if resp != nil {
537 if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
538 if s, ok := resp.Header["Retry-After"]; ok {
539 if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil {
540 return time.Second * time.Duration(sleep)
541 }
542 }
543 }
544 }
545
546 mult := math.Pow(2, float64(attemptNum)) * float64(min)
547 sleep := time.Duration(mult)
548 if float64(sleep) != mult || sleep > max {
549 sleep = max
550 }
551 return sleep
552 }
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570 func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
571
572 attemptNum++
573
574 if max <= min {
575
576
577 return min * time.Duration(attemptNum)
578 }
579
580
581 rand := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
582
583
584
585
586
587 jitter := rand.Float64() * float64(max-min)
588 jitterMin := int64(jitter) + int64(min)
589 return time.Duration(jitterMin * int64(attemptNum))
590 }
591
592
593
594
595 func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Response, error) {
596 return resp, err
597 }
598
599
600 func (c *Client) Do(req *Request) (*http.Response, error) {
601 c.clientInit.Do(func() {
602 if c.HTTPClient == nil {
603 c.HTTPClient = cleanhttp.DefaultPooledClient()
604 }
605 })
606
607 logger := c.logger()
608
609 if logger != nil {
610 switch v := logger.(type) {
611 case LeveledLogger:
612 v.Debug("performing request", "method", req.Method, "url", req.URL)
613 case Logger:
614 v.Printf("[DEBUG] %s %s", req.Method, req.URL)
615 }
616 }
617
618 var resp *http.Response
619 var attempt int
620 var shouldRetry bool
621 var doErr, respErr, checkErr error
622
623 for i := 0; ; i++ {
624 doErr, respErr = nil, nil
625 attempt++
626
627
628 if req.body != nil {
629 body, err := req.body()
630 if err != nil {
631 c.HTTPClient.CloseIdleConnections()
632 return resp, err
633 }
634 if c, ok := body.(io.ReadCloser); ok {
635 req.Body = c
636 } else {
637 req.Body = ioutil.NopCloser(body)
638 }
639 }
640
641 if c.RequestLogHook != nil {
642 switch v := logger.(type) {
643 case LeveledLogger:
644 c.RequestLogHook(hookLogger{v}, req.Request, i)
645 case Logger:
646 c.RequestLogHook(v, req.Request, i)
647 default:
648 c.RequestLogHook(nil, req.Request, i)
649 }
650 }
651
652
653 resp, doErr = c.HTTPClient.Do(req.Request)
654
655
656 shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)
657 if !shouldRetry && doErr == nil && req.responseHandler != nil {
658 respErr = req.responseHandler(resp)
659 shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr)
660 }
661
662 err := doErr
663 if respErr != nil {
664 err = respErr
665 }
666 if err != nil {
667 switch v := logger.(type) {
668 case LeveledLogger:
669 v.Error("request failed", "error", err, "method", req.Method, "url", req.URL)
670 case Logger:
671 v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err)
672 }
673 } else {
674
675
676 if c.ResponseLogHook != nil {
677
678 switch v := logger.(type) {
679 case LeveledLogger:
680 c.ResponseLogHook(hookLogger{v}, resp)
681 case Logger:
682 c.ResponseLogHook(v, resp)
683 default:
684 c.ResponseLogHook(nil, resp)
685 }
686 }
687 }
688
689 if !shouldRetry {
690 break
691 }
692
693
694
695 remain := c.RetryMax - i
696 if remain <= 0 {
697 break
698 }
699
700
701 if doErr == nil {
702 c.drainBody(resp.Body)
703 }
704
705 wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp)
706 if logger != nil {
707 desc := fmt.Sprintf("%s %s", req.Method, req.URL)
708 if resp != nil {
709 desc = fmt.Sprintf("%s (status: %d)", desc, resp.StatusCode)
710 }
711 switch v := logger.(type) {
712 case LeveledLogger:
713 v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain)
714 case Logger:
715 v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain)
716 }
717 }
718 timer := time.NewTimer(wait)
719 select {
720 case <-req.Context().Done():
721 timer.Stop()
722 c.HTTPClient.CloseIdleConnections()
723 return nil, req.Context().Err()
724 case <-timer.C:
725 }
726
727
728
729 httpreq := *req.Request
730 req.Request = &httpreq
731 }
732
733
734 if doErr == nil && respErr == nil && checkErr == nil && !shouldRetry {
735 return resp, nil
736 }
737
738 defer c.HTTPClient.CloseIdleConnections()
739
740 var err error
741 if checkErr != nil {
742 err = checkErr
743 } else if respErr != nil {
744 err = respErr
745 } else {
746 err = doErr
747 }
748
749 if c.ErrorHandler != nil {
750 return c.ErrorHandler(resp, err, attempt)
751 }
752
753
754
755 if resp != nil {
756 c.drainBody(resp.Body)
757 }
758
759
760
761 if err == nil {
762 return nil, fmt.Errorf("%s %s giving up after %d attempt(s)",
763 req.Method, req.URL, attempt)
764 }
765
766 return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w",
767 req.Method, req.URL, attempt, err)
768 }
769
770
771 func (c *Client) drainBody(body io.ReadCloser) {
772 defer body.Close()
773 _, err := io.Copy(ioutil.Discard, io.LimitReader(body, respReadLimit))
774 if err != nil {
775 if c.logger() != nil {
776 switch v := c.logger().(type) {
777 case LeveledLogger:
778 v.Error("error reading response body", "error", err)
779 case Logger:
780 v.Printf("[ERR] error reading response body: %v", err)
781 }
782 }
783 }
784 }
785
786
787 func Get(url string) (*http.Response, error) {
788 return defaultClient.Get(url)
789 }
790
791
792 func (c *Client) Get(url string) (*http.Response, error) {
793 req, err := NewRequest("GET", url, nil)
794 if err != nil {
795 return nil, err
796 }
797 return c.Do(req)
798 }
799
800
801 func Head(url string) (*http.Response, error) {
802 return defaultClient.Head(url)
803 }
804
805
806 func (c *Client) Head(url string) (*http.Response, error) {
807 req, err := NewRequest("HEAD", url, nil)
808 if err != nil {
809 return nil, err
810 }
811 return c.Do(req)
812 }
813
814
815 func Post(url, bodyType string, body interface{}) (*http.Response, error) {
816 return defaultClient.Post(url, bodyType, body)
817 }
818
819
820 func (c *Client) Post(url, bodyType string, body interface{}) (*http.Response, error) {
821 req, err := NewRequest("POST", url, body)
822 if err != nil {
823 return nil, err
824 }
825 req.Header.Set("Content-Type", bodyType)
826 return c.Do(req)
827 }
828
829
830
831 func PostForm(url string, data url.Values) (*http.Response, error) {
832 return defaultClient.PostForm(url, data)
833 }
834
835
836
837 func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) {
838 return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
839 }
840
841
842
843 func (c *Client) StandardClient() *http.Client {
844 return &http.Client{
845 Transport: &RoundTripper{Client: c},
846 }
847 }
848
View as plain text