...

Source file src/github.com/Microsoft/hcsshim/internal/gcs/guestconnection_test.go

Documentation: github.com/Microsoft/hcsshim/internal/gcs

     1  //go:build windows
     2  
     3  package gcs
     4  
     5  import (
     6  	"context"
     7  	"encoding/base64"
     8  	"encoding/hex"
     9  	"encoding/json"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/Microsoft/go-winio"
    18  	"github.com/Microsoft/go-winio/pkg/guid"
    19  	"github.com/sirupsen/logrus"
    20  	"go.opencensus.io/trace"
    21  	"go.opencensus.io/trace/tracestate"
    22  
    23  	"github.com/Microsoft/hcsshim/internal/oc"
    24  )
    25  
    26  const pipePortFmt = `\\.\pipe\gctest-port-%d`
    27  
    28  func npipeIoListen(port uint32) (net.Listener, error) {
    29  	return winio.ListenPipe(fmt.Sprintf(pipePortFmt, port), &winio.PipeConfig{
    30  		MessageMode: true,
    31  	})
    32  }
    33  
    34  func dialPort(port uint32) (net.Conn, error) {
    35  	return winio.DialPipe(fmt.Sprintf(pipePortFmt, port), nil)
    36  }
    37  
    38  func simpleGcs(t *testing.T, rwc io.ReadWriteCloser) {
    39  	t.Helper()
    40  	defer rwc.Close()
    41  	err := simpleGcsLoop(t, rwc)
    42  	if err != nil {
    43  		t.Error(err)
    44  	}
    45  }
    46  
    47  func simpleGcsLoop(t *testing.T, rw io.ReadWriter) error {
    48  	t.Helper()
    49  	for {
    50  		id, typ, b, err := readMessage(rw)
    51  		if err != nil {
    52  			if err == io.EOF || err == io.ErrClosedPipe {
    53  				err = nil
    54  			}
    55  			return err
    56  		}
    57  		switch proc := rpcProc(typ &^ msgTypeRequest); proc {
    58  		case rpcNegotiateProtocol:
    59  			err := sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &negotiateProtocolResponse{
    60  				Version: protocolVersion,
    61  				Capabilities: gcsCapabilities{
    62  					RuntimeOsType: "linux",
    63  				},
    64  			})
    65  			if err != nil {
    66  				return err
    67  			}
    68  		case rpcCreate:
    69  			err := sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &containerCreateResponse{})
    70  			if err != nil {
    71  				return err
    72  			}
    73  		case rpcExecuteProcess:
    74  			var req containerExecuteProcess
    75  			var params baseProcessParams
    76  			req.Settings.ProcessParameters.Value = &params
    77  			err := json.Unmarshal(b, &req)
    78  			if err != nil {
    79  				return err
    80  			}
    81  			var stdin, stdout, stderr net.Conn
    82  			if params.CreateStdInPipe {
    83  				stdin, err = dialPort(req.Settings.VsockStdioRelaySettings.StdIn)
    84  				if err != nil {
    85  					return err
    86  				}
    87  				defer stdin.Close()
    88  			}
    89  			if params.CreateStdOutPipe {
    90  				stdout, err = dialPort(req.Settings.VsockStdioRelaySettings.StdOut)
    91  				if err != nil {
    92  					return err
    93  				}
    94  				defer stdout.Close()
    95  			}
    96  			if params.CreateStdErrPipe {
    97  				stderr, err = dialPort(req.Settings.VsockStdioRelaySettings.StdErr)
    98  				if err != nil {
    99  					return err
   100  				}
   101  				defer stderr.Close()
   102  			}
   103  			if stdin != nil && stdout != nil {
   104  				go func() {
   105  					_, err := io.Copy(stdout, stdin)
   106  					if err != nil {
   107  						t.Error(err)
   108  					}
   109  					stdin.Close()
   110  					stdout.Close()
   111  				}()
   112  			}
   113  			err = sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &containerExecuteProcessResponse{
   114  				ProcessID: 42,
   115  			})
   116  			if err != nil {
   117  				return err
   118  			}
   119  		case rpcWaitForProcess:
   120  			// nothing
   121  		case rpcShutdownForced:
   122  			var req requestBase
   123  			err = json.Unmarshal(b, &req)
   124  			if err != nil {
   125  				return err
   126  			}
   127  			err = sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &responseBase{})
   128  			if err != nil {
   129  				return err
   130  			}
   131  			time.Sleep(50 * time.Millisecond)
   132  			err = sendJSON(t, rw, msgType(msgTypeNotify|notifyContainer), 0, &containerNotification{
   133  				requestBase: requestBase{
   134  					ContainerID: req.ContainerID,
   135  				},
   136  			})
   137  			if err != nil {
   138  				return err
   139  			}
   140  		default:
   141  			return fmt.Errorf("unsupported msg %s", typ)
   142  		}
   143  	}
   144  }
   145  
   146  func connectGcs(ctx context.Context, t *testing.T) *GuestConnection {
   147  	t.Helper()
   148  	s, c := pipeConn()
   149  	if ctx != context.Background() && ctx != context.TODO() {
   150  		go func() {
   151  			<-ctx.Done()
   152  			c.Close()
   153  		}()
   154  	}
   155  	go simpleGcs(t, c)
   156  	gcc := &GuestConnectionConfig{
   157  		Conn:     s,
   158  		Log:      logrus.NewEntry(logrus.StandardLogger()),
   159  		IoListen: npipeIoListen,
   160  	}
   161  	gc, err := gcc.Connect(context.Background(), true)
   162  	if err != nil {
   163  		c.Close()
   164  		t.Fatal(err)
   165  	}
   166  	return gc
   167  }
   168  
   169  func TestGcsConnect(t *testing.T) {
   170  	gc := connectGcs(context.Background(), t)
   171  	defer gc.Close()
   172  }
   173  
   174  func TestGcsCreateContainer(t *testing.T) {
   175  	gc := connectGcs(context.Background(), t)
   176  	defer gc.Close()
   177  	c, err := gc.CreateContainer(context.Background(), "foo", nil)
   178  	if err != nil {
   179  		t.Fatal(err)
   180  	}
   181  	c.Close()
   182  }
   183  
   184  func TestGcsWaitContainer(t *testing.T) {
   185  	gc := connectGcs(context.Background(), t)
   186  	defer gc.Close()
   187  	c, err := gc.CreateContainer(context.Background(), "foo", nil)
   188  	if err != nil {
   189  		t.Fatal(err)
   190  	}
   191  	defer c.Close()
   192  	err = c.Terminate(context.Background())
   193  	if err != nil {
   194  		t.Fatal(err)
   195  	}
   196  	err = c.Wait()
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  }
   201  
   202  func TestGcsWaitContainerBridgeTerminated(t *testing.T) {
   203  	ctx, cancel := context.WithCancel(context.Background())
   204  	defer cancel()
   205  	gc := connectGcs(ctx, t)
   206  	c, err := gc.CreateContainer(context.Background(), "foo", nil)
   207  	if err != nil {
   208  		t.Fatal(err)
   209  	}
   210  	defer c.Close()
   211  	cancel() // close the GCS connection
   212  	err = c.Wait()
   213  	if err != nil {
   214  		t.Fatal(err)
   215  	}
   216  }
   217  
   218  func TestGcsCreateProcess(t *testing.T) {
   219  	gc := connectGcs(context.Background(), t)
   220  	defer gc.Close()
   221  	p, err := gc.CreateProcess(context.Background(), &baseProcessParams{
   222  		CreateStdInPipe:  true,
   223  		CreateStdOutPipe: true,
   224  	})
   225  	if err != nil {
   226  		t.Fatal(err)
   227  	}
   228  	defer p.Close()
   229  	stdin, stdout, _ := p.Stdio()
   230  	_, err = stdin.Write(([]byte)("hello world"))
   231  	if err != nil {
   232  		t.Fatal(err)
   233  	}
   234  	err = p.CloseStdin(context.Background())
   235  	if err != nil {
   236  		t.Fatal(err)
   237  	}
   238  	b, err := io.ReadAll(stdout)
   239  	if err != nil {
   240  		t.Fatal(err)
   241  	}
   242  	if string(b) != "hello world" {
   243  		t.Errorf("unexpected: %q", string(b))
   244  	}
   245  }
   246  
   247  func TestGcsWaitProcessBridgeTerminated(t *testing.T) {
   248  	ctx, cancel := context.WithCancel(context.Background())
   249  	defer cancel()
   250  	gc := connectGcs(ctx, t)
   251  	defer gc.Close()
   252  	p, err := gc.CreateProcess(context.Background(), nil)
   253  	if err != nil {
   254  		t.Fatal(err)
   255  	}
   256  	defer p.Close()
   257  	cancel()
   258  	err = p.Wait()
   259  	if err == nil || !strings.Contains(err.Error(), "bridge closed") {
   260  		t.Fatal("unexpected: ", err)
   261  	}
   262  }
   263  
   264  func Test_makeRequestNoSpan(t *testing.T) {
   265  	r := makeRequest(context.Background(), t.Name())
   266  
   267  	if r.ContainerID != t.Name() {
   268  		t.Fatalf("expected ContainerID: %q, got: %q", t.Name(), r.ContainerID)
   269  	}
   270  	var empty guid.GUID
   271  	if r.ActivityID != empty {
   272  		t.Fatalf("expected ActivityID empty, got: %q", r.ActivityID.String())
   273  	}
   274  	if r.OpenCensusSpanContext != nil {
   275  		t.Fatal("expected nil span context")
   276  	}
   277  }
   278  
   279  func Test_makeRequestWithSpan(t *testing.T) {
   280  	ctx, span := oc.StartSpan(context.Background(), t.Name())
   281  	defer span.End()
   282  	r := makeRequest(ctx, t.Name())
   283  
   284  	if r.ContainerID != t.Name() {
   285  		t.Fatalf("expected ContainerID: %q, got: %q", t.Name(), r.ContainerID)
   286  	}
   287  	var empty guid.GUID
   288  	if r.ActivityID != empty {
   289  		t.Fatalf("expected ActivityID empty, got: %q", r.ActivityID.String())
   290  	}
   291  	if r.OpenCensusSpanContext == nil {
   292  		t.Fatal("expected non-nil span context")
   293  	}
   294  	sc := span.SpanContext()
   295  	encodedTraceID := hex.EncodeToString(sc.TraceID[:])
   296  	if r.OpenCensusSpanContext.TraceID != encodedTraceID {
   297  		t.Fatalf("expected encoded TraceID: %q, got: %q", encodedTraceID, r.OpenCensusSpanContext.TraceID)
   298  	}
   299  	encodedSpanID := hex.EncodeToString(sc.SpanID[:])
   300  	if r.OpenCensusSpanContext.SpanID != encodedSpanID {
   301  		t.Fatalf("expected encoded SpanID: %q, got: %q", encodedSpanID, r.OpenCensusSpanContext.SpanID)
   302  	}
   303  	encodedTraceOptions := uint32(sc.TraceOptions)
   304  	if r.OpenCensusSpanContext.TraceOptions != encodedTraceOptions {
   305  		t.Fatalf("expected encoded TraceOptions: %v, got: %v", encodedTraceOptions, r.OpenCensusSpanContext.TraceOptions)
   306  	}
   307  	if r.OpenCensusSpanContext.Tracestate != "" {
   308  		t.Fatalf("expected encoded TraceState: '', got: %q", r.OpenCensusSpanContext.Tracestate)
   309  	}
   310  }
   311  
   312  func Test_makeRequestWithSpan_TraceStateEmptyEntries(t *testing.T) {
   313  	// Start a remote context span so we can forward trace state.
   314  	ts, err := tracestate.New(nil)
   315  	if err != nil {
   316  		t.Fatalf("failed to make test Tracestate")
   317  	}
   318  	parent := trace.SpanContext{
   319  		Tracestate: ts,
   320  	}
   321  	ctx, span := trace.StartSpanWithRemoteParent(context.Background(), t.Name(), parent)
   322  	defer span.End()
   323  	r := makeRequest(ctx, t.Name())
   324  
   325  	if r.OpenCensusSpanContext == nil {
   326  		t.Fatal("expected non-nil span context")
   327  	}
   328  	if r.OpenCensusSpanContext.Tracestate != "" {
   329  		t.Fatalf("expected encoded TraceState: '', got: %q", r.OpenCensusSpanContext.Tracestate)
   330  	}
   331  }
   332  
   333  func Test_makeRequestWithSpan_TraceStateEntries(t *testing.T) {
   334  	// Start a remote context span so we can forward trace state.
   335  	ts, err := tracestate.New(nil, tracestate.Entry{Key: "test", Value: "test"})
   336  	if err != nil {
   337  		t.Fatalf("failed to make test Tracestate")
   338  	}
   339  	parent := trace.SpanContext{
   340  		Tracestate: ts,
   341  	}
   342  	ctx, span := trace.StartSpanWithRemoteParent(context.Background(), t.Name(), parent)
   343  	defer span.End()
   344  	r := makeRequest(ctx, t.Name())
   345  
   346  	if r.OpenCensusSpanContext == nil {
   347  		t.Fatal("expected non-nil span context")
   348  	}
   349  	encodedTraceState := base64.StdEncoding.EncodeToString([]byte(`[{"Key":"test","Value":"test"}]`))
   350  	if r.OpenCensusSpanContext.Tracestate != encodedTraceState {
   351  		t.Fatalf("expected encoded TraceState: %q, got: %q", encodedTraceState, r.OpenCensusSpanContext.Tracestate)
   352  	}
   353  }
   354  

View as plain text