1
16
17 package remotecommand
18
19 import (
20 "errors"
21 "io"
22 "net/http"
23 "strings"
24 "testing"
25 "time"
26
27 "k8s.io/api/core/v1"
28 "k8s.io/apimachinery/pkg/util/httpstream"
29 "k8s.io/apimachinery/pkg/util/wait"
30 )
31
32 type fakeReader struct {
33 err error
34 }
35
36 func (r *fakeReader) Read([]byte) (int, error) { return 0, r.err }
37
38 type fakeWriter struct{}
39
40 func (*fakeWriter) Write([]byte) (int, error) { return 0, nil }
41
42 type fakeStreamCreator struct {
43 created map[string]bool
44 errors map[string]error
45 }
46
47 var _ streamCreator = &fakeStreamCreator{}
48
49 func (f *fakeStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, error) {
50 streamType := headers.Get(v1.StreamType)
51 f.created[streamType] = true
52 return nil, f.errors[streamType]
53 }
54
55 func TestV2CreateStreams(t *testing.T) {
56 tests := []struct {
57 name string
58 stdin bool
59 stdinError error
60 stdout bool
61 stdoutError error
62 stderr bool
63 stderrError error
64 errorError error
65 tty bool
66 expectError bool
67 }{
68 {
69 name: "stdin error",
70 stdin: true,
71 stdinError: errors.New("stdin error"),
72 expectError: true,
73 },
74 {
75 name: "stdout error",
76 stdout: true,
77 stdoutError: errors.New("stdout error"),
78 expectError: true,
79 },
80 {
81 name: "stderr error",
82 stderr: true,
83 stderrError: errors.New("stderr error"),
84 expectError: true,
85 },
86 {
87 name: "error stream error",
88 stdin: true,
89 stdout: true,
90 stderr: true,
91 errorError: errors.New("error stream error"),
92 expectError: true,
93 },
94 {
95 name: "no errors",
96 stdin: true,
97 stdout: true,
98 stderr: true,
99 expectError: false,
100 },
101 {
102 name: "no errors, stderr & tty set, don't expect stderr",
103 stdin: true,
104 stdout: true,
105 stderr: true,
106 tty: true,
107 expectError: false,
108 },
109 }
110 for _, test := range tests {
111 conn := &fakeStreamCreator{
112 created: make(map[string]bool),
113 errors: map[string]error{
114 v1.StreamTypeStdin: test.stdinError,
115 v1.StreamTypeStdout: test.stdoutError,
116 v1.StreamTypeStderr: test.stderrError,
117 v1.StreamTypeError: test.errorError,
118 },
119 }
120
121 opts := StreamOptions{Tty: test.tty}
122 if test.stdin {
123 opts.Stdin = &fakeReader{}
124 }
125 if test.stdout {
126 opts.Stdout = &fakeWriter{}
127 }
128 if test.stderr {
129 opts.Stderr = &fakeWriter{}
130 }
131
132 h := newStreamProtocolV2(opts).(*streamProtocolV2)
133 err := h.createStreams(conn)
134
135 if test.expectError {
136 if err == nil {
137 t.Errorf("%s: expected error", test.name)
138 continue
139 }
140 if e, a := test.stdinError, err; test.stdinError != nil && e != a {
141 t.Errorf("%s: expected %v, got %v", test.name, e, a)
142 }
143 if e, a := test.stdoutError, err; test.stdoutError != nil && e != a {
144 t.Errorf("%s: expected %v, got %v", test.name, e, a)
145 }
146 if e, a := test.stderrError, err; test.stderrError != nil && e != a {
147 t.Errorf("%s: expected %v, got %v", test.name, e, a)
148 }
149 if e, a := test.errorError, err; test.errorError != nil && e != a {
150 t.Errorf("%s: expected %v, got %v", test.name, e, a)
151 }
152 continue
153 }
154
155 if !test.expectError && err != nil {
156 t.Errorf("%s: unexpected error: %v", test.name, err)
157 continue
158 }
159
160 if test.stdin && !conn.created[v1.StreamTypeStdin] {
161 t.Errorf("%s: expected stdin stream", test.name)
162 }
163 if test.stdout && !conn.created[v1.StreamTypeStdout] {
164 t.Errorf("%s: expected stdout stream", test.name)
165 }
166 if test.stderr {
167 if test.tty && conn.created[v1.StreamTypeStderr] {
168 t.Errorf("%s: unexpected stderr stream because tty is set", test.name)
169 } else if !test.tty && !conn.created[v1.StreamTypeStderr] {
170 t.Errorf("%s: expected stderr stream", test.name)
171 }
172 }
173 if !conn.created[v1.StreamTypeError] {
174 t.Errorf("%s: expected error stream", test.name)
175 }
176
177 }
178 }
179
180 func TestV2ErrorStreamReading(t *testing.T) {
181 tests := []struct {
182 name string
183 stream io.Reader
184 expectedError error
185 }{
186 {
187 name: "error reading from stream",
188 stream: &fakeReader{errors.New("foo")},
189 expectedError: errors.New("error reading from error stream: foo"),
190 },
191 {
192 name: "stream returns an error",
193 stream: strings.NewReader("some error"),
194 expectedError: errors.New("error executing remote command: some error"),
195 },
196 }
197
198 for _, test := range tests {
199 h := newStreamProtocolV2(StreamOptions{}).(*streamProtocolV2)
200 h.errorStream = test.stream
201
202 ch := watchErrorStream(h.errorStream, &errorDecoderV2{})
203 if ch == nil {
204 t.Fatalf("%s: unexpected nil channel", test.name)
205 }
206
207 var err error
208 select {
209 case err = <-ch:
210 case <-time.After(wait.ForeverTestTimeout):
211 t.Fatalf("%s: timed out", test.name)
212 }
213
214 if test.expectedError != nil {
215 if err == nil {
216 t.Errorf("%s: expected an error", test.name)
217 } else if e, a := test.expectedError, err; e.Error() != a.Error() {
218 t.Errorf("%s: expected %q, got %q", test.name, e, a)
219 }
220 continue
221 }
222
223 if test.expectedError == nil && err != nil {
224 t.Errorf("%s: unexpected error: %v", test.name, err)
225 continue
226 }
227 }
228 }
229
View as plain text