1
16
17 package pod
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "io"
24 "net"
25 "net/http"
26 "regexp"
27 "strconv"
28 "time"
29
30 v1 "k8s.io/api/core/v1"
31 "k8s.io/apimachinery/pkg/runtime/schema"
32 "k8s.io/apimachinery/pkg/util/httpstream"
33 "k8s.io/client-go/kubernetes"
34 "k8s.io/client-go/kubernetes/scheme"
35 "k8s.io/client-go/rest"
36 "k8s.io/client-go/tools/portforward"
37 "k8s.io/client-go/transport/spdy"
38 "k8s.io/klog/v2"
39 )
40
41
42
43 func NewTransport(client kubernetes.Interface, restConfig *rest.Config) *http.Transport {
44 return &http.Transport{
45 DialContext: func(ctx context.Context, _, addr string) (net.Conn, error) {
46 dialer := NewDialer(client, restConfig)
47 a, err := ParseAddr(addr)
48 if err != nil {
49 return nil, err
50 }
51 return dialer.DialContainerPort(ctx, *a)
52 },
53 }
54 }
55
56
57 func NewDialer(client kubernetes.Interface, restConfig *rest.Config) *Dialer {
58 return &Dialer{
59 client: client,
60 restConfig: restConfig,
61 }
62 }
63
64
65 type Dialer struct {
66 client kubernetes.Interface
67 restConfig *rest.Config
68 }
69
70
71 func (d *Dialer) DialContainerPort(ctx context.Context, addr Addr) (conn net.Conn, finalErr error) {
72 restClient := d.client.CoreV1().RESTClient()
73 restConfig := d.restConfig
74 if restConfig.GroupVersion == nil {
75 restConfig.GroupVersion = &schema.GroupVersion{}
76 }
77 if restConfig.NegotiatedSerializer == nil {
78 restConfig.NegotiatedSerializer = scheme.Codecs
79 }
80
81
82
83 req := restClient.Post().
84 Resource("pods").
85 Namespace(addr.Namespace).
86 Name(addr.PodName).
87 SubResource("portforward")
88 transport, upgrader, err := spdy.RoundTripperFor(restConfig)
89 if err != nil {
90 return nil, fmt.Errorf("create round tripper: %w", err)
91 }
92 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL())
93
94 streamConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name)
95 if err != nil {
96 return nil, fmt.Errorf("dialer failed: %w", err)
97 }
98 requestID := "1"
99 defer func() {
100 if finalErr != nil {
101 streamConn.Close()
102 }
103 }()
104
105
106 headers := http.Header{}
107 headers.Set(v1.StreamType, v1.StreamTypeError)
108 headers.Set(v1.PortHeader, fmt.Sprintf("%d", addr.Port))
109 headers.Set(v1.PortForwardRequestIDHeader, requestID)
110
111
112
113 errorStream, err := streamConn.CreateStream(headers)
114 if err != nil {
115 return nil, fmt.Errorf("error creating error stream: %w", err)
116 }
117 errorStream.Close()
118 go func() {
119 message, err := io.ReadAll(errorStream)
120 switch {
121 case err != nil:
122 klog.ErrorS(err, "error reading from error stream")
123 case len(message) > 0:
124 klog.ErrorS(errors.New(string(message)), "an error occurred connecting to the remote port")
125 }
126 }()
127
128
129 headers.Set(v1.StreamType, v1.StreamTypeData)
130 dataStream, err := streamConn.CreateStream(headers)
131 if err != nil {
132 return nil, fmt.Errorf("error creating data stream: %w", err)
133 }
134
135 return &stream{
136 Stream: dataStream,
137 streamConn: streamConn,
138 }, nil
139 }
140
141
142
143
144 type Addr struct {
145 Namespace, PodName string
146 Port int
147 }
148
149 var _ net.Addr = Addr{}
150
151 func (a Addr) Network() string {
152 return "port-forwarding"
153 }
154
155 func (a Addr) String() string {
156 return fmt.Sprintf("%s.%s:%d", a.Namespace, a.PodName, a.Port)
157 }
158
159
160
161 func ParseAddr(addr string) (*Addr, error) {
162 parts := addrRegex.FindStringSubmatch(addr)
163 if parts == nil {
164 return nil, fmt.Errorf("%q: must match the format <namespace>.<pod>:<port number>", addr)
165 }
166 port, _ := strconv.Atoi(parts[3])
167 return &Addr{
168 Namespace: parts[1],
169 PodName: parts[2],
170 Port: port,
171 }, nil
172 }
173
174 var addrRegex = regexp.MustCompile(`^([^\.]+)\.([^:]+):(\d+)$`)
175
176 type stream struct {
177 addr Addr
178 httpstream.Stream
179 streamConn httpstream.Connection
180 }
181
182 var _ net.Conn = &stream{}
183
184 func (s *stream) Close() error {
185 s.Stream.Close()
186 s.streamConn.Close()
187 return nil
188 }
189
190 func (s *stream) LocalAddr() net.Addr {
191 return LocalAddr{}
192 }
193
194 func (s *stream) RemoteAddr() net.Addr {
195 return s.addr
196 }
197
198 func (s *stream) SetDeadline(t time.Time) error {
199 return nil
200 }
201
202 func (s *stream) SetReadDeadline(t time.Time) error {
203 return nil
204 }
205
206 func (s *stream) SetWriteDeadline(t time.Time) error {
207 return nil
208 }
209
210 type LocalAddr struct{}
211
212 var _ net.Addr = LocalAddr{}
213
214 func (l LocalAddr) Network() string { return "port-forwarding" }
215 func (l LocalAddr) String() string { return "apiserver" }
216
View as plain text