1
2
3 package gcs
4
5 import (
6 "context"
7 "encoding/base64"
8 "encoding/hex"
9 "encoding/json"
10 "fmt"
11 "io"
12 "net"
13 "strings"
14 "testing"
15 "time"
16
17 "github.com/Microsoft/go-winio"
18 "github.com/Microsoft/go-winio/pkg/guid"
19 "github.com/sirupsen/logrus"
20 "go.opencensus.io/trace"
21 "go.opencensus.io/trace/tracestate"
22
23 "github.com/Microsoft/hcsshim/internal/oc"
24 )
25
26 const pipePortFmt = `\\.\pipe\gctest-port-%d`
27
28 func npipeIoListen(port uint32) (net.Listener, error) {
29 return winio.ListenPipe(fmt.Sprintf(pipePortFmt, port), &winio.PipeConfig{
30 MessageMode: true,
31 })
32 }
33
34 func dialPort(port uint32) (net.Conn, error) {
35 return winio.DialPipe(fmt.Sprintf(pipePortFmt, port), nil)
36 }
37
38 func simpleGcs(t *testing.T, rwc io.ReadWriteCloser) {
39 t.Helper()
40 defer rwc.Close()
41 err := simpleGcsLoop(t, rwc)
42 if err != nil {
43 t.Error(err)
44 }
45 }
46
47 func simpleGcsLoop(t *testing.T, rw io.ReadWriter) error {
48 t.Helper()
49 for {
50 id, typ, b, err := readMessage(rw)
51 if err != nil {
52 if err == io.EOF || err == io.ErrClosedPipe {
53 err = nil
54 }
55 return err
56 }
57 switch proc := rpcProc(typ &^ msgTypeRequest); proc {
58 case rpcNegotiateProtocol:
59 err := sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &negotiateProtocolResponse{
60 Version: protocolVersion,
61 Capabilities: gcsCapabilities{
62 RuntimeOsType: "linux",
63 },
64 })
65 if err != nil {
66 return err
67 }
68 case rpcCreate:
69 err := sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &containerCreateResponse{})
70 if err != nil {
71 return err
72 }
73 case rpcExecuteProcess:
74 var req containerExecuteProcess
75 var params baseProcessParams
76 req.Settings.ProcessParameters.Value = ¶ms
77 err := json.Unmarshal(b, &req)
78 if err != nil {
79 return err
80 }
81 var stdin, stdout, stderr net.Conn
82 if params.CreateStdInPipe {
83 stdin, err = dialPort(req.Settings.VsockStdioRelaySettings.StdIn)
84 if err != nil {
85 return err
86 }
87 defer stdin.Close()
88 }
89 if params.CreateStdOutPipe {
90 stdout, err = dialPort(req.Settings.VsockStdioRelaySettings.StdOut)
91 if err != nil {
92 return err
93 }
94 defer stdout.Close()
95 }
96 if params.CreateStdErrPipe {
97 stderr, err = dialPort(req.Settings.VsockStdioRelaySettings.StdErr)
98 if err != nil {
99 return err
100 }
101 defer stderr.Close()
102 }
103 if stdin != nil && stdout != nil {
104 go func() {
105 _, err := io.Copy(stdout, stdin)
106 if err != nil {
107 t.Error(err)
108 }
109 stdin.Close()
110 stdout.Close()
111 }()
112 }
113 err = sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &containerExecuteProcessResponse{
114 ProcessID: 42,
115 })
116 if err != nil {
117 return err
118 }
119 case rpcWaitForProcess:
120
121 case rpcShutdownForced:
122 var req requestBase
123 err = json.Unmarshal(b, &req)
124 if err != nil {
125 return err
126 }
127 err = sendJSON(t, rw, msgTypeResponse|msgType(proc), id, &responseBase{})
128 if err != nil {
129 return err
130 }
131 time.Sleep(50 * time.Millisecond)
132 err = sendJSON(t, rw, msgType(msgTypeNotify|notifyContainer), 0, &containerNotification{
133 requestBase: requestBase{
134 ContainerID: req.ContainerID,
135 },
136 })
137 if err != nil {
138 return err
139 }
140 default:
141 return fmt.Errorf("unsupported msg %s", typ)
142 }
143 }
144 }
145
146 func connectGcs(ctx context.Context, t *testing.T) *GuestConnection {
147 t.Helper()
148 s, c := pipeConn()
149 if ctx != context.Background() && ctx != context.TODO() {
150 go func() {
151 <-ctx.Done()
152 c.Close()
153 }()
154 }
155 go simpleGcs(t, c)
156 gcc := &GuestConnectionConfig{
157 Conn: s,
158 Log: logrus.NewEntry(logrus.StandardLogger()),
159 IoListen: npipeIoListen,
160 }
161 gc, err := gcc.Connect(context.Background(), true)
162 if err != nil {
163 c.Close()
164 t.Fatal(err)
165 }
166 return gc
167 }
168
169 func TestGcsConnect(t *testing.T) {
170 gc := connectGcs(context.Background(), t)
171 defer gc.Close()
172 }
173
174 func TestGcsCreateContainer(t *testing.T) {
175 gc := connectGcs(context.Background(), t)
176 defer gc.Close()
177 c, err := gc.CreateContainer(context.Background(), "foo", nil)
178 if err != nil {
179 t.Fatal(err)
180 }
181 c.Close()
182 }
183
184 func TestGcsWaitContainer(t *testing.T) {
185 gc := connectGcs(context.Background(), t)
186 defer gc.Close()
187 c, err := gc.CreateContainer(context.Background(), "foo", nil)
188 if err != nil {
189 t.Fatal(err)
190 }
191 defer c.Close()
192 err = c.Terminate(context.Background())
193 if err != nil {
194 t.Fatal(err)
195 }
196 err = c.Wait()
197 if err != nil {
198 t.Fatal(err)
199 }
200 }
201
202 func TestGcsWaitContainerBridgeTerminated(t *testing.T) {
203 ctx, cancel := context.WithCancel(context.Background())
204 defer cancel()
205 gc := connectGcs(ctx, t)
206 c, err := gc.CreateContainer(context.Background(), "foo", nil)
207 if err != nil {
208 t.Fatal(err)
209 }
210 defer c.Close()
211 cancel()
212 err = c.Wait()
213 if err != nil {
214 t.Fatal(err)
215 }
216 }
217
218 func TestGcsCreateProcess(t *testing.T) {
219 gc := connectGcs(context.Background(), t)
220 defer gc.Close()
221 p, err := gc.CreateProcess(context.Background(), &baseProcessParams{
222 CreateStdInPipe: true,
223 CreateStdOutPipe: true,
224 })
225 if err != nil {
226 t.Fatal(err)
227 }
228 defer p.Close()
229 stdin, stdout, _ := p.Stdio()
230 _, err = stdin.Write(([]byte)("hello world"))
231 if err != nil {
232 t.Fatal(err)
233 }
234 err = p.CloseStdin(context.Background())
235 if err != nil {
236 t.Fatal(err)
237 }
238 b, err := io.ReadAll(stdout)
239 if err != nil {
240 t.Fatal(err)
241 }
242 if string(b) != "hello world" {
243 t.Errorf("unexpected: %q", string(b))
244 }
245 }
246
247 func TestGcsWaitProcessBridgeTerminated(t *testing.T) {
248 ctx, cancel := context.WithCancel(context.Background())
249 defer cancel()
250 gc := connectGcs(ctx, t)
251 defer gc.Close()
252 p, err := gc.CreateProcess(context.Background(), nil)
253 if err != nil {
254 t.Fatal(err)
255 }
256 defer p.Close()
257 cancel()
258 err = p.Wait()
259 if err == nil || !strings.Contains(err.Error(), "bridge closed") {
260 t.Fatal("unexpected: ", err)
261 }
262 }
263
264 func Test_makeRequestNoSpan(t *testing.T) {
265 r := makeRequest(context.Background(), t.Name())
266
267 if r.ContainerID != t.Name() {
268 t.Fatalf("expected ContainerID: %q, got: %q", t.Name(), r.ContainerID)
269 }
270 var empty guid.GUID
271 if r.ActivityID != empty {
272 t.Fatalf("expected ActivityID empty, got: %q", r.ActivityID.String())
273 }
274 if r.OpenCensusSpanContext != nil {
275 t.Fatal("expected nil span context")
276 }
277 }
278
279 func Test_makeRequestWithSpan(t *testing.T) {
280 ctx, span := oc.StartSpan(context.Background(), t.Name())
281 defer span.End()
282 r := makeRequest(ctx, t.Name())
283
284 if r.ContainerID != t.Name() {
285 t.Fatalf("expected ContainerID: %q, got: %q", t.Name(), r.ContainerID)
286 }
287 var empty guid.GUID
288 if r.ActivityID != empty {
289 t.Fatalf("expected ActivityID empty, got: %q", r.ActivityID.String())
290 }
291 if r.OpenCensusSpanContext == nil {
292 t.Fatal("expected non-nil span context")
293 }
294 sc := span.SpanContext()
295 encodedTraceID := hex.EncodeToString(sc.TraceID[:])
296 if r.OpenCensusSpanContext.TraceID != encodedTraceID {
297 t.Fatalf("expected encoded TraceID: %q, got: %q", encodedTraceID, r.OpenCensusSpanContext.TraceID)
298 }
299 encodedSpanID := hex.EncodeToString(sc.SpanID[:])
300 if r.OpenCensusSpanContext.SpanID != encodedSpanID {
301 t.Fatalf("expected encoded SpanID: %q, got: %q", encodedSpanID, r.OpenCensusSpanContext.SpanID)
302 }
303 encodedTraceOptions := uint32(sc.TraceOptions)
304 if r.OpenCensusSpanContext.TraceOptions != encodedTraceOptions {
305 t.Fatalf("expected encoded TraceOptions: %v, got: %v", encodedTraceOptions, r.OpenCensusSpanContext.TraceOptions)
306 }
307 if r.OpenCensusSpanContext.Tracestate != "" {
308 t.Fatalf("expected encoded TraceState: '', got: %q", r.OpenCensusSpanContext.Tracestate)
309 }
310 }
311
312 func Test_makeRequestWithSpan_TraceStateEmptyEntries(t *testing.T) {
313
314 ts, err := tracestate.New(nil)
315 if err != nil {
316 t.Fatalf("failed to make test Tracestate")
317 }
318 parent := trace.SpanContext{
319 Tracestate: ts,
320 }
321 ctx, span := trace.StartSpanWithRemoteParent(context.Background(), t.Name(), parent)
322 defer span.End()
323 r := makeRequest(ctx, t.Name())
324
325 if r.OpenCensusSpanContext == nil {
326 t.Fatal("expected non-nil span context")
327 }
328 if r.OpenCensusSpanContext.Tracestate != "" {
329 t.Fatalf("expected encoded TraceState: '', got: %q", r.OpenCensusSpanContext.Tracestate)
330 }
331 }
332
333 func Test_makeRequestWithSpan_TraceStateEntries(t *testing.T) {
334
335 ts, err := tracestate.New(nil, tracestate.Entry{Key: "test", Value: "test"})
336 if err != nil {
337 t.Fatalf("failed to make test Tracestate")
338 }
339 parent := trace.SpanContext{
340 Tracestate: ts,
341 }
342 ctx, span := trace.StartSpanWithRemoteParent(context.Background(), t.Name(), parent)
343 defer span.End()
344 r := makeRequest(ctx, t.Name())
345
346 if r.OpenCensusSpanContext == nil {
347 t.Fatal("expected non-nil span context")
348 }
349 encodedTraceState := base64.StdEncoding.EncodeToString([]byte(`[{"Key":"test","Value":"test"}]`))
350 if r.OpenCensusSpanContext.Tracestate != encodedTraceState {
351 t.Fatalf("expected encoded TraceState: %q, got: %q", encodedTraceState, r.OpenCensusSpanContext.Tracestate)
352 }
353 }
354
View as plain text