...

Source file src/google.golang.org/grpc/internal/transport/handler_server_test.go

Documentation: google.golang.org/grpc/internal/transport

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package transport
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"reflect"
    30  	"sync"
    31  	"testing"
    32  	"time"
    33  
    34  	epb "google.golang.org/genproto/googleapis/rpc/errdetails"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/metadata"
    37  	"google.golang.org/grpc/status"
    38  	"google.golang.org/protobuf/proto"
    39  	"google.golang.org/protobuf/protoadapt"
    40  	"google.golang.org/protobuf/types/known/durationpb"
    41  )
    42  
    43  func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
    44  	type testCase struct {
    45  		name        string
    46  		req         *http.Request
    47  		wantErr     string
    48  		wantErrCode int
    49  		modrw       func(http.ResponseWriter) http.ResponseWriter
    50  		check       func(*serverHandlerTransport, *testCase) error
    51  	}
    52  	tests := []testCase{
    53  		{
    54  			name: "bad method",
    55  			req: &http.Request{
    56  				ProtoMajor: 2,
    57  				Method:     "GET",
    58  				Header:     http.Header{},
    59  			},
    60  			wantErr:     `invalid gRPC request method "GET"`,
    61  			wantErrCode: http.StatusMethodNotAllowed,
    62  		},
    63  		{
    64  			name: "bad content type",
    65  			req: &http.Request{
    66  				ProtoMajor: 2,
    67  				Method:     "POST",
    68  				Header: http.Header{
    69  					"Content-Type": {"application/foo"},
    70  				},
    71  			},
    72  			wantErr:     `invalid gRPC request content-type "application/foo"`,
    73  			wantErrCode: http.StatusUnsupportedMediaType,
    74  		},
    75  		{
    76  			name: "http/1.1",
    77  			req: &http.Request{
    78  				ProtoMajor: 1,
    79  				ProtoMinor: 1,
    80  				Method:     "POST",
    81  				Header:     http.Header{"Content-Type": []string{"application/grpc"}},
    82  			},
    83  			wantErr:     "gRPC requires HTTP/2",
    84  			wantErrCode: http.StatusHTTPVersionNotSupported,
    85  		},
    86  		{
    87  			name: "not flusher",
    88  			req: &http.Request{
    89  				ProtoMajor: 2,
    90  				Method:     "POST",
    91  				Header: http.Header{
    92  					"Content-Type": {"application/grpc"},
    93  				},
    94  			},
    95  			modrw: func(w http.ResponseWriter) http.ResponseWriter {
    96  				// Return w without its Flush method
    97  				type onlyCloseNotifier interface {
    98  					http.ResponseWriter
    99  				}
   100  				return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
   101  			},
   102  			wantErr:     "gRPC requires a ResponseWriter supporting http.Flusher",
   103  			wantErrCode: http.StatusInternalServerError,
   104  		},
   105  		{
   106  			name: "valid",
   107  			req: &http.Request{
   108  				ProtoMajor: 2,
   109  				Method:     "POST",
   110  				Header: http.Header{
   111  					"Content-Type": {"application/grpc"},
   112  				},
   113  				URL: &url.URL{
   114  					Path: "/service/foo.bar",
   115  				},
   116  			},
   117  			check: func(t *serverHandlerTransport, tt *testCase) error {
   118  				if t.req != tt.req {
   119  					return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
   120  				}
   121  				if t.rw == nil {
   122  					return errors.New("t.rw = nil; want non-nil")
   123  				}
   124  				return nil
   125  			},
   126  		},
   127  		{
   128  			name: "with timeout",
   129  			req: &http.Request{
   130  				ProtoMajor: 2,
   131  				Method:     "POST",
   132  				Header: http.Header{
   133  					"Content-Type": []string{"application/grpc"},
   134  					"Grpc-Timeout": {"200m"},
   135  				},
   136  				URL: &url.URL{
   137  					Path: "/service/foo.bar",
   138  				},
   139  			},
   140  			check: func(t *serverHandlerTransport, tt *testCase) error {
   141  				if !t.timeoutSet {
   142  					return errors.New("timeout not set")
   143  				}
   144  				if want := 200 * time.Millisecond; t.timeout != want {
   145  					return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
   146  				}
   147  				return nil
   148  			},
   149  		},
   150  		{
   151  			name: "with bad timeout",
   152  			req: &http.Request{
   153  				ProtoMajor: 2,
   154  				Method:     "POST",
   155  				Header: http.Header{
   156  					"Content-Type": []string{"application/grpc"},
   157  					"Grpc-Timeout": {"tomorrow"},
   158  				},
   159  				URL: &url.URL{
   160  					Path: "/service/foo.bar",
   161  				},
   162  			},
   163  			wantErr:     `rpc error: code = Internal desc = malformed grpc-timeout: transport: timeout unit is not recognized: "tomorrow"`,
   164  			wantErrCode: http.StatusBadRequest,
   165  		},
   166  		{
   167  			name: "with metadata",
   168  			req: &http.Request{
   169  				ProtoMajor: 2,
   170  				Method:     "POST",
   171  				Header: http.Header{
   172  					"Content-Type": []string{"application/grpc"},
   173  					"meta-foo":     {"foo-val"},
   174  					"meta-bar":     {"bar-val1", "bar-val2"},
   175  					"user-agent":   {"x/y a/b"},
   176  				},
   177  				URL: &url.URL{
   178  					Path: "/service/foo.bar",
   179  				},
   180  			},
   181  			check: func(ht *serverHandlerTransport, tt *testCase) error {
   182  				want := metadata.MD{
   183  					"meta-bar":     {"bar-val1", "bar-val2"},
   184  					"user-agent":   {"x/y a/b"},
   185  					"meta-foo":     {"foo-val"},
   186  					"content-type": {"application/grpc"},
   187  				}
   188  
   189  				if !reflect.DeepEqual(ht.headerMD, want) {
   190  					return fmt.Errorf("metadata = %#v; want %#v", ht.headerMD, want)
   191  				}
   192  				return nil
   193  			},
   194  		},
   195  	}
   196  
   197  	for _, tt := range tests {
   198  		rrec := httptest.NewRecorder()
   199  		rw := http.ResponseWriter(testHandlerResponseWriter{
   200  			ResponseRecorder: rrec,
   201  		})
   202  
   203  		if tt.modrw != nil {
   204  			rw = tt.modrw(rw)
   205  		}
   206  		got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
   207  		if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
   208  			t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
   209  			continue
   210  		}
   211  		if tt.wantErrCode == 0 {
   212  			tt.wantErrCode = http.StatusOK
   213  		}
   214  		if rrec.Code != tt.wantErrCode {
   215  			t.Errorf("%s: code = %d; want %d", tt.name, rrec.Code, tt.wantErrCode)
   216  			continue
   217  		}
   218  		if gotErr != nil {
   219  			continue
   220  		}
   221  		if tt.check != nil {
   222  			if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
   223  				t.Errorf("%s: %v", tt.name, err)
   224  			}
   225  		}
   226  	}
   227  }
   228  
   229  type testHandlerResponseWriter struct {
   230  	*httptest.ResponseRecorder
   231  }
   232  
   233  func (w testHandlerResponseWriter) Flush() {}
   234  
   235  func newTestHandlerResponseWriter() http.ResponseWriter {
   236  	return testHandlerResponseWriter{
   237  		ResponseRecorder: httptest.NewRecorder(),
   238  	}
   239  }
   240  
   241  type handleStreamTest struct {
   242  	t     *testing.T
   243  	bodyw *io.PipeWriter
   244  	rw    testHandlerResponseWriter
   245  	ht    *serverHandlerTransport
   246  }
   247  
   248  func newHandleStreamTest(t *testing.T) *handleStreamTest {
   249  	bodyr, bodyw := io.Pipe()
   250  	req := &http.Request{
   251  		ProtoMajor: 2,
   252  		Method:     "POST",
   253  		Header: http.Header{
   254  			"Content-Type": {"application/grpc"},
   255  		},
   256  		URL: &url.URL{
   257  			Path: "/service/foo.bar",
   258  		},
   259  		Body: bodyr,
   260  	}
   261  	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
   262  	ht, err := NewServerHandlerTransport(rw, req, nil)
   263  	if err != nil {
   264  		t.Fatal(err)
   265  	}
   266  	return &handleStreamTest{
   267  		t:     t,
   268  		bodyw: bodyw,
   269  		ht:    ht.(*serverHandlerTransport),
   270  		rw:    rw,
   271  	}
   272  }
   273  
   274  func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
   275  	st := newHandleStreamTest(t)
   276  	handleStream := func(s *Stream) {
   277  		if want := "/service/foo.bar"; s.method != want {
   278  			t.Errorf("stream method = %q; want %q", s.method, want)
   279  		}
   280  
   281  		if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
   282  			t.Error(err)
   283  		}
   284  
   285  		if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
   286  			t.Error(err)
   287  		}
   288  
   289  		if err := s.SetSendCompress("gzip"); err != nil {
   290  			t.Error(err)
   291  		}
   292  
   293  		md := metadata.Pairs("custom-header", "Another custom header value")
   294  		if err := s.SendHeader(md); err != nil {
   295  			t.Error(err)
   296  		}
   297  		delete(md, "custom-header")
   298  
   299  		if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
   300  			t.Error("expected SetHeader call after SendHeader to fail")
   301  		}
   302  
   303  		if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
   304  			t.Error("expected second SendHeader call to fail")
   305  		}
   306  
   307  		if err := s.SetSendCompress("snappy"); err == nil {
   308  			t.Error("expected second SetSendCompress call to fail")
   309  		}
   310  
   311  		st.bodyw.Close() // no body
   312  		st.ht.WriteStatus(s, status.New(codes.OK, ""))
   313  	}
   314  	st.ht.HandleStreams(
   315  		context.Background(), func(s *Stream) { go handleStream(s) },
   316  	)
   317  	wantHeader := http.Header{
   318  		"Date":          nil,
   319  		"Content-Type":  {"application/grpc"},
   320  		"Trailer":       {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   321  		"Custom-Header": {"Custom header value", "Another custom header value"},
   322  		"Grpc-Encoding": {"gzip"},
   323  	}
   324  	wantTrailer := http.Header{
   325  		"Grpc-Status":    {"0"},
   326  		"Custom-Trailer": {"Custom trailer value"},
   327  	}
   328  	checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
   329  }
   330  
   331  // Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
   332  func (s) TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
   333  	handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
   334  }
   335  
   336  // Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
   337  func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
   338  	handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
   339  }
   340  
   341  func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
   342  	st := newHandleStreamTest(t)
   343  
   344  	handleStream := func(s *Stream) {
   345  		st.ht.WriteStatus(s, status.New(statusCode, msg))
   346  	}
   347  	st.ht.HandleStreams(
   348  		context.Background(), func(s *Stream) { go handleStream(s) },
   349  	)
   350  	wantHeader := http.Header{
   351  		"Date":         nil,
   352  		"Content-Type": {"application/grpc"},
   353  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   354  	}
   355  	wantTrailer := http.Header{
   356  		"Grpc-Status":  {fmt.Sprint(uint32(statusCode))},
   357  		"Grpc-Message": {encodeGrpcMessage(msg)},
   358  	}
   359  	checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
   360  }
   361  
   362  func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
   363  	bodyr, bodyw := io.Pipe()
   364  	req := &http.Request{
   365  		ProtoMajor: 2,
   366  		Method:     "POST",
   367  		Header: http.Header{
   368  			"Content-Type": {"application/grpc"},
   369  			"Grpc-Timeout": {"200m"},
   370  		},
   371  		URL: &url.URL{
   372  			Path: "/service/foo.bar",
   373  		},
   374  		Body: bodyr,
   375  	}
   376  	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
   377  	ht, err := NewServerHandlerTransport(rw, req, nil)
   378  	if err != nil {
   379  		t.Fatal(err)
   380  	}
   381  	runStream := func(s *Stream) {
   382  		defer bodyw.Close()
   383  		select {
   384  		case <-s.ctx.Done():
   385  		case <-time.After(5 * time.Second):
   386  			t.Errorf("timeout waiting for ctx.Done")
   387  			return
   388  		}
   389  		err := s.ctx.Err()
   390  		if err != context.DeadlineExceeded {
   391  			t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
   392  			return
   393  		}
   394  		ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
   395  	}
   396  	ht.HandleStreams(
   397  		context.Background(), func(s *Stream) { go runStream(s) },
   398  	)
   399  	wantHeader := http.Header{
   400  		"Date":         nil,
   401  		"Content-Type": {"application/grpc"},
   402  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   403  	}
   404  	wantTrailer := http.Header{
   405  		"Grpc-Status":  {"4"},
   406  		"Grpc-Message": {encodeGrpcMessage("too slow")},
   407  	}
   408  	checkHeaderAndTrailer(t, rw, wantHeader, wantTrailer)
   409  }
   410  
   411  // TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
   412  // concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
   413  func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
   414  	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
   415  		if want := "/service/foo.bar"; s.method != want {
   416  			t.Errorf("stream method = %q; want %q", s.method, want)
   417  		}
   418  		st.bodyw.Close() // no body
   419  
   420  		var wg sync.WaitGroup
   421  		wg.Add(5)
   422  		for i := 0; i < 5; i++ {
   423  			go func() {
   424  				defer wg.Done()
   425  				st.ht.WriteStatus(s, status.New(codes.OK, ""))
   426  			}()
   427  		}
   428  		wg.Wait()
   429  	})
   430  }
   431  
   432  // TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
   433  // following "WriteStatus" does not panic writing to closed "writes" channel.
   434  func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
   435  	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
   436  		if want := "/service/foo.bar"; s.method != want {
   437  			t.Errorf("stream method = %q; want %q", s.method, want)
   438  		}
   439  		st.bodyw.Close() // no body
   440  
   441  		st.ht.WriteStatus(s, status.New(codes.OK, ""))
   442  		st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
   443  	})
   444  }
   445  
   446  func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
   447  	st := newHandleStreamTest(t)
   448  	st.ht.HandleStreams(
   449  		context.Background(), func(s *Stream) { go handleStream(st, s) },
   450  	)
   451  }
   452  
   453  func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
   454  	errDetails := []protoadapt.MessageV1{
   455  		&epb.RetryInfo{
   456  			RetryDelay: &durationpb.Duration{Seconds: 60},
   457  		},
   458  		&epb.ResourceInfo{
   459  			ResourceType: "foo bar",
   460  			ResourceName: "service.foo.bar",
   461  			Owner:        "User",
   462  		},
   463  	}
   464  
   465  	statusCode := codes.ResourceExhausted
   466  	msg := "you are being throttled"
   467  	st, err := status.New(statusCode, msg).WithDetails(errDetails...)
   468  	if err != nil {
   469  		t.Fatal(err)
   470  	}
   471  
   472  	stBytes, err := proto.Marshal(st.Proto())
   473  	if err != nil {
   474  		t.Fatal(err)
   475  	}
   476  
   477  	hst := newHandleStreamTest(t)
   478  	handleStream := func(s *Stream) {
   479  		hst.ht.WriteStatus(s, st)
   480  	}
   481  	hst.ht.HandleStreams(
   482  		context.Background(), func(s *Stream) { go handleStream(s) },
   483  	)
   484  	wantHeader := http.Header{
   485  		"Date":         nil,
   486  		"Content-Type": {"application/grpc"},
   487  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   488  	}
   489  	wantTrailer := http.Header{
   490  		"Grpc-Status":             {fmt.Sprint(uint32(statusCode))},
   491  		"Grpc-Message":            {encodeGrpcMessage(msg)},
   492  		"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
   493  	}
   494  
   495  	checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
   496  }
   497  
   498  // TestHandlerTransport_Drain verifies that Drain() is not implemented
   499  // by `serverHandlerTransport`.
   500  func (s) TestHandlerTransport_Drain(t *testing.T) {
   501  	defer func() { recover() }()
   502  	st := newHandleStreamTest(t)
   503  	st.ht.Drain("whatever")
   504  	t.Errorf("serverHandlerTransport.Drain() should have panicked")
   505  }
   506  
   507  // checkHeaderAndTrailer checks that the resulting header and trailer matches the expectation.
   508  func checkHeaderAndTrailer(t *testing.T, rw testHandlerResponseWriter, wantHeader, wantTrailer http.Header) {
   509  	// For trailer-only responses, the trailer values might be reported as part of the Header. They will however
   510  	// be present in Trailer in either case. Hence, normalize the header by removing all trailer values.
   511  	actualHeader := rw.Result().Header.Clone()
   512  	for _, trailerKey := range actualHeader["Trailer"] {
   513  		actualHeader.Del(trailerKey)
   514  	}
   515  
   516  	if !reflect.DeepEqual(actualHeader, wantHeader) {
   517  		t.Errorf("Header mismatch.\n got: %#v\n want: %#v", actualHeader, wantHeader)
   518  	}
   519  	if actualTrailer := rw.Result().Trailer; !reflect.DeepEqual(actualTrailer, wantTrailer) {
   520  		t.Errorf("Trailer mismatch.\n got: %#v\n want: %#v", actualTrailer, wantTrailer)
   521  	}
   522  }
   523  

View as plain text