...
1
16
17 package wsstream
18
19 import (
20 "encoding/base64"
21 "io"
22 "net/http"
23 "sync"
24 "time"
25
26 "golang.org/x/net/websocket"
27
28 "k8s.io/apimachinery/pkg/util/runtime"
29 )
30
31
32
33
34 const binaryWebSocketProtocol = "binary.k8s.io"
35
36
37
38
39
40 const base64BinaryWebSocketProtocol = "base64.binary.k8s.io"
41
42
43 type ReaderProtocolConfig struct {
44 Binary bool
45 }
46
47
48
49 func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig {
50 return map[string]ReaderProtocolConfig{
51 "": {Binary: true},
52 binaryWebSocketProtocol: {Binary: true},
53 base64BinaryWebSocketProtocol: {Binary: false},
54 }
55 }
56
57
58 type Reader struct {
59 err chan error
60 r io.Reader
61 ping bool
62 timeout time.Duration
63 protocols map[string]ReaderProtocolConfig
64 selectedProtocol string
65
66 handleCrash func(additionalHandlers ...func(interface{}))
67 }
68
69
70
71
72
73
74
75 func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader {
76 return &Reader{
77 r: r,
78 err: make(chan error),
79 ping: ping,
80 protocols: protocols,
81 handleCrash: runtime.HandleCrash,
82 }
83 }
84
85
86
87 func (r *Reader) SetIdleTimeout(duration time.Duration) {
88 r.timeout = duration
89 }
90
91 func (r *Reader) handshake(config *websocket.Config, req *http.Request) error {
92 supportedProtocols := make([]string, 0, len(r.protocols))
93 for p := range r.protocols {
94 supportedProtocols = append(supportedProtocols, p)
95 }
96 return handshake(config, req, supportedProtocols)
97 }
98
99
100
101 func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
102 go func() {
103 defer r.handleCrash()
104 websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req)
105 }()
106 return <-r.err
107 }
108
109
110 func (r *Reader) handle(ws *websocket.Conn) {
111
112 closeConnOnce := &sync.Once{}
113 closeConn := func() {
114 closeConnOnce.Do(func() {
115 ws.Close()
116 })
117 }
118
119 negotiated := ws.Config().Protocol
120 r.selectedProtocol = negotiated[0]
121 defer close(r.err)
122 defer closeConn()
123
124 go func() {
125 defer runtime.HandleCrash()
126
127
128 IgnoreReceives(ws, r.timeout)
129
130 closeConn()
131 }()
132
133 r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout)
134 }
135
136 func resetTimeout(ws *websocket.Conn, timeout time.Duration) {
137 if timeout > 0 {
138 ws.SetDeadline(time.Now().Add(timeout))
139 }
140 }
141
142 func messageCopy(ws *websocket.Conn, r io.Reader, base64Encode, ping bool, timeout time.Duration) error {
143 buf := make([]byte, 2048)
144 if ping {
145 resetTimeout(ws, timeout)
146 if base64Encode {
147 if err := websocket.Message.Send(ws, ""); err != nil {
148 return err
149 }
150 } else {
151 if err := websocket.Message.Send(ws, []byte{}); err != nil {
152 return err
153 }
154 }
155 }
156 for {
157 resetTimeout(ws, timeout)
158 n, err := r.Read(buf)
159 if err != nil {
160 if err == io.EOF {
161 return nil
162 }
163 return err
164 }
165 if n > 0 {
166 if base64Encode {
167 if err := websocket.Message.Send(ws, base64.StdEncoding.EncodeToString(buf[:n])); err != nil {
168 return err
169 }
170 } else {
171 if err := websocket.Message.Send(ws, buf[:n]); err != nil {
172 return err
173 }
174 }
175 }
176 }
177 }
178
View as plain text