1
16
17 package remotecommand
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "io"
24 "net"
25 "net/http"
26 "sync"
27 "time"
28
29 gwebsocket "github.com/gorilla/websocket"
30
31 v1 "k8s.io/api/core/v1"
32 "k8s.io/apimachinery/pkg/util/httpstream"
33 "k8s.io/apimachinery/pkg/util/remotecommand"
34 restclient "k8s.io/client-go/rest"
35 "k8s.io/client-go/transport/websocket"
36 "k8s.io/klog/v2"
37 )
38
39
40
41 const writeDeadline = 60 * time.Second
42
43 var (
44 _ Executor = &wsStreamExecutor{}
45 _ streamCreator = &wsStreamCreator{}
46 _ httpstream.Stream = &stream{}
47
48 streamType2streamID = map[string]byte{
49 v1.StreamTypeStdin: remotecommand.StreamStdIn,
50 v1.StreamTypeStdout: remotecommand.StreamStdOut,
51 v1.StreamTypeStderr: remotecommand.StreamStdErr,
52 v1.StreamTypeError: remotecommand.StreamErr,
53 v1.StreamTypeResize: remotecommand.StreamResize,
54 }
55 )
56
57 const (
58
59 pingPeriod = 5 * time.Second
60
61
62
63
64
65 pingReadDeadline = (pingPeriod * 12) + (1 * time.Second)
66 )
67
68
69 type wsStreamExecutor struct {
70 transport http.RoundTripper
71 upgrader websocket.ConnectionHolder
72 method string
73 url string
74
75 protocols []string
76
77 negotiated string
78
79 heartbeatPeriod time.Duration
80
81 heartbeatDeadline time.Duration
82 }
83
84 func NewWebSocketExecutor(config *restclient.Config, method, url string) (Executor, error) {
85
86
87
88
89 return NewWebSocketExecutorForProtocols(config, method, url, remotecommand.StreamProtocolV5Name)
90 }
91
92
93 func NewWebSocketExecutorForProtocols(config *restclient.Config, method, url string, protocols ...string) (Executor, error) {
94 transport, upgrader, err := websocket.RoundTripperFor(config)
95 if err != nil {
96 return nil, fmt.Errorf("error creating websocket transports: %v", err)
97 }
98 return &wsStreamExecutor{
99 transport: transport,
100 upgrader: upgrader,
101 method: method,
102 url: url,
103 protocols: protocols,
104 heartbeatPeriod: pingPeriod,
105 heartbeatDeadline: pingReadDeadline,
106 }, nil
107 }
108
109
110
111 func (e *wsStreamExecutor) Stream(options StreamOptions) error {
112 return e.StreamWithContext(context.Background(), options)
113 }
114
115
116
117
118
119 func (e *wsStreamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error {
120 req, err := http.NewRequestWithContext(ctx, e.method, e.url, nil)
121 if err != nil {
122 return err
123 }
124 conn, err := websocket.Negotiate(e.transport, e.upgrader, req, e.protocols...)
125 if err != nil {
126 return err
127 }
128 if conn == nil {
129 panic(fmt.Errorf("websocket connection is nil"))
130 }
131 defer conn.Close()
132 e.negotiated = conn.Subprotocol()
133 klog.V(4).Infof("The subprotocol is %s", e.negotiated)
134
135 var streamer streamProtocolHandler
136 switch e.negotiated {
137 case remotecommand.StreamProtocolV5Name:
138 streamer = newStreamProtocolV5(options)
139 case remotecommand.StreamProtocolV4Name:
140 streamer = newStreamProtocolV4(options)
141 case remotecommand.StreamProtocolV3Name:
142 streamer = newStreamProtocolV3(options)
143 case remotecommand.StreamProtocolV2Name:
144 streamer = newStreamProtocolV2(options)
145 case "":
146 klog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", remotecommand.StreamProtocolV1Name)
147 fallthrough
148 case remotecommand.StreamProtocolV1Name:
149 streamer = newStreamProtocolV1(options)
150 }
151
152 panicChan := make(chan any, 1)
153 errorChan := make(chan error, 1)
154 go func() {
155 defer func() {
156 if p := recover(); p != nil {
157 panicChan <- p
158 }
159 }()
160 creator := newWSStreamCreator(conn)
161 go creator.readDemuxLoop(
162 e.upgrader.DataBufferSize(),
163 e.heartbeatPeriod,
164 e.heartbeatDeadline,
165 )
166 errorChan <- streamer.stream(creator)
167 }()
168
169 select {
170 case p := <-panicChan:
171 panic(p)
172 case err := <-errorChan:
173 return err
174 case <-ctx.Done():
175 return ctx.Err()
176 }
177 }
178
179 type wsStreamCreator struct {
180 conn *gwebsocket.Conn
181
182 connWriteLock sync.Mutex
183
184 streams map[byte]*stream
185 streamsMu sync.Mutex
186
187
188 setStreamErr error
189 }
190
191 func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator {
192 return &wsStreamCreator{
193 conn: conn,
194 streams: map[byte]*stream{},
195 }
196 }
197
198 func (c *wsStreamCreator) getStream(id byte) *stream {
199 c.streamsMu.Lock()
200 defer c.streamsMu.Unlock()
201 return c.streams[id]
202 }
203
204 func (c *wsStreamCreator) setStream(id byte, s *stream) error {
205 c.streamsMu.Lock()
206 defer c.streamsMu.Unlock()
207 if c.setStreamErr != nil {
208 return c.setStreamErr
209 }
210 c.streams[id] = s
211 return nil
212 }
213
214
215
216 func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, error) {
217 streamType := headers.Get(v1.StreamType)
218 id, ok := streamType2streamID[streamType]
219 if !ok {
220 return nil, fmt.Errorf("unknown stream type: %s", streamType)
221 }
222 if s := c.getStream(id); s != nil {
223 return nil, fmt.Errorf("duplicate stream for type %s", streamType)
224 }
225 reader, writer := io.Pipe()
226 s := &stream{
227 headers: headers,
228 readPipe: reader,
229 writePipe: writer,
230 conn: c.conn,
231 connWriteLock: &c.connWriteLock,
232 id: id,
233 }
234 if err := c.setStream(id, s); err != nil {
235 _ = s.writePipe.Close()
236 _ = s.readPipe.Close()
237 return nil, err
238 }
239 return s, nil
240 }
241
242
243
244
245
246
247 func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, deadline time.Duration) {
248
249 h := newHeartbeat(c.conn, period, deadline)
250
251 if err := c.conn.SetReadDeadline(time.Now().Add(deadline)); err != nil {
252 klog.Errorf("Websocket initial setting read deadline failed %v", err)
253 return
254 }
255 go h.start()
256
257
258
259 readBuffer := make([]byte, bufferSize)
260 for {
261
262
263
264
265
266
267
268
269
270 messageType, r, err := c.conn.NextReader()
271 if err != nil {
272 websocketErr, ok := err.(*gwebsocket.CloseError)
273 if ok && websocketErr.Code == gwebsocket.CloseNormalClosure {
274 err = nil
275 } else {
276 err = fmt.Errorf("next reader: %w", err)
277 }
278 c.closeAllStreamReaders(err)
279 return
280 }
281
282 if messageType != gwebsocket.BinaryMessage {
283 c.closeAllStreamReaders(fmt.Errorf("unexpected message type: %d", messageType))
284 return
285 }
286
287
288 _, err = io.ReadFull(r, readBuffer[:1])
289 if err != nil {
290 c.closeAllStreamReaders(fmt.Errorf("read stream id: %w", err))
291 return
292 }
293 streamID := readBuffer[0]
294 s := c.getStream(streamID)
295 if s == nil {
296 klog.Errorf("Unknown stream id %d, discarding message", streamID)
297 continue
298 }
299 for {
300 nr, errRead := r.Read(readBuffer)
301 if nr > 0 {
302
303 _, errWrite := s.writePipe.Write(readBuffer[:nr])
304 if errWrite != nil {
305
306
307 break
308 }
309 }
310 if errRead != nil {
311 if errRead == io.EOF {
312 break
313 }
314 c.closeAllStreamReaders(fmt.Errorf("read message: %w", err))
315 return
316 }
317 }
318 }
319 }
320
321
322
323 func (c *wsStreamCreator) closeAllStreamReaders(err error) {
324 c.streamsMu.Lock()
325 defer c.streamsMu.Unlock()
326 for _, s := range c.streams {
327
328 _ = s.writePipe.CloseWithError(err)
329 }
330
331 if err != nil {
332 c.setStreamErr = err
333 } else {
334 c.setStreamErr = fmt.Errorf("closed all streams")
335 }
336 }
337
338 type stream struct {
339 headers http.Header
340 readPipe *io.PipeReader
341 writePipe *io.PipeWriter
342
343
344 conn *gwebsocket.Conn
345
346
347 connWriteLock *sync.Mutex
348 id byte
349 }
350
351 func (s *stream) Read(p []byte) (n int, err error) {
352 return s.readPipe.Read(p)
353 }
354
355
356 func (s *stream) Write(p []byte) (n int, err error) {
357 klog.V(4).Infof("Write() on stream %d", s.id)
358 defer klog.V(4).Infof("Write() done on stream %d", s.id)
359 s.connWriteLock.Lock()
360 defer s.connWriteLock.Unlock()
361 if s.conn == nil {
362 return 0, fmt.Errorf("write on closed stream %d", s.id)
363 }
364 err = s.conn.SetWriteDeadline(time.Now().Add(writeDeadline))
365 if err != nil {
366 klog.V(7).Infof("Websocket setting write deadline failed %v", err)
367 return 0, err
368 }
369
370
371 w, err := s.conn.NextWriter(gwebsocket.BinaryMessage)
372 if err != nil {
373 return 0, err
374 }
375 defer func() {
376 if w != nil {
377 w.Close()
378 }
379 }()
380 _, err = w.Write([]byte{s.id})
381 if err != nil {
382 return 0, err
383 }
384 n, err = w.Write(p)
385 if err != nil {
386 return n, err
387 }
388 err = w.Close()
389 w = nil
390 return n, err
391 }
392
393
394 func (s *stream) Close() error {
395 klog.V(4).Infof("Close() on stream %d", s.id)
396 defer klog.V(4).Infof("Close() done on stream %d", s.id)
397 s.connWriteLock.Lock()
398 defer s.connWriteLock.Unlock()
399 if s.conn == nil {
400 return fmt.Errorf("Close() on already closed stream %d", s.id)
401 }
402
403 err := s.conn.WriteMessage(gwebsocket.BinaryMessage, []byte{remotecommand.StreamClose, s.id})
404 s.conn = nil
405 return err
406 }
407
408 func (s *stream) Reset() error {
409 klog.V(4).Infof("Reset() on stream %d", s.id)
410 defer klog.V(4).Infof("Reset() done on stream %d", s.id)
411 s.Close()
412 return s.writePipe.Close()
413 }
414
415 func (s *stream) Headers() http.Header {
416 return s.headers
417 }
418
419 func (s *stream) Identifier() uint32 {
420 return uint32(s.id)
421 }
422
423
424
425
426
427
428
429 type heartbeat struct {
430 conn *gwebsocket.Conn
431
432 period time.Duration
433
434 closer chan struct{}
435
436 message []byte
437
438 pongMessage []byte
439 }
440
441
442
443
444 func newHeartbeat(conn *gwebsocket.Conn, period time.Duration, deadline time.Duration) *heartbeat {
445 h := &heartbeat{
446 conn: conn,
447 period: period,
448 closer: make(chan struct{}),
449 }
450
451
452
453 h.conn.SetPongHandler(func(msg string) error {
454
455 klog.V(8).Infof("Pong message received (%s)--resetting read deadline", msg)
456 err := h.conn.SetReadDeadline(time.Now().Add(deadline))
457 if err != nil {
458 klog.Errorf("Websocket setting read deadline failed %v", err)
459 return err
460 }
461 if len(msg) > 0 {
462 h.pongMessage = []byte(msg)
463 }
464 return nil
465 })
466
467 closeHandler := h.conn.CloseHandler()
468 h.conn.SetCloseHandler(func(code int, text string) error {
469 close(h.closer)
470 return closeHandler(code, text)
471 })
472 return h
473 }
474
475
476
477 func (h *heartbeat) setMessage(msg string) {
478 h.message = []byte(msg)
479 }
480
481
482
483 func (h *heartbeat) start() {
484
485 t := time.NewTicker(h.period)
486 defer t.Stop()
487 for {
488 select {
489 case <-h.closer:
490 klog.V(8).Infof("closed channel--returning")
491 return
492 case <-t.C:
493
494
495
496 if err := h.conn.WriteControl(gwebsocket.PingMessage, h.message, time.Now().Add(pingReadDeadline)); err == nil {
497 klog.V(8).Infof("Websocket Ping succeeeded")
498 } else {
499 klog.Errorf("Websocket Ping failed: %v", err)
500 if errors.Is(err, gwebsocket.ErrCloseSent) {
501
502 continue
503 } else if e, ok := err.(net.Error); ok && e.Timeout() {
504
505
506
507
508
509 continue
510 }
511 return
512 }
513 }
514 }
515 }
516
View as plain text