...
1
16
17 package controlplane
18
19 import (
20 "io"
21 "net/http"
22 "sync"
23 "testing"
24 "time"
25
26 "k8s.io/apimachinery/pkg/util/wait"
27 "k8s.io/client-go/rest"
28 kubeapiservertesting "k8s.io/kubernetes/cmd/kube-apiserver/app/testing"
29 "k8s.io/kubernetes/test/integration/framework"
30 )
31
32 func TestGracefulShutdown(t *testing.T) {
33 server := kubeapiservertesting.StartTestServerOrDie(t, nil, nil, framework.SharedEtcd())
34
35 tearDownOnce := sync.Once{}
36 defer tearDownOnce.Do(server.TearDownFn)
37
38 transport, err := rest.TransportFor(server.ClientConfig)
39 if err != nil {
40 t.Fatalf("unexpected error: %v", err)
41 }
42 client := http.Client{Transport: transport}
43
44 req, body, err := newBlockingRequest("POST", server.ClientConfig.Host+"/api/v1/namespaces")
45 if err != nil {
46 t.Fatal(err)
47 }
48 respErrCh := backgroundRoundtrip(transport, req)
49
50 t.Logf("server should be blocking request for data in body")
51 time.Sleep(time.Millisecond * 500)
52 select {
53 case respErr := <-respErrCh:
54 if respErr.err != nil {
55 t.Fatalf("unexpected error: %v", err)
56 }
57 bs, err := io.ReadAll(respErr.resp.Body)
58 if err != nil {
59 t.Fatal(err)
60 }
61 t.Fatalf("unexpected server answer: %d, body: %s", respErr.resp.StatusCode, string(bs))
62 default:
63 }
64
65 t.Logf("server should answer")
66 resp, err := client.Get(server.ClientConfig.Host + "/")
67 if err != nil {
68 t.Fatal(err)
69 }
70 resp.Body.Close()
71
72 t.Logf("shutting down server")
73
74
75 wg := sync.WaitGroup{}
76 wg.Add(1)
77 go func() {
78 defer wg.Done()
79 tearDownOnce.Do(server.TearDownFn)
80 }()
81
82 t.Logf("server should fail new requests")
83 if err := wait.Poll(time.Millisecond*100, wait.ForeverTestTimeout, func() (done bool, err error) {
84 resp, err := client.Get(server.ClientConfig.Host + "/")
85 if err != nil {
86 return true, nil
87 }
88 resp.Body.Close()
89 return false, nil
90 }); err != nil {
91 t.Fatalf("server did not shutdown")
92 }
93
94 t.Logf("server should answer pending request")
95 time.Sleep(time.Millisecond * 500)
96 if _, err := body.Write([]byte("garbage")); err != nil {
97 t.Fatal(err)
98 }
99 body.Close()
100 respErr := <-respErrCh
101 if respErr.err != nil {
102 t.Fatal(respErr.err)
103 }
104 defer respErr.resp.Body.Close()
105 bs, err := io.ReadAll(respErr.resp.Body)
106 if err != nil {
107 t.Fatal(err)
108 }
109 t.Logf("response: code %d, body: %s", respErr.resp.StatusCode, string(bs))
110
111 wg.Wait()
112 }
113
114 type responseErrorPair struct {
115 resp *http.Response
116 err error
117 }
118
119 func backgroundRoundtrip(transport http.RoundTripper, req *http.Request) <-chan responseErrorPair {
120 ch := make(chan responseErrorPair)
121 go func() {
122 resp, err := transport.RoundTrip(req)
123 ch <- responseErrorPair{resp, err}
124 }()
125 return ch
126 }
127
128 func newBlockingRequest(method, url string) (*http.Request, io.WriteCloser, error) {
129 bodyReader, bodyWriter := io.Pipe()
130 req, err := http.NewRequest(method, url, bodyReader)
131 return req, bodyWriter, err
132 }
133
View as plain text