1 package k8s
2
3 import (
4 "context"
5 "fmt"
6 "io"
7 "net"
8 "net/http"
9 "net/url"
10 "os"
11 "strconv"
12
13 log "github.com/sirupsen/logrus"
14 corev1 "k8s.io/api/core/v1"
15 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
16 "k8s.io/client-go/rest"
17 "k8s.io/client-go/tools/portforward"
18 "k8s.io/client-go/transport/spdy"
19
20
21 _ "k8s.io/client-go/plugin/pkg/client/auth"
22 )
23
24
25 type PortForward struct {
26 method string
27 url *url.URL
28 host string
29 namespace string
30 podName string
31 localPort int
32 remotePort int
33 emitLogs bool
34 stopCh chan struct{}
35 readyCh chan struct{}
36 config *rest.Config
37 }
38
39
40
41
42 func NewContainerMetricsForward(
43 k8sAPI *KubernetesAPI,
44 pod corev1.Pod,
45 container corev1.Container,
46 emitLogs bool,
47 portName string,
48 ) (*PortForward, error) {
49 var port corev1.ContainerPort
50 for _, p := range container.Ports {
51 if p.Name == portName {
52 port = p
53 break
54 }
55 }
56 if port.Name != portName {
57 return nil, fmt.Errorf("no %s port found for container %s/%s", portName, pod.GetName(), container.Name)
58 }
59
60 return NewPodPortForward(k8sAPI, pod.GetNamespace(), pod.GetName(), "localhost", 0, int(port.ContainerPort), emitLogs)
61 }
62
63
64
65
66
67
68
69
70 func NewPortForward(
71 ctx context.Context,
72 k8sAPI *KubernetesAPI,
73 namespace, deployName string,
74 host string, localPort, remotePort int,
75 emitLogs bool,
76 ) (*PortForward, error) {
77 timeoutSeconds := int64(30)
78 podList, err := k8sAPI.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{TimeoutSeconds: &timeoutSeconds})
79 if err != nil {
80 return nil, err
81 }
82
83 podName := ""
84 for _, pod := range podList.Items {
85 if pod.Status.Phase == corev1.PodRunning {
86 grandparent, err := getDeploymentForPod(ctx, k8sAPI, pod)
87 if err != nil {
88 log.Warnf("Failed to get deploy for pod [%s]: %s", pod.Name, err)
89 continue
90 }
91 if grandparent == deployName {
92 podName = pod.Name
93 break
94 }
95 }
96 }
97
98 if podName == "" {
99 return nil, fmt.Errorf("no running pods found for %s", deployName)
100 }
101
102 return NewPodPortForward(k8sAPI, namespace, podName, host, localPort, remotePort, emitLogs)
103 }
104
105 func getDeploymentForPod(ctx context.Context, k8sAPI *KubernetesAPI, pod corev1.Pod) (string, error) {
106 parents := pod.GetOwnerReferences()
107 if len(parents) != 1 {
108 return "", nil
109 }
110 rs, err := k8sAPI.AppsV1().ReplicaSets(pod.Namespace).Get(ctx, parents[0].Name, metav1.GetOptions{})
111 if err != nil {
112 return "", err
113 }
114 grandparents := rs.GetOwnerReferences()
115 if len(grandparents) != 1 {
116 return "", nil
117 }
118 return grandparents[0].Name, nil
119 }
120
121
122
123 func NewPodPortForward(
124 k8sAPI *KubernetesAPI,
125 namespace, podName string,
126 host string, localPort, remotePort int,
127 emitLogs bool,
128 ) (*PortForward, error) {
129
130 restClient := k8sAPI.CoreV1().RESTClient()
131
132
133
134 if fakeRest, ok := restClient.(*rest.RESTClient); ok {
135 if fakeRest == nil {
136 return nil, nil
137 }
138 }
139
140 req := restClient.Post().
141 Resource("pods").
142 Namespace(namespace).
143 Name(podName).
144 SubResource("portforward")
145
146 var err error
147 if localPort == 0 {
148 if host != "localhost" {
149 return nil, fmt.Errorf("local port must be specified when host is not localhost")
150 }
151
152 localPort, err = getEphemeralPort()
153 if err != nil {
154 return nil, err
155 }
156 }
157
158 return &PortForward{
159 method: "POST",
160 url: req.URL(),
161 host: host,
162 namespace: namespace,
163 podName: podName,
164 localPort: localPort,
165 remotePort: remotePort,
166 emitLogs: emitLogs,
167 stopCh: make(chan struct{}, 1),
168 readyCh: make(chan struct{}),
169 config: k8sAPI.Config,
170 }, nil
171 }
172
173
174
175 func (pf *PortForward) run() error {
176 transport, upgrader, err := spdy.RoundTripperFor(pf.config)
177 if err != nil {
178 return err
179 }
180
181 out := io.Discard
182 errOut := io.Discard
183 if pf.emitLogs {
184 out = os.Stdout
185 errOut = os.Stderr
186 }
187
188 ports := []string{fmt.Sprintf("%d:%d", pf.localPort, pf.remotePort)}
189 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, pf.method, pf.url)
190
191 fw, err := portforward.NewOnAddresses(dialer, []string{pf.host}, ports, pf.stopCh, pf.readyCh, out, errOut)
192 if err != nil {
193 return err
194 }
195
196 err = fw.ForwardPorts()
197 if err != nil {
198 err = fmt.Errorf("%w for %s/%s", err, pf.namespace, pf.podName)
199 return err
200 }
201 return nil
202 }
203
204
205
206
207 func (pf *PortForward) Init() error {
208 log.Debugf("Starting port forward to %s %d:%d", pf.url, pf.localPort, pf.remotePort)
209
210 failure := make(chan error, 1)
211
212 go func() {
213 if err := pf.run(); err != nil {
214 failure <- err
215 }
216 }()
217
218
219
220
221 select {
222 case <-pf.readyCh:
223 log.Debug("Port forward initialised")
224 case err := <-failure:
225 log.Debugf("Port forward failed: %v", err)
226 return err
227 }
228
229 return nil
230 }
231
232
233
234 func (pf *PortForward) Stop() {
235 close(pf.stopCh)
236 }
237
238
239
240 func (pf *PortForward) GetStop() <-chan struct{} {
241 return pf.stopCh
242 }
243
244
245 func (pf *PortForward) URLFor(path string) string {
246 strPort := strconv.Itoa(pf.localPort)
247 urlAddress := net.JoinHostPort(pf.host, strPort)
248 return fmt.Sprintf("http://%s%s", urlAddress, path)
249 }
250
251
252 func (pf *PortForward) AddressAndPort() string {
253 strPort := strconv.Itoa(pf.localPort)
254 return net.JoinHostPort(pf.host, strPort)
255 }
256
257
258
259 func getEphemeralPort() (int, error) {
260 ln, err := net.Listen("tcp", "localhost:0")
261 if err != nil {
262 return 0, err
263 }
264 defer ln.Close()
265
266 tcpAddr, ok := ln.Addr().(*net.TCPAddr)
267 if !ok {
268 return 0, fmt.Errorf("invalid listen address: %s", ln.Addr())
269 }
270
271 return tcpAddr.Port, nil
272 }
273
View as plain text