1
16
17 package rest
18
19 import (
20 "context"
21 "fmt"
22 "io"
23 "net"
24 "net/http"
25 "net/http/httptest"
26 "net/url"
27 "strconv"
28 "strings"
29 "sync/atomic"
30 "testing"
31 "time"
32
33 "k8s.io/apimachinery/pkg/runtime/schema"
34 "k8s.io/apimachinery/pkg/runtime/serializer"
35 utilnet "k8s.io/apimachinery/pkg/util/net"
36 "k8s.io/apimachinery/pkg/util/wait"
37 )
38
39 type tcpLB struct {
40 t *testing.T
41 ln net.Listener
42 serverURL string
43 dials int32
44 }
45
46 func (lb *tcpLB) handleConnection(in net.Conn, stopCh chan struct{}) {
47 out, err := net.Dial("tcp", lb.serverURL)
48 if err != nil {
49 lb.t.Log(err)
50 return
51 }
52 go io.Copy(out, in)
53 go io.Copy(in, out)
54 <-stopCh
55 if err := out.Close(); err != nil {
56 lb.t.Fatalf("failed to close connection: %v", err)
57 }
58 }
59
60 func (lb *tcpLB) serve(stopCh chan struct{}) {
61 conn, err := lb.ln.Accept()
62 if err != nil {
63 lb.t.Fatalf("failed to accept: %v", err)
64 }
65 atomic.AddInt32(&lb.dials, 1)
66 go lb.handleConnection(conn, stopCh)
67 }
68
69 func newLB(t *testing.T, serverURL string) *tcpLB {
70 ln, err := net.Listen("tcp", "127.0.0.1:0")
71 if err != nil {
72 t.Fatalf("failed to bind: %v", err)
73 }
74 lb := tcpLB{
75 serverURL: serverURL,
76 ln: ln,
77 t: t,
78 }
79 return &lb
80 }
81
82 const (
83 readIdleTimeout int = 1
84 pingTimeout int = 1
85 )
86
87 func TestReconnectBrokenTCP(t *testing.T) {
88 t.Setenv("HTTP2_READ_IDLE_TIMEOUT_SECONDS", strconv.Itoa(readIdleTimeout))
89 t.Setenv("HTTP2_PING_TIMEOUT_SECONDS", strconv.Itoa(pingTimeout))
90 t.Setenv("DISABLE_HTTP2", "")
91 ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
92 fmt.Fprintf(w, "Hello, %s", r.Proto)
93 }))
94 ts.EnableHTTP2 = true
95 ts.StartTLS()
96 defer ts.Close()
97
98 u, err := url.Parse(ts.URL)
99 if err != nil {
100 t.Fatalf("failed to parse URL from %q: %v", ts.URL, err)
101 }
102 lb := newLB(t, u.Host)
103 defer lb.ln.Close()
104 stopCh := make(chan struct{})
105 go lb.serve(stopCh)
106 transport, ok := ts.Client().Transport.(*http.Transport)
107 if !ok {
108 t.Fatalf("failed to assert *http.Transport")
109 }
110 config := &Config{
111 Host: "https://" + lb.ln.Addr().String(),
112 Transport: utilnet.SetTransportDefaults(transport),
113 Timeout: 1 * time.Second,
114
115 ContentConfig: ContentConfig{
116 GroupVersion: &schema.GroupVersion{},
117 NegotiatedSerializer: &serializer.CodecFactory{},
118 },
119 }
120 client, err := RESTClientFor(config)
121 if err != nil {
122 t.Fatalf("failed to create REST client: %v", err)
123 }
124 data, err := client.Get().AbsPath("/").DoRaw(context.TODO())
125 if err != nil {
126 t.Fatalf("unexpected err: %s: %v", data, err)
127 }
128 if string(data) != "Hello, HTTP/2.0" {
129 t.Fatalf("unexpected response: %s", data)
130 }
131
132
133
134
135 close(stopCh)
136
137 stopCh = make(chan struct{})
138 go lb.serve(stopCh)
139
140
141 time.Sleep(time.Duration(1+readIdleTimeout+pingTimeout) * time.Second)
142
143
144
145
146 data, err = client.Get().AbsPath("/").DoRaw(context.TODO())
147 if err != nil {
148 t.Fatalf("unexpected err: %v", err)
149 }
150 if string(data) != "Hello, HTTP/2.0" {
151 t.Fatalf("unexpected response: %s", data)
152 }
153 dials := atomic.LoadInt32(&lb.dials)
154 if dials != 2 {
155 t.Fatalf("expected %d dials, got %d", 2, dials)
156 }
157 }
158
159
160
161
162
163
164 func TestReconnectBrokenTCP_HTTP1(t *testing.T) {
165 ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
166 fmt.Fprintf(w, "Hello, %s", r.Proto)
167 }))
168 ts.EnableHTTP2 = false
169 ts.StartTLS()
170 defer ts.Close()
171
172 u, err := url.Parse(ts.URL)
173 if err != nil {
174 t.Fatalf("failed to parse URL from %q: %v", ts.URL, err)
175 }
176 lb := newLB(t, u.Host)
177 defer lb.ln.Close()
178 stopCh := make(chan struct{})
179 go lb.serve(stopCh)
180 transport, ok := ts.Client().Transport.(*http.Transport)
181 if !ok {
182 t.Fatal("failed to assert *http.Transport")
183 }
184 config := &Config{
185 Host: "https://" + lb.ln.Addr().String(),
186 Transport: utilnet.SetTransportDefaults(transport),
187
188 Timeout: wait.ForeverTestTimeout,
189
190 ContentConfig: ContentConfig{
191 GroupVersion: &schema.GroupVersion{},
192 NegotiatedSerializer: &serializer.CodecFactory{},
193 },
194 }
195 config.TLSClientConfig.NextProtos = []string{"http/1.1"}
196 client, err := RESTClientFor(config)
197 if err != nil {
198 t.Fatalf("failed to create REST client: %v", err)
199 }
200
201 data, err := client.Get().AbsPath("/").DoRaw(context.TODO())
202 if err != nil {
203 t.Fatalf("unexpected err: %s: %v", data, err)
204 }
205 if string(data) != "Hello, HTTP/1.1" {
206 t.Fatalf("unexpected response: %s", data)
207 }
208
209
210
211
212 close(stopCh)
213
214 stopCh = make(chan struct{})
215 go lb.serve(stopCh)
216
217 utilnet.CloseIdleConnectionsFor(client.Client.Transport)
218
219
220
221
222
223 data, err = client.Get().AbsPath("/").DoRaw(context.TODO())
224 if err != nil {
225 t.Fatalf("unexpected err: %v", err)
226 }
227 if string(data) != "Hello, HTTP/1.1" {
228 t.Fatalf("unexpected response: %s", data)
229 }
230 dials := atomic.LoadInt32(&lb.dials)
231 if dials != 2 {
232 t.Fatalf("expected %d dials, got %d", 2, dials)
233 }
234 }
235
236
237
238
239
240 func TestReconnectBrokenTCPInFlight_HTTP1(t *testing.T) {
241 done := make(chan struct{})
242 defer close(done)
243 received := make(chan struct{})
244
245 ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
246 if r.URL.Path == "/hang" {
247 conn, _, _ := w.(http.Hijacker).Hijack()
248 close(received)
249 <-done
250 conn.Close()
251 }
252 fmt.Fprintf(w, "Hello, %s", r.Proto)
253 }))
254 ts.EnableHTTP2 = false
255 ts.StartTLS()
256 defer ts.Close()
257
258 u, err := url.Parse(ts.URL)
259 if err != nil {
260 t.Fatalf("failed to parse URL from %q: %v", ts.URL, err)
261 }
262
263 lb := newLB(t, u.Host)
264 defer lb.ln.Close()
265 stopCh := make(chan struct{})
266 go lb.serve(stopCh)
267
268 transport, ok := ts.Client().Transport.(*http.Transport)
269 if !ok {
270 t.Fatal("failed to assert *http.Transport")
271 }
272 config := &Config{
273 Host: "https://" + lb.ln.Addr().String(),
274 Transport: utilnet.SetTransportDefaults(transport),
275
276 Timeout: wait.ForeverTestTimeout,
277
278 ContentConfig: ContentConfig{
279 GroupVersion: &schema.GroupVersion{},
280 NegotiatedSerializer: &serializer.CodecFactory{},
281 },
282 }
283 config.TLSClientConfig.NextProtos = []string{"http/1.1"}
284
285 client, err := RESTClientFor(config)
286 if err != nil {
287 t.Fatalf("failed to create REST client: %v", err)
288 }
289
290
291
292
293 ctx, cancel := context.WithCancel(context.Background())
294 reqErrCh := make(chan error, 1)
295 defer close(reqErrCh)
296 go func() {
297 _, err = client.Get().AbsPath("/hang").DoRaw(ctx)
298 reqErrCh <- err
299 }()
300
301
302 select {
303 case <-received:
304 case <-time.After(wait.ForeverTestTimeout):
305 t.Fatal("Test timed out waiting for first request to fail")
306 }
307
308
309
310
311 close(stopCh)
312
313 stopCh = make(chan struct{})
314 go lb.serve(stopCh)
315
316
317 data, err := client.Get().AbsPath("/").DoRaw(context.Background())
318 if err != nil {
319 t.Fatalf("unexpected err: %v", err)
320 }
321 if string(data) != "Hello, HTTP/1.1" {
322 t.Fatalf("unexpected response: %s", data)
323 }
324 dials := atomic.LoadInt32(&lb.dials)
325 if dials != 2 {
326 t.Fatalf("expected %d dials, got %d", 2, dials)
327 }
328
329
330 cancel()
331 select {
332 case <-reqErrCh:
333 if err == nil {
334 t.Fatal("Connection succeeded but was expected to timeout")
335 }
336 case <-time.After(10 * time.Second):
337 t.Fatal("Test timed out waiting for the request to fail")
338 }
339
340 }
341
342 func TestRestClientTimeout(t *testing.T) {
343 ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
344 time.Sleep(2 * time.Second)
345 fmt.Fprintf(w, "Hello, %s", r.Proto)
346 }))
347 ts.Start()
348 defer ts.Close()
349
350 config := &Config{
351 Host: ts.URL,
352 Timeout: 1 * time.Second,
353
354 ContentConfig: ContentConfig{
355 GroupVersion: &schema.GroupVersion{},
356 NegotiatedSerializer: &serializer.CodecFactory{},
357 },
358 }
359 client, err := RESTClientFor(config)
360 if err != nil {
361 t.Fatalf("failed to create REST client: %v", err)
362 }
363 _, err = client.Get().AbsPath("/").DoRaw(context.TODO())
364 if err == nil {
365 t.Fatalf("timeout error expected")
366 }
367 if !strings.Contains(err.Error(), "deadline exceeded") {
368 t.Fatalf("timeout error expected, received %v", err)
369 }
370 }
371
View as plain text