1
18
19 package transport
20
21 import (
22 "bytes"
23 "context"
24 "encoding/binary"
25 "errors"
26 "fmt"
27 "io"
28 "math"
29 "net"
30 "os"
31 "runtime"
32 "strconv"
33 "strings"
34 "sync"
35 "testing"
36 "time"
37
38 "github.com/google/go-cmp/cmp"
39 "golang.org/x/net/http2"
40 "golang.org/x/net/http2/hpack"
41 "google.golang.org/grpc/attributes"
42 "google.golang.org/grpc/codes"
43 "google.golang.org/grpc/credentials"
44 "google.golang.org/grpc/internal/channelz"
45 "google.golang.org/grpc/internal/grpctest"
46 "google.golang.org/grpc/internal/leakcheck"
47 "google.golang.org/grpc/internal/testutils"
48 "google.golang.org/grpc/metadata"
49 "google.golang.org/grpc/resolver"
50 "google.golang.org/grpc/status"
51 )
52
53 type s struct {
54 grpctest.Tester
55 }
56
57 func Test(t *testing.T) {
58 grpctest.RunSubTests(t, s{})
59 }
60
61 var (
62 expectedRequest = []byte("ping")
63 expectedResponse = []byte("pong")
64 expectedRequestLarge = make([]byte, initialWindowSize*2)
65 expectedResponseLarge = make([]byte, initialWindowSize*2)
66 expectedInvalidHeaderField = "invalid/content-type"
67 )
68
69 func init() {
70 expectedRequestLarge[0] = 'g'
71 expectedRequestLarge[len(expectedRequestLarge)-1] = 'r'
72 expectedResponseLarge[0] = 'p'
73 expectedResponseLarge[len(expectedResponseLarge)-1] = 'c'
74 }
75
76 type testStreamHandler struct {
77 t *http2Server
78 notify chan struct{}
79 getNotified chan struct{}
80 }
81
82 type hType int
83
84 const (
85 normal hType = iota
86 suspended
87 notifyCall
88 misbehaved
89 encodingRequiredStatus
90 invalidHeaderField
91 delayRead
92 pingpong
93 )
94
95 func (h *testStreamHandler) handleStreamAndNotify(s *Stream) {
96 if h.notify == nil {
97 return
98 }
99 go func() {
100 select {
101 case <-h.notify:
102 default:
103 close(h.notify)
104 }
105 }()
106 }
107
108 func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
109 req := expectedRequest
110 resp := expectedResponse
111 if s.Method() == "foo.Large" {
112 req = expectedRequestLarge
113 resp = expectedResponseLarge
114 }
115 p := make([]byte, len(req))
116 _, err := s.Read(p)
117 if err != nil {
118 return
119 }
120 if !bytes.Equal(p, req) {
121 t.Errorf("handleStream got %v, want %v", p, req)
122 h.t.WriteStatus(s, status.New(codes.Internal, "panic"))
123 return
124 }
125
126 h.t.Write(s, nil, resp, &Options{})
127
128 h.t.WriteStatus(s, status.New(codes.OK, ""))
129 }
130
131 func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
132 header := make([]byte, 5)
133 for {
134 if _, err := s.Read(header); err != nil {
135 if err == io.EOF {
136 h.t.WriteStatus(s, status.New(codes.OK, ""))
137 return
138 }
139 t.Errorf("Error on server while reading data header: %v", err)
140 h.t.WriteStatus(s, status.New(codes.Internal, "panic"))
141 return
142 }
143 sz := binary.BigEndian.Uint32(header[1:])
144 msg := make([]byte, int(sz))
145 if _, err := s.Read(msg); err != nil {
146 t.Errorf("Error on server while reading message: %v", err)
147 h.t.WriteStatus(s, status.New(codes.Internal, "panic"))
148 return
149 }
150 buf := make([]byte, sz+5)
151 buf[0] = byte(0)
152 binary.BigEndian.PutUint32(buf[1:], uint32(sz))
153 copy(buf[5:], msg)
154 h.t.Write(s, nil, buf, &Options{})
155 }
156 }
157
158 func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
159 conn, ok := s.st.(*http2Server)
160 if !ok {
161 t.Errorf("Failed to convert %v to *http2Server", s.st)
162 h.t.WriteStatus(s, status.New(codes.Internal, ""))
163 return
164 }
165 var sent int
166 p := make([]byte, http2MaxFrameLen)
167 for sent < initialWindowSize {
168 n := initialWindowSize - sent
169
170 if n <= http2MaxFrameLen {
171 if s.Method() == "foo.Connection" {
172
173
174 p = make([]byte, n)
175 } else {
176
177 p = make([]byte, n+1)
178 }
179 }
180 conn.controlBuf.put(&dataFrame{
181 streamID: s.id,
182 h: nil,
183 d: p,
184 onEachWrite: func() {},
185 })
186 sent += len(p)
187 }
188 }
189
190 func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *Stream) {
191
192 h.t.WriteStatus(s, encodingTestStatus)
193 }
194
195 func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) {
196 headerFields := []hpack.HeaderField{}
197 headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
198 h.t.controlBuf.put(&headerFrame{
199 streamID: s.id,
200 hf: headerFields,
201 endStream: false,
202 })
203 }
204
205
206
207
208
209 func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
210 req := expectedRequest
211 resp := expectedResponse
212 if s.Method() == "foo.Large" {
213 req = expectedRequestLarge
214 resp = expectedResponseLarge
215 }
216 var (
217 mu sync.Mutex
218 total int
219 )
220 s.wq.replenish = func(n int) {
221 mu.Lock()
222 total += n
223 mu.Unlock()
224 s.wq.realReplenish(n)
225 }
226 getTotal := func() int {
227 mu.Lock()
228 defer mu.Unlock()
229 return total
230 }
231 done := make(chan struct{})
232 defer close(done)
233 go func() {
234 for {
235 select {
236
237 case <-done:
238 return
239 default:
240 }
241 if getTotal() == defaultWindowSize {
242
243
244 close(h.notify)
245 return
246 }
247 runtime.Gosched()
248 }
249 }()
250 p := make([]byte, len(req))
251
252
253
254 timer := time.NewTimer(time.Second * 10)
255 select {
256 case <-h.getNotified:
257 timer.Stop()
258 case <-timer.C:
259 t.Errorf("Server timed-out.")
260 return
261 }
262 _, err := s.Read(p)
263 if err != nil {
264 t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
265 return
266 }
267
268 if !bytes.Equal(p, req) {
269 t.Errorf("handleStream got %v, want %v", p, req)
270 return
271 }
272
273
274
275 if err := h.t.Write(s, nil, resp, &Options{}); err != nil {
276 t.Errorf("server Write got %v, want <nil>", err)
277 return
278 }
279
280
281
282 _, err = s.Read(p)
283 if err != nil {
284 t.Errorf("s.Read(_) = _, %v, want _, nil", err)
285 return
286 }
287
288 if err := h.t.WriteStatus(s, status.New(codes.OK, "")); err != nil {
289 t.Errorf("server WriteStatus got %v, want <nil>", err)
290 return
291 }
292 }
293
294 type server struct {
295 lis net.Listener
296 port string
297 startedErr chan error
298 mu sync.Mutex
299 conns map[ServerTransport]net.Conn
300 h *testStreamHandler
301 ready chan struct{}
302 channelz *channelz.Server
303 }
304
305 func newTestServer() *server {
306 return &server{
307 startedErr: make(chan error, 1),
308 ready: make(chan struct{}),
309 channelz: channelz.RegisterServer("test server"),
310 }
311 }
312
313
314 func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) {
315 var err error
316 if port == 0 {
317 s.lis, err = net.Listen("tcp", "localhost:0")
318 } else {
319 s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
320 }
321 if err != nil {
322 s.startedErr <- fmt.Errorf("failed to listen: %v", err)
323 return
324 }
325 _, p, err := net.SplitHostPort(s.lis.Addr().String())
326 if err != nil {
327 s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err)
328 return
329 }
330 s.port = p
331 s.conns = make(map[ServerTransport]net.Conn)
332 s.startedErr <- nil
333 for {
334 conn, err := s.lis.Accept()
335 if err != nil {
336 return
337 }
338 rawConn := conn
339 if serverConfig.MaxStreams == 0 {
340 serverConfig.MaxStreams = math.MaxUint32
341 }
342 transport, err := NewServerTransport(conn, serverConfig)
343 if err != nil {
344 return
345 }
346 s.mu.Lock()
347 if s.conns == nil {
348 s.mu.Unlock()
349 transport.Close(errors.New("s.conns is nil"))
350 return
351 }
352 s.conns[transport] = rawConn
353 h := &testStreamHandler{t: transport.(*http2Server)}
354 s.h = h
355 s.mu.Unlock()
356 switch ht {
357 case notifyCall:
358 go transport.HandleStreams(context.Background(), h.handleStreamAndNotify)
359 case suspended:
360 go transport.HandleStreams(context.Background(), func(*Stream) {})
361 case misbehaved:
362 go transport.HandleStreams(context.Background(), func(s *Stream) {
363 go h.handleStreamMisbehave(t, s)
364 })
365 case encodingRequiredStatus:
366 go transport.HandleStreams(context.Background(), func(s *Stream) {
367 go h.handleStreamEncodingRequiredStatus(s)
368 })
369 case invalidHeaderField:
370 go transport.HandleStreams(context.Background(), func(s *Stream) {
371 go h.handleStreamInvalidHeaderField(s)
372 })
373 case delayRead:
374 h.notify = make(chan struct{})
375 h.getNotified = make(chan struct{})
376 s.mu.Lock()
377 close(s.ready)
378 s.mu.Unlock()
379 go transport.HandleStreams(context.Background(), func(s *Stream) {
380 go h.handleStreamDelayRead(t, s)
381 })
382 case pingpong:
383 go transport.HandleStreams(context.Background(), func(s *Stream) {
384 go h.handleStreamPingPong(t, s)
385 })
386 default:
387 go transport.HandleStreams(context.Background(), func(s *Stream) {
388 go h.handleStream(t, s)
389 })
390 }
391 }
392 }
393
394 func (s *server) wait(t *testing.T, timeout time.Duration) {
395 select {
396 case err := <-s.startedErr:
397 if err != nil {
398 t.Fatal(err)
399 }
400 case <-time.After(timeout):
401 t.Fatalf("Timed out after %v waiting for server to be ready", timeout)
402 }
403 }
404
405 func (s *server) stop() {
406 s.lis.Close()
407 s.mu.Lock()
408 for c := range s.conns {
409 c.Close(errors.New("server Stop called"))
410 }
411 s.conns = nil
412 s.mu.Unlock()
413 }
414
415 func (s *server) addr() string {
416 if s.lis == nil {
417 return ""
418 }
419 return s.lis.Addr().String()
420 }
421
422 func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server {
423 server := newTestServer()
424 sc.ChannelzParent = server.channelz
425 go server.start(t, port, sc, ht)
426 server.wait(t, 2*time.Second)
427 return server
428 }
429
430 func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
431 return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
432 }
433
434 func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
435 server := setUpServerOnly(t, port, sc, ht)
436 addr := resolver.Address{Addr: "localhost:" + server.port}
437 copts.ChannelzParent = channelzSubChannel(t)
438
439 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
440 ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {})
441 if connErr != nil {
442 cancel()
443 t.Fatalf("failed to create transport: %v", connErr)
444 }
445 return server, ct.(*http2Client), cancel
446 }
447
448 func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) {
449 lis, err := net.Listen("tcp", "localhost:0")
450 if err != nil {
451 t.Fatalf("Failed to listen: %v", err)
452 }
453
454 go func() {
455 defer lis.Close()
456 conn, err := lis.Accept()
457 if err != nil {
458 t.Errorf("Error at server-side while accepting: %v", err)
459 close(connCh)
460 return
461 }
462 framer := http2.NewFramer(conn, conn)
463 if err := framer.WriteSettings(); err != nil {
464 t.Errorf("Error at server-side while writing settings: %v", err)
465 close(connCh)
466 return
467 }
468 connCh <- conn
469 }()
470 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
471 tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
472 if err != nil {
473 cancel()
474
475 lis.Close()
476 if conn, ok := <-connCh; ok {
477 conn.Close()
478 }
479 t.Fatalf("Failed to dial: %v", err)
480 }
481 return tr.(*http2Client), cancel
482 }
483
484
485
486 func (s) TestInflightStreamClosing(t *testing.T) {
487 serverConfig := &ServerConfig{}
488 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
489 defer cancel()
490 defer server.stop()
491 defer client.Close(fmt.Errorf("closed manually by test"))
492
493 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
494 defer cancel()
495 stream, err := client.NewStream(ctx, &CallHdr{})
496 if err != nil {
497 t.Fatalf("Client failed to create RPC request: %v", err)
498 }
499
500 donec := make(chan struct{})
501 serr := status.Error(codes.Internal, "client connection is closing")
502 go func() {
503 defer close(donec)
504 if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr {
505 t.Errorf("unexpected Stream error %v, expected %v", err, serr)
506 }
507 }()
508
509
510 client.CloseStream(stream, serr)
511
512
513 timeout := time.NewTimer(5 * time.Second)
514 select {
515 case <-donec:
516 if !timeout.Stop() {
517 <-timeout.C
518 }
519 case <-timeout.C:
520 t.Fatalf("Test timed out, expected a status error.")
521 }
522 }
523
524
525 func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
526 server, ct, cancel := setUp(t, 0, normal)
527 defer cancel()
528 defer server.stop()
529 callHdr := &CallHdr{
530 Host: "localhost",
531 Method: "foo.Small",
532 }
533
534 originalMaxStreamID := MaxStreamID
535 MaxStreamID = 3
536 defer func() {
537 MaxStreamID = originalMaxStreamID
538 }()
539
540 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
541 defer ctxCancel()
542
543 s, err := ct.NewStream(ctx, callHdr)
544 if err != nil {
545 t.Fatalf("ct.NewStream() = %v", err)
546 }
547 if s.id != 1 {
548 t.Fatalf("Stream id: %d, want: 1", s.id)
549 }
550
551 if got, want := ct.stateForTesting(), reachable; got != want {
552 t.Fatalf("Client transport state %v, want %v", got, want)
553 }
554
555
556 s, err = ct.NewStream(ctx, callHdr)
557 if err != nil {
558 t.Fatalf("ct.NewStream() = %v", err)
559 }
560 if s.id != 3 {
561 t.Fatalf("Stream id: %d, want: 3", s.id)
562 }
563
564
565 if got, want := ct.stateForTesting(), draining; got != want {
566 t.Fatalf("Client transport state %v, want %v", got, want)
567 }
568 }
569
570 func (s) TestClientSendAndReceive(t *testing.T) {
571 server, ct, cancel := setUp(t, 0, normal)
572 defer cancel()
573 callHdr := &CallHdr{
574 Host: "localhost",
575 Method: "foo.Small",
576 }
577 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
578 defer ctxCancel()
579 s1, err1 := ct.NewStream(ctx, callHdr)
580 if err1 != nil {
581 t.Fatalf("failed to open stream: %v", err1)
582 }
583 if s1.id != 1 {
584 t.Fatalf("wrong stream id: %d", s1.id)
585 }
586 s2, err2 := ct.NewStream(ctx, callHdr)
587 if err2 != nil {
588 t.Fatalf("failed to open stream: %v", err2)
589 }
590 if s2.id != 3 {
591 t.Fatalf("wrong stream id: %d", s2.id)
592 }
593 opts := Options{Last: true}
594 if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF {
595 t.Fatalf("failed to send data: %v", err)
596 }
597 p := make([]byte, len(expectedResponse))
598 _, recvErr := s1.Read(p)
599 if recvErr != nil || !bytes.Equal(p, expectedResponse) {
600 t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
601 }
602 _, recvErr = s1.Read(p)
603 if recvErr != io.EOF {
604 t.Fatalf("Error: %v; want <EOF>", recvErr)
605 }
606 ct.Close(fmt.Errorf("closed manually by test"))
607 server.stop()
608 }
609
610 func (s) TestClientErrorNotify(t *testing.T) {
611 server, ct, cancel := setUp(t, 0, normal)
612 defer cancel()
613 go server.stop()
614
615 <-ct.Error()
616 ct.Close(fmt.Errorf("closed manually by test"))
617 }
618
619 func performOneRPC(ct ClientTransport) {
620 callHdr := &CallHdr{
621 Host: "localhost",
622 Method: "foo.Small",
623 }
624 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
625 defer cancel()
626 s, err := ct.NewStream(ctx, callHdr)
627 if err != nil {
628 return
629 }
630 opts := Options{Last: true}
631 if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF {
632 time.Sleep(5 * time.Millisecond)
633
634
635
636
637 p := make([]byte, len(expectedResponse))
638 s.Read(p)
639
640 s.Read(p)
641 }
642 }
643
644 func (s) TestClientMix(t *testing.T) {
645 s, ct, cancel := setUp(t, 0, normal)
646 defer cancel()
647 time.AfterFunc(time.Second, s.stop)
648 go func(ct ClientTransport) {
649 <-ct.Error()
650 ct.Close(fmt.Errorf("closed manually by test"))
651 }(ct)
652 for i := 0; i < 750; i++ {
653 time.Sleep(2 * time.Millisecond)
654 go performOneRPC(ct)
655 }
656 }
657
658 func (s) TestLargeMessage(t *testing.T) {
659 server, ct, cancel := setUp(t, 0, normal)
660 defer cancel()
661 callHdr := &CallHdr{
662 Host: "localhost",
663 Method: "foo.Large",
664 }
665 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
666 defer ctxCancel()
667 var wg sync.WaitGroup
668 for i := 0; i < 2; i++ {
669 wg.Add(1)
670 go func() {
671 defer wg.Done()
672 s, err := ct.NewStream(ctx, callHdr)
673 if err != nil {
674 t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
675 }
676 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil && err != io.EOF {
677 t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
678 }
679 p := make([]byte, len(expectedResponseLarge))
680 if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
681 t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
682 }
683 if _, err = s.Read(p); err != io.EOF {
684 t.Errorf("Failed to complete the stream %v; want <EOF>", err)
685 }
686 }()
687 }
688 wg.Wait()
689 ct.Close(fmt.Errorf("closed manually by test"))
690 server.stop()
691 }
692
693 func (s) TestLargeMessageWithDelayRead(t *testing.T) {
694
695 sc := &ServerConfig{
696 InitialWindowSize: defaultWindowSize,
697 InitialConnWindowSize: defaultWindowSize,
698 }
699 co := ConnectOptions{
700 InitialWindowSize: defaultWindowSize,
701 InitialConnWindowSize: defaultWindowSize,
702 }
703 server, ct, cancel := setUpWithOptions(t, 0, sc, delayRead, co)
704 defer cancel()
705 defer server.stop()
706 defer ct.Close(fmt.Errorf("closed manually by test"))
707 server.mu.Lock()
708 ready := server.ready
709 server.mu.Unlock()
710 callHdr := &CallHdr{
711 Host: "localhost",
712 Method: "foo.Large",
713 }
714 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10))
715 defer cancel()
716 s, err := ct.NewStream(ctx, callHdr)
717 if err != nil {
718 t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
719 return
720 }
721
722 select {
723 case <-ready:
724 case <-ctx.Done():
725 t.Fatalf("Client timed out waiting for server handler to be initialized.")
726 }
727 server.mu.Lock()
728 serviceHandler := server.h
729 server.mu.Unlock()
730 var (
731 mu sync.Mutex
732 total int
733 )
734 s.wq.replenish = func(n int) {
735 mu.Lock()
736 total += n
737 mu.Unlock()
738 s.wq.realReplenish(n)
739 }
740 getTotal := func() int {
741 mu.Lock()
742 defer mu.Unlock()
743 return total
744 }
745 done := make(chan struct{})
746 defer close(done)
747 go func() {
748 for {
749 select {
750
751 case <-done:
752 return
753 default:
754 }
755 if getTotal() == defaultWindowSize {
756
757
758 close(serviceHandler.getNotified)
759 return
760 }
761 runtime.Gosched()
762 }
763 }()
764
765
766
767 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{}); err != nil {
768 t.Fatalf("write(_, _, _) = %v, want <nil>", err)
769 }
770 p := make([]byte, len(expectedResponseLarge))
771
772
773
774 select {
775 case <-serviceHandler.notify:
776 case <-ctx.Done():
777 t.Fatalf("Client timed out")
778 }
779 if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
780 t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
781 }
782 if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil {
783 t.Fatalf("Write(_, _, _) = %v, want <nil>", err)
784 }
785 if _, err = s.Read(p); err != io.EOF {
786 t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
787 }
788 }
789
790
791
792
793 func (s) TestGracefulClose(t *testing.T) {
794 server, ct, cancel := setUp(t, 0, pingpong)
795 defer cancel()
796 defer func() {
797
798
799 server.lis.Close()
800
801
802 leakcheck.Check(t)
803
804 server.stop()
805 }()
806 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10))
807 defer cancel()
808
809
810
811 s, err := ct.NewStream(ctx, &CallHdr{})
812 if err != nil {
813 t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err)
814 }
815 msg := make([]byte, 1024)
816 outgoingHeader := make([]byte, 5)
817 outgoingHeader[0] = byte(0)
818 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg)))
819 incomingHeader := make([]byte, 5)
820 if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil {
821 t.Fatalf("Error while writing: %v", err)
822 }
823 if _, err := s.Read(incomingHeader); err != nil {
824 t.Fatalf("Error while reading: %v", err)
825 }
826 sz := binary.BigEndian.Uint32(incomingHeader[1:])
827 recvMsg := make([]byte, int(sz))
828 if _, err := s.Read(recvMsg); err != nil {
829 t.Fatalf("Error while reading: %v", err)
830 }
831
832
833
834 ct.GracefulClose()
835
836 var wg sync.WaitGroup
837
838
839 for i := 0; i < 200; i++ {
840 wg.Add(1)
841 go func() {
842 defer wg.Done()
843 _, err := ct.NewStream(ctx, &CallHdr{})
844 if err != nil && err.(*NewStreamError).Err == ErrConnClosing && err.(*NewStreamError).AllowTransparentRetry {
845 return
846 }
847 t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing)
848 }()
849 }
850
851
852 ct.Write(s, nil, nil, &Options{Last: true})
853 if _, err := s.Read(incomingHeader); err != io.EOF {
854 t.Fatalf("Client expected EOF from the server. Got: %v", err)
855 }
856 wg.Wait()
857 }
858
859 func (s) TestLargeMessageSuspension(t *testing.T) {
860 server, ct, cancel := setUp(t, 0, suspended)
861 defer cancel()
862 callHdr := &CallHdr{
863 Host: "localhost",
864 Method: "foo.Large",
865 }
866
867 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
868 defer cancel()
869 s, err := ct.NewStream(ctx, callHdr)
870 if err != nil {
871 t.Fatalf("failed to open stream: %v", err)
872 }
873
874
875 go func() {
876 <-ctx.Done()
877 ct.CloseStream(s, ContextErr(ctx.Err()))
878 }()
879
880 msg := make([]byte, initialWindowSize*8)
881 ct.Write(s, nil, msg, &Options{})
882 err = ct.Write(s, nil, msg, &Options{Last: true})
883 if err != errStreamDone {
884 t.Fatalf("Write got %v, want io.EOF", err)
885 }
886 expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
887 if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() {
888 t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr)
889 }
890 ct.Close(fmt.Errorf("closed manually by test"))
891 server.stop()
892 }
893
894 func (s) TestMaxStreams(t *testing.T) {
895 serverConfig := &ServerConfig{
896 MaxStreams: 1,
897 }
898 server, ct, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
899 defer cancel()
900 defer ct.Close(fmt.Errorf("closed manually by test"))
901 defer server.stop()
902 callHdr := &CallHdr{
903 Host: "localhost",
904 Method: "foo.Large",
905 }
906 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
907 defer cancel()
908 s, err := ct.NewStream(ctx, callHdr)
909 if err != nil {
910 t.Fatalf("Failed to open stream: %v", err)
911 }
912
913
914 slist := []*Stream{}
915 pctx, cancel := context.WithCancel(context.Background())
916 defer cancel()
917 timer := time.NewTimer(time.Second * 10)
918 expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
919 for {
920 select {
921 case <-timer.C:
922 t.Fatalf("Test timeout: client didn't receive server settings.")
923 default:
924 }
925 ctx, cancel := context.WithDeadline(pctx, time.Now().Add(time.Second))
926
927
928 defer cancel()
929 if str, err := ct.NewStream(ctx, callHdr); err == nil {
930 slist = append(slist, str)
931 continue
932 } else if err.Error() != expectedErr.Error() {
933 t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr)
934 }
935 timer.Stop()
936 break
937 }
938 done := make(chan struct{})
939
940 go func() {
941 defer close(done)
942 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10))
943 defer cancel()
944 if _, err := ct.NewStream(ctx, callHdr); err != nil {
945 t.Errorf("Failed to open stream: %v", err)
946 }
947 }()
948
949 for _, str := range slist {
950 ct.CloseStream(str, nil)
951 }
952 select {
953 case <-done:
954 t.Fatalf("Test failed: didn't expect new stream to be created just yet.")
955 default:
956 }
957
958 ct.CloseStream(s, nil)
959 <-done
960 ct.Close(fmt.Errorf("closed manually by test"))
961 <-ct.writerDone
962 if ct.maxConcurrentStreams != 1 {
963 t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams)
964 }
965 }
966
967 func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
968 server, ct, cancel := setUp(t, 0, suspended)
969 defer cancel()
970 callHdr := &CallHdr{
971 Host: "localhost",
972 Method: "foo",
973 }
974 var sc *http2Server
975
976 for {
977 server.mu.Lock()
978 if len(server.conns) == 0 {
979 server.mu.Unlock()
980 time.Sleep(time.Millisecond)
981 continue
982 }
983 for k := range server.conns {
984 var ok bool
985 sc, ok = k.(*http2Server)
986 if !ok {
987 t.Fatalf("Failed to convert %v to *http2Server", k)
988 }
989 }
990 server.mu.Unlock()
991 break
992 }
993 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
994 defer cancel()
995 s, err := ct.NewStream(ctx, callHdr)
996 if err != nil {
997 t.Fatalf("Failed to open stream: %v", err)
998 }
999 ct.controlBuf.put(&dataFrame{
1000 streamID: s.id,
1001 endStream: false,
1002 h: nil,
1003 d: make([]byte, http2MaxFrameLen),
1004 onEachWrite: func() {},
1005 })
1006
1007 var ss *Stream
1008 for {
1009 time.Sleep(time.Second)
1010 sc.mu.Lock()
1011 if len(sc.activeStreams) == 0 {
1012 sc.mu.Unlock()
1013 continue
1014 }
1015 ss = sc.activeStreams[s.id]
1016 sc.mu.Unlock()
1017 break
1018 }
1019 ct.Close(fmt.Errorf("closed manually by test"))
1020 select {
1021 case <-ss.Context().Done():
1022 if ss.Context().Err() != context.Canceled {
1023 t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled)
1024 }
1025 case <-time.After(5 * time.Second):
1026 t.Fatalf("Failed to cancel the context of the sever side stream.")
1027 }
1028 server.stop()
1029 }
1030
1031 func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) {
1032 connectOptions := ConnectOptions{
1033 InitialWindowSize: defaultWindowSize,
1034 InitialConnWindowSize: defaultWindowSize,
1035 }
1036 server, client, cancel := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions)
1037 defer cancel()
1038 defer server.stop()
1039 defer client.Close(fmt.Errorf("closed manually by test"))
1040
1041 waitWhileTrue(t, func() (bool, error) {
1042 server.mu.Lock()
1043 defer server.mu.Unlock()
1044
1045 if len(server.conns) == 0 {
1046 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
1047 }
1048 return false, nil
1049 })
1050
1051 var st *http2Server
1052 server.mu.Lock()
1053 for k := range server.conns {
1054 st = k.(*http2Server)
1055 }
1056 notifyChan := make(chan struct{})
1057 server.h.notify = notifyChan
1058 server.mu.Unlock()
1059 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1060 defer cancel()
1061 cstream1, err := client.NewStream(ctx, &CallHdr{})
1062 if err != nil {
1063 t.Fatalf("Client failed to create first stream. Err: %v", err)
1064 }
1065
1066 <-notifyChan
1067 var sstream1 *Stream
1068
1069 st.mu.Lock()
1070 for _, v := range st.activeStreams {
1071 if v.id == cstream1.id {
1072 sstream1 = v
1073 }
1074 }
1075 st.mu.Unlock()
1076 if sstream1 == nil {
1077 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
1078 }
1079
1080 if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
1081 t.Fatalf("Server failed to write data. Err: %v", err)
1082 }
1083 notifyChan = make(chan struct{})
1084 server.mu.Lock()
1085 server.h.notify = notifyChan
1086 server.mu.Unlock()
1087
1088 cstream2, err := client.NewStream(ctx, &CallHdr{})
1089 if err != nil {
1090 t.Fatalf("Client failed to create second stream. Err: %v", err)
1091 }
1092 <-notifyChan
1093 var sstream2 *Stream
1094 st.mu.Lock()
1095 for _, v := range st.activeStreams {
1096 if v.id == cstream2.id {
1097 sstream2 = v
1098 }
1099 }
1100 st.mu.Unlock()
1101 if sstream2 == nil {
1102 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
1103 }
1104
1105 if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
1106 t.Fatalf("Server failed to write data. Err: %v", err)
1107 }
1108
1109
1110 if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil {
1111 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
1112 }
1113
1114
1115 if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil {
1116 t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
1117 }
1118 }
1119
1120 func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) {
1121 serverConfig := &ServerConfig{
1122 InitialWindowSize: defaultWindowSize,
1123 InitialConnWindowSize: defaultWindowSize,
1124 }
1125 server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
1126 defer cancel()
1127 defer server.stop()
1128 defer client.Close(fmt.Errorf("closed manually by test"))
1129 waitWhileTrue(t, func() (bool, error) {
1130 server.mu.Lock()
1131 defer server.mu.Unlock()
1132
1133 if len(server.conns) == 0 {
1134 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
1135 }
1136 return false, nil
1137 })
1138 var st *http2Server
1139 server.mu.Lock()
1140 for k := range server.conns {
1141 st = k.(*http2Server)
1142 }
1143 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1144 defer cancel()
1145 server.mu.Unlock()
1146 cstream1, err := client.NewStream(ctx, &CallHdr{})
1147 if err != nil {
1148 t.Fatalf("Failed to create 1st stream. Err: %v", err)
1149 }
1150
1151 if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
1152 t.Fatalf("Client failed to write data. Err: %v", err)
1153 }
1154
1155 cstream2, err := client.NewStream(ctx, &CallHdr{})
1156 if err != nil {
1157 t.Fatalf("Failed to create 2nd stream. Err: %v", err)
1158 }
1159 if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil {
1160 t.Fatalf("Client failed to write data. Err: %v", err)
1161 }
1162
1163 waitWhileTrue(t, func() (bool, error) {
1164 st.mu.Lock()
1165 defer st.mu.Unlock()
1166
1167 if len(st.activeStreams) != 2 {
1168 return true, fmt.Errorf("timed-out while waiting for server to have created the streams")
1169 }
1170 return false, nil
1171 })
1172 var sstream1 *Stream
1173 st.mu.Lock()
1174 for _, v := range st.activeStreams {
1175 if v.id == 1 {
1176 sstream1 = v
1177 }
1178 }
1179 st.mu.Unlock()
1180
1181 if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil {
1182 t.Fatalf("_.Read(_) = %v, want <nil>", err)
1183 }
1184
1185 if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF {
1186 t.Fatalf("_.Read(_) = %v, want io.EOF", err)
1187 }
1188
1189 }
1190
1191 func (s) TestServerWithMisbehavedClient(t *testing.T) {
1192 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended)
1193 defer server.stop()
1194
1195 mconn, err := net.Dial("tcp", server.lis.Addr().String())
1196 if err != nil {
1197 t.Fatalf("Clent failed to dial:%v", err)
1198 }
1199 defer mconn.Close()
1200 if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil {
1201 t.Fatalf("Failed to set write deadline: %v", err)
1202 }
1203 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
1204 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
1205 }
1206
1207 success := make(chan struct{})
1208 var mu sync.Mutex
1209 framer := http2.NewFramer(mconn, mconn)
1210 if err := framer.WriteSettings(); err != nil {
1211 t.Fatalf("Error while writing settings: %v", err)
1212 }
1213 go func() {
1214 for {
1215 frame, err := framer.ReadFrame()
1216 if err != nil {
1217 return
1218 }
1219 switch frame := frame.(type) {
1220 case *http2.PingFrame:
1221
1222 mu.Lock()
1223 framer.WritePing(true, frame.Data)
1224 mu.Unlock()
1225 case *http2.RSTStreamFrame:
1226 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl {
1227 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))
1228 }
1229 close(success)
1230 return
1231 default:
1232
1233 }
1234
1235 }
1236 }()
1237
1238 var buf bytes.Buffer
1239 henc := hpack.NewEncoder(&buf)
1240
1241 if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil {
1242 t.Fatalf("Error while encoding header: %v", err)
1243 }
1244 if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil {
1245 t.Fatalf("Error while encoding header: %v", err)
1246 }
1247 if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil {
1248 t.Fatalf("Error while encoding header: %v", err)
1249 }
1250 if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil {
1251 t.Fatalf("Error while encoding header: %v", err)
1252 }
1253 mu.Lock()
1254 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
1255 mu.Unlock()
1256 t.Fatalf("Error while writing headers: %v", err)
1257 }
1258 mu.Unlock()
1259
1260
1261 timer := time.NewTimer(time.Second * 5)
1262 dbuf := make([]byte, http2MaxFrameLen)
1263 for {
1264 select {
1265 case <-timer.C:
1266 t.Fatalf("Test timed out.")
1267 case <-success:
1268 return
1269 default:
1270 }
1271 mu.Lock()
1272 if err := framer.WriteData(1, false, dbuf); err != nil {
1273 mu.Unlock()
1274
1275
1276 select {
1277 case <-timer.C:
1278 t.Fatalf("Error while writing data: %v", err)
1279 case <-success:
1280 return
1281 }
1282 }
1283 mu.Unlock()
1284
1285
1286
1287 runtime.Gosched()
1288 }
1289 }
1290
1291 func (s) TestClientHonorsConnectContext(t *testing.T) {
1292
1293 lis, err := net.Listen("tcp", "localhost:0")
1294 if err != nil {
1295 t.Fatalf("Error while listening: %v", err)
1296 }
1297 defer lis.Close()
1298 go func() {
1299 sconn, err := lis.Accept()
1300 if err != nil {
1301 t.Errorf("Error while accepting: %v", err)
1302 return
1303 }
1304 defer sconn.Close()
1305 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
1306 t.Errorf("Error while reading client preface: %v", err)
1307 return
1308 }
1309 sfr := http2.NewFramer(sconn, sconn)
1310
1311 for {
1312 if _, err := sfr.ReadFrame(); err != nil {
1313 return
1314 }
1315 }
1316 }()
1317
1318
1319 timeBefore := time.Now()
1320 connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1321 time.AfterFunc(100*time.Millisecond, cancel)
1322
1323 parent := channelzSubChannel(t)
1324 copts := ConnectOptions{ChannelzParent: parent}
1325 _, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
1326 if err == nil {
1327 t.Fatalf("NewClientTransport() returned successfully; wanted error")
1328 }
1329 t.Logf("NewClientTransport() = _, %v", err)
1330 if time.Since(timeBefore) > 3*time.Second {
1331 t.Fatalf("NewClientTransport returned > 2.9s after context cancelation")
1332 }
1333
1334
1335 connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
1336 defer cancel()
1337 _, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
1338 if err == nil {
1339 t.Fatalf("NewClientTransport() returned successfully; wanted error")
1340 }
1341 t.Logf("NewClientTransport() = _, %v", err)
1342 }
1343
1344 func (s) TestClientWithMisbehavedServer(t *testing.T) {
1345
1346 lis, err := net.Listen("tcp", "localhost:0")
1347 if err != nil {
1348 t.Fatalf("Error while listening: %v", err)
1349 }
1350 defer lis.Close()
1351
1352
1353 success := make(chan struct{})
1354 go func() {
1355 sconn, err := lis.Accept()
1356 if err != nil {
1357 t.Errorf("Error while accepting: %v", err)
1358 return
1359 }
1360 defer sconn.Close()
1361 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
1362 t.Errorf("Error while reading client preface: %v", err)
1363 return
1364 }
1365 sfr := http2.NewFramer(sconn, sconn)
1366 if err := sfr.WriteSettings(); err != nil {
1367 t.Errorf("Error while writing settings: %v", err)
1368 return
1369 }
1370 if err := sfr.WriteSettingsAck(); err != nil {
1371 t.Errorf("Error while writing settings: %v", err)
1372 return
1373 }
1374 var mu sync.Mutex
1375 for {
1376 frame, err := sfr.ReadFrame()
1377 if err != nil {
1378 return
1379 }
1380 switch frame := frame.(type) {
1381 case *http2.HeadersFrame:
1382
1383 go func() {
1384 buf := make([]byte, http2MaxFrameLen)
1385 for {
1386 mu.Lock()
1387 if err := sfr.WriteData(1, false, buf); err != nil {
1388 mu.Unlock()
1389 return
1390 }
1391 mu.Unlock()
1392
1393
1394
1395 runtime.Gosched()
1396 }
1397 }()
1398 case *http2.RSTStreamFrame:
1399 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl {
1400 t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))
1401 }
1402 close(success)
1403 return
1404 case *http2.PingFrame:
1405 mu.Lock()
1406 sfr.WritePing(true, frame.Data)
1407 mu.Unlock()
1408 default:
1409 }
1410 }
1411 }()
1412 connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
1413 defer cancel()
1414
1415 parent := channelzSubChannel(t)
1416 copts := ConnectOptions{ChannelzParent: parent}
1417 ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
1418 if err != nil {
1419 t.Fatalf("Error while creating client transport: %v", err)
1420 }
1421 defer ct.Close(fmt.Errorf("closed manually by test"))
1422
1423 str, err := ct.NewStream(connectCtx, &CallHdr{})
1424 if err != nil {
1425 t.Fatalf("Error while creating stream: %v", err)
1426 }
1427 timer := time.NewTimer(time.Second * 5)
1428 go func() {
1429 <-str.Done()
1430 ct.CloseStream(str, nil)
1431 }()
1432 select {
1433 case <-timer.C:
1434 t.Fatalf("Test timed-out.")
1435 case <-success:
1436 }
1437 }
1438
1439 var encodingTestStatus = status.New(codes.Internal, "\n")
1440
1441 func (s) TestEncodingRequiredStatus(t *testing.T) {
1442 server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
1443 defer cancel()
1444 callHdr := &CallHdr{
1445 Host: "localhost",
1446 Method: "foo",
1447 }
1448 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1449 defer cancel()
1450 s, err := ct.NewStream(ctx, callHdr)
1451 if err != nil {
1452 return
1453 }
1454 opts := Options{Last: true}
1455 if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone {
1456 t.Fatalf("Failed to write the request: %v", err)
1457 }
1458 p := make([]byte, http2MaxFrameLen)
1459 if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
1460 t.Fatalf("Read got error %v, want %v", err, io.EOF)
1461 }
1462 if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) {
1463 t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus)
1464 }
1465 ct.Close(fmt.Errorf("closed manually by test"))
1466 server.stop()
1467 }
1468
1469 func (s) TestInvalidHeaderField(t *testing.T) {
1470 server, ct, cancel := setUp(t, 0, invalidHeaderField)
1471 defer cancel()
1472 callHdr := &CallHdr{
1473 Host: "localhost",
1474 Method: "foo",
1475 }
1476 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1477 defer cancel()
1478 s, err := ct.NewStream(ctx, callHdr)
1479 if err != nil {
1480 return
1481 }
1482 p := make([]byte, http2MaxFrameLen)
1483 _, err = s.trReader.(*transportReader).Read(p)
1484 if se, ok := status.FromError(err); !ok || se.Code() != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
1485 t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField)
1486 }
1487 ct.Close(fmt.Errorf("closed manually by test"))
1488 server.stop()
1489 }
1490
1491 func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
1492 server, ct, cancel := setUp(t, 0, invalidHeaderField)
1493 defer cancel()
1494 defer server.stop()
1495 defer ct.Close(fmt.Errorf("closed manually by test"))
1496 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1497 defer cancel()
1498 s, err := ct.NewStream(ctx, &CallHdr{Host: "localhost", Method: "foo"})
1499 if err != nil {
1500 t.Fatalf("failed to create the stream")
1501 }
1502 timer := time.NewTimer(time.Second)
1503 defer timer.Stop()
1504 select {
1505 case <-s.headerChan:
1506 case <-timer.C:
1507 t.Errorf("s.headerChan: got open, want closed")
1508 }
1509 }
1510
1511 func (s) TestIsReservedHeader(t *testing.T) {
1512 tests := []struct {
1513 h string
1514 want bool
1515 }{
1516 {"", false},
1517 {"foo", false},
1518 {"content-type", true},
1519 {"user-agent", true},
1520 {":anything", true},
1521 {"grpc-message-type", true},
1522 {"grpc-encoding", true},
1523 {"grpc-message", true},
1524 {"grpc-status", true},
1525 {"grpc-timeout", true},
1526 {"te", true},
1527 }
1528 for _, tt := range tests {
1529 got := isReservedHeader(tt.h)
1530 if got != tt.want {
1531 t.Errorf("isReservedHeader(%q) = %v; want %v", tt.h, got, tt.want)
1532 }
1533 }
1534 }
1535
1536 func (s) TestContextErr(t *testing.T) {
1537 for _, test := range []struct {
1538
1539 errIn error
1540
1541 errOut error
1542 }{
1543 {context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())},
1544 {context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())},
1545 } {
1546 err := ContextErr(test.errIn)
1547 if err.Error() != test.errOut.Error() {
1548 t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
1549 }
1550 }
1551 }
1552
1553 type windowSizeConfig struct {
1554 serverStream int32
1555 serverConn int32
1556 clientStream int32
1557 clientConn int32
1558 }
1559
1560 func (s) TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) {
1561 wc := windowSizeConfig{
1562 serverStream: 10 * 1024 * 1024,
1563 serverConn: 12 * 1024 * 1024,
1564 clientStream: 6 * 1024 * 1024,
1565 clientConn: 8 * 1024 * 1024,
1566 }
1567 testFlowControlAccountCheck(t, 1024*1024, wc)
1568 }
1569
1570 func (s) TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) {
1571
1572
1573 wc := windowSizeConfig{
1574 serverStream: defaultWindowSize,
1575 serverConn: defaultWindowSize,
1576 clientStream: defaultWindowSize,
1577 clientConn: defaultWindowSize,
1578 }
1579 testFlowControlAccountCheck(t, 1024*1024, wc)
1580 }
1581
1582 func (s) TestAccountCheckDynamicWindowSmallMessage(t *testing.T) {
1583 testFlowControlAccountCheck(t, 1024, windowSizeConfig{})
1584 }
1585
1586 func (s) TestAccountCheckDynamicWindowLargeMessage(t *testing.T) {
1587 testFlowControlAccountCheck(t, 1024*1024, windowSizeConfig{})
1588 }
1589
1590 func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) {
1591 sc := &ServerConfig{
1592 InitialWindowSize: wc.serverStream,
1593 InitialConnWindowSize: wc.serverConn,
1594 }
1595 co := ConnectOptions{
1596 InitialWindowSize: wc.clientStream,
1597 InitialConnWindowSize: wc.clientConn,
1598 }
1599 server, client, cancel := setUpWithOptions(t, 0, sc, pingpong, co)
1600 defer cancel()
1601 defer server.stop()
1602 defer client.Close(fmt.Errorf("closed manually by test"))
1603 waitWhileTrue(t, func() (bool, error) {
1604 server.mu.Lock()
1605 defer server.mu.Unlock()
1606 if len(server.conns) == 0 {
1607 return true, fmt.Errorf("timed out while waiting for server transport to be created")
1608 }
1609 return false, nil
1610 })
1611 var st *http2Server
1612 server.mu.Lock()
1613 for k := range server.conns {
1614 st = k.(*http2Server)
1615 }
1616 server.mu.Unlock()
1617
1618 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1619 defer cancel()
1620 const numStreams = 5
1621 clientStreams := make([]*Stream, numStreams)
1622 for i := 0; i < numStreams; i++ {
1623 var err error
1624 clientStreams[i], err = client.NewStream(ctx, &CallHdr{})
1625 if err != nil {
1626 t.Fatalf("Failed to create stream. Err: %v", err)
1627 }
1628 }
1629 var wg sync.WaitGroup
1630
1631 for _, stream := range clientStreams {
1632 wg.Add(1)
1633 go func(stream *Stream) {
1634 defer wg.Done()
1635 buf := make([]byte, msgSize+5)
1636 buf[0] = byte(0)
1637 binary.BigEndian.PutUint32(buf[1:], uint32(msgSize))
1638 opts := Options{}
1639 header := make([]byte, 5)
1640 for i := 1; i <= 5; i++ {
1641 if err := client.Write(stream, nil, buf, &opts); err != nil {
1642 t.Errorf("Error on client while writing message %v on stream %v: %v", i, stream.id, err)
1643 return
1644 }
1645 if _, err := stream.Read(header); err != nil {
1646 t.Errorf("Error on client while reading data frame header %v on stream %v: %v", i, stream.id, err)
1647 return
1648 }
1649 sz := binary.BigEndian.Uint32(header[1:])
1650 recvMsg := make([]byte, int(sz))
1651 if _, err := stream.Read(recvMsg); err != nil {
1652 t.Errorf("Error on client while reading data %v on stream %v: %v", i, stream.id, err)
1653 return
1654 }
1655 if len(recvMsg) != msgSize {
1656 t.Errorf("Length of message %v received by client on stream %v: %v, want: %v", i, stream.id, len(recvMsg), msgSize)
1657 return
1658 }
1659 }
1660 t.Logf("stream %v done with pingpongs", stream.id)
1661 }(stream)
1662 }
1663 wg.Wait()
1664 serverStreams := map[uint32]*Stream{}
1665 loopyClientStreams := map[uint32]*outStream{}
1666 loopyServerStreams := map[uint32]*outStream{}
1667
1668 st.mu.Lock()
1669 client.mu.Lock()
1670 for _, stream := range clientStreams {
1671 id := stream.id
1672 serverStreams[id] = st.activeStreams[id]
1673 loopyServerStreams[id] = st.loopy.estdStreams[id]
1674 loopyClientStreams[id] = client.loopy.estdStreams[id]
1675
1676 }
1677 client.mu.Unlock()
1678 st.mu.Unlock()
1679
1680 for _, stream := range clientStreams {
1681 client.Write(stream, nil, nil, &Options{Last: true})
1682 if _, err := stream.Read(make([]byte, 5)); err != io.EOF {
1683 t.Fatalf("Client expected an EOF from the server. Got: %v", err)
1684 }
1685 }
1686
1687
1688 client.Close(errors.New("closed manually by test"))
1689 st.Close(errors.New("closed manually by test"))
1690 <-st.readerDone
1691 <-st.loopyWriterDone
1692 <-client.readerDone
1693 <-client.writerDone
1694 for _, cstream := range clientStreams {
1695 id := cstream.id
1696 sstream := serverStreams[id]
1697 loopyServerStream := loopyServerStreams[id]
1698 loopyClientStream := loopyClientStreams[id]
1699 if loopyServerStream == nil {
1700 t.Fatalf("Unexpected nil loopyServerStream")
1701 }
1702
1703 if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding {
1704 t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding)
1705 }
1706 if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(client.loopy.oiws)-loopyClientStream.bytesOutStanding {
1707 t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, client.loopy.oiws, loopyClientStream.bytesOutStanding)
1708 }
1709 }
1710
1711 if client.fc.limit != client.fc.unacked+st.loopy.sendQuota {
1712 t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", client.fc.limit, client.fc.unacked, st.loopy.sendQuota)
1713 }
1714 if st.fc.limit != st.fc.unacked+client.loopy.sendQuota {
1715 t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, client.loopy.sendQuota)
1716 }
1717 }
1718
1719 func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
1720 var (
1721 wait bool
1722 err error
1723 )
1724 timer := time.NewTimer(time.Second * 5)
1725 for {
1726 wait, err = condition()
1727 if wait {
1728 select {
1729 case <-timer.C:
1730 t.Fatalf(err.Error())
1731 default:
1732 time.Sleep(50 * time.Millisecond)
1733 continue
1734 }
1735 }
1736 if !timer.Stop() {
1737 <-timer.C
1738 }
1739 break
1740 }
1741 }
1742
1743
1744
1745 func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
1746 testRecvBuffer := newRecvBuffer()
1747 s := &Stream{
1748 ctx: context.Background(),
1749 buf: testRecvBuffer,
1750 requestRead: func(int) {},
1751 }
1752 s.trReader = &transportReader{
1753 reader: &recvBufferReader{
1754 ctx: s.ctx,
1755 ctxDone: s.ctx.Done(),
1756 recv: s.buf,
1757 freeBuffer: func(*bytes.Buffer) {},
1758 },
1759 windowHandler: func(int) {},
1760 }
1761 testData := make([]byte, 1)
1762 testData[0] = 5
1763 testBuffer := bytes.NewBuffer(testData)
1764 testErr := errors.New("test error")
1765 s.write(recvMsg{buffer: testBuffer, err: testErr})
1766
1767 inBuf := make([]byte, 1)
1768 actualCount, actualErr := s.Read(inBuf)
1769 if actualCount != 0 {
1770 t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
1771 }
1772 if actualErr.Error() != testErr.Error() {
1773 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
1774 }
1775
1776 s.write(recvMsg{buffer: testBuffer, err: nil})
1777 s.write(recvMsg{buffer: testBuffer, err: errors.New("different error from first")})
1778
1779 for i := 0; i < 2; i++ {
1780 inBuf := make([]byte, 1)
1781 actualCount, actualErr := s.Read(inBuf)
1782 if actualCount != 0 {
1783 t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
1784 }
1785 if actualErr.Error() != testErr.Error() {
1786 t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
1787 }
1788 }
1789 }
1790
1791
1792
1793
1794 func (s) TestHeadersCausingStreamError(t *testing.T) {
1795 tests := []struct {
1796 name string
1797 headers []struct {
1798 name string
1799 values []string
1800 }
1801 }{
1802
1803
1804
1805 {
1806 name: "Connection header present",
1807 headers: []struct {
1808 name string
1809 values []string
1810 }{
1811 {name: ":method", values: []string{"POST"}},
1812 {name: ":path", values: []string{"foo"}},
1813 {name: ":authority", values: []string{"localhost"}},
1814 {name: "content-type", values: []string{"application/grpc"}},
1815 {name: "connection", values: []string{"not-supported"}},
1816 },
1817 },
1818
1819
1820
1821
1822 {
1823
1824
1825
1826
1827 name: "Multiple authority headers",
1828 headers: []struct {
1829 name string
1830 values []string
1831 }{
1832 {name: ":method", values: []string{"POST"}},
1833 {name: ":path", values: []string{"foo"}},
1834 {name: ":authority", values: []string{"localhost", "localhost2"}},
1835 {name: "host", values: []string{"localhost"}},
1836 },
1837 },
1838 }
1839 for _, test := range tests {
1840 t.Run(test.name, func(t *testing.T) {
1841 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended)
1842 defer server.stop()
1843
1844
1845 mconn, err := net.Dial("tcp", server.lis.Addr().String())
1846 if err != nil {
1847 t.Fatalf("Client failed to dial: %v", err)
1848 }
1849 defer mconn.Close()
1850
1851 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
1852 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
1853 }
1854
1855 framer := http2.NewFramer(mconn, mconn)
1856 if err := framer.WriteSettings(); err != nil {
1857 t.Fatalf("Error while writing settings: %v", err)
1858 }
1859
1860
1861
1862 result := testutils.NewChannel()
1863
1864
1865 go func() {
1866 for {
1867 frame, err := framer.ReadFrame()
1868 if err != nil {
1869 return
1870 }
1871 switch frame := frame.(type) {
1872 case *http2.SettingsFrame:
1873
1874 case *http2.RSTStreamFrame:
1875 if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeProtocol {
1876
1877 result.Send(fmt.Errorf("RST stream received with streamID: %d and code %v, want streamID: 1 and code: http.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)))
1878 }
1879
1880 result.Send(nil)
1881 return
1882 default:
1883
1884 result.Send(errors.New("the client received a frame other than RST Stream"))
1885 }
1886 }
1887 }()
1888
1889 var buf bytes.Buffer
1890 henc := hpack.NewEncoder(&buf)
1891
1892
1893
1894 for _, header := range test.headers {
1895 for _, value := range header.values {
1896 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil {
1897 t.Fatalf("Error while encoding header: %v", err)
1898 }
1899 }
1900 }
1901
1902 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
1903 t.Fatalf("Error while writing headers: %v", err)
1904 }
1905 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1906 defer cancel()
1907 r, err := result.Receive(ctx)
1908 if err != nil {
1909 t.Fatalf("Error receiving from channel: %v", err)
1910 }
1911 if r != nil {
1912 t.Fatalf("want nil, got %v", r)
1913 }
1914 })
1915 }
1916 }
1917
1918
1919
1920 func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) {
1921 tests := []struct {
1922 name string
1923 headers []struct {
1924 name string
1925 values []string
1926 }
1927 httpStatusWant string
1928 grpcStatusWant string
1929 grpcMessageWant string
1930 }{
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942 {
1943 name: "Multiple host headers non grpc",
1944 headers: []struct {
1945 name string
1946 values []string
1947 }{
1948 {name: ":method", values: []string{"POST"}},
1949 {name: ":path", values: []string{"foo"}},
1950 {name: ":authority", values: []string{"localhost"}},
1951 {name: "host", values: []string{"localhost", "localhost2"}},
1952 },
1953 httpStatusWant: "400",
1954 grpcStatusWant: "13",
1955 grpcMessageWant: "both must only have 1 value as per HTTP/2 spec",
1956 },
1957 {
1958 name: "Multiple host headers grpc",
1959 headers: []struct {
1960 name string
1961 values []string
1962 }{
1963 {name: ":method", values: []string{"POST"}},
1964 {name: ":path", values: []string{"foo"}},
1965 {name: ":authority", values: []string{"localhost"}},
1966 {name: "content-type", values: []string{"application/grpc"}},
1967 {name: "host", values: []string{"localhost", "localhost2"}},
1968 },
1969 httpStatusWant: "400",
1970 grpcStatusWant: "13",
1971 grpcMessageWant: "both must only have 1 value as per HTTP/2 spec",
1972 },
1973
1974
1975
1976 {
1977 name: "Client Sending Wrong Method",
1978 headers: []struct {
1979 name string
1980 values []string
1981 }{
1982 {name: ":method", values: []string{"PUT"}},
1983 {name: ":path", values: []string{"foo"}},
1984 {name: ":authority", values: []string{"localhost"}},
1985 {name: "content-type", values: []string{"application/grpc"}},
1986 },
1987 httpStatusWant: "405",
1988 grpcStatusWant: "13",
1989 grpcMessageWant: "which should be POST",
1990 },
1991 {
1992 name: "Client Sending Wrong Content-Type",
1993 headers: []struct {
1994 name string
1995 values []string
1996 }{
1997 {name: ":method", values: []string{"POST"}},
1998 {name: ":path", values: []string{"foo"}},
1999 {name: ":authority", values: []string{"localhost"}},
2000 {name: "content-type", values: []string{"application/json"}},
2001 },
2002 httpStatusWant: "415",
2003 grpcStatusWant: "3",
2004 grpcMessageWant: `invalid gRPC request content-type "application/json"`,
2005 },
2006 {
2007 name: "Client Sending Bad Timeout",
2008 headers: []struct {
2009 name string
2010 values []string
2011 }{
2012 {name: ":method", values: []string{"POST"}},
2013 {name: ":path", values: []string{"foo"}},
2014 {name: ":authority", values: []string{"localhost"}},
2015 {name: "content-type", values: []string{"application/grpc"}},
2016 {name: "grpc-timeout", values: []string{"18f6n"}},
2017 },
2018 httpStatusWant: "400",
2019 grpcStatusWant: "13",
2020 grpcMessageWant: "malformed grpc-timeout",
2021 },
2022 {
2023 name: "Client Sending Bad Binary Header",
2024 headers: []struct {
2025 name string
2026 values []string
2027 }{
2028 {name: ":method", values: []string{"POST"}},
2029 {name: ":path", values: []string{"foo"}},
2030 {name: ":authority", values: []string{"localhost"}},
2031 {name: "content-type", values: []string{"application/grpc"}},
2032 {name: "foobar-bin", values: []string{"X()3e@#$-"}},
2033 },
2034 httpStatusWant: "400",
2035 grpcStatusWant: "13",
2036 grpcMessageWant: `header "foobar-bin": illegal base64 data`,
2037 },
2038 }
2039 for _, test := range tests {
2040 t.Run(test.name, func(t *testing.T) {
2041 server := setUpServerOnly(t, 0, &ServerConfig{}, suspended)
2042 defer server.stop()
2043
2044
2045 mconn, err := net.Dial("tcp", server.lis.Addr().String())
2046 if err != nil {
2047 t.Fatalf("Client failed to dial: %v", err)
2048 }
2049 defer mconn.Close()
2050
2051 if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
2052 t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
2053 }
2054
2055 framer := http2.NewFramer(mconn, mconn)
2056 framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil)
2057 if err := framer.WriteSettings(); err != nil {
2058 t.Fatalf("Error while writing settings: %v", err)
2059 }
2060
2061
2062
2063
2064 result := testutils.NewChannel()
2065
2066
2067 go func() {
2068 for {
2069 frame, err := framer.ReadFrame()
2070 if err != nil {
2071 return
2072 }
2073 switch frame := frame.(type) {
2074 case *http2.SettingsFrame:
2075
2076 case *http2.MetaHeadersFrame:
2077 var httpStatus, grpcStatus, grpcMessage string
2078 for _, header := range frame.Fields {
2079 if header.Name == ":status" {
2080 httpStatus = header.Value
2081 }
2082 if header.Name == "grpc-status" {
2083 grpcStatus = header.Value
2084 }
2085 if header.Name == "grpc-message" {
2086 grpcMessage = header.Value
2087 }
2088 }
2089 if httpStatus != test.httpStatusWant {
2090 result.Send(fmt.Errorf("incorrect HTTP Status got %v, want %v", httpStatus, test.httpStatusWant))
2091 return
2092 }
2093 if grpcStatus != test.grpcStatusWant {
2094 result.Send(fmt.Errorf("incorrect gRPC Status got %v, want %v", grpcStatus, test.grpcStatusWant))
2095 return
2096 }
2097 if !strings.Contains(grpcMessage, test.grpcMessageWant) {
2098 result.Send(fmt.Errorf("incorrect gRPC message, want %q got %q", test.grpcMessageWant, grpcMessage))
2099 return
2100 }
2101
2102
2103
2104 result.Send(nil)
2105 return
2106 default:
2107
2108 result.Send(errors.New("the client received a frame other than Settings or Headers"))
2109 }
2110 }
2111 }()
2112
2113 var buf bytes.Buffer
2114 henc := hpack.NewEncoder(&buf)
2115
2116
2117
2118 for _, header := range test.headers {
2119 for _, value := range header.values {
2120 if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil {
2121 t.Fatalf("Error while encoding header: %v", err)
2122 }
2123 }
2124 }
2125
2126 if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
2127 t.Fatalf("Error while writing headers: %v", err)
2128 }
2129 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2130 defer cancel()
2131 r, err := result.Receive(ctx)
2132 if err != nil {
2133 t.Fatalf("Error receiving from channel: %v", err)
2134 }
2135 if r != nil {
2136 t.Fatalf("want nil, got %v", r)
2137 }
2138 })
2139 }
2140 }
2141
2142 func (s) TestWriteHeaderConnectionError(t *testing.T) {
2143 server, client, cancel := setUp(t, 0, notifyCall)
2144 defer cancel()
2145 defer server.stop()
2146
2147 waitWhileTrue(t, func() (bool, error) {
2148 server.mu.Lock()
2149 defer server.mu.Unlock()
2150
2151 if len(server.conns) == 0 {
2152 return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
2153 }
2154 return false, nil
2155 })
2156
2157 server.mu.Lock()
2158
2159 if len(server.conns) != 1 {
2160 t.Fatalf("Server has %d connections from the client, want 1", len(server.conns))
2161 }
2162
2163
2164 var serverTransport *http2Server
2165 for k := range server.conns {
2166 serverTransport = k.(*http2Server)
2167 }
2168 notifyChan := make(chan struct{})
2169 server.h.notify = notifyChan
2170 server.mu.Unlock()
2171
2172 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2173 defer cancel()
2174 cstream, err := client.NewStream(ctx, &CallHdr{})
2175 if err != nil {
2176 t.Fatalf("Client failed to create first stream. Err: %v", err)
2177 }
2178
2179 <-notifyChan
2180 var sstream *Stream
2181
2182 serverTransport.mu.Lock()
2183 for _, v := range serverTransport.activeStreams {
2184 if v.id == cstream.id {
2185 sstream = v
2186 }
2187 }
2188 serverTransport.mu.Unlock()
2189 if sstream == nil {
2190 t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream.id)
2191 }
2192
2193 client.Close(fmt.Errorf("closed manually by test"))
2194
2195
2196 <-serverTransport.done
2197
2198
2199 err = serverTransport.WriteHeader(sstream, metadata.MD{})
2200 st := status.Convert(err)
2201 if st.Code() != codes.Unavailable {
2202 t.Fatalf("WriteHeader() failed with status code %s, want %s", st.Code(), codes.Unavailable)
2203 }
2204 }
2205
2206 func (s) TestPingPong1B(t *testing.T) {
2207 runPingPongTest(t, 1)
2208 }
2209
2210 func (s) TestPingPong1KB(t *testing.T) {
2211 runPingPongTest(t, 1024)
2212 }
2213
2214 func (s) TestPingPong64KB(t *testing.T) {
2215 runPingPongTest(t, 65536)
2216 }
2217
2218 func (s) TestPingPong1MB(t *testing.T) {
2219 runPingPongTest(t, 1048576)
2220 }
2221
2222
2223 func runPingPongTest(t *testing.T, msgSize int) {
2224 server, client, cancel := setUp(t, 0, pingpong)
2225 defer cancel()
2226 defer server.stop()
2227 defer client.Close(fmt.Errorf("closed manually by test"))
2228 waitWhileTrue(t, func() (bool, error) {
2229 server.mu.Lock()
2230 defer server.mu.Unlock()
2231 if len(server.conns) == 0 {
2232 return true, fmt.Errorf("timed out while waiting for server transport to be created")
2233 }
2234 return false, nil
2235 })
2236 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2237 defer cancel()
2238 stream, err := client.NewStream(ctx, &CallHdr{})
2239 if err != nil {
2240 t.Fatalf("Failed to create stream. Err: %v", err)
2241 }
2242 msg := make([]byte, msgSize)
2243 outgoingHeader := make([]byte, 5)
2244 outgoingHeader[0] = byte(0)
2245 binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize))
2246 opts := &Options{}
2247 incomingHeader := make([]byte, 5)
2248
2249 ctx, cancel = context.WithTimeout(ctx, time.Second)
2250 defer cancel()
2251 for ctx.Err() == nil {
2252 if err := client.Write(stream, outgoingHeader, msg, opts); err != nil {
2253 t.Fatalf("Error on client while writing message. Err: %v", err)
2254 }
2255 if _, err := stream.Read(incomingHeader); err != nil {
2256 t.Fatalf("Error on client while reading data header. Err: %v", err)
2257 }
2258 sz := binary.BigEndian.Uint32(incomingHeader[1:])
2259 recvMsg := make([]byte, int(sz))
2260 if _, err := stream.Read(recvMsg); err != nil {
2261 t.Fatalf("Error on client while reading data. Err: %v", err)
2262 }
2263 }
2264
2265 client.Write(stream, nil, nil, &Options{Last: true})
2266 if _, err := stream.Read(incomingHeader); err != io.EOF {
2267 t.Fatalf("Client expected EOF from the server. Got: %v", err)
2268 }
2269 }
2270
2271 type tableSizeLimit struct {
2272 mu sync.Mutex
2273 limits []uint32
2274 }
2275
2276 func (t *tableSizeLimit) add(limit uint32) {
2277 t.mu.Lock()
2278 t.limits = append(t.limits, limit)
2279 t.mu.Unlock()
2280 }
2281
2282 func (t *tableSizeLimit) getLen() int {
2283 t.mu.Lock()
2284 defer t.mu.Unlock()
2285 return len(t.limits)
2286 }
2287
2288 func (t *tableSizeLimit) getIndex(i int) uint32 {
2289 t.mu.Lock()
2290 defer t.mu.Unlock()
2291 return t.limits[i]
2292 }
2293
2294 func (s) TestHeaderTblSize(t *testing.T) {
2295 limits := &tableSizeLimit{}
2296 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
2297 e.SetMaxDynamicTableSizeLimit(v)
2298 limits.add(v)
2299 }
2300 defer func() {
2301 updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
2302 e.SetMaxDynamicTableSizeLimit(v)
2303 }
2304 }()
2305
2306 server, ct, cancel := setUp(t, 0, normal)
2307 defer cancel()
2308 defer ct.Close(fmt.Errorf("closed manually by test"))
2309 defer server.stop()
2310 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2311 defer ctxCancel()
2312 _, err := ct.NewStream(ctx, &CallHdr{})
2313 if err != nil {
2314 t.Fatalf("failed to open stream: %v", err)
2315 }
2316
2317 var svrTransport ServerTransport
2318 var i int
2319 for i = 0; i < 1000; i++ {
2320 server.mu.Lock()
2321 if len(server.conns) != 0 {
2322 server.mu.Unlock()
2323 break
2324 }
2325 server.mu.Unlock()
2326 time.Sleep(10 * time.Millisecond)
2327 continue
2328 }
2329 if i == 1000 {
2330 t.Fatalf("unable to create any server transport after 10s")
2331 }
2332
2333 for st := range server.conns {
2334 svrTransport = st
2335 break
2336 }
2337 svrTransport.(*http2Server).controlBuf.put(&outgoingSettings{
2338 ss: []http2.Setting{
2339 {
2340 ID: http2.SettingHeaderTableSize,
2341 Val: uint32(100),
2342 },
2343 },
2344 })
2345
2346 for i = 0; i < 1000; i++ {
2347 if limits.getLen() != 1 {
2348 time.Sleep(10 * time.Millisecond)
2349 continue
2350 }
2351 if val := limits.getIndex(0); val != uint32(100) {
2352 t.Fatalf("expected limits[0] = 100, got %d", val)
2353 }
2354 break
2355 }
2356 if i == 1000 {
2357 t.Fatalf("expected len(limits) = 1 within 10s, got != 1")
2358 }
2359
2360 ct.controlBuf.put(&outgoingSettings{
2361 ss: []http2.Setting{
2362 {
2363 ID: http2.SettingHeaderTableSize,
2364 Val: uint32(200),
2365 },
2366 },
2367 })
2368
2369 for i := 0; i < 1000; i++ {
2370 if limits.getLen() != 2 {
2371 time.Sleep(10 * time.Millisecond)
2372 continue
2373 }
2374 if val := limits.getIndex(1); val != uint32(200) {
2375 t.Fatalf("expected limits[1] = 200, got %d", val)
2376 }
2377 break
2378 }
2379 if i == 1000 {
2380 t.Fatalf("expected len(limits) = 2 within 10s, got != 2")
2381 }
2382 }
2383
2384
2385
2386
2387 type attrTransportCreds struct {
2388 credentials.TransportCredentials
2389 attr *attributes.Attributes
2390 }
2391
2392 func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
2393 ai := credentials.ClientHandshakeInfoFromContext(ctx)
2394 ac.attr = ai.Attributes
2395 return rawConn, nil, nil
2396 }
2397 func (ac *attrTransportCreds) Info() credentials.ProtocolInfo {
2398 return credentials.ProtocolInfo{}
2399 }
2400 func (ac *attrTransportCreds) Clone() credentials.TransportCredentials {
2401 return nil
2402 }
2403
2404
2405
2406
2407 func (s) TestClientHandshakeInfo(t *testing.T) {
2408 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong)
2409 defer server.stop()
2410
2411 const (
2412 testAttrKey = "foo"
2413 testAttrVal = "bar"
2414 )
2415 addr := resolver.Address{
2416 Addr: "localhost:" + server.port,
2417 Attributes: attributes.New(testAttrKey, testAttrVal),
2418 }
2419 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
2420 defer cancel()
2421 creds := &attrTransportCreds{}
2422
2423 copts := ConnectOptions{
2424 TransportCredentials: creds,
2425 ChannelzParent: channelzSubChannel(t),
2426 }
2427 tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
2428 if err != nil {
2429 t.Fatalf("NewClientTransport(): %v", err)
2430 }
2431 defer tr.Close(fmt.Errorf("closed manually by test"))
2432
2433 wantAttr := attributes.New(testAttrKey, testAttrVal)
2434 if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
2435 t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr)
2436 }
2437 }
2438
2439
2440
2441
2442 func (s) TestClientHandshakeInfoDialer(t *testing.T) {
2443 server := setUpServerOnly(t, 0, &ServerConfig{}, pingpong)
2444 defer server.stop()
2445
2446 const (
2447 testAttrKey = "foo"
2448 testAttrVal = "bar"
2449 )
2450 addr := resolver.Address{
2451 Addr: "localhost:" + server.port,
2452 Attributes: attributes.New(testAttrKey, testAttrVal),
2453 }
2454 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
2455 defer cancel()
2456
2457 var attr *attributes.Attributes
2458 dialer := func(ctx context.Context, addr string) (net.Conn, error) {
2459 ai := credentials.ClientHandshakeInfoFromContext(ctx)
2460 attr = ai.Attributes
2461 return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
2462 }
2463
2464 copts := ConnectOptions{
2465 Dialer: dialer,
2466 ChannelzParent: channelzSubChannel(t),
2467 }
2468 tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
2469 if err != nil {
2470 t.Fatalf("NewClientTransport(): %v", err)
2471 }
2472 defer tr.Close(fmt.Errorf("closed manually by test"))
2473
2474 wantAttr := attributes.New(testAttrKey, testAttrVal)
2475 if gotAttr := attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
2476 t.Errorf("Received attributes %v in custom dialer, want %v", gotAttr, wantAttr)
2477 }
2478 }
2479
2480 func (s) TestClientDecodeHeaderStatusErr(t *testing.T) {
2481 testStream := func() *Stream {
2482 return &Stream{
2483 done: make(chan struct{}),
2484 headerChan: make(chan struct{}),
2485 buf: &recvBuffer{
2486 c: make(chan recvMsg),
2487 mu: sync.Mutex{},
2488 },
2489 }
2490 }
2491
2492 testClient := func(ts *Stream) *http2Client {
2493 return &http2Client{
2494 mu: sync.Mutex{},
2495 activeStreams: map[uint32]*Stream{
2496 0: ts,
2497 },
2498 controlBuf: &controlBuffer{
2499 ch: make(chan struct{}),
2500 done: make(chan struct{}),
2501 list: &itemList{},
2502 },
2503 }
2504 }
2505
2506 for _, test := range []struct {
2507 name string
2508
2509 metaHeaderFrame *http2.MetaHeadersFrame
2510
2511 wantStatus *status.Status
2512 }{
2513 {
2514 name: "valid header",
2515 metaHeaderFrame: &http2.MetaHeadersFrame{
2516 Fields: []hpack.HeaderField{
2517 {Name: "content-type", Value: "application/grpc"},
2518 {Name: "grpc-status", Value: "0"},
2519 {Name: ":status", Value: "200"},
2520 },
2521 },
2522
2523 wantStatus: status.New(codes.OK, ""),
2524 },
2525 {
2526 name: "missing content-type header",
2527 metaHeaderFrame: &http2.MetaHeadersFrame{
2528 Fields: []hpack.HeaderField{
2529 {Name: "grpc-status", Value: "0"},
2530 {Name: ":status", Value: "200"},
2531 },
2532 },
2533 wantStatus: status.New(
2534 codes.Unknown,
2535 "malformed header: missing HTTP content-type",
2536 ),
2537 },
2538 {
2539 name: "invalid grpc status header field",
2540 metaHeaderFrame: &http2.MetaHeadersFrame{
2541 Fields: []hpack.HeaderField{
2542 {Name: "content-type", Value: "application/grpc"},
2543 {Name: "grpc-status", Value: "xxxx"},
2544 {Name: ":status", Value: "200"},
2545 },
2546 },
2547 wantStatus: status.New(
2548 codes.Internal,
2549 "transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax",
2550 ),
2551 },
2552 {
2553 name: "invalid http content type",
2554 metaHeaderFrame: &http2.MetaHeadersFrame{
2555 Fields: []hpack.HeaderField{
2556 {Name: "content-type", Value: "application/json"},
2557 },
2558 },
2559 wantStatus: status.New(
2560 codes.Internal,
2561 "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"",
2562 ),
2563 },
2564 {
2565 name: "http fallback and invalid http status",
2566 metaHeaderFrame: &http2.MetaHeadersFrame{
2567 Fields: []hpack.HeaderField{
2568
2569 {Name: ":status", Value: "xxxx"},
2570 },
2571 },
2572 wantStatus: status.New(
2573 codes.Internal,
2574 "transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax",
2575 ),
2576 },
2577 {
2578 name: "http2 frame size exceeds",
2579 metaHeaderFrame: &http2.MetaHeadersFrame{
2580 Fields: nil,
2581 Truncated: true,
2582 },
2583 wantStatus: status.New(
2584 codes.Internal,
2585 "peer header list size exceeded limit",
2586 ),
2587 },
2588 {
2589 name: "bad status in grpc mode",
2590 metaHeaderFrame: &http2.MetaHeadersFrame{
2591 Fields: []hpack.HeaderField{
2592 {Name: "content-type", Value: "application/grpc"},
2593 {Name: "grpc-status", Value: "0"},
2594 {Name: ":status", Value: "504"},
2595 },
2596 },
2597 wantStatus: status.New(
2598 codes.Unavailable,
2599 "unexpected HTTP status code received from server: 504 (Gateway Timeout)",
2600 ),
2601 },
2602 {
2603 name: "missing http status",
2604 metaHeaderFrame: &http2.MetaHeadersFrame{
2605 Fields: []hpack.HeaderField{
2606 {Name: "content-type", Value: "application/grpc"},
2607 },
2608 },
2609 wantStatus: status.New(
2610 codes.Internal,
2611 "malformed header: missing HTTP status",
2612 ),
2613 },
2614 } {
2615
2616 t.Run(test.name, func(t *testing.T) {
2617 ts := testStream()
2618 s := testClient(ts)
2619
2620 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{
2621 FrameHeader: http2.FrameHeader{
2622 StreamID: 0,
2623 },
2624 }
2625
2626 s.operateHeaders(test.metaHeaderFrame)
2627
2628 got := ts.status
2629 want := test.wantStatus
2630 if got.Code() != want.Code() || got.Message() != want.Message() {
2631 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want)
2632 }
2633 })
2634 t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) {
2635 ts := testStream()
2636 s := testClient(ts)
2637
2638 test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{
2639 FrameHeader: http2.FrameHeader{
2640 StreamID: 0,
2641 Flags: http2.FlagHeadersEndStream,
2642 },
2643 }
2644
2645 s.operateHeaders(test.metaHeaderFrame)
2646
2647 got := ts.status
2648 want := test.wantStatus
2649 if got.Code() != want.Code() || got.Message() != want.Message() {
2650 t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want)
2651 }
2652 })
2653 }
2654 }
2655
2656 func TestConnectionError_Unwrap(t *testing.T) {
2657 err := connectionErrorf(false, os.ErrNotExist, "unwrap me")
2658 if !errors.Is(err, os.ErrNotExist) {
2659 t.Error("ConnectionError does not unwrap")
2660 }
2661 }
2662
2663
2664
2665
2666 func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
2667
2668 lis, err := net.Listen("tcp", "localhost:0")
2669 if err != nil {
2670 t.Fatalf("Error while listening: %v", err)
2671 }
2672 defer lis.Close()
2673
2674 greetDone := make(chan struct{})
2675
2676 errorCh := make(chan error)
2677 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
2678 defer cancel()
2679
2680 go func() {
2681 sconn, err := lis.Accept()
2682 if err != nil {
2683 t.Errorf("Error while accepting: %v", err)
2684 }
2685 defer sconn.Close()
2686 if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
2687 t.Errorf("Error while writing settings ack: %v", err)
2688 return
2689 }
2690 sfr := http2.NewFramer(sconn, sconn)
2691 if err := sfr.WriteSettings(); err != nil {
2692 t.Errorf("Error while writing settings %v", err)
2693 return
2694 }
2695 fr, _ := sfr.ReadFrame()
2696 if _, ok := fr.(*http2.SettingsFrame); !ok {
2697 t.Errorf("Expected settings frame, got %v", fr)
2698 }
2699 fr, _ = sfr.ReadFrame()
2700 if fr, ok := fr.(*http2.SettingsFrame); !ok || !fr.IsAck() {
2701 t.Errorf("Expected settings ACK frame, got %v", fr)
2702 }
2703 fr, _ = sfr.ReadFrame()
2704 if fr, ok := fr.(*http2.HeadersFrame); !ok || !fr.Flags.Has(http2.FlagHeadersEndHeaders) {
2705 t.Errorf("Expected Headers frame with END_HEADERS frame, got %v", fr)
2706 }
2707 close(greetDone)
2708
2709 frame, err := sfr.ReadFrame()
2710 if err != nil {
2711 return
2712 }
2713 switch fr := frame.(type) {
2714 case *http2.GoAwayFrame:
2715
2716 goAwayFrame := fr
2717 if goAwayFrame.ErrCode == http2.ErrCodeNo {
2718 t.Logf("Received goAway frame from client")
2719 close(errorCh)
2720 } else {
2721 errorCh <- fmt.Errorf("received unexpected goAway frame: %v", err)
2722 close(errorCh)
2723 }
2724 return
2725 default:
2726 errorCh <- fmt.Errorf("server received a frame other than GOAWAY: %v", err)
2727 close(errorCh)
2728 return
2729 }
2730 }()
2731
2732 ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {})
2733 if err != nil {
2734 t.Fatalf("Error while creating client transport: %v", err)
2735 }
2736 _, err = ct.NewStream(ctx, &CallHdr{})
2737 if err != nil {
2738 t.Fatalf("failed to open stream: %v", err)
2739 }
2740
2741 <-greetDone
2742 ct.Close(errors.New("manually closed by client"))
2743 t.Logf("Closed the client connection")
2744 select {
2745 case err := <-errorCh:
2746 if err != nil {
2747 t.Errorf("Error receiving the GOAWAY frame: %v", err)
2748 }
2749 case <-ctx.Done():
2750 t.Errorf("Context timed out")
2751 }
2752 }
2753
View as plain text