1
16
17 package tests
18
19 import (
20 "bytes"
21 "context"
22 "errors"
23 "fmt"
24 "io"
25 "io/ioutil"
26 "net/http"
27 "net/http/httptest"
28 "net/url"
29 "strings"
30 "testing"
31 "time"
32
33 "github.com/stretchr/testify/require"
34
35 "k8s.io/apimachinery/pkg/runtime"
36 "k8s.io/apimachinery/pkg/runtime/schema"
37 "k8s.io/apimachinery/pkg/types"
38 "k8s.io/apimachinery/pkg/util/httpstream"
39 remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
40 restclient "k8s.io/client-go/rest"
41 remoteclient "k8s.io/client-go/tools/remotecommand"
42 "k8s.io/client-go/transport/spdy"
43 "k8s.io/kubelet/pkg/cri/streaming/remotecommand"
44 "k8s.io/kubernetes/pkg/api/legacyscheme"
45 api "k8s.io/kubernetes/pkg/apis/core"
46 )
47
48 type fakeExecutor struct {
49 t *testing.T
50 testName string
51 errorData string
52 stdoutData string
53 stderrData string
54 expectStdin bool
55 stdinReceived bytes.Buffer
56 tty bool
57 messageCount int
58 command []string
59 exec bool
60 }
61
62 func (ex *fakeExecutor) ExecInContainer(_ context.Context, name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize, timeout time.Duration) error {
63 return ex.run(name, uid, container, cmd, in, out, err, tty)
64 }
65
66 func (ex *fakeExecutor) AttachContainer(_ context.Context, name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize) error {
67 return ex.run(name, uid, container, nil, in, out, err, tty)
68 }
69
70 func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
71 ex.command = cmd
72 ex.tty = tty
73
74 if e, a := "pod", name; e != a {
75 ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a)
76 }
77 if e, a := "uid", uid; e != string(a) {
78 ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a)
79 }
80 if ex.exec {
81 if e, a := "ls /", strings.Join(ex.command, " "); e != a {
82 ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a)
83 }
84 } else {
85 if len(ex.command) > 0 {
86 ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command)
87 }
88 }
89
90 if len(ex.errorData) > 0 {
91 return errors.New(ex.errorData)
92 }
93
94 if len(ex.stdoutData) > 0 {
95 for i := 0; i < ex.messageCount; i++ {
96 fmt.Fprint(out, ex.stdoutData)
97 }
98 }
99
100 if len(ex.stderrData) > 0 {
101 for i := 0; i < ex.messageCount; i++ {
102 fmt.Fprint(err, ex.stderrData)
103 }
104 }
105
106 if ex.expectStdin {
107 io.Copy(&ex.stdinReceived, in)
108 }
109
110 return nil
111 }
112
113 func fakeServer(t *testing.T, requestReceived chan struct{}, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc {
114 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
115 executor := &fakeExecutor{
116 t: t,
117 testName: testName,
118 errorData: errorData,
119 stdoutData: stdoutData,
120 stderrData: stderrData,
121 expectStdin: len(stdinData) > 0,
122 tty: tty,
123 messageCount: messageCount,
124 exec: exec,
125 }
126
127 opts, err := remotecommand.NewOptions(req)
128 require.NoError(t, err)
129 if exec {
130 cmd := req.URL.Query()[api.ExecCommandParam]
131 remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", cmd, opts, 0, 10*time.Second, serverProtocols)
132 } else {
133 remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", opts, 0, 10*time.Second, serverProtocols)
134 }
135
136 if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a {
137 t.Errorf("%s: stdin: expected %q, got %q", testName, e, a)
138 }
139 close(requestReceived)
140 })
141 }
142
143 func TestStream(t *testing.T) {
144 testCases := []struct {
145 TestName string
146 Stdin string
147 Stdout string
148 Stderr string
149 Error string
150 Tty bool
151 MessageCount int
152 ClientProtocols []string
153 ServerProtocols []string
154 }{
155 {
156 TestName: "error",
157 Error: "bail",
158 Stdout: "a",
159 ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
160 ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
161 },
162 {
163 TestName: "in/out/err",
164 Stdin: "a",
165 Stdout: "b",
166 Stderr: "c",
167 MessageCount: 100,
168 ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
169 ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
170 },
171 {
172 TestName: "oversized stdin",
173 Stdin: strings.Repeat("a", 20*1024*1024),
174 Stdout: "b",
175 Stderr: "",
176 MessageCount: 1,
177 ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
178 ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
179 },
180 {
181 TestName: "in/out/tty",
182 Stdin: "a",
183 Stdout: "b",
184 Tty: true,
185 MessageCount: 100,
186 ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
187 ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
188 },
189 }
190
191 for _, testCase := range testCases {
192 for _, exec := range []bool{true, false} {
193 var name string
194 if exec {
195 name = testCase.TestName + " (exec)"
196 } else {
197 name = testCase.TestName + " (attach)"
198 }
199
200 t.Run(name, func(t *testing.T) {
201 var (
202 streamIn io.Reader
203 streamOut, streamErr io.Writer
204 )
205 localOut := &bytes.Buffer{}
206 localErr := &bytes.Buffer{}
207
208 requestReceived := make(chan struct{})
209 server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
210 defer server.Close()
211
212 url, _ := url.ParseRequestURI(server.URL)
213 config := restclient.ClientContentConfig{
214 GroupVersion: schema.GroupVersion{Group: "x"},
215 Negotiator: runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}),
216 }
217 c, err := restclient.NewRESTClient(url, "", config, nil, nil)
218 if err != nil {
219 t.Fatalf("failed to create a client: %v", err)
220 }
221 req := c.Post().Resource("testing")
222
223 if exec {
224 req.Param("command", "ls")
225 req.Param("command", "/")
226 }
227
228 if len(testCase.Stdin) > 0 {
229 req.Param(api.ExecStdinParam, "1")
230 streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
231 }
232
233 if len(testCase.Stdout) > 0 {
234 req.Param(api.ExecStdoutParam, "1")
235 streamOut = localOut
236 }
237
238 if testCase.Tty {
239 req.Param(api.ExecTTYParam, "1")
240 } else if len(testCase.Stderr) > 0 {
241 req.Param(api.ExecStderrParam, "1")
242 streamErr = localErr
243 }
244
245 conf := &restclient.Config{
246 Host: server.URL,
247 }
248 transport, upgradeTransport, err := spdy.RoundTripperFor(conf)
249 if err != nil {
250 t.Fatalf("%s: unexpected error: %v", name, err)
251 }
252 e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...)
253 if err != nil {
254 t.Fatalf("%s: unexpected error: %v", name, err)
255 }
256 err = e.StreamWithContext(context.Background(), remoteclient.StreamOptions{
257 Stdin: streamIn,
258 Stdout: streamOut,
259 Stderr: streamErr,
260 Tty: testCase.Tty,
261 })
262 hasErr := err != nil
263
264 if len(testCase.Error) > 0 {
265 if !hasErr {
266 t.Errorf("%s: expected an error", name)
267 } else {
268 if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
269 t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
270 }
271 }
272 return
273 }
274
275 if hasErr {
276 t.Fatalf("%s: unexpected error: %v", name, err)
277 }
278
279 if len(testCase.Stdout) > 0 {
280 if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
281 t.Fatalf("%s: expected stdout data %q, got %q", name, e, a)
282 }
283 }
284
285 if testCase.Stderr != "" {
286 if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
287 t.Fatalf("%s: expected stderr data %q, got %q", name, e, a)
288 }
289 }
290
291 select {
292 case <-requestReceived:
293 case <-time.After(time.Minute):
294 t.Errorf("%s: expected fakeServerInstance to receive request", name)
295 }
296 })
297 }
298 }
299 }
300
301 type fakeUpgrader struct {
302 req *http.Request
303 resp *http.Response
304 conn httpstream.Connection
305 err, connErr error
306 checkResponse bool
307 called bool
308
309 t *testing.T
310 }
311
312 func (u *fakeUpgrader) RoundTrip(req *http.Request) (*http.Response, error) {
313 u.called = true
314 u.req = req
315 return u.resp, u.err
316 }
317
318 func (u *fakeUpgrader) NewConnection(resp *http.Response) (httpstream.Connection, error) {
319 if u.checkResponse && u.resp != resp {
320 u.t.Errorf("response objects passed did not match: %#v", resp)
321 }
322 return u.conn, u.connErr
323 }
324
325 type fakeConnection struct {
326 httpstream.Connection
327 }
328
329
330
331
332 func TestDial(t *testing.T) {
333 upgrader := &fakeUpgrader{
334 t: t,
335 checkResponse: true,
336 conn: &fakeConnection{},
337 resp: &http.Response{
338 StatusCode: http.StatusSwitchingProtocols,
339 Body: ioutil.NopCloser(&bytes.Buffer{}),
340 },
341 }
342 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: upgrader}, "POST", &url.URL{Host: "something.com", Scheme: "https"})
343 conn, protocol, err := dialer.Dial("protocol1")
344 if err != nil {
345 t.Fatal(err)
346 }
347 if conn != upgrader.conn {
348 t.Errorf("unexpected connection: %#v", conn)
349 }
350 if !upgrader.called {
351 t.Errorf("request not called")
352 }
353 _ = protocol
354 }
355
View as plain text