1
16
17 package proxy
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "io"
24 "net"
25 "net/http"
26 "sync"
27 "time"
28
29 v1 "k8s.io/api/core/v1"
30 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
31 "k8s.io/apimachinery/pkg/runtime/schema"
32 "k8s.io/apimachinery/pkg/util/httpstream"
33 "k8s.io/client-go/kubernetes"
34 "k8s.io/client-go/kubernetes/scheme"
35 "k8s.io/client-go/rest"
36 "k8s.io/client-go/tools/portforward"
37 "k8s.io/client-go/transport/spdy"
38 "k8s.io/klog/v2"
39 )
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56 const maxConcurrentConnections = 10
57
58
59
60
61
62
63
64
65 const connectionPollInterval = 100 * time.Millisecond
66
67
68
69
70
71
72
73
74
75
76
77 func Listen(ctx context.Context, clientset kubernetes.Interface, restConfig *rest.Config, addr Addr) (net.Listener, error) {
78
79
80
81
82
83 restClient := clientset.CoreV1().RESTClient()
84 if restConfig.GroupVersion == nil {
85 restConfig.GroupVersion = &schema.GroupVersion{}
86 }
87 if restConfig.NegotiatedSerializer == nil {
88 restConfig.NegotiatedSerializer = scheme.Codecs
89 }
90
91
92
93 req := restClient.Post().
94 Resource("pods").
95 Namespace(addr.Namespace).
96 Name(addr.PodName).
97 SubResource("portforward")
98 transport, upgrader, err := spdy.RoundTripperFor(restConfig)
99 if err != nil {
100 return nil, fmt.Errorf("create round tripper: %w", err)
101 }
102 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL())
103
104 prefix := fmt.Sprintf("port forwarding for %s", addr)
105 ctx, cancel := context.WithCancel(ctx)
106 l := &listener{
107 ctx: ctx,
108 cancel: cancel,
109 addr: addr,
110 }
111
112 var connectionsCreated int
113
114 runForwarding := func() {
115 klog.V(2).Infof("%s: starting connection polling", prefix)
116 defer klog.V(2).Infof("%s: connection polling ended", prefix)
117
118 tryConnect := time.NewTicker(connectionPollInterval)
119 defer tryConnect.Stop()
120 for {
121 select {
122 case <-ctx.Done():
123 return
124 case <-tryConnect.C:
125 func() {
126 l.mutex.Lock()
127 defer l.mutex.Unlock()
128
129 for i, c := range l.connections {
130 if c == nil {
131 klog.V(5).Infof("%s: trying to create a new connection #%d", prefix, connectionsCreated)
132 stream, err := dial(ctx, fmt.Sprintf("%s #%d", prefix, connectionsCreated), dialer, addr.Port)
133 if err != nil {
134 klog.Errorf("%s: no connection: %v", prefix, err)
135 return
136 }
137
138 klog.V(5).Infof("%s: created a new connection #%d", prefix, connectionsCreated)
139 c := &connection{
140 l: l,
141 stream: stream,
142 addr: addr,
143 counter: connectionsCreated,
144 }
145 l.connections[i] = c
146 connectionsCreated++
147 return
148 }
149 }
150 }()
151 }
152 }
153 }
154
155
156 go func() {
157 for {
158 running := false
159 pod, err := clientset.CoreV1().Pods(addr.Namespace).Get(ctx, addr.PodName, metav1.GetOptions{})
160 if err != nil {
161 klog.V(5).Infof("checking for container %q in pod %s/%s: %v", addr.ContainerName, addr.Namespace, addr.PodName, err)
162 }
163 for i, status := range pod.Status.ContainerStatuses {
164 if pod.Spec.Containers[i].Name == addr.ContainerName &&
165 status.State.Running != nil {
166 running = true
167 break
168 }
169 }
170
171 if running {
172 klog.V(2).Infof("container %q in pod %s/%s is running", addr.ContainerName, addr.Namespace, addr.PodName)
173 runForwarding()
174 }
175
176 select {
177 case <-ctx.Done():
178 return
179
180
181
182 case <-time.After(1 * time.Second):
183 }
184 }
185 }()
186
187 return l, nil
188 }
189
190
191
192 type Addr struct {
193 Namespace, PodName, ContainerName string
194 Port int
195 }
196
197 var _ net.Addr = Addr{}
198
199 func (a Addr) Network() string {
200 return "port-forwarding"
201 }
202
203 func (a Addr) String() string {
204 return fmt.Sprintf("%s/%s:%d", a.Namespace, a.PodName, a.Port)
205 }
206
207 type stream struct {
208 httpstream.Stream
209 streamConn httpstream.Connection
210 }
211
212 func dial(ctx context.Context, prefix string, dialer httpstream.Dialer, port int) (s *stream, finalErr error) {
213 streamConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name)
214 if err != nil {
215 return nil, fmt.Errorf("dialer failed: %w", err)
216 }
217 requestID := "1"
218 defer func() {
219 if finalErr != nil {
220 streamConn.Close()
221 }
222 }()
223
224
225 headers := http.Header{}
226 headers.Set(v1.StreamType, v1.StreamTypeError)
227 headers.Set(v1.PortHeader, fmt.Sprintf("%d", port))
228 headers.Set(v1.PortForwardRequestIDHeader, requestID)
229
230
231
232 errorStream, err := streamConn.CreateStream(headers)
233 if err != nil {
234 return nil, fmt.Errorf("error creating error stream: %w", err)
235 }
236 errorStream.Close()
237 go func() {
238 message, err := io.ReadAll(errorStream)
239 switch {
240 case err != nil:
241 klog.Errorf("%s: error reading from error stream: %v", prefix, err)
242 case len(message) > 0:
243 klog.Errorf("%s: an error occurred connecting to the remote port: %v", prefix, string(message))
244 }
245 }()
246
247
248 headers.Set(v1.StreamType, v1.StreamTypeData)
249 dataStream, err := streamConn.CreateStream(headers)
250 if err != nil {
251 return nil, fmt.Errorf("error creating data stream: %w", err)
252 }
253
254 return &stream{
255 Stream: dataStream,
256 streamConn: streamConn,
257 }, nil
258 }
259
260 func (s *stream) Close() {
261 s.Stream.Close()
262 s.streamConn.Close()
263 }
264
265 type listener struct {
266 addr Addr
267 ctx context.Context
268 cancel func()
269
270 mutex sync.Mutex
271 connections [maxConcurrentConnections]*connection
272 }
273
274 var _ net.Listener = &listener{}
275
276 func (l *listener) Close() error {
277 klog.V(5).Infof("forward listener for %s: closing", l.addr)
278 l.cancel()
279
280 l.mutex.Lock()
281 defer l.mutex.Unlock()
282 for _, c := range l.connections {
283 if c != nil {
284 c.stream.Close()
285 }
286 }
287
288 return nil
289 }
290
291 func (l *listener) Accept() (net.Conn, error) {
292 tryAccept := time.NewTicker(connectionPollInterval)
293 defer tryAccept.Stop()
294 for {
295 select {
296 case <-l.ctx.Done():
297 return nil, errors.New("listening was stopped")
298 case <-tryAccept.C:
299 conn := func() net.Conn {
300 l.mutex.Lock()
301 defer l.mutex.Unlock()
302
303 for _, c := range l.connections {
304 if c != nil && !c.accepted {
305 klog.V(5).Infof("forward listener for %s: got a new connection #%d", l.addr, c.counter)
306 c.accepted = true
307 return c
308 }
309 }
310 return nil
311 }()
312 if conn != nil {
313 return conn, nil
314 }
315 }
316 }
317 }
318
319 type connection struct {
320 l *listener
321 stream *stream
322 addr Addr
323 counter int
324 mutex sync.Mutex
325 accepted, closed bool
326 }
327
328 var _ net.Conn = &connection{}
329
330 func (c *connection) LocalAddr() net.Addr {
331 return c.addr
332 }
333
334 func (c *connection) RemoteAddr() net.Addr {
335 return c.addr
336 }
337
338 func (c *connection) SetDeadline(t time.Time) error {
339 return nil
340 }
341
342 func (c *connection) SetReadDeadline(t time.Time) error {
343 return nil
344 }
345
346 func (c *connection) SetWriteDeadline(t time.Time) error {
347 return nil
348 }
349
350 func (c *connection) Read(b []byte) (int, error) {
351 n, err := c.stream.Read(b)
352 if errors.Is(err, io.EOF) {
353 klog.V(5).Infof("forward connection #%d for %s: remote side closed the stream", c.counter, c.addr)
354 }
355 return n, err
356 }
357
358 func (c *connection) Write(b []byte) (int, error) {
359 n, err := c.stream.Write(b)
360 if errors.Is(err, io.EOF) {
361 klog.V(5).Infof("forward connection #%d for %s: remote side closed the stream", c.counter, c.addr)
362 }
363 return n, err
364 }
365
366 func (c *connection) Close() error {
367 c.mutex.Lock()
368 defer c.mutex.Unlock()
369 if !c.closed {
370
371 klog.V(5).Infof("forward connection #%d for %s: closing our side", c.counter, c.addr)
372
373 c.l.mutex.Lock()
374 defer c.l.mutex.Unlock()
375 for i, c2 := range c.l.connections {
376 if c2 == c {
377 c.l.connections[i] = nil
378 break
379 }
380 }
381 }
382 c.stream.Close()
383
384 return nil
385 }
386
387 func (l *listener) Addr() net.Addr {
388 return l.addr
389 }
390
View as plain text