1
16
17 package wsstream
18
19 import (
20 "encoding/base64"
21 "fmt"
22 "io"
23 "net/http"
24 "strings"
25 "time"
26
27 "golang.org/x/net/websocket"
28
29 "k8s.io/apimachinery/pkg/util/httpstream"
30 "k8s.io/apimachinery/pkg/util/portforward"
31 "k8s.io/apimachinery/pkg/util/remotecommand"
32 "k8s.io/apimachinery/pkg/util/runtime"
33 "k8s.io/klog/v2"
34 )
35
36 const WebSocketProtocolHeader = "Sec-Websocket-Protocol"
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51 const ChannelWebSocketProtocol = "channel.k8s.io"
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66 const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
67
68 type codecType int
69
70 const (
71 rawCodec codecType = iota
72 base64Codec
73 )
74
75 type ChannelType int
76
77 const (
78 IgnoreChannel ChannelType = iota
79 ReadChannel
80 WriteChannel
81 ReadWriteChannel
82 )
83
84
85
86 func IsWebSocketRequest(req *http.Request) bool {
87 if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
88 return false
89 }
90 return httpstream.IsUpgradeRequest(req)
91 }
92
93
94
95
96 func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool {
97 if !IsWebSocketRequest(req) {
98 return false
99 }
100 requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
101 for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
102 if protocolSupportsStreamClose(strings.TrimSpace(requestedProtocol)) {
103 return true
104 }
105 }
106
107 return false
108 }
109
110
111
112
113 func IsWebSocketRequestWithTunnelingProtocol(req *http.Request) bool {
114 if !IsWebSocketRequest(req) {
115 return false
116 }
117 requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
118 for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
119 if protocolSupportsWebsocketTunneling(strings.TrimSpace(requestedProtocol)) {
120 return true
121 }
122 }
123
124 return false
125 }
126
127
128
129 func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
130 defer runtime.HandleCrash()
131 var data []byte
132 for {
133 resetTimeout(ws, timeout)
134 if err := websocket.Message.Receive(ws, &data); err != nil {
135 return
136 }
137 }
138 }
139
140
141
142 func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
143 protocols := config.Protocol
144 if len(protocols) == 0 {
145 protocols = []string{""}
146 }
147
148 for _, protocol := range protocols {
149 for _, allow := range allowed {
150 if allow == protocol {
151 config.Protocol = []string{protocol}
152 return nil
153 }
154 }
155 }
156
157 return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
158 }
159
160
161 type ChannelProtocolConfig struct {
162 Binary bool
163 Channels []ChannelType
164 }
165
166
167
168
169 func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
170 return map[string]ChannelProtocolConfig{
171 "": {Binary: true, Channels: channels},
172 ChannelWebSocketProtocol: {Binary: true, Channels: channels},
173 Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
174 }
175 }
176
177
178 type Conn struct {
179 protocols map[string]ChannelProtocolConfig
180 selectedProtocol string
181 channels []*websocketChannel
182 codec codecType
183 ready chan struct{}
184 ws *websocket.Conn
185 timeout time.Duration
186 }
187
188
189
190
191
192
193
194
195 func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
196 return &Conn{
197 ready: make(chan struct{}),
198 protocols: protocols,
199 }
200 }
201
202
203
204 func (conn *Conn) SetIdleTimeout(duration time.Duration) {
205 conn.timeout = duration
206 }
207
208
209
210
211 func (conn *Conn) SetWriteDeadline(duration time.Duration) {
212 conn.ws.SetWriteDeadline(time.Now().Add(duration))
213 }
214
215
216
217 func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
218
219 serveHTTPComplete := make(chan struct{})
220
221 panicChan := make(chan any, 1)
222 go func() {
223
224
225 defer func() {
226 if p := recover(); p != nil {
227 panicChan <- p
228 } else {
229 close(serveHTTPComplete)
230 }
231 }()
232 websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
233 }()
234
235
236
237 select {
238 case <-conn.ready:
239 klog.V(8).Infof("websocket server initialized--serving")
240 case <-serveHTTPComplete:
241
242 conn.closeNonThreadSafe()
243 return "", nil, fmt.Errorf("websocket server finished before becoming ready")
244 case p := <-panicChan:
245 panic(p)
246 }
247
248 rwc := make([]io.ReadWriteCloser, len(conn.channels))
249 for i := range conn.channels {
250 rwc[i] = conn.channels[i]
251 }
252 return conn.selectedProtocol, rwc, nil
253 }
254
255 func (conn *Conn) initialize(ws *websocket.Conn) {
256 negotiated := ws.Config().Protocol
257 conn.selectedProtocol = negotiated[0]
258 p := conn.protocols[conn.selectedProtocol]
259 if p.Binary {
260 conn.codec = rawCodec
261 } else {
262 conn.codec = base64Codec
263 }
264 conn.ws = ws
265 conn.channels = make([]*websocketChannel, len(p.Channels))
266 for i, t := range p.Channels {
267 switch t {
268 case ReadChannel:
269 conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
270 case WriteChannel:
271 conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
272 case ReadWriteChannel:
273 conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
274 case IgnoreChannel:
275 conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
276 }
277 }
278
279 close(conn.ready)
280 }
281
282 func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
283 supportedProtocols := make([]string, 0, len(conn.protocols))
284 for p := range conn.protocols {
285 supportedProtocols = append(supportedProtocols, p)
286 }
287 return handshake(config, req, supportedProtocols)
288 }
289
290 func (conn *Conn) resetTimeout() {
291 if conn.timeout > 0 {
292 conn.ws.SetDeadline(time.Now().Add(conn.timeout))
293 }
294 }
295
296
297
298 func (conn *Conn) closeNonThreadSafe() error {
299 for _, s := range conn.channels {
300 s.Close()
301 }
302 var err error
303 if conn.ws != nil {
304 err = conn.ws.Close()
305 }
306 return err
307 }
308
309
310 func (conn *Conn) Close() error {
311 <-conn.ready
312 return conn.closeNonThreadSafe()
313 }
314
315
316
317
318 func protocolSupportsStreamClose(protocol string) bool {
319 return protocol == remotecommand.StreamProtocolV5Name
320 }
321
322
323
324 func protocolSupportsWebsocketTunneling(protocol string) bool {
325 return strings.HasPrefix(protocol, portforward.WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, portforward.KubernetesSuffix)
326 }
327
328
329 func (conn *Conn) handle(ws *websocket.Conn) {
330 conn.initialize(ws)
331 defer conn.Close()
332 supportsStreamClose := protocolSupportsStreamClose(conn.selectedProtocol)
333
334 for {
335 conn.resetTimeout()
336 var data []byte
337 if err := websocket.Message.Receive(ws, &data); err != nil {
338 if err != io.EOF {
339 klog.Errorf("Error on socket receive: %v", err)
340 }
341 break
342 }
343 if len(data) == 0 {
344 continue
345 }
346 if supportsStreamClose && data[0] == remotecommand.StreamClose {
347 if len(data) != 2 {
348 klog.Errorf("Single channel byte should follow stream close signal. Got %d bytes", len(data)-1)
349 break
350 } else {
351 channel := data[1]
352 if int(channel) >= len(conn.channels) {
353 klog.Errorf("Close is targeted for a channel %d that is not valid, possible protocol error", channel)
354 break
355 }
356 klog.V(4).Infof("Received half-close signal from client; close %d stream", channel)
357 conn.channels[channel].Close()
358 }
359 continue
360 }
361 channel := data[0]
362 if conn.codec == base64Codec {
363 channel = channel - '0'
364 }
365 data = data[1:]
366 if int(channel) >= len(conn.channels) {
367 klog.V(6).Infof("Frame is targeted for a reader %d that is not valid, possible protocol error", channel)
368 continue
369 }
370 if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
371 klog.Errorf("Unable to write frame (%d bytes) to %d: %v", len(data), channel, err)
372 continue
373 }
374 }
375 }
376
377
378 func (conn *Conn) write(num byte, data []byte) (int, error) {
379 conn.resetTimeout()
380 switch conn.codec {
381 case rawCodec:
382 frame := make([]byte, len(data)+1)
383 frame[0] = num
384 copy(frame[1:], data)
385 if err := websocket.Message.Send(conn.ws, frame); err != nil {
386 return 0, err
387 }
388 case base64Codec:
389 frame := string('0'+num) + base64.StdEncoding.EncodeToString(data)
390 if err := websocket.Message.Send(conn.ws, frame); err != nil {
391 return 0, err
392 }
393 }
394 return len(data), nil
395 }
396
397
398 type websocketChannel struct {
399 conn *Conn
400 num byte
401 r io.Reader
402 w io.WriteCloser
403
404 read, write bool
405 }
406
407
408
409
410 func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel {
411 r, w := io.Pipe()
412 return &websocketChannel{conn, num, r, w, read, write}
413 }
414
415 func (p *websocketChannel) Write(data []byte) (int, error) {
416 if !p.write {
417 return len(data), nil
418 }
419 return p.conn.write(p.num, data)
420 }
421
422
423
424 func (p *websocketChannel) DataFromSocket(data []byte) (int, error) {
425 if !p.read {
426 return len(data), nil
427 }
428
429 switch p.conn.codec {
430 case rawCodec:
431 return p.w.Write(data)
432 case base64Codec:
433 dst := make([]byte, len(data))
434 n, err := base64.StdEncoding.Decode(dst, data)
435 if err != nil {
436 return 0, err
437 }
438 return p.w.Write(dst[:n])
439 }
440 return 0, nil
441 }
442
443 func (p *websocketChannel) Read(data []byte) (int, error) {
444 if !p.read {
445 return 0, io.EOF
446 }
447 return p.r.Read(data)
448 }
449
450 func (p *websocketChannel) Close() error {
451 return p.w.Close()
452 }
453
View as plain text