1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package transport
16
17 import (
18 "context"
19 "errors"
20 "io"
21 "net/http"
22 "net/http/httptest"
23 "strings"
24 "testing"
25 "time"
26
27 "github.com/google/go-containerregistry/internal/retry"
28 )
29
30 type mockTransport struct {
31 errs []error
32 resps []*http.Response
33 count int
34 }
35
36 func (t *mockTransport) RoundTrip(_ *http.Request) (out *http.Response, err error) {
37 defer func() { t.count++ }()
38 if t.count < len(t.resps) {
39 out = t.resps[t.count]
40 }
41 if t.count < len(t.errs) {
42 err = t.errs[t.count]
43 }
44 return
45 }
46
47 type perm struct{}
48
49 func (e perm) Error() string {
50 return "permanent error"
51 }
52
53 type temp struct{}
54
55 func (e temp) Error() string {
56 return "temporary error"
57 }
58
59 func (e temp) Temporary() bool {
60 return true
61 }
62
63 func resp(code int) *http.Response {
64 return &http.Response{
65 StatusCode: code,
66 Body: io.NopCloser(strings.NewReader("hi")),
67 }
68 }
69
70 func TestRetryTransport(t *testing.T) {
71 for _, test := range []struct {
72 errs []error
73 resps []*http.Response
74 ctx context.Context
75 count int
76 }{{
77
78 errs: []error{temp{}},
79 ctx: retry.Never(context.Background()),
80 count: 1,
81 }, {
82
83 errs: []error{perm{}},
84 count: 1,
85 }, {
86
87 errs: []error{temp{}, perm{}},
88 count: 2,
89 }, {
90
91 errs: []error{temp{}, temp{}, temp{}, temp{}, temp{}},
92 count: 3,
93 }, {
94
95 errs: []error{nil, nil, temp{}, temp{}, temp{}},
96 resps: []*http.Response{
97 resp(http.StatusRequestTimeout),
98 resp(http.StatusInternalServerError),
99 nil,
100 },
101 count: 3,
102 }} {
103 mt := mockTransport{
104 errs: test.errs,
105 resps: test.resps,
106 }
107
108 tr := NewRetry(&mt,
109 WithRetryBackoff(retry.Backoff{Steps: 3}),
110 WithRetryPredicate(retry.IsTemporary),
111 WithRetryStatusCodes(http.StatusRequestTimeout, http.StatusInternalServerError),
112 )
113
114 ctx := context.Background()
115 if test.ctx != nil {
116 ctx = test.ctx
117 }
118 req, err := http.NewRequestWithContext(ctx, "GET", "example.com", nil)
119 if err != nil {
120 t.Fatal(err)
121 }
122 tr.RoundTrip(req)
123 if mt.count != test.count {
124 t.Errorf("wrong count, wanted %d, got %d", test.count, mt.count)
125 }
126 }
127 }
128
129 func TestRetryDefaults(t *testing.T) {
130 tr := NewRetry(http.DefaultTransport)
131 rt, ok := tr.(*retryTransport)
132 if !ok {
133 t.Fatal("could not cast to retryTransport")
134 }
135
136 if rt.backoff != defaultBackoff {
137 t.Fatalf("default backoff wrong: %v", rt.backoff)
138 }
139
140 if rt.predicate == nil {
141 t.Fatal("default predicate not set")
142 }
143 }
144
145 func TestTimeoutContext(t *testing.T) {
146 tr := NewRetry(http.DefaultTransport)
147
148 slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
149
150 time.Sleep(time.Second * 1)
151 }))
152 defer func() { go func() { slowServer.Close() }() }()
153
154 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*20))
155 defer cancel()
156 req, err := http.NewRequest("GET", slowServer.URL, nil)
157 if err != nil {
158 t.Fatal(err)
159 }
160 req = req.WithContext(ctx)
161
162 result := make(chan error)
163
164 go func() {
165 _, err := tr.RoundTrip(req)
166 result <- err
167 }()
168
169 select {
170 case err := <-result:
171 if !errors.Is(err, context.DeadlineExceeded) {
172 t.Fatalf("got: %v, want: %v", err, context.DeadlineExceeded)
173 }
174 case <-time.After(time.Millisecond * 100):
175 t.Fatalf("deadline was not recognized by transport")
176 }
177 }
178
View as plain text