1
2
3
4 package bridge
5
6 import (
7 "encoding/binary"
8 "encoding/json"
9 "io"
10 "os"
11 "strings"
12 "sync"
13 "testing"
14
15 "github.com/Microsoft/hcsshim/internal/guest/gcserr"
16 "github.com/Microsoft/hcsshim/internal/guest/prot"
17 "github.com/Microsoft/hcsshim/internal/guest/transport"
18 "github.com/pkg/errors"
19 "github.com/sirupsen/logrus"
20 )
21
22 func Test_Bridge_Mux_New(t *testing.T) {
23 m := NewBridgeMux()
24 if m == nil {
25 t.Error("Failed to create bridge mux")
26 }
27 }
28
29 func Test_Bridge_Mux_New_Success(t *testing.T) {
30 m := NewBridgeMux()
31 if m.m == nil {
32 t.Error("Bridge mux map is not initialized")
33 }
34 }
35
36 type thandler struct {
37 set bool
38 resp RequestResponse
39 err error
40 }
41
42 func (h *thandler) ServeMsg(_ *Request) (RequestResponse, error) {
43 h.set = true
44 return h.resp, h.err
45 }
46
47 func TestBridgeMux_Handle_NilHandler_Panic(t *testing.T) {
48 defer func() {
49 if r := recover(); r == nil {
50 t.Error("The code did not panic on nil handler")
51 }
52 }()
53
54 m := NewBridgeMux()
55 m.Handle(prot.ComputeSystemCreateV1, prot.PvInvalid, nil)
56 }
57
58 func TestBridgeMux_Handle_NilMap_Panic(t *testing.T) {
59 defer func() {
60 if r := recover(); r == nil {
61 t.Error("The code did not panic on nil map")
62 }
63 }()
64
65 m := &Mux{}
66 th := &thandler{}
67 m.Handle(prot.ComputeSystemCreateV1, prot.PvInvalid, th)
68 }
69
70 func Test_Bridge_Mux_Handle_Succeeds(t *testing.T) {
71 th := &thandler{}
72 m := NewBridgeMux()
73 m.Handle(prot.ComputeSystemCreateV1, prot.PvInvalid, th)
74
75 var verMap map[prot.ProtocolVersion]Handler
76 var ok bool
77 if verMap, ok = m.m[prot.ComputeSystemCreateV1]; !ok {
78 t.Error("The handler type map not successfully added.")
79 }
80
81 var hOut Handler
82 if hOut, ok = verMap[prot.PvInvalid]; !ok {
83 t.Error("The handler was not successfully added.")
84 }
85
86
87 _, _ = hOut.ServeMsg(nil)
88
89 if !th.set {
90 t.Error("The handler added was not the same handler.")
91 }
92 }
93
94 func TestBridgeMux_HandleFunc_NilHandleFunc_Panic(t *testing.T) {
95 defer func() {
96 if r := recover(); r == nil {
97 t.Error("The code did not panic on nil handler")
98 }
99 }()
100
101 m := NewBridgeMux()
102 m.HandleFunc(prot.ComputeSystemCreateV1, prot.PvInvalid, nil)
103 }
104
105 func TestBridgeMux_HandleFunc_NilMap_Panic(t *testing.T) {
106 defer func() {
107 if r := recover(); r == nil {
108 t.Error("The code did not panic on nil handler")
109 }
110 }()
111
112 hIn := func(*Request) (RequestResponse, error) {
113 return nil, nil
114 }
115
116 m := &Mux{}
117 m.HandleFunc(prot.ComputeSystemCreateV1, prot.PvInvalid, hIn)
118 }
119
120 func Test_Bridge_Mux_HandleFunc_Succeeds(t *testing.T) {
121 var set bool
122 hIn := func(*Request) (RequestResponse, error) {
123 set = true
124 return nil, nil
125 }
126
127 m := NewBridgeMux()
128 m.HandleFunc(prot.ComputeSystemCreateV1, prot.PvInvalid, hIn)
129
130 var verMap map[prot.ProtocolVersion]Handler
131 var ok bool
132 if verMap, ok = m.m[prot.ComputeSystemCreateV1]; !ok {
133 t.Error("The handler type map not successfully added.")
134 }
135
136 var hOut Handler
137 if hOut, ok = verMap[prot.PvInvalid]; !ok {
138 t.Error("The handler was not successfully added.")
139 }
140
141
142 _, _ = hOut.ServeMsg(nil)
143
144 if !set {
145 t.Error("The handler added was not the same handler.")
146 }
147 }
148
149 func Test_Bridge_Mux_Handler_NilRequest_Panic(t *testing.T) {
150 defer func() {
151 if r := recover(); r == nil {
152 t.Error("The code did not panic on nil request to handler")
153 }
154 }()
155
156 var set bool
157 hIn := func(*Request) (RequestResponse, error) {
158 set = true
159 return nil, nil
160 }
161
162 m := NewBridgeMux()
163 m.HandleFunc(prot.ComputeSystemCreateV1, prot.PvInvalid, hIn)
164 m.Handler(nil)
165 if set {
166 t.Fatal("should not be set on nil request")
167 }
168 }
169
170 func verifyResponseIsDefaultHandler(t *testing.T, resp RequestResponse) {
171 t.Helper()
172 if resp == nil {
173 t.Fatal("The response is nil")
174 }
175
176 base := resp.Base()
177 if base.Result != int32(gcserr.HrNotImpl) {
178 t.Fatal("The default handler did not set a -1 error result.")
179 }
180 if len(base.ErrorRecords) != 1 {
181 t.Fatal("The default handler did not set an error record.")
182 }
183 if !strings.Contains(base.ErrorRecords[0].Message, "bridge: function not supported") {
184 t.Fatal("The default handler did not return the not supported message")
185 }
186 }
187
188 func Test_Bridge_Mux_Handler_NotAdded_Default(t *testing.T) {
189
190
191
192 m := NewBridgeMux()
193
194 req := &Request{
195 Header: &prot.MessageHeader{
196 Type: prot.ComputeSystemCreateV1,
197 Size: 0,
198 ID: prot.SequenceID(1),
199 },
200 }
201
202 hOut := m.Handler(req)
203 resp, err := hOut.ServeMsg(req)
204 if resp != nil {
205 t.Fatalf("expected nil response got: %+v", resp)
206 }
207 if err == nil {
208 t.Fatal("expected valid error got: nil")
209 }
210 }
211
212 func Test_Bridge_Mux_Handler_Added_NotMatched(t *testing.T) {
213
214
215
216
217 m := NewBridgeMux()
218 th := &thandler{}
219
220
221 m.Handle(prot.ComputeSystemCreateV1, prot.PvInvalid, th)
222
223 req := &Request{
224 Header: &prot.MessageHeader{
225 Type: prot.ComputeSystemExecuteProcessV1,
226 Size: 0,
227 ID: prot.SequenceID(1),
228 },
229 }
230
231
232 hOut := m.Handler(req)
233 respChan := make(chan bridgeResponse, 1)
234 defer close(respChan)
235
236 resp, err := hOut.ServeMsg(req)
237 if resp != nil {
238 t.Fatalf("expected nil response got: %+v", resp)
239 }
240 if err == nil {
241 t.Fatal("expected valid error got: nil")
242 }
243 if th.set {
244 t.Error("Handler did not call the appropriate handler for a match request")
245 }
246 }
247
248 func Test_Bridge_Mux_Handler_Success(t *testing.T) {
249 m := NewBridgeMux()
250 th := &thandler{
251 resp: &prot.ContainerCreateResponse{},
252 }
253
254 m.Handle(prot.ComputeSystemCreateV1, prot.PvInvalid, th)
255
256 req := &Request{
257 Header: &prot.MessageHeader{
258 Type: prot.ComputeSystemCreateV1,
259 Size: 0,
260 ID: prot.SequenceID(1),
261 },
262 }
263
264 hOut := m.Handler(req)
265 respChan := make(chan bridgeResponse, 1)
266 defer close(respChan)
267
268 resp, err := hOut.ServeMsg(req)
269 if resp == nil {
270 t.Fatal("expected valid response got: nil")
271 }
272 if err != nil {
273 t.Fatalf("expected nil error got: %v", err)
274 }
275 if !th.set {
276 t.Error("Handler did not call the appropriate handler for a match request")
277 }
278 }
279
280 func Test_Bridge_Mux_ServeMsg_NotAdded_Default(t *testing.T) {
281
282
283
284 m := NewBridgeMux()
285
286 req := &Request{
287 Header: &prot.MessageHeader{
288 Type: prot.ComputeSystemCreateV1,
289 Size: 0,
290 ID: prot.SequenceID(1),
291 },
292 }
293
294 respChan := make(chan bridgeResponse, 1)
295 defer close(respChan)
296
297 resp, err := m.ServeMsg(req)
298 if resp != nil {
299 t.Fatalf("expected nil response, got: %+v", resp)
300 }
301 if err == nil {
302 t.Fatal("expected error got: nil")
303 }
304 }
305
306 func Test_Bridge_Mux_ServeMsg_Added_NotMatched(t *testing.T) {
307
308
309
310
311 m := NewBridgeMux()
312 th := &thandler{}
313
314
315 m.Handle(prot.ComputeSystemCreateV1, prot.PvInvalid, th)
316
317 req := &Request{
318 Header: &prot.MessageHeader{
319 Type: prot.ComputeSystemExecuteProcessV1,
320 Size: 0,
321 ID: prot.SequenceID(1),
322 },
323 }
324
325
326 respChan := make(chan bridgeResponse, 1)
327 defer close(respChan)
328
329 resp, err := m.ServeMsg(req)
330 if resp != nil {
331 t.Fatalf("expected nil response, got: %+v", resp)
332 }
333 if err == nil {
334 t.Fatal("expected error got: nil")
335 }
336 if th.set {
337 t.Error("Handler did not call the appropriate handler for a match request")
338 }
339 }
340
341 func Test_Bridge_Mux_ServeMsg_Success(t *testing.T) {
342 m := NewBridgeMux()
343 th := &thandler{
344 resp: &prot.ContainerCreateResponse{},
345 }
346
347 m.Handle(prot.ComputeSystemCreateV1, prot.PvInvalid, th)
348
349 req := &Request{
350 Header: &prot.MessageHeader{
351 Type: prot.ComputeSystemCreateV1,
352 Size: 0,
353 ID: prot.SequenceID(1),
354 },
355 }
356
357 respChan := make(chan bridgeResponse, 1)
358 defer close(respChan)
359
360 resp, err := m.ServeMsg(req)
361 if resp == nil {
362 t.Fatal("expected valid response got: nil")
363 }
364 if err != nil {
365 t.Fatalf("expected nil error got: %v", err)
366 }
367 if !th.set {
368 t.Error("Handler did not call the appropriate handler for a match request")
369 }
370 }
371
372
373 type errorTransport struct {
374 e error
375 }
376
377
378 func (e *errorTransport) Dial(_ uint32) (transport.Connection, error) {
379 return nil, e.e
380 }
381
382 func serverSend(conn io.Writer, messageType prot.MessageIdentifier, messageID prot.SequenceID, i interface{}) error {
383 body := make([]byte, 0)
384 if i != nil {
385 var err error
386 body, err = json.Marshal(i)
387 if err != nil {
388 return errors.Wrap(err, "failed to json marshal to server.")
389 }
390 }
391
392 header := prot.MessageHeader{
393 Type: messageType,
394 ID: messageID,
395 Size: uint32(len(body) + prot.MessageHeaderSize),
396 }
397
398
399 if err := binary.Write(conn, binary.LittleEndian, header); err != nil {
400 return errors.Wrap(err, "bridge_test: failed to write message header")
401 }
402
403 if _, err := conn.Write(body); err != nil {
404 return errors.Wrap(err, "bridge_test: failed to write the message body")
405 }
406 return nil
407 }
408
409 func serverRead(conn io.Reader) (*prot.MessageHeader, []byte, error) {
410 header := &prot.MessageHeader{}
411
412 if err := binary.Read(conn, binary.LittleEndian, header); err != nil {
413 return nil, nil, errors.Wrap(err, "bridge_test: failed to read message header")
414 }
415 message := make([]byte, header.Size-prot.MessageHeaderSize)
416
417 if _, err := io.ReadFull(conn, message); err != nil {
418 return nil, nil, errors.Wrap(err, "bridge_test: failed to read the message body")
419 }
420
421 return header, message, nil
422 }
423
424 type loopbackConnection struct {
425
426 pipes [4]*os.File
427 }
428
429 func (lc *loopbackConnection) close() {
430 for i := 3; i >= 0; i-- {
431 lc.pipes[i].Close()
432 }
433 }
434
435 func (lc *loopbackConnection) CRead() io.ReadCloser {
436 return lc.pipes[0]
437 }
438
439 func (lc *loopbackConnection) CWrite() io.WriteCloser {
440 return lc.pipes[3]
441 }
442
443 func (lc *loopbackConnection) SRead() io.ReadCloser {
444 return lc.pipes[2]
445 }
446
447 func (lc *loopbackConnection) SWrite() io.WriteCloser {
448 return lc.pipes[1]
449 }
450
451 func newLoopbackConnection() *loopbackConnection {
452 l := new(loopbackConnection)
453 l.pipes[0], l.pipes[1], _ = os.Pipe()
454 l.pipes[2], l.pipes[3], _ = os.Pipe()
455 return l
456 }
457
458 func Test_Bridge_ListenAndServe_UnknownMessageHandler_Success(t *testing.T) {
459
460 logrus.SetOutput(io.Discard)
461
462 lc := newLoopbackConnection()
463 defer lc.close()
464
465 b := &Bridge{
466 Handler: UnknownMessageHandler(),
467 }
468
469 go func() {
470 if err := b.ListenAndServe(lc.SRead(), lc.SWrite()); err != nil {
471 t.Error(err)
472 }
473 }()
474 defer func() {
475 b.quitChan <- true
476 }()
477
478 message := &prot.ContainerResizeConsole{
479 MessageBase: prot.MessageBase{
480 ContainerID: "01234567-89ab-cdef-0123-456789abcdef",
481 ActivityID: "00000000-0000-0000-0000-000000000001",
482 },
483 }
484 if err := serverSend(lc.CWrite(), prot.ComputeSystemResizeConsoleV1, prot.SequenceID(1), message); err != nil {
485 t.Error("Failed to send message to server")
486 return
487 }
488 header, body, err := serverRead(lc.CRead())
489 if err != nil {
490 t.Error("Failed to read message response from server")
491 return
492 }
493 response := &prot.MessageResponseBase{}
494 if err := json.Unmarshal(body, response); err != nil {
495 t.Error("Failed to unmarshal response body from server")
496 return
497 }
498
499
500 if header.Type != prot.ComputeSystemResponseResizeConsoleV1 {
501 t.Error("Response header was not resize console response.")
502 }
503 if header.ID != prot.SequenceID(1) {
504 t.Error("Response header had wrong sequence id")
505 }
506 verifyResponseIsDefaultHandler(t, response)
507 if response.ActivityID != message.ActivityID {
508 t.Fatal("Response had invalid activity id")
509 }
510 }
511
512 func Test_Bridge_ListenAndServe_CorrectHandler_Success(t *testing.T) {
513
514 logrus.SetOutput(io.Discard)
515
516 lc := newLoopbackConnection()
517 defer lc.close()
518
519 mux := NewBridgeMux()
520 message := &prot.ContainerResizeConsole{
521 MessageBase: prot.MessageBase{
522 ContainerID: "01234567-89ab-cdef-0123-456789abcdef",
523 ActivityID: "00000000-0000-0000-0000-000000000010",
524 },
525 }
526 resizeFn := func(r *Request) (RequestResponse, error) {
527
528 if r.Header.Type != prot.ComputeSystemResizeConsoleV1 {
529 return nil, errors.New("bridge_test: wrong request type")
530 }
531 if r.Header.ID != prot.SequenceID(1) {
532 return nil, errors.New("bridge_test: wrong sequence id")
533 }
534
535 rBody := prot.ContainerResizeConsole{}
536
537 if err := json.Unmarshal(r.Message, &rBody); err != nil {
538 return nil, errors.New("failed to unmarshal body")
539 }
540 if message.ContainerID != rBody.ContainerID {
541 return nil, errors.New("containerID of source and handler func not equal")
542 }
543
544 return &prot.MessageResponseBase{
545 Result: 1,
546 ActivityID: rBody.ActivityID,
547 }, nil
548 }
549 mux.HandleFunc(prot.ComputeSystemResizeConsoleV1, prot.PvV4, resizeFn)
550 b := &Bridge{
551 Handler: mux,
552 protVer: prot.PvV4,
553 }
554
555 go func() {
556 if err := b.ListenAndServe(lc.SRead(), lc.SWrite()); err != nil {
557 t.Error(err)
558 }
559 }()
560 defer func() {
561 b.quitChan <- true
562 }()
563
564 if err := serverSend(lc.CWrite(), prot.ComputeSystemResizeConsoleV1, prot.SequenceID(1), message); err != nil {
565 t.Error("Failed to send message to server")
566 return
567 }
568 header, body, err := serverRead(lc.CRead())
569 if err != nil {
570 t.Error("Failed to read message response from server")
571 return
572 }
573 response := &prot.MessageResponseBase{}
574 if err := json.Unmarshal(body, response); err != nil {
575 t.Error("Failed to unmarshal response body from server")
576 return
577 }
578
579 if header.Type != prot.ComputeSystemResponseResizeConsoleV1 {
580 t.Error("response header was not resize console response.")
581 }
582 if header.ID != prot.SequenceID(1) {
583 t.Error("response header had wrong sequence id")
584 }
585 if response.ActivityID != message.ActivityID {
586 t.Error("response body did not have same activity id")
587 }
588 if response.Result != 1 {
589 t.Error("response result was not 1 as expected")
590 }
591 }
592
593 func Test_Bridge_ListenAndServe_HandlersAreAsync_Success(t *testing.T) {
594
595 logrus.SetOutput(io.Discard)
596
597 lc := newLoopbackConnection()
598 defer lc.close()
599
600 mux := NewBridgeMux()
601
602 orderWg := sync.WaitGroup{}
603 orderWg.Add(1)
604
605 firstFn := func(r *Request) (RequestResponse, error) {
606
607 orderWg.Wait()
608 return &prot.MessageResponseBase{
609 Result: 1,
610 }, nil
611 }
612 secondFn := func(r *Request) (RequestResponse, error) {
613 defer orderWg.Done()
614 return &prot.MessageResponseBase{
615 Result: 10,
616 }, nil
617 }
618 mux.HandleFunc(prot.ComputeSystemResizeConsoleV1, prot.PvV4, firstFn)
619 mux.HandleFunc(prot.ComputeSystemModifySettingsV1, prot.PvV4, secondFn)
620
621 b := &Bridge{
622 Handler: mux,
623 protVer: prot.PvV4,
624 }
625
626 go func() {
627 if err := b.ListenAndServe(lc.SRead(), lc.SWrite()); err != nil {
628 t.Error(err)
629 }
630 }()
631 defer func() {
632 b.quitChan <- true
633 }()
634
635 if err := serverSend(lc.CWrite(), prot.ComputeSystemResizeConsoleV1, prot.SequenceID(0), nil); err != nil {
636 t.Error("Failed to send first message to server")
637 return
638 }
639 if err := serverSend(lc.CWrite(), prot.ComputeSystemModifySettingsV1, prot.SequenceID(1), nil); err != nil {
640 t.Error("Failed to send second message to server")
641 return
642 }
643
644 headerFirst, _, errFirst := serverRead(lc.CRead())
645 if errFirst != nil {
646 t.Error("Failed to read first response from server")
647 return
648 }
649 headerSecond, _, errSecond := serverRead(lc.CRead())
650 if errSecond != nil {
651 t.Error("Failed to read first response from server")
652 return
653 }
654
655 if headerFirst.Type != prot.ComputeSystemResponseModifySettingsV1 {
656 t.Error("Incorrect response type for 2nd request")
657 }
658 if headerFirst.ID != prot.SequenceID(1) {
659 t.Error("Incorrect response order for 2nd request")
660 }
661
662 if headerSecond.Type != prot.ComputeSystemResponseResizeConsoleV1 {
663 t.Error("Incorrect response for 1st request")
664 }
665 if headerSecond.ID != prot.SequenceID(0) {
666 t.Error("Incorrect response order for 1st request")
667 }
668 }
669
View as plain text