1
2
3
4 package retryablehttp
5
6 import (
7 "context"
8 "errors"
9 "io/ioutil"
10 "net"
11 "net/http"
12 "net/http/httptest"
13 "net/url"
14 "reflect"
15 "sync/atomic"
16 "testing"
17 )
18
19 func TestRoundTripper_implements(t *testing.T) {
20
21 var _ http.RoundTripper = &RoundTripper{}
22 }
23
24 func TestRoundTripper_init(t *testing.T) {
25 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
26 w.WriteHeader(200)
27 }))
28 defer ts.Close()
29
30
31 rt := &RoundTripper{}
32
33
34 req, _ := http.NewRequest("GET", ts.URL, nil)
35 if _, err := rt.RoundTrip(req); err != nil {
36 t.Fatal(err)
37 }
38
39
40 if rt.Client == nil {
41 t.Fatal("expected rt.Client to be initialized")
42 }
43
44
45 initialClient := rt.Client
46
47
48 req, _ = http.NewRequest("GET", ts.URL, nil)
49 if _, err := rt.RoundTrip(req); err != nil {
50 t.Fatal(err)
51 }
52
53
54 if rt.Client != initialClient {
55 t.Fatalf("expected %v, got %v", initialClient, rt.Client)
56 }
57 }
58
59 func TestRoundTripper_RoundTrip(t *testing.T) {
60 var reqCount int32 = 0
61
62 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
63 reqNo := atomic.AddInt32(&reqCount, 1)
64 if reqNo < 3 {
65 w.WriteHeader(404)
66 } else {
67 w.WriteHeader(200)
68 w.Write([]byte("success!"))
69 }
70 }))
71 defer ts.Close()
72
73
74 retryClient := NewClient()
75 retryClient.CheckRetry = func(_ context.Context, resp *http.Response, _ error) (bool, error) {
76 return resp.StatusCode == 404, nil
77 }
78
79
80 client := retryClient.StandardClient()
81 resp, err := client.Get(ts.URL)
82 if err != nil {
83 t.Fatal(err)
84 }
85 defer resp.Body.Close()
86
87
88 if resp.StatusCode != 200 {
89 t.Fatalf("expected 200, got %d", resp.StatusCode)
90 }
91 if v, err := ioutil.ReadAll(resp.Body); err != nil {
92 t.Fatal(err)
93 } else if string(v) != "success!" {
94 t.Fatalf("expected %q, got %q", "success!", v)
95 }
96 }
97
98 func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {
99
100 retryClient := NewClient()
101 retryClient.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
102 if err != nil {
103 return true, err
104 }
105
106 return false, nil
107 }
108
109 retryClient.ErrorHandler = PassthroughErrorHandler
110
111 expectedError := &url.Error{
112 Op: "Get",
113 URL: "http://999.999.999.999:999/",
114 Err: &net.OpError{
115 Op: "dial",
116 Net: "tcp",
117 Err: &net.DNSError{
118 Name: "999.999.999.999",
119 Err: "no such host",
120 IsNotFound: true,
121 },
122 },
123 }
124
125
126 client := retryClient.StandardClient()
127 _, err := client.Get("http://999.999.999.999:999/")
128
129
130 if !reflect.DeepEqual(expectedError, normalizeError(err)) {
131 t.Fatalf("expected %q, got %q", expectedError, err)
132 }
133 }
134
135 func normalizeError(err error) error {
136 var dnsError *net.DNSError
137
138 if errors.As(err, &dnsError) {
139
140 dnsError.Server = ""
141 }
142
143 return err
144 }
145
View as plain text