1
16
17 package portforward
18
19 import (
20 "errors"
21 "fmt"
22 "io"
23 "net"
24 "net/http"
25 "sort"
26 "strconv"
27 "strings"
28 "sync"
29
30 v1 "k8s.io/api/core/v1"
31 "k8s.io/apimachinery/pkg/util/httpstream"
32 "k8s.io/apimachinery/pkg/util/runtime"
33 netutils "k8s.io/utils/net"
34 )
35
36
37
38 const PortForwardProtocolV1Name = "portforward.k8s.io"
39
40 var ErrLostConnectionToPod = errors.New("lost connection to pod")
41
42
43
44 type PortForwarder struct {
45 addresses []listenAddress
46 ports []ForwardedPort
47 stopChan <-chan struct{}
48
49 dialer httpstream.Dialer
50 streamConn httpstream.Connection
51 listeners []io.Closer
52 Ready chan struct{}
53 requestIDLock sync.Mutex
54 requestID int
55 out io.Writer
56 errOut io.Writer
57 }
58
59
60 type ForwardedPort struct {
61 Local uint16
62 Remote uint16
63 }
64
65
79 func parsePorts(ports []string) ([]ForwardedPort, error) {
80 var forwards []ForwardedPort
81 for _, portString := range ports {
82 parts := strings.Split(portString, ":")
83 var localString, remoteString string
84 if len(parts) == 1 {
85 localString = parts[0]
86 remoteString = parts[0]
87 } else if len(parts) == 2 {
88 localString = parts[0]
89 if localString == "" {
90
91 localString = "0"
92 }
93 remoteString = parts[1]
94 } else {
95 return nil, fmt.Errorf("invalid port format '%s'", portString)
96 }
97
98 localPort, err := strconv.ParseUint(localString, 10, 16)
99 if err != nil {
100 return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
101 }
102
103 remotePort, err := strconv.ParseUint(remoteString, 10, 16)
104 if err != nil {
105 return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
106 }
107 if remotePort == 0 {
108 return nil, fmt.Errorf("remote port must be > 0")
109 }
110
111 forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
112 }
113
114 return forwards, nil
115 }
116
117 type listenAddress struct {
118 address string
119 protocol string
120 failureMode string
121 }
122
123 func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
124 var addresses []listenAddress
125 parsed := make(map[string]listenAddress)
126 for _, address := range addressesToParse {
127 if address == "localhost" {
128 if _, exists := parsed["127.0.0.1"]; !exists {
129 ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
130 parsed[ip.address] = ip
131 }
132 if _, exists := parsed["::1"]; !exists {
133 ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
134 parsed[ip.address] = ip
135 }
136 } else if netutils.ParseIPSloppy(address).To4() != nil {
137 parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
138 } else if netutils.ParseIPSloppy(address) != nil {
139 parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
140 } else {
141 return nil, fmt.Errorf("%s is not a valid IP", address)
142 }
143 }
144 addresses = make([]listenAddress, len(parsed))
145 id := 0
146 for _, v := range parsed {
147 addresses[id] = v
148 id++
149 }
150
151 sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })
152
153 return addresses, nil
154 }
155
156
157 func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
158 return NewOnAddresses(dialer, []string{"localhost"}, ports, stopChan, readyChan, out, errOut)
159 }
160
161
162 func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
163 if len(addresses) == 0 {
164 return nil, errors.New("you must specify at least 1 address")
165 }
166 parsedAddresses, err := parseAddresses(addresses)
167 if err != nil {
168 return nil, err
169 }
170 if len(ports) == 0 {
171 return nil, errors.New("you must specify at least 1 port")
172 }
173 parsedPorts, err := parsePorts(ports)
174 if err != nil {
175 return nil, err
176 }
177 return &PortForwarder{
178 dialer: dialer,
179 addresses: parsedAddresses,
180 ports: parsedPorts,
181 stopChan: stopChan,
182 Ready: readyChan,
183 out: out,
184 errOut: errOut,
185 }, nil
186 }
187
188
189
190 func (pf *PortForwarder) ForwardPorts() error {
191 defer pf.Close()
192
193 var err error
194 var protocol string
195 pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
196 if err != nil {
197 return fmt.Errorf("error upgrading connection: %s", err)
198 }
199 defer pf.streamConn.Close()
200 if protocol != PortForwardProtocolV1Name {
201 return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
202 }
203
204 return pf.forward()
205 }
206
207
208
209
210 func (pf *PortForwarder) forward() error {
211 var err error
212
213 listenSuccess := false
214 for i := range pf.ports {
215 port := &pf.ports[i]
216 err = pf.listenOnPort(port)
217 switch {
218 case err == nil:
219 listenSuccess = true
220 default:
221 if pf.errOut != nil {
222 fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
223 }
224 }
225 }
226
227 if !listenSuccess {
228 return fmt.Errorf("unable to listen on any of the requested ports: %v", pf.ports)
229 }
230
231 if pf.Ready != nil {
232 close(pf.Ready)
233 }
234
235
236 select {
237 case <-pf.stopChan:
238 case <-pf.streamConn.CloseChan():
239 return ErrLostConnectionToPod
240 }
241
242 return nil
243 }
244
245
246
247 func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
248 var errors []error
249 failCounters := make(map[string]int, 2)
250 successCounters := make(map[string]int, 2)
251 for _, addr := range pf.addresses {
252 err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
253 if err != nil {
254 errors = append(errors, err)
255 failCounters[addr.failureMode]++
256 } else {
257 successCounters[addr.failureMode]++
258 }
259 }
260 if successCounters["all"] == 0 && failCounters["all"] > 0 {
261 return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
262 }
263 if failCounters["any"] > 0 {
264 return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
265 }
266 return nil
267 }
268
269
270
271 func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
272 listener, err := pf.getListener(protocol, address, port)
273 if err != nil {
274 return err
275 }
276 pf.listeners = append(pf.listeners, listener)
277 go pf.waitForConnection(listener, *port)
278 return nil
279 }
280
281
282
283 func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
284 listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
285 if err != nil {
286 return nil, fmt.Errorf("unable to create listener: Error %s", err)
287 }
288 listenerAddress := listener.Addr().String()
289 host, localPort, _ := net.SplitHostPort(listenerAddress)
290 localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
291
292 if err != nil {
293 fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote)
294 return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
295 }
296 port.Local = uint16(localPortUInt)
297 if pf.out != nil {
298 fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote)
299 }
300
301 return listener, nil
302 }
303
304
305
306 func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
307 for {
308 select {
309 case <-pf.streamConn.CloseChan():
310 return
311 default:
312 conn, err := listener.Accept()
313 if err != nil {
314
315 if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
316 runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
317 }
318 return
319 }
320 go pf.handleConnection(conn, port)
321 }
322 }
323 }
324
325 func (pf *PortForwarder) nextRequestID() int {
326 pf.requestIDLock.Lock()
327 defer pf.requestIDLock.Unlock()
328 id := pf.requestID
329 pf.requestID++
330 return id
331 }
332
333
334
335 func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
336 defer conn.Close()
337
338 if pf.out != nil {
339 fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local)
340 }
341
342 requestID := pf.nextRequestID()
343
344
345 headers := http.Header{}
346 headers.Set(v1.StreamType, v1.StreamTypeError)
347 headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
348 headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
349 errorStream, err := pf.streamConn.CreateStream(headers)
350 if err != nil {
351 runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
352 return
353 }
354
355 errorStream.Close()
356 defer pf.streamConn.RemoveStreams(errorStream)
357
358 errorChan := make(chan error)
359 go func() {
360 message, err := io.ReadAll(errorStream)
361 switch {
362 case err != nil:
363 errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
364 case len(message) > 0:
365 errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
366 }
367 close(errorChan)
368 }()
369
370
371 headers.Set(v1.StreamType, v1.StreamTypeData)
372 dataStream, err := pf.streamConn.CreateStream(headers)
373 if err != nil {
374 runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
375 return
376 }
377 defer pf.streamConn.RemoveStreams(dataStream)
378
379 localError := make(chan struct{})
380 remoteDone := make(chan struct{})
381
382 go func() {
383
384 if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
385 runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
386 }
387
388
389 close(remoteDone)
390 }()
391
392 go func() {
393
394 defer dataStream.Close()
395
396
397 if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
398 runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
399
400 close(localError)
401 }
402 }()
403
404
405 select {
406 case <-remoteDone:
407 case <-localError:
408 }
409
410
411 err = <-errorChan
412 if err != nil {
413 runtime.HandleError(err)
414 pf.streamConn.Close()
415 }
416 }
417
418
419 func (pf *PortForwarder) Close() {
420
421 for _, l := range pf.listeners {
422 if err := l.Close(); err != nil {
423 runtime.HandleError(fmt.Errorf("error closing listener: %v", err))
424 }
425 }
426 }
427
428
429
430
431
432
433 func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
434 if pf.Ready == nil {
435 return nil, fmt.Errorf("no Ready channel provided")
436 }
437 select {
438 case <-pf.Ready:
439 return pf.ports, nil
440 default:
441 return nil, fmt.Errorf("listeners not ready")
442 }
443 }
444
View as plain text