1 /* 2 Copyright 2024 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package net 18 19 import ( 20 "context" 21 "fmt" 22 "net" 23 "sync" 24 ) 25 26 // connErrPair pairs conn and error which is returned by accept on sub-listeners. 27 type connErrPair struct { 28 conn net.Conn 29 err error 30 } 31 32 // multiListener implements net.Listener 33 type multiListener struct { 34 listeners []net.Listener 35 wg sync.WaitGroup 36 37 // connCh passes accepted connections, from child listeners to parent. 38 connCh chan connErrPair 39 // stopCh communicates from parent to child listeners. 40 stopCh chan struct{} 41 } 42 43 // compile time check to ensure *multiListener implements net.Listener 44 var _ net.Listener = &multiListener{} 45 46 // MultiListen returns net.Listener which can listen on and accept connections for 47 // the given network on multiple addresses. Internally it uses stdlib to create 48 // sub-listener and multiplexes connection requests using go-routines. 49 // The network must be "tcp", "tcp4" or "tcp6". 50 // It follows the semantics of net.Listen that primarily means: 51 // 1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen 52 // listens on all available unicast and anycast IP addresses of the local system. 53 // 2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively. 54 // 3. The host can accept names (e.g, localhost) and it will create a listener for at 55 // most one of the host's IP. 56 func MultiListen(ctx context.Context, network string, addrs ...string) (net.Listener, error) { 57 var lc net.ListenConfig 58 return multiListen( 59 ctx, 60 network, 61 addrs, 62 func(ctx context.Context, network, address string) (net.Listener, error) { 63 return lc.Listen(ctx, network, address) 64 }) 65 } 66 67 // multiListen implements MultiListen by consuming stdlib functions as dependency allowing 68 // mocking for unit-testing. 69 func multiListen( 70 ctx context.Context, 71 network string, 72 addrs []string, 73 listenFunc func(ctx context.Context, network, address string) (net.Listener, error), 74 ) (net.Listener, error) { 75 if !(network == "tcp" || network == "tcp4" || network == "tcp6") { 76 return nil, fmt.Errorf("network %q not supported", network) 77 } 78 if len(addrs) == 0 { 79 return nil, fmt.Errorf("no address provided to listen on") 80 } 81 82 ml := &multiListener{ 83 connCh: make(chan connErrPair), 84 stopCh: make(chan struct{}), 85 } 86 for _, addr := range addrs { 87 l, err := listenFunc(ctx, network, addr) 88 if err != nil { 89 // close all the sub-listeners and exit 90 _ = ml.Close() 91 return nil, err 92 } 93 ml.listeners = append(ml.listeners, l) 94 } 95 96 for _, l := range ml.listeners { 97 ml.wg.Add(1) 98 go func(l net.Listener) { 99 defer ml.wg.Done() 100 for { 101 // Accept() is blocking, unless ml.Close() is called, in which 102 // case it will return immediately with an error. 103 conn, err := l.Accept() 104 // This assumes that ANY error from Accept() will terminate the 105 // sub-listener. We could maybe be more precise, but it 106 // doesn't seem necessary. 107 terminate := err != nil 108 109 select { 110 case ml.connCh <- connErrPair{conn: conn, err: err}: 111 case <-ml.stopCh: 112 // In case we accepted a connection AND were stopped, and 113 // this select-case was chosen, just throw away the 114 // connection. This avoids potentially blocking on connCh 115 // or leaking a connection. 116 if conn != nil { 117 _ = conn.Close() 118 } 119 terminate = true 120 } 121 // Make sure we don't loop on Accept() returning an error and 122 // the select choosing the channel case. 123 if terminate { 124 return 125 } 126 } 127 }(l) 128 } 129 return ml, nil 130 } 131 132 // Accept implements net.Listener. It waits for and returns a connection from 133 // any of the sub-listener. 134 func (ml *multiListener) Accept() (net.Conn, error) { 135 // wait for any sub-listener to enqueue an accepted connection 136 connErr, ok := <-ml.connCh 137 if !ok { 138 // The channel will be closed only when Close() is called on the 139 // multiListener. Closing of this channel implies that all 140 // sub-listeners are also closed, which causes a "use of closed 141 // network connection" error on their Accept() calls. We return the 142 // same error for multiListener.Accept() if multiListener.Close() 143 // has already been called. 144 return nil, fmt.Errorf("use of closed network connection") 145 } 146 return connErr.conn, connErr.err 147 } 148 149 // Close implements net.Listener. It will close all sub-listeners and wait for 150 // the go-routines to exit. 151 func (ml *multiListener) Close() error { 152 // Make sure this can be called repeatedly without explosions. 153 select { 154 case <-ml.stopCh: 155 return fmt.Errorf("use of closed network connection") 156 default: 157 } 158 159 // Tell all sub-listeners to stop. 160 close(ml.stopCh) 161 162 // Closing the listeners causes Accept() to immediately return an error in 163 // the sub-listener go-routines. 164 for _, l := range ml.listeners { 165 _ = l.Close() 166 } 167 168 // Wait for all the sub-listener go-routines to exit. 169 ml.wg.Wait() 170 close(ml.connCh) 171 172 // Drain any already-queued connections. 173 for connErr := range ml.connCh { 174 if connErr.conn != nil { 175 _ = connErr.conn.Close() 176 } 177 } 178 return nil 179 } 180 181 // Addr is an implementation of the net.Listener interface. It always returns 182 // the address of the first listener. Callers should use conn.LocalAddr() to 183 // obtain the actual local address of the sub-listener. 184 func (ml *multiListener) Addr() net.Addr { 185 return ml.listeners[0].Addr() 186 } 187 188 // Addrs is like Addr, but returns the address for all registered listeners. 189 func (ml *multiListener) Addrs() []net.Addr { 190 var ret []net.Addr 191 for _, l := range ml.listeners { 192 ret = append(ret, l.Addr()) 193 } 194 return ret 195 } 196