...

Source file src/k8s.io/client-go/tools/portforward/portforward.go

Documentation: k8s.io/client-go/tools/portforward

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package portforward
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"net"
    24  	"net/http"
    25  	"sort"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  
    30  	v1 "k8s.io/api/core/v1"
    31  	"k8s.io/apimachinery/pkg/util/httpstream"
    32  	"k8s.io/apimachinery/pkg/util/runtime"
    33  	netutils "k8s.io/utils/net"
    34  )
    35  
    36  // PortForwardProtocolV1Name is the subprotocol used for port forwarding.
    37  // TODO move to API machinery and re-unify with kubelet/server/portfoward
    38  const PortForwardProtocolV1Name = "portforward.k8s.io"
    39  
    40  var ErrLostConnectionToPod = errors.New("lost connection to pod")
    41  
    42  // PortForwarder knows how to listen for local connections and forward them to
    43  // a remote pod via an upgraded HTTP request.
    44  type PortForwarder struct {
    45  	addresses []listenAddress
    46  	ports     []ForwardedPort
    47  	stopChan  <-chan struct{}
    48  
    49  	dialer        httpstream.Dialer
    50  	streamConn    httpstream.Connection
    51  	listeners     []io.Closer
    52  	Ready         chan struct{}
    53  	requestIDLock sync.Mutex
    54  	requestID     int
    55  	out           io.Writer
    56  	errOut        io.Writer
    57  }
    58  
    59  // ForwardedPort contains a Local:Remote port pairing.
    60  type ForwardedPort struct {
    61  	Local  uint16
    62  	Remote uint16
    63  }
    64  
    65  /*
    66  valid port specifications:
    67  
    68  5000
    69  - forwards from localhost:5000 to pod:5000
    70  
    71  8888:5000
    72  - forwards from localhost:8888 to pod:5000
    73  
    74  0:5000
    75  :5000
    76    - selects a random available local port,
    77      forwards from localhost:<random port> to pod:5000
    78  */
    79  func parsePorts(ports []string) ([]ForwardedPort, error) {
    80  	var forwards []ForwardedPort
    81  	for _, portString := range ports {
    82  		parts := strings.Split(portString, ":")
    83  		var localString, remoteString string
    84  		if len(parts) == 1 {
    85  			localString = parts[0]
    86  			remoteString = parts[0]
    87  		} else if len(parts) == 2 {
    88  			localString = parts[0]
    89  			if localString == "" {
    90  				// support :5000
    91  				localString = "0"
    92  			}
    93  			remoteString = parts[1]
    94  		} else {
    95  			return nil, fmt.Errorf("invalid port format '%s'", portString)
    96  		}
    97  
    98  		localPort, err := strconv.ParseUint(localString, 10, 16)
    99  		if err != nil {
   100  			return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
   101  		}
   102  
   103  		remotePort, err := strconv.ParseUint(remoteString, 10, 16)
   104  		if err != nil {
   105  			return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
   106  		}
   107  		if remotePort == 0 {
   108  			return nil, fmt.Errorf("remote port must be > 0")
   109  		}
   110  
   111  		forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
   112  	}
   113  
   114  	return forwards, nil
   115  }
   116  
   117  type listenAddress struct {
   118  	address     string
   119  	protocol    string
   120  	failureMode string
   121  }
   122  
   123  func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
   124  	var addresses []listenAddress
   125  	parsed := make(map[string]listenAddress)
   126  	for _, address := range addressesToParse {
   127  		if address == "localhost" {
   128  			if _, exists := parsed["127.0.0.1"]; !exists {
   129  				ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
   130  				parsed[ip.address] = ip
   131  			}
   132  			if _, exists := parsed["::1"]; !exists {
   133  				ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
   134  				parsed[ip.address] = ip
   135  			}
   136  		} else if netutils.ParseIPSloppy(address).To4() != nil {
   137  			parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
   138  		} else if netutils.ParseIPSloppy(address) != nil {
   139  			parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
   140  		} else {
   141  			return nil, fmt.Errorf("%s is not a valid IP", address)
   142  		}
   143  	}
   144  	addresses = make([]listenAddress, len(parsed))
   145  	id := 0
   146  	for _, v := range parsed {
   147  		addresses[id] = v
   148  		id++
   149  	}
   150  	// Sort addresses before returning to get a stable order
   151  	sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })
   152  
   153  	return addresses, nil
   154  }
   155  
   156  // New creates a new PortForwarder with localhost listen addresses.
   157  func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
   158  	return NewOnAddresses(dialer, []string{"localhost"}, ports, stopChan, readyChan, out, errOut)
   159  }
   160  
   161  // NewOnAddresses creates a new PortForwarder with custom listen addresses.
   162  func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
   163  	if len(addresses) == 0 {
   164  		return nil, errors.New("you must specify at least 1 address")
   165  	}
   166  	parsedAddresses, err := parseAddresses(addresses)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  	if len(ports) == 0 {
   171  		return nil, errors.New("you must specify at least 1 port")
   172  	}
   173  	parsedPorts, err := parsePorts(ports)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  	return &PortForwarder{
   178  		dialer:    dialer,
   179  		addresses: parsedAddresses,
   180  		ports:     parsedPorts,
   181  		stopChan:  stopChan,
   182  		Ready:     readyChan,
   183  		out:       out,
   184  		errOut:    errOut,
   185  	}, nil
   186  }
   187  
   188  // ForwardPorts formats and executes a port forwarding request. The connection will remain
   189  // open until stopChan is closed.
   190  func (pf *PortForwarder) ForwardPorts() error {
   191  	defer pf.Close()
   192  
   193  	var err error
   194  	var protocol string
   195  	pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
   196  	if err != nil {
   197  		return fmt.Errorf("error upgrading connection: %s", err)
   198  	}
   199  	defer pf.streamConn.Close()
   200  	if protocol != PortForwardProtocolV1Name {
   201  		return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
   202  	}
   203  
   204  	return pf.forward()
   205  }
   206  
   207  // forward dials the remote host specific in req, upgrades the request, starts
   208  // listeners for each port specified in ports, and forwards local connections
   209  // to the remote host via streams.
   210  func (pf *PortForwarder) forward() error {
   211  	var err error
   212  
   213  	listenSuccess := false
   214  	for i := range pf.ports {
   215  		port := &pf.ports[i]
   216  		err = pf.listenOnPort(port)
   217  		switch {
   218  		case err == nil:
   219  			listenSuccess = true
   220  		default:
   221  			if pf.errOut != nil {
   222  				fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
   223  			}
   224  		}
   225  	}
   226  
   227  	if !listenSuccess {
   228  		return fmt.Errorf("unable to listen on any of the requested ports: %v", pf.ports)
   229  	}
   230  
   231  	if pf.Ready != nil {
   232  		close(pf.Ready)
   233  	}
   234  
   235  	// wait for interrupt or conn closure
   236  	select {
   237  	case <-pf.stopChan:
   238  	case <-pf.streamConn.CloseChan():
   239  		return ErrLostConnectionToPod
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  // listenOnPort delegates listener creation and waits for connections on requested bind addresses.
   246  // An error is raised based on address groups (default and localhost) and their failure modes
   247  func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
   248  	var errors []error
   249  	failCounters := make(map[string]int, 2)
   250  	successCounters := make(map[string]int, 2)
   251  	for _, addr := range pf.addresses {
   252  		err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
   253  		if err != nil {
   254  			errors = append(errors, err)
   255  			failCounters[addr.failureMode]++
   256  		} else {
   257  			successCounters[addr.failureMode]++
   258  		}
   259  	}
   260  	if successCounters["all"] == 0 && failCounters["all"] > 0 {
   261  		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
   262  	}
   263  	if failCounters["any"] > 0 {
   264  		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
   265  	}
   266  	return nil
   267  }
   268  
   269  // listenOnPortAndAddress delegates listener creation and waits for new connections
   270  // in the background f
   271  func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
   272  	listener, err := pf.getListener(protocol, address, port)
   273  	if err != nil {
   274  		return err
   275  	}
   276  	pf.listeners = append(pf.listeners, listener)
   277  	go pf.waitForConnection(listener, *port)
   278  	return nil
   279  }
   280  
   281  // getListener creates a listener on the interface targeted by the given hostname on the given port with
   282  // the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
   283  func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
   284  	listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
   285  	if err != nil {
   286  		return nil, fmt.Errorf("unable to create listener: Error %s", err)
   287  	}
   288  	listenerAddress := listener.Addr().String()
   289  	host, localPort, _ := net.SplitHostPort(listenerAddress)
   290  	localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
   291  
   292  	if err != nil {
   293  		fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote)
   294  		return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
   295  	}
   296  	port.Local = uint16(localPortUInt)
   297  	if pf.out != nil {
   298  		fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote)
   299  	}
   300  
   301  	return listener, nil
   302  }
   303  
   304  // waitForConnection waits for new connections to listener and handles them in
   305  // the background.
   306  func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
   307  	for {
   308  		select {
   309  		case <-pf.streamConn.CloseChan():
   310  			return
   311  		default:
   312  			conn, err := listener.Accept()
   313  			if err != nil {
   314  				// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
   315  				if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
   316  					runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
   317  				}
   318  				return
   319  			}
   320  			go pf.handleConnection(conn, port)
   321  		}
   322  	}
   323  }
   324  
   325  func (pf *PortForwarder) nextRequestID() int {
   326  	pf.requestIDLock.Lock()
   327  	defer pf.requestIDLock.Unlock()
   328  	id := pf.requestID
   329  	pf.requestID++
   330  	return id
   331  }
   332  
   333  // handleConnection copies data between the local connection and the stream to
   334  // the remote server.
   335  func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
   336  	defer conn.Close()
   337  
   338  	if pf.out != nil {
   339  		fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local)
   340  	}
   341  
   342  	requestID := pf.nextRequestID()
   343  
   344  	// create error stream
   345  	headers := http.Header{}
   346  	headers.Set(v1.StreamType, v1.StreamTypeError)
   347  	headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
   348  	headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
   349  	errorStream, err := pf.streamConn.CreateStream(headers)
   350  	if err != nil {
   351  		runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
   352  		return
   353  	}
   354  	// we're not writing to this stream
   355  	errorStream.Close()
   356  	defer pf.streamConn.RemoveStreams(errorStream)
   357  
   358  	errorChan := make(chan error)
   359  	go func() {
   360  		message, err := io.ReadAll(errorStream)
   361  		switch {
   362  		case err != nil:
   363  			errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
   364  		case len(message) > 0:
   365  			errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
   366  		}
   367  		close(errorChan)
   368  	}()
   369  
   370  	// create data stream
   371  	headers.Set(v1.StreamType, v1.StreamTypeData)
   372  	dataStream, err := pf.streamConn.CreateStream(headers)
   373  	if err != nil {
   374  		runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
   375  		return
   376  	}
   377  	defer pf.streamConn.RemoveStreams(dataStream)
   378  
   379  	localError := make(chan struct{})
   380  	remoteDone := make(chan struct{})
   381  
   382  	go func() {
   383  		// Copy from the remote side to the local port.
   384  		if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
   385  			runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
   386  		}
   387  
   388  		// inform the select below that the remote copy is done
   389  		close(remoteDone)
   390  	}()
   391  
   392  	go func() {
   393  		// inform server we're not sending any more data after copy unblocks
   394  		defer dataStream.Close()
   395  
   396  		// Copy from the local port to the remote side.
   397  		if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
   398  			runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
   399  			// break out of the select below without waiting for the other copy to finish
   400  			close(localError)
   401  		}
   402  	}()
   403  
   404  	// wait for either a local->remote error or for copying from remote->local to finish
   405  	select {
   406  	case <-remoteDone:
   407  	case <-localError:
   408  	}
   409  
   410  	// always expect something on errorChan (it may be nil)
   411  	err = <-errorChan
   412  	if err != nil {
   413  		runtime.HandleError(err)
   414  		pf.streamConn.Close()
   415  	}
   416  }
   417  
   418  // Close stops all listeners of PortForwarder.
   419  func (pf *PortForwarder) Close() {
   420  	// stop all listeners
   421  	for _, l := range pf.listeners {
   422  		if err := l.Close(); err != nil {
   423  			runtime.HandleError(fmt.Errorf("error closing listener: %v", err))
   424  		}
   425  	}
   426  }
   427  
   428  // GetPorts will return the ports that were forwarded; this can be used to
   429  // retrieve the locally-bound port in cases where the input was port 0. This
   430  // function will signal an error if the Ready channel is nil or if the
   431  // listeners are not ready yet; this function will succeed after the Ready
   432  // channel has been closed.
   433  func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
   434  	if pf.Ready == nil {
   435  		return nil, fmt.Errorf("no Ready channel provided")
   436  	}
   437  	select {
   438  	case <-pf.Ready:
   439  		return pf.ports, nil
   440  	default:
   441  		return nil, fmt.Errorf("listeners not ready")
   442  	}
   443  }
   444  

View as plain text