...

Source file src/k8s.io/utils/net/multi_listen.go

Documentation: k8s.io/utils/net

     1  /*
     2  Copyright 2024 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package net
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"net"
    23  	"sync"
    24  )
    25  
    26  // connErrPair pairs conn and error which is returned by accept on sub-listeners.
    27  type connErrPair struct {
    28  	conn net.Conn
    29  	err  error
    30  }
    31  
    32  // multiListener implements net.Listener
    33  type multiListener struct {
    34  	listeners []net.Listener
    35  	wg        sync.WaitGroup
    36  
    37  	// connCh passes accepted connections, from child listeners to parent.
    38  	connCh chan connErrPair
    39  	// stopCh communicates from parent to child listeners.
    40  	stopCh chan struct{}
    41  }
    42  
    43  // compile time check to ensure *multiListener implements net.Listener
    44  var _ net.Listener = &multiListener{}
    45  
    46  // MultiListen returns net.Listener which can listen on and accept connections for
    47  // the given network on multiple addresses. Internally it uses stdlib to create
    48  // sub-listener and multiplexes connection requests using go-routines.
    49  // The network must be "tcp", "tcp4" or "tcp6".
    50  // It follows the semantics of net.Listen that primarily means:
    51  //  1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen
    52  //     listens on all available unicast and anycast IP addresses of the local system.
    53  //  2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively.
    54  //  3. The host can accept names (e.g, localhost) and it will create a listener for at
    55  //     most one of the host's IP.
    56  func MultiListen(ctx context.Context, network string, addrs ...string) (net.Listener, error) {
    57  	var lc net.ListenConfig
    58  	return multiListen(
    59  		ctx,
    60  		network,
    61  		addrs,
    62  		func(ctx context.Context, network, address string) (net.Listener, error) {
    63  			return lc.Listen(ctx, network, address)
    64  		})
    65  }
    66  
    67  // multiListen implements MultiListen by consuming stdlib functions as dependency allowing
    68  // mocking for unit-testing.
    69  func multiListen(
    70  	ctx context.Context,
    71  	network string,
    72  	addrs []string,
    73  	listenFunc func(ctx context.Context, network, address string) (net.Listener, error),
    74  ) (net.Listener, error) {
    75  	if !(network == "tcp" || network == "tcp4" || network == "tcp6") {
    76  		return nil, fmt.Errorf("network %q not supported", network)
    77  	}
    78  	if len(addrs) == 0 {
    79  		return nil, fmt.Errorf("no address provided to listen on")
    80  	}
    81  
    82  	ml := &multiListener{
    83  		connCh: make(chan connErrPair),
    84  		stopCh: make(chan struct{}),
    85  	}
    86  	for _, addr := range addrs {
    87  		l, err := listenFunc(ctx, network, addr)
    88  		if err != nil {
    89  			// close all the sub-listeners and exit
    90  			_ = ml.Close()
    91  			return nil, err
    92  		}
    93  		ml.listeners = append(ml.listeners, l)
    94  	}
    95  
    96  	for _, l := range ml.listeners {
    97  		ml.wg.Add(1)
    98  		go func(l net.Listener) {
    99  			defer ml.wg.Done()
   100  			for {
   101  				// Accept() is blocking, unless ml.Close() is called, in which
   102  				// case it will return immediately with an error.
   103  				conn, err := l.Accept()
   104  				// This assumes that ANY error from Accept() will terminate the
   105  				// sub-listener. We could maybe be more precise, but it
   106  				// doesn't seem necessary.
   107  				terminate := err != nil
   108  
   109  				select {
   110  				case ml.connCh <- connErrPair{conn: conn, err: err}:
   111  				case <-ml.stopCh:
   112  					// In case we accepted a connection AND were stopped, and
   113  					// this select-case was chosen, just throw away the
   114  					// connection.  This avoids potentially blocking on connCh
   115  					// or leaking a connection.
   116  					if conn != nil {
   117  						_ = conn.Close()
   118  					}
   119  					terminate = true
   120  				}
   121  				// Make sure we don't loop on Accept() returning an error and
   122  				// the select choosing the channel case.
   123  				if terminate {
   124  					return
   125  				}
   126  			}
   127  		}(l)
   128  	}
   129  	return ml, nil
   130  }
   131  
   132  // Accept implements net.Listener. It waits for and returns a connection from
   133  // any of the sub-listener.
   134  func (ml *multiListener) Accept() (net.Conn, error) {
   135  	// wait for any sub-listener to enqueue an accepted connection
   136  	connErr, ok := <-ml.connCh
   137  	if !ok {
   138  		// The channel will be closed only when Close() is called on the
   139  		// multiListener. Closing of this channel implies that all
   140  		// sub-listeners are also closed, which causes a "use of closed
   141  		// network connection" error on their Accept() calls. We return the
   142  		// same error for multiListener.Accept() if multiListener.Close()
   143  		// has already been called.
   144  		return nil, fmt.Errorf("use of closed network connection")
   145  	}
   146  	return connErr.conn, connErr.err
   147  }
   148  
   149  // Close implements net.Listener. It will close all sub-listeners and wait for
   150  // the go-routines to exit.
   151  func (ml *multiListener) Close() error {
   152  	// Make sure this can be called repeatedly without explosions.
   153  	select {
   154  	case <-ml.stopCh:
   155  		return fmt.Errorf("use of closed network connection")
   156  	default:
   157  	}
   158  
   159  	// Tell all sub-listeners to stop.
   160  	close(ml.stopCh)
   161  
   162  	// Closing the listeners causes Accept() to immediately return an error in
   163  	// the sub-listener go-routines.
   164  	for _, l := range ml.listeners {
   165  		_ = l.Close()
   166  	}
   167  
   168  	// Wait for all the sub-listener go-routines to exit.
   169  	ml.wg.Wait()
   170  	close(ml.connCh)
   171  
   172  	// Drain any already-queued connections.
   173  	for connErr := range ml.connCh {
   174  		if connErr.conn != nil {
   175  			_ = connErr.conn.Close()
   176  		}
   177  	}
   178  	return nil
   179  }
   180  
   181  // Addr is an implementation of the net.Listener interface.  It always returns
   182  // the address of the first listener.  Callers should  use conn.LocalAddr() to
   183  // obtain the actual local address of the sub-listener.
   184  func (ml *multiListener) Addr() net.Addr {
   185  	return ml.listeners[0].Addr()
   186  }
   187  
   188  // Addrs is like Addr, but returns the address for all registered listeners.
   189  func (ml *multiListener) Addrs() []net.Addr {
   190  	var ret []net.Addr
   191  	for _, l := range ml.listeners {
   192  		ret = append(ret, l.Addr())
   193  	}
   194  	return ret
   195  }
   196  

View as plain text