1 package ktest
2
3 import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/url"
8 "os"
9 "os/signal"
10 "strings"
11 "testing"
12
13 "golang.org/x/sys/unix"
14 v1 "k8s.io/api/core/v1"
15 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
16 "k8s.io/apimachinery/pkg/labels"
17 "k8s.io/apimachinery/pkg/types"
18 "k8s.io/cli-runtime/pkg/genericclioptions"
19 "k8s.io/client-go/rest"
20 "k8s.io/client-go/tools/portforward"
21 "k8s.io/client-go/transport/spdy"
22 sigclient "sigs.k8s.io/controller-runtime/pkg/client"
23
24 corev1 "k8s.io/api/core/v1"
25
26 "edge-infra.dev/pkg/lib/fog"
27 "edge-infra.dev/test/f2"
28 )
29
30
31 type PortForward struct {
32
33
34
35 Namespace string
36 fw *portforward.PortForwarder
37 }
38
39 type portForwardRequest struct {
40
41 RestConfig *rest.Config
42
43 Pod v1.Pod
44
45 LocalPort int
46
47 PodPort int
48
49
50 Streams genericclioptions.IOStreams
51
52 StopCh <-chan struct{}
53
54 ReadyCh chan struct{}
55 }
56
57 func (pf *PortForward) portForwardAPod(req portForwardRequest) error {
58 path := fmt.Sprintf("/api/v1/namespaces/%s/pods/%s/portforward",
59 req.Pod.Namespace, req.Pod.Name)
60 hostIP := strings.TrimRight(strings.TrimPrefix(req.RestConfig.Host, "https://"), "/")
61 transport, upgrader, err := spdy.RoundTripperFor(req.RestConfig)
62 if err != nil {
63 return err
64 }
65
66 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, &url.URL{Scheme: "https", Path: path, Host: hostIP})
67 fw, err := portforward.New(dialer, []string{fmt.Sprintf("%d:%d", req.LocalPort, req.PodPort)}, req.StopCh, req.ReadyCh, req.Streams.Out, req.Streams.ErrOut)
68 if err != nil {
69 return err
70 }
71 pf.fw = fw
72 return fw.ForwardPorts()
73 }
74
75
76 func (pf *PortForward) ForwardPodCtx(ctx f2.Context, podName string, portnr int) error {
77 k, err := FromContext(ctx)
78 if err != nil {
79 return err
80 }
81 if pf.Namespace == "" {
82 pf.Namespace = k.Namespace
83 }
84 pf.podPortForward(ctx, k.Env.Config, podName, portnr)
85 return nil
86 }
87
88
89 func (pf *PortForward) ForwardPod(podName string, portnr int) f2.StepFn {
90 return func(ctx f2.Context, t *testing.T) f2.Context {
91 k := FromContextT(ctx, t)
92 if pf.Namespace == "" {
93 pf.Namespace = k.Namespace
94 }
95 pf.podPortForward(ctx, k.Env.Config, podName, portnr)
96 return ctx
97 }
98 }
99
100
101
102 func (pf *PortForward) Forward(serviceName string, portnr int) f2.StepFn {
103 return func(ctx f2.Context, t *testing.T) f2.Context {
104 k := FromContextT(ctx, t)
105 if pf.Namespace == "" {
106 pf.Namespace = k.Namespace
107 }
108 podName, err := lookupPod(ctx, pf.Namespace, serviceName, k)
109 if err != nil {
110 t.Fatal(err)
111 }
112 pf.podPortForward(ctx, k.Env.Config, podName, portnr)
113 return ctx
114 }
115 }
116
117 func (pf *PortForward) podPortForward(ctx context.Context, config *rest.Config, podName string, portnr int) {
118
119
120 stopCh := make(chan struct{}, 1)
121
122 readyCh := make(chan struct{})
123
124
125
126 stream := genericclioptions.IOStreams{
127 In: os.Stdin,
128 Out: os.Stdout,
129 ErrOut: os.Stderr,
130 }
131
132 sigs := make(chan os.Signal, 1)
133 signal.Notify(sigs, unix.SIGINT, unix.SIGTERM)
134 go func() {
135 <-sigs
136 close(stopCh)
137 }()
138 log := fog.FromContext(ctx)
139 go func() {
140 log.Info("FORWARDING", "namespace", pf.Namespace, "name", podName)
141 err := pf.portForwardAPod(portForwardRequest{
142 RestConfig: config,
143 Pod: v1.Pod{
144 ObjectMeta: metav1.ObjectMeta{
145 Name: podName,
146 Namespace: pf.Namespace,
147 },
148 },
149 LocalPort: 0,
150 PodPort: portnr,
151 Streams: stream,
152 StopCh: stopCh,
153 ReadyCh: readyCh,
154 })
155 if err != nil {
156 log.Error(err, "error forwarding pod", "namespace", pf.Namespace, "name", podName)
157 panic(err)
158 }
159 }()
160
161
162 <-readyCh
163
164 log.Info("FORWARDING: Port forwarding is ready to get traffic.")
165 }
166
167
168 func (pf PortForward) Retrieve(t *testing.T) string {
169 ports, err := pf.fw.GetPorts()
170 if err != nil {
171 t.Fatal(err)
172 }
173 if len(ports) == 0 {
174 t.Fatalf("No ports returned from GetPorts")
175 }
176 assignedLocalPort := ports[0].Local
177 return fmt.Sprintf("localhost:%d", assignedLocalPort)
178 }
179
180
181 func lookupPod(ctx f2.Context, namespace string, serviceName string, k *K8s) (string, error) {
182 service := &corev1.Service{}
183 err := k.Client.Get(ctx, types.NamespacedName{
184 Namespace: namespace,
185 Name: serviceName,
186 }, service)
187 if err != nil {
188 return "", err
189 }
190
191
192 podList := &corev1.PodList{}
193 err = k.Client.List(ctx, podList, &sigclient.ListOptions{
194 LabelSelector: labels.SelectorFromSet(service.Spec.Selector),
195 Namespace: namespace,
196 })
197 if err != nil {
198 return "", err
199 }
200
201 for _, pod := range podList.Items {
202 return pod.Name, nil
203 }
204 return "", fmt.Errorf("No pod found in %s", serviceName)
205 }
206
207
208
209
View as plain text