1 package novnc
2
3 import (
4 "bytes"
5 "context"
6 "embed"
7 "errors"
8 "fmt"
9 "io/fs"
10 "net/http"
11 "net/url"
12 "os"
13 "testing"
14 "time"
15
16 "github.com/gorilla/websocket"
17 "github.com/stretchr/testify/assert"
18 "gotest.tools/v3/poll"
19 corev1 "k8s.io/api/core/v1"
20
21 "edge-infra.dev/pkg/k8s/testing/kmp"
22 "edge-infra.dev/pkg/sds/k8s/daemonsetdns/daemonsetdnstest"
23 "edge-infra.dev/test/f2"
24 "edge-infra.dev/test/f2/integration"
25 "edge-infra.dev/test/f2/x/ktest"
26 "edge-infra.dev/test/f2/x/ktest/envtest"
27 "edge-infra.dev/test/f2/x/ktest/kustomization"
28 )
29
30
31 var manifests embed.FS
32
33 var f f2.Framework
34
35 var novncManifests []byte
36
37 var vncserverManifests []byte
38
39 func TestMain(m *testing.M) {
40
41 f = f2.New(
42 context.Background(),
43 f2.WithExtensions(
44
45 ktest.New(
46 ktest.WithEnvtestOptions(
47
48 envtest.WithoutCRDs(),
49 ),
50 ),
51 ),
52 ).
53 Setup(func(ctx f2.Context) (f2.Context, error) {
54
55 if !integration.IsL2() {
56 return ctx, fmt.Errorf("%w: requires L2 integration test level", f2.ErrSkip)
57 }
58
59 return ctx, nil
60 }).
61 Setup(func(ctx f2.Context) (f2.Context, error) {
62
63 var err error
64 novncManifests, err = fs.ReadFile(manifests, "testdata/local_manifests.yaml")
65 if err != nil {
66 return ctx, err
67 }
68
69 vncserverManifests, err = fs.ReadFile(manifests, "testdata/kustomization_manifests.yaml")
70 if err != nil {
71 return ctx, err
72 }
73
74 return ctx, nil
75 }).
76 Setup(
77 daemonsetdnstest.LoadManifests,
78 daemonsetdnstest.Install,
79 )
80
81
82
83
84
85 os.Exit(f.Run(m))
86 }
87
88 func TestNovnc(t *testing.T) {
89
90
91
92
93
94
95 var (
96 addr string
97 nodenames []string
98 path = "/ws"
99 portforward = ktest.PortForward{}
100 queryTemplate = "token=%s"
101 sentMessage = []byte{'h', 'e', 'l', 'l', 'o'}
102 )
103
104 novnc := f2.NewFeature("Test NoVNC websocket proxy").
105 Setup("Discover nodenames", func(ctx f2.Context, t *testing.T) f2.Context {
106 k := ktest.FromContextT(ctx, t)
107
108 nodes := corev1.NodeList{}
109
110 err := k.Client.List(ctx, &nodes)
111 assert.NoError(t, err)
112
113 for _, node := range nodes.Items {
114 if node.Name == "edge-control-plane" {
115
116 continue
117 }
118 nodenames = append(nodenames, node.Name)
119 }
120
121 t.Logf("Discovered nodes: %s", nodenames)
122
123 return ctx
124 }).
125 Setup("Create wsserver mock vnc server", func(ctx f2.Context, t *testing.T) f2.Context {
126 k := ktest.FromContextT(ctx, t)
127
128 manifests, err := kustomization.ProcessManifests(ctx.RunID, vncserverManifests, k.Namespace)
129 assert.NoError(t, err)
130
131 for _, manifest := range manifests {
132 err = k.Client.Create(ctx, manifest)
133 assert.NoError(t, err)
134 }
135
136 return ctx
137 }).
138 Setup("Create novnc resources", func(ctx f2.Context, t *testing.T) f2.Context {
139 k := ktest.FromContextT(ctx, t)
140
141 manifests, err := kustomization.ProcessManifests(ctx.RunID, novncManifests, k.Namespace)
142 assert.NoError(t, err)
143
144 for _, manifest := range manifests {
145 if manifest.GetKind() == "Namespace" {
146 continue
147 }
148 err = k.Client.Create(ctx, manifest)
149 assert.NoError(t, err)
150 }
151
152 return ctx
153 }).
154 Setup("Wait for wsserver daemonset", func(ctx f2.Context, t *testing.T) f2.Context {
155 k := ktest.FromContextT(ctx, t)
156
157 manifests, err := kustomization.ProcessManifests(ctx.RunID, vncserverManifests, k.Namespace)
158 assert.NoError(t, err)
159
160 for _, manifest := range manifests {
161 if manifest.GetKind() != "DaemonSet" {
162 continue
163 }
164 k.WaitOn(t, k.Check(manifest, kmp.IsCurrent()))
165 }
166
167 return ctx
168 }).
169 Setup("Wait for novnc deployment", func(ctx f2.Context, t *testing.T) f2.Context {
170 k := ktest.FromContextT(ctx, t)
171
172 manifests, err := kustomization.ProcessManifests(ctx.RunID, novncManifests, k.Namespace)
173 assert.NoError(t, err)
174
175 for _, manifest := range manifests {
176 if manifest.GetKind() != "Deployment" {
177 continue
178 }
179 k.WaitOn(t, k.Check(manifest, kmp.IsCurrent()), poll.WithTimeout(300*time.Second))
180 }
181 return ctx
182 }).
183 Setup("Port forwarding", portforward.Forward("novnc", 80)).
184 Setup("Discover NoVNC address", func(ctx f2.Context, t *testing.T) f2.Context {
185 time.Sleep(1 * time.Second)
186 addr = portforward.Retrieve(t)
187 time.Sleep(4 * time.Second)
188 return ctx
189 }).
190 Test("Web Socket Upgrade", func(ctx f2.Context, t *testing.T) f2.Context {
191 for _, node := range nodenames {
192 query := fmt.Sprintf(queryTemplate, node)
193 c, resp, err := dialConnection(t, addr, path, query)
194 assert.NoError(t, err)
195 defer c.Close()
196
197 if resp.StatusCode != 101 {
198 t.Fatalf("Unexpected response status code %d", resp.StatusCode)
199 }
200
201 err = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
202 assert.NoError(t, err)
203 }
204 return ctx
205 }).
206 Test("Correct Node Connection", func(ctx f2.Context, t *testing.T) f2.Context {
207 for _, node := range nodenames {
208 query := fmt.Sprintf(queryTemplate, node)
209 c, resp, err := dialConnection(t, addr, path, query)
210 assert.NoError(t, err)
211 defer c.Close()
212
213 if resp.StatusCode != 101 {
214 t.Fatalf("Unexpected response status code %d", resp.StatusCode)
215 }
216
217 err = c.WriteMessage(websocket.TextMessage, sentMessage)
218 assert.NoError(t, err)
219
220 mt, message, err := c.ReadMessage()
221 assert.NoError(t, err)
222 if mt != websocket.TextMessage {
223 t.Fatalf("Unexpected message type returned: %v", mt)
224 }
225 t.Logf("Response message: %s", message)
226
227 if !bytes.Equal(message, []byte(node)) {
228 t.Fatalf("Unexpected message returned: %s", message)
229 }
230
231 err = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
232 assert.NoError(t, err)
233 }
234 return ctx
235 }).
236 Test("Missing Node Name", func(ctx f2.Context, t *testing.T) f2.Context {
237 query := fmt.Sprintf(queryTemplate, "non-existent-node")
238 _, resp, err := dialConnection(t, addr, path, query)
239
240
241
242 if resp.StatusCode != http.StatusBadGateway {
243 t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusBadGateway)
244 }
245 if !errors.Is(err, websocket.ErrBadHandshake) {
246 t.Fatalf("Expected an ErrBadHandshake response, got %v", err)
247 }
248 return ctx
249 }).
250 Test("Ilegal Name", func(ctx f2.Context, t *testing.T) f2.Context {
251 query := fmt.Sprintf(queryTemplate, "edge-worker.vncserver.vnc.pod-lookup./")
252 _, resp, err := dialConnection(t, addr, path, query)
253 if resp.StatusCode != 400 {
254 t.Fatalf("Unexpected status code: %d, expected 400", resp.StatusCode)
255 }
256 if !errors.Is(err, websocket.ErrBadHandshake) {
257 t.Fatalf("Expected an ErrBadHandshake response, got %v", err)
258 }
259 return ctx
260 }).
261 Feature()
262
263
264 f.Test(t, novnc)
265 }
266
267 func dialConnection(t *testing.T, addr string, path string, query string) (*websocket.Conn, *http.Response, error) {
268 u := url.URL{Scheme: "ws", Host: addr, Path: path, RawQuery: query}
269 t.Logf("connecting to %s", u.String())
270 c, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
271 return c, resp, err
272 }
273
View as plain text