package novnc import ( "bytes" "context" "embed" "errors" "fmt" "io/fs" "net/http" "net/url" "os" "testing" "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "gotest.tools/v3/poll" corev1 "k8s.io/api/core/v1" "edge-infra.dev/pkg/k8s/testing/kmp" "edge-infra.dev/pkg/sds/k8s/daemonsetdns/daemonsetdnstest" "edge-infra.dev/test/f2" "edge-infra.dev/test/f2/integration" "edge-infra.dev/test/f2/x/ktest" "edge-infra.dev/test/f2/x/ktest/envtest" "edge-infra.dev/test/f2/x/ktest/kustomization" ) //go:embed testdata var manifests embed.FS var f f2.Framework var novncManifests []byte var vncserverManifests []byte func TestMain(m *testing.M) { // Set up test framework in TestMain f = f2.New( context.Background(), f2.WithExtensions( // Include ktest extension for access to a k8s cluster ktest.New( ktest.WithEnvtestOptions( // Do not install any CRD into cluster envtest.WithoutCRDs(), ), ), ), ). Setup(func(ctx f2.Context) (f2.Context, error) { // Test execution should end here unless -integration-level=2 is passed to test if !integration.IsL2() { return ctx, fmt.Errorf("%w: requires L2 integration test level", f2.ErrSkip) } return ctx, nil }). Setup(func(ctx f2.Context) (f2.Context, error) { // Load wsserver and novnc manifests var err error novncManifests, err = fs.ReadFile(manifests, "testdata/local_manifests.yaml") if err != nil { return ctx, err } vncserverManifests, err = fs.ReadFile(manifests, "testdata/kustomization_manifests.yaml") if err != nil { return ctx, err } return ctx, nil }). Setup( daemonsetdnstest.LoadManifests, daemonsetdnstest.Install, ) // It's also possible to run functions once for every test // f.BeforeEachTest() // Run the tests os.Exit(f.Run(m)) } func TestNovnc(t *testing.T) { // Tests are broken down into features, defined inside a golang testing Test // definition. // For k8s tests (using the ktest framework extension) each golang Test // creates a new unique namespace for the tests to run in // Each feature can define setup and teardown, and can have multiple test cases var ( addr string nodenames []string path = "/ws" portforward = ktest.PortForward{} queryTemplate = "token=%s" // query string template sentMessage = []byte{'h', 'e', 'l', 'l', 'o'} // message for wsserver ) novnc := f2.NewFeature("Test NoVNC websocket proxy"). Setup("Discover nodenames", func(ctx f2.Context, t *testing.T) f2.Context { k := ktest.FromContextT(ctx, t) nodes := corev1.NodeList{} err := k.Client.List(ctx, &nodes) assert.NoError(t, err) for _, node := range nodes.Items { if node.Name == "edge-control-plane" { // TODO should we instead be checking the taints continue } nodenames = append(nodenames, node.Name) } t.Logf("Discovered nodes: %s", nodenames) return ctx }). Setup("Create wsserver mock vnc server", func(ctx f2.Context, t *testing.T) f2.Context { k := ktest.FromContextT(ctx, t) manifests, err := kustomization.ProcessManifests(ctx.RunID, vncserverManifests, k.Namespace) assert.NoError(t, err) for _, manifest := range manifests { err = k.Client.Create(ctx, manifest) assert.NoError(t, err) } return ctx }). Setup("Create novnc resources", func(ctx f2.Context, t *testing.T) f2.Context { k := ktest.FromContextT(ctx, t) manifests, err := kustomization.ProcessManifests(ctx.RunID, novncManifests, k.Namespace) assert.NoError(t, err) for _, manifest := range manifests { if manifest.GetKind() == "Namespace" { continue } err = k.Client.Create(ctx, manifest) assert.NoError(t, err) } return ctx }). Setup("Wait for wsserver daemonset", func(ctx f2.Context, t *testing.T) f2.Context { k := ktest.FromContextT(ctx, t) manifests, err := kustomization.ProcessManifests(ctx.RunID, vncserverManifests, k.Namespace) assert.NoError(t, err) for _, manifest := range manifests { if manifest.GetKind() != "DaemonSet" { continue } k.WaitOn(t, k.Check(manifest, kmp.IsCurrent())) } return ctx }). Setup("Wait for novnc deployment", func(ctx f2.Context, t *testing.T) f2.Context { k := ktest.FromContextT(ctx, t) manifests, err := kustomization.ProcessManifests(ctx.RunID, novncManifests, k.Namespace) assert.NoError(t, err) for _, manifest := range manifests { if manifest.GetKind() != "Deployment" { continue } k.WaitOn(t, k.Check(manifest, kmp.IsCurrent()), poll.WithTimeout(300*time.Second)) } return ctx }). Setup("Port forwarding", portforward.Forward("novnc", 80)). Setup("Discover NoVNC address", func(ctx f2.Context, t *testing.T) f2.Context { time.Sleep(1 * time.Second) addr = portforward.Retrieve(t) time.Sleep(4 * time.Second) return ctx }). Test("Web Socket Upgrade", func(ctx f2.Context, t *testing.T) f2.Context { for _, node := range nodenames { query := fmt.Sprintf(queryTemplate, node) c, resp, err := dialConnection(t, addr, path, query) assert.NoError(t, err) defer c.Close() if resp.StatusCode != 101 { t.Fatalf("Unexpected response status code %d", resp.StatusCode) } err = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) assert.NoError(t, err) } return ctx }). Test("Correct Node Connection", func(ctx f2.Context, t *testing.T) f2.Context { for _, node := range nodenames { query := fmt.Sprintf(queryTemplate, node) c, resp, err := dialConnection(t, addr, path, query) assert.NoError(t, err) defer c.Close() if resp.StatusCode != 101 { t.Fatalf("Unexpected response status code %d", resp.StatusCode) } err = c.WriteMessage(websocket.TextMessage, sentMessage) assert.NoError(t, err) mt, message, err := c.ReadMessage() assert.NoError(t, err) if mt != websocket.TextMessage { t.Fatalf("Unexpected message type returned: %v", mt) } t.Logf("Response message: %s", message) // Expected response is name of connected node if !bytes.Equal(message, []byte(node)) { t.Fatalf("Unexpected message returned: %s", message) } err = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) assert.NoError(t, err) } return ctx }). Test("Missing Node Name", func(ctx f2.Context, t *testing.T) f2.Context { query := fmt.Sprintf(queryTemplate, "non-existent-node") _, resp, err := dialConnection(t, addr, path, query) // Currently nginx returns a 502 BadGateway response as this is the // response when it is unable to resolve a hostname if resp.StatusCode != http.StatusBadGateway { t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusBadGateway) } if !errors.Is(err, websocket.ErrBadHandshake) { t.Fatalf("Expected an ErrBadHandshake response, got %v", err) } return ctx }). Test("Ilegal Name", func(ctx f2.Context, t *testing.T) f2.Context { query := fmt.Sprintf(queryTemplate, "edge-worker.vncserver.vnc.pod-lookup./") _, resp, err := dialConnection(t, addr, path, query) if resp.StatusCode != 400 { t.Fatalf("Unexpected status code: %d, expected 400", resp.StatusCode) } if !errors.Is(err, websocket.ErrBadHandshake) { t.Fatalf("Expected an ErrBadHandshake response, got %v", err) } return ctx }). Feature() // Run the tests f.Test(t, novnc) } func dialConnection(t *testing.T, addr string, path string, query string) (*websocket.Conn, *http.Response, error) { u := url.URL{Scheme: "ws", Host: addr, Path: path, RawQuery: query} t.Logf("connecting to %s", u.String()) c, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) return c, resp, err }