...

Source file src/edge-infra.dev/cmd/sds/novnc/novnc_test.go

Documentation: edge-infra.dev/cmd/sds/novnc

     1  package novnc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"embed"
     7  	"errors"
     8  	"fmt"
     9  	"io/fs"
    10  	"net/http"
    11  	"net/url"
    12  	"os"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/gorilla/websocket"
    17  	"github.com/stretchr/testify/assert"
    18  	"gotest.tools/v3/poll"
    19  	corev1 "k8s.io/api/core/v1"
    20  
    21  	"edge-infra.dev/pkg/k8s/testing/kmp"
    22  	"edge-infra.dev/pkg/sds/k8s/daemonsetdns/daemonsetdnstest"
    23  	"edge-infra.dev/test/f2"
    24  	"edge-infra.dev/test/f2/integration"
    25  	"edge-infra.dev/test/f2/x/ktest"
    26  	"edge-infra.dev/test/f2/x/ktest/envtest"
    27  	"edge-infra.dev/test/f2/x/ktest/kustomization"
    28  )
    29  
    30  //go:embed testdata
    31  var manifests embed.FS
    32  
    33  var f f2.Framework
    34  
    35  var novncManifests []byte
    36  
    37  var vncserverManifests []byte
    38  
    39  func TestMain(m *testing.M) {
    40  	// Set up test framework in TestMain
    41  	f = f2.New(
    42  		context.Background(),
    43  		f2.WithExtensions(
    44  			// Include ktest extension for access to a k8s cluster
    45  			ktest.New(
    46  				ktest.WithEnvtestOptions(
    47  					// Do not install any CRD into cluster
    48  					envtest.WithoutCRDs(),
    49  				),
    50  			),
    51  		),
    52  	).
    53  		Setup(func(ctx f2.Context) (f2.Context, error) {
    54  			// Test execution should end here unless -integration-level=2 is passed to test
    55  			if !integration.IsL2() {
    56  				return ctx, fmt.Errorf("%w: requires L2 integration test level", f2.ErrSkip)
    57  			}
    58  
    59  			return ctx, nil
    60  		}).
    61  		Setup(func(ctx f2.Context) (f2.Context, error) {
    62  			// Load wsserver and novnc manifests
    63  			var err error
    64  			novncManifests, err = fs.ReadFile(manifests, "testdata/local_manifests.yaml")
    65  			if err != nil {
    66  				return ctx, err
    67  			}
    68  
    69  			vncserverManifests, err = fs.ReadFile(manifests, "testdata/kustomization_manifests.yaml")
    70  			if err != nil {
    71  				return ctx, err
    72  			}
    73  
    74  			return ctx, nil
    75  		}).
    76  		Setup(
    77  			daemonsetdnstest.LoadManifests,
    78  			daemonsetdnstest.Install,
    79  		)
    80  
    81  	// It's also possible to run functions once for every test
    82  	// f.BeforeEachTest()
    83  
    84  	// Run the tests
    85  	os.Exit(f.Run(m))
    86  }
    87  
    88  func TestNovnc(t *testing.T) {
    89  	// Tests are broken down into features, defined inside a golang testing Test
    90  	// definition.
    91  	// For k8s tests (using the ktest framework extension) each golang Test
    92  	// creates a new unique namespace for the tests to run in
    93  	// Each feature can define setup and teardown, and can have multiple test cases
    94  
    95  	var (
    96  		addr          string
    97  		nodenames     []string
    98  		path          = "/ws"
    99  		portforward   = ktest.PortForward{}
   100  		queryTemplate = "token=%s"                      // query string template
   101  		sentMessage   = []byte{'h', 'e', 'l', 'l', 'o'} // message for wsserver
   102  	)
   103  
   104  	novnc := f2.NewFeature("Test NoVNC websocket proxy").
   105  		Setup("Discover nodenames", func(ctx f2.Context, t *testing.T) f2.Context {
   106  			k := ktest.FromContextT(ctx, t)
   107  
   108  			nodes := corev1.NodeList{}
   109  
   110  			err := k.Client.List(ctx, &nodes)
   111  			assert.NoError(t, err)
   112  
   113  			for _, node := range nodes.Items {
   114  				if node.Name == "edge-control-plane" {
   115  					// TODO should we instead be checking the taints
   116  					continue
   117  				}
   118  				nodenames = append(nodenames, node.Name)
   119  			}
   120  
   121  			t.Logf("Discovered nodes: %s", nodenames)
   122  
   123  			return ctx
   124  		}).
   125  		Setup("Create wsserver mock vnc server", func(ctx f2.Context, t *testing.T) f2.Context {
   126  			k := ktest.FromContextT(ctx, t)
   127  
   128  			manifests, err := kustomization.ProcessManifests(ctx.RunID, vncserverManifests, k.Namespace)
   129  			assert.NoError(t, err)
   130  
   131  			for _, manifest := range manifests {
   132  				err = k.Client.Create(ctx, manifest)
   133  				assert.NoError(t, err)
   134  			}
   135  
   136  			return ctx
   137  		}).
   138  		Setup("Create novnc resources", func(ctx f2.Context, t *testing.T) f2.Context {
   139  			k := ktest.FromContextT(ctx, t)
   140  
   141  			manifests, err := kustomization.ProcessManifests(ctx.RunID, novncManifests, k.Namespace)
   142  			assert.NoError(t, err)
   143  
   144  			for _, manifest := range manifests {
   145  				if manifest.GetKind() == "Namespace" {
   146  					continue
   147  				}
   148  				err = k.Client.Create(ctx, manifest)
   149  				assert.NoError(t, err)
   150  			}
   151  
   152  			return ctx
   153  		}).
   154  		Setup("Wait for wsserver daemonset", func(ctx f2.Context, t *testing.T) f2.Context {
   155  			k := ktest.FromContextT(ctx, t)
   156  
   157  			manifests, err := kustomization.ProcessManifests(ctx.RunID, vncserverManifests, k.Namespace)
   158  			assert.NoError(t, err)
   159  
   160  			for _, manifest := range manifests {
   161  				if manifest.GetKind() != "DaemonSet" {
   162  					continue
   163  				}
   164  				k.WaitOn(t, k.Check(manifest, kmp.IsCurrent()))
   165  			}
   166  
   167  			return ctx
   168  		}).
   169  		Setup("Wait for novnc deployment", func(ctx f2.Context, t *testing.T) f2.Context {
   170  			k := ktest.FromContextT(ctx, t)
   171  
   172  			manifests, err := kustomization.ProcessManifests(ctx.RunID, novncManifests, k.Namespace)
   173  			assert.NoError(t, err)
   174  
   175  			for _, manifest := range manifests {
   176  				if manifest.GetKind() != "Deployment" {
   177  					continue
   178  				}
   179  				k.WaitOn(t, k.Check(manifest, kmp.IsCurrent()), poll.WithTimeout(300*time.Second))
   180  			}
   181  			return ctx
   182  		}).
   183  		Setup("Port forwarding", portforward.Forward("novnc", 80)).
   184  		Setup("Discover NoVNC address", func(ctx f2.Context, t *testing.T) f2.Context {
   185  			time.Sleep(1 * time.Second)
   186  			addr = portforward.Retrieve(t)
   187  			time.Sleep(4 * time.Second)
   188  			return ctx
   189  		}).
   190  		Test("Web Socket Upgrade", func(ctx f2.Context, t *testing.T) f2.Context {
   191  			for _, node := range nodenames {
   192  				query := fmt.Sprintf(queryTemplate, node)
   193  				c, resp, err := dialConnection(t, addr, path, query)
   194  				assert.NoError(t, err)
   195  				defer c.Close()
   196  
   197  				if resp.StatusCode != 101 {
   198  					t.Fatalf("Unexpected response status code %d", resp.StatusCode)
   199  				}
   200  
   201  				err = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
   202  				assert.NoError(t, err)
   203  			}
   204  			return ctx
   205  		}).
   206  		Test("Correct Node Connection", func(ctx f2.Context, t *testing.T) f2.Context {
   207  			for _, node := range nodenames {
   208  				query := fmt.Sprintf(queryTemplate, node)
   209  				c, resp, err := dialConnection(t, addr, path, query)
   210  				assert.NoError(t, err)
   211  				defer c.Close()
   212  
   213  				if resp.StatusCode != 101 {
   214  					t.Fatalf("Unexpected response status code %d", resp.StatusCode)
   215  				}
   216  
   217  				err = c.WriteMessage(websocket.TextMessage, sentMessage)
   218  				assert.NoError(t, err)
   219  
   220  				mt, message, err := c.ReadMessage()
   221  				assert.NoError(t, err)
   222  				if mt != websocket.TextMessage {
   223  					t.Fatalf("Unexpected message type returned: %v", mt)
   224  				}
   225  				t.Logf("Response message: %s", message)
   226  				// Expected response is name of connected node
   227  				if !bytes.Equal(message, []byte(node)) {
   228  					t.Fatalf("Unexpected message returned: %s", message)
   229  				}
   230  
   231  				err = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
   232  				assert.NoError(t, err)
   233  			}
   234  			return ctx
   235  		}).
   236  		Test("Missing Node Name", func(ctx f2.Context, t *testing.T) f2.Context {
   237  			query := fmt.Sprintf(queryTemplate, "non-existent-node")
   238  			_, resp, err := dialConnection(t, addr, path, query)
   239  
   240  			// Currently nginx returns a 502 BadGateway response as this is the
   241  			// response when it is unable to resolve a hostname
   242  			if resp.StatusCode != http.StatusBadGateway {
   243  				t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusBadGateway)
   244  			}
   245  			if !errors.Is(err, websocket.ErrBadHandshake) {
   246  				t.Fatalf("Expected an ErrBadHandshake response, got %v", err)
   247  			}
   248  			return ctx
   249  		}).
   250  		Test("Ilegal Name", func(ctx f2.Context, t *testing.T) f2.Context {
   251  			query := fmt.Sprintf(queryTemplate, "edge-worker.vncserver.vnc.pod-lookup./")
   252  			_, resp, err := dialConnection(t, addr, path, query)
   253  			if resp.StatusCode != 400 {
   254  				t.Fatalf("Unexpected status code: %d, expected 400", resp.StatusCode)
   255  			}
   256  			if !errors.Is(err, websocket.ErrBadHandshake) {
   257  				t.Fatalf("Expected an ErrBadHandshake response, got %v", err)
   258  			}
   259  			return ctx
   260  		}).
   261  		Feature()
   262  
   263  	// Run the tests
   264  	f.Test(t, novnc)
   265  }
   266  
   267  func dialConnection(t *testing.T, addr string, path string, query string) (*websocket.Conn, *http.Response, error) {
   268  	u := url.URL{Scheme: "ws", Host: addr, Path: path, RawQuery: query}
   269  	t.Logf("connecting to %s", u.String())
   270  	c, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
   271  	return c, resp, err
   272  }
   273  

View as plain text