1
16
17 package websocket
18
19 import (
20 "crypto/tls"
21 "errors"
22 "fmt"
23 "net/http"
24 "net/url"
25
26 gwebsocket "github.com/gorilla/websocket"
27
28 "k8s.io/apimachinery/pkg/util/httpstream"
29 "k8s.io/apimachinery/pkg/util/httpstream/wsstream"
30 utilnet "k8s.io/apimachinery/pkg/util/net"
31 restclient "k8s.io/client-go/rest"
32 "k8s.io/client-go/transport"
33 )
34
35 var (
36 _ utilnet.TLSClientConfigHolder = &RoundTripper{}
37 _ http.RoundTripper = &RoundTripper{}
38 )
39
40
41
42 type ConnectionHolder interface {
43 DataBufferSize() int
44 Connection() *gwebsocket.Conn
45 }
46
47
48
49 type RoundTripper struct {
50
51
52 TLSConfig *tls.Config
53
54
55
56
57
58 Proxier func(req *http.Request) (*url.URL, error)
59
60
61 Conn *gwebsocket.Conn
62 }
63
64
65 func (rt *RoundTripper) Connection() *gwebsocket.Conn {
66 return rt.Conn
67 }
68
69
70
71 func (rt *RoundTripper) DataBufferSize() int {
72 return 32 * 1024
73 }
74
75
76 func (rt *RoundTripper) TLSClientConfig() *tls.Config {
77 return rt.TLSConfig
78 }
79
80
81
82 func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response, retErr error) {
83 defer func() {
84 if request.Body != nil {
85 err := request.Body.Close()
86 if retErr == nil {
87 retErr = err
88 }
89 }
90 }()
91
92
93 protocolVersions := request.Header[wsstream.WebSocketProtocolHeader]
94 delete(request.Header, wsstream.WebSocketProtocolHeader)
95
96 dialer := gwebsocket.Dialer{
97 Proxy: rt.Proxier,
98 TLSClientConfig: rt.TLSConfig,
99 Subprotocols: protocolVersions,
100 ReadBufferSize: rt.DataBufferSize() + 1024,
101 WriteBufferSize: rt.DataBufferSize() + 1024,
102 }
103 switch request.URL.Scheme {
104 case "https":
105 request.URL.Scheme = "wss"
106 case "http":
107 request.URL.Scheme = "ws"
108 default:
109 return nil, fmt.Errorf("unknown url scheme: %s", request.URL.Scheme)
110 }
111 wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
112 if err != nil {
113 if errors.Is(err, gwebsocket.ErrBadHandshake) {
114 return nil, &httpstream.UpgradeFailureError{Cause: err}
115 }
116 return nil, err
117 }
118
119
120 foundProtocol := false
121 for _, protocolVersion := range protocolVersions {
122 if protocolVersion == wsConn.Subprotocol() {
123 foundProtocol = true
124 break
125 }
126 }
127 if !foundProtocol {
128 wsConn.Close()
129 return nil, &httpstream.UpgradeFailureError{Cause: fmt.Errorf("invalid protocol, expected one of %q, got %q", protocolVersions, wsConn.Subprotocol())}
130 }
131
132 rt.Conn = wsConn
133
134 return resp, nil
135 }
136
137
138
139
140
141 func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHolder, error) {
142 transportCfg, err := config.TransportConfig()
143 if err != nil {
144 return nil, nil, err
145 }
146 tlsConfig, err := transport.TLSConfigFor(transportCfg)
147 if err != nil {
148 return nil, nil, err
149 }
150 proxy := config.Proxy
151 if proxy == nil {
152 proxy = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
153 }
154
155 upgradeRoundTripper := &RoundTripper{
156 TLSConfig: tlsConfig,
157 Proxier: proxy,
158 }
159 wrapper, err := transport.HTTPWrappersForConfig(transportCfg, upgradeRoundTripper)
160 if err != nil {
161 return nil, nil, err
162 }
163 return wrapper, upgradeRoundTripper, nil
164 }
165
166
167
168
169 func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) {
170
171 req.Header[wsstream.WebSocketProtocolHeader] = protocols
172 resp, err := rt.RoundTrip(req)
173 if err != nil {
174 return nil, err
175 }
176 err = resp.Body.Close()
177 if err != nil {
178 connectionInfo.Connection().Close()
179 return nil, fmt.Errorf("error closing response body: %v", err)
180 }
181 return connectionInfo.Connection(), nil
182 }
183
View as plain text