1
16
17 package remotecommand
18
19 import (
20 "bytes"
21 "context"
22 "crypto/rand"
23 "io"
24 "net/http"
25 "net/http/httptest"
26 "net/url"
27 "testing"
28 "time"
29
30 "github.com/stretchr/testify/assert"
31 "github.com/stretchr/testify/require"
32 "k8s.io/apimachinery/pkg/util/remotecommand"
33 "k8s.io/apimachinery/pkg/util/wait"
34 "k8s.io/client-go/rest"
35 )
36
37 func TestFallbackClient_WebSocketPrimarySucceeds(t *testing.T) {
38
39 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
40 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
41 if err != nil {
42 w.WriteHeader(http.StatusForbidden)
43 return
44 }
45 defer conns.conn.Close()
46
47 _, err = io.Copy(conns.stdoutStream, conns.stdinStream)
48 require.NoError(t, err)
49 }))
50 defer websocketServer.Close()
51
52
53
54 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
55 websocketLocation, err := url.Parse(websocketServer.URL)
56 require.NoError(t, err)
57 websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
58 require.NoError(t, err)
59 spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation)
60 require.NoError(t, err)
61
62 exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return false })
63 require.NoError(t, err)
64
65
66 randomSize := 1024 * 1024
67 randomData := make([]byte, randomSize)
68 if _, err := rand.Read(randomData); err != nil {
69 t.Errorf("unexpected error reading random data: %v", err)
70 }
71 var stdout bytes.Buffer
72 options := &StreamOptions{
73 Stdin: bytes.NewReader(randomData),
74 Stdout: &stdout,
75 }
76 errorChan := make(chan error)
77 go func() {
78
79 errorChan <- exec.StreamWithContext(context.Background(), *options)
80 }()
81
82 select {
83 case <-time.After(wait.ForeverTestTimeout):
84 t.Fatalf("expect stream to be closed after connection is closed.")
85 case err := <-errorChan:
86 if err != nil {
87 t.Errorf("unexpected error")
88 }
89 }
90
91 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
92 if err != nil {
93 t.Errorf("error reading the stream: %v", err)
94 return
95 }
96
97 if !bytes.Equal(randomData, data) {
98 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
99 }
100 }
101
102 func TestFallbackClient_SPDYSecondarySucceeds(t *testing.T) {
103
104 spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
105 var stdin, stdout bytes.Buffer
106 ctx, err := createHTTPStreams(w, req, &StreamOptions{
107 Stdin: &stdin,
108 Stdout: &stdout,
109 })
110 if err != nil {
111 w.WriteHeader(http.StatusForbidden)
112 return
113 }
114 defer ctx.conn.Close()
115 _, err = io.Copy(ctx.stdoutStream, ctx.stdinStream)
116 if err != nil {
117 t.Fatalf("error copying STDIN to STDOUT: %v", err)
118 }
119 }))
120 defer spdyServer.Close()
121
122 spdyLocation, err := url.Parse(spdyServer.URL)
123 require.NoError(t, err)
124 websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: spdyLocation.Host}, "GET", spdyServer.URL)
125 require.NoError(t, err)
126 spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: spdyLocation.Host}, "POST", spdyLocation)
127 require.NoError(t, err)
128
129 exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true })
130 require.NoError(t, err)
131
132
133 randomSize := 1024 * 1024
134 randomData := make([]byte, randomSize)
135 if _, err := rand.Read(randomData); err != nil {
136 t.Errorf("unexpected error reading random data: %v", err)
137 }
138 var stdout bytes.Buffer
139 options := &StreamOptions{
140 Stdin: bytes.NewReader(randomData),
141 Stdout: &stdout,
142 }
143 errorChan := make(chan error)
144 go func() {
145 errorChan <- exec.StreamWithContext(context.Background(), *options)
146 }()
147
148 select {
149 case <-time.After(wait.ForeverTestTimeout):
150 t.Fatalf("expect stream to be closed after connection is closed.")
151 case err := <-errorChan:
152 if err != nil {
153 t.Errorf("unexpected error")
154 }
155 }
156
157 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
158 if err != nil {
159 t.Errorf("error reading the stream: %v", err)
160 return
161 }
162
163 if !bytes.Equal(randomData, data) {
164 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
165 }
166 }
167
168 func TestFallbackClient_PrimaryAndSecondaryFail(t *testing.T) {
169
170 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
171 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
172 if err != nil {
173 w.WriteHeader(http.StatusForbidden)
174 return
175 }
176 defer conns.conn.Close()
177
178 _, err = io.Copy(conns.stdoutStream, conns.stdinStream)
179 require.NoError(t, err)
180 }))
181 defer websocketServer.Close()
182
183
184
185 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
186 websocketLocation, err := url.Parse(websocketServer.URL)
187 require.NoError(t, err)
188 websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
189 require.NoError(t, err)
190 spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation)
191 require.NoError(t, err)
192
193 exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true })
194 require.NoError(t, err)
195
196 fallbackExec, ok := exec.(*FallbackExecutor)
197 assert.True(t, ok, "error casting executor as FallbackExecutor")
198 websocketExec, ok := fallbackExec.primary.(*wsStreamExecutor)
199 assert.True(t, ok, "error casting executor as websocket executor")
200
201 websocketExec.protocols = []string{remotecommand.StreamProtocolV4Name}
202
203
204
205 randomSize := 1024 * 1024
206 randomData := make([]byte, randomSize)
207 if _, err := rand.Read(randomData); err != nil {
208 t.Errorf("unexpected error reading random data: %v", err)
209 }
210 var stdout bytes.Buffer
211 options := &StreamOptions{
212 Stdin: bytes.NewReader(randomData),
213 Stdout: &stdout,
214 }
215 errorChan := make(chan error)
216 go func() {
217 errorChan <- exec.StreamWithContext(context.Background(), *options)
218 }()
219
220 select {
221 case <-time.After(wait.ForeverTestTimeout):
222 t.Fatalf("expect stream to be closed after connection is closed.")
223 case err := <-errorChan:
224
225 require.Error(t, err)
226 }
227 }
228
View as plain text