1
18
19 package transport
20
21 import (
22 "context"
23 "errors"
24 "fmt"
25 "io"
26 "net/http"
27 "net/http/httptest"
28 "net/url"
29 "reflect"
30 "sync"
31 "testing"
32 "time"
33
34 epb "google.golang.org/genproto/googleapis/rpc/errdetails"
35 "google.golang.org/grpc/codes"
36 "google.golang.org/grpc/metadata"
37 "google.golang.org/grpc/status"
38 "google.golang.org/protobuf/proto"
39 "google.golang.org/protobuf/protoadapt"
40 "google.golang.org/protobuf/types/known/durationpb"
41 )
42
43 func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
44 type testCase struct {
45 name string
46 req *http.Request
47 wantErr string
48 wantErrCode int
49 modrw func(http.ResponseWriter) http.ResponseWriter
50 check func(*serverHandlerTransport, *testCase) error
51 }
52 tests := []testCase{
53 {
54 name: "bad method",
55 req: &http.Request{
56 ProtoMajor: 2,
57 Method: "GET",
58 Header: http.Header{},
59 },
60 wantErr: `invalid gRPC request method "GET"`,
61 wantErrCode: http.StatusMethodNotAllowed,
62 },
63 {
64 name: "bad content type",
65 req: &http.Request{
66 ProtoMajor: 2,
67 Method: "POST",
68 Header: http.Header{
69 "Content-Type": {"application/foo"},
70 },
71 },
72 wantErr: `invalid gRPC request content-type "application/foo"`,
73 wantErrCode: http.StatusUnsupportedMediaType,
74 },
75 {
76 name: "http/1.1",
77 req: &http.Request{
78 ProtoMajor: 1,
79 ProtoMinor: 1,
80 Method: "POST",
81 Header: http.Header{"Content-Type": []string{"application/grpc"}},
82 },
83 wantErr: "gRPC requires HTTP/2",
84 wantErrCode: http.StatusHTTPVersionNotSupported,
85 },
86 {
87 name: "not flusher",
88 req: &http.Request{
89 ProtoMajor: 2,
90 Method: "POST",
91 Header: http.Header{
92 "Content-Type": {"application/grpc"},
93 },
94 },
95 modrw: func(w http.ResponseWriter) http.ResponseWriter {
96
97 type onlyCloseNotifier interface {
98 http.ResponseWriter
99 }
100 return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
101 },
102 wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
103 wantErrCode: http.StatusInternalServerError,
104 },
105 {
106 name: "valid",
107 req: &http.Request{
108 ProtoMajor: 2,
109 Method: "POST",
110 Header: http.Header{
111 "Content-Type": {"application/grpc"},
112 },
113 URL: &url.URL{
114 Path: "/service/foo.bar",
115 },
116 },
117 check: func(t *serverHandlerTransport, tt *testCase) error {
118 if t.req != tt.req {
119 return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
120 }
121 if t.rw == nil {
122 return errors.New("t.rw = nil; want non-nil")
123 }
124 return nil
125 },
126 },
127 {
128 name: "with timeout",
129 req: &http.Request{
130 ProtoMajor: 2,
131 Method: "POST",
132 Header: http.Header{
133 "Content-Type": []string{"application/grpc"},
134 "Grpc-Timeout": {"200m"},
135 },
136 URL: &url.URL{
137 Path: "/service/foo.bar",
138 },
139 },
140 check: func(t *serverHandlerTransport, tt *testCase) error {
141 if !t.timeoutSet {
142 return errors.New("timeout not set")
143 }
144 if want := 200 * time.Millisecond; t.timeout != want {
145 return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
146 }
147 return nil
148 },
149 },
150 {
151 name: "with bad timeout",
152 req: &http.Request{
153 ProtoMajor: 2,
154 Method: "POST",
155 Header: http.Header{
156 "Content-Type": []string{"application/grpc"},
157 "Grpc-Timeout": {"tomorrow"},
158 },
159 URL: &url.URL{
160 Path: "/service/foo.bar",
161 },
162 },
163 wantErr: `rpc error: code = Internal desc = malformed grpc-timeout: transport: timeout unit is not recognized: "tomorrow"`,
164 wantErrCode: http.StatusBadRequest,
165 },
166 {
167 name: "with metadata",
168 req: &http.Request{
169 ProtoMajor: 2,
170 Method: "POST",
171 Header: http.Header{
172 "Content-Type": []string{"application/grpc"},
173 "meta-foo": {"foo-val"},
174 "meta-bar": {"bar-val1", "bar-val2"},
175 "user-agent": {"x/y a/b"},
176 },
177 URL: &url.URL{
178 Path: "/service/foo.bar",
179 },
180 },
181 check: func(ht *serverHandlerTransport, tt *testCase) error {
182 want := metadata.MD{
183 "meta-bar": {"bar-val1", "bar-val2"},
184 "user-agent": {"x/y a/b"},
185 "meta-foo": {"foo-val"},
186 "content-type": {"application/grpc"},
187 }
188
189 if !reflect.DeepEqual(ht.headerMD, want) {
190 return fmt.Errorf("metadata = %#v; want %#v", ht.headerMD, want)
191 }
192 return nil
193 },
194 },
195 }
196
197 for _, tt := range tests {
198 rrec := httptest.NewRecorder()
199 rw := http.ResponseWriter(testHandlerResponseWriter{
200 ResponseRecorder: rrec,
201 })
202
203 if tt.modrw != nil {
204 rw = tt.modrw(rw)
205 }
206 got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
207 if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
208 t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
209 continue
210 }
211 if tt.wantErrCode == 0 {
212 tt.wantErrCode = http.StatusOK
213 }
214 if rrec.Code != tt.wantErrCode {
215 t.Errorf("%s: code = %d; want %d", tt.name, rrec.Code, tt.wantErrCode)
216 continue
217 }
218 if gotErr != nil {
219 continue
220 }
221 if tt.check != nil {
222 if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
223 t.Errorf("%s: %v", tt.name, err)
224 }
225 }
226 }
227 }
228
229 type testHandlerResponseWriter struct {
230 *httptest.ResponseRecorder
231 }
232
233 func (w testHandlerResponseWriter) Flush() {}
234
235 func newTestHandlerResponseWriter() http.ResponseWriter {
236 return testHandlerResponseWriter{
237 ResponseRecorder: httptest.NewRecorder(),
238 }
239 }
240
241 type handleStreamTest struct {
242 t *testing.T
243 bodyw *io.PipeWriter
244 rw testHandlerResponseWriter
245 ht *serverHandlerTransport
246 }
247
248 func newHandleStreamTest(t *testing.T) *handleStreamTest {
249 bodyr, bodyw := io.Pipe()
250 req := &http.Request{
251 ProtoMajor: 2,
252 Method: "POST",
253 Header: http.Header{
254 "Content-Type": {"application/grpc"},
255 },
256 URL: &url.URL{
257 Path: "/service/foo.bar",
258 },
259 Body: bodyr,
260 }
261 rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
262 ht, err := NewServerHandlerTransport(rw, req, nil)
263 if err != nil {
264 t.Fatal(err)
265 }
266 return &handleStreamTest{
267 t: t,
268 bodyw: bodyw,
269 ht: ht.(*serverHandlerTransport),
270 rw: rw,
271 }
272 }
273
274 func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
275 st := newHandleStreamTest(t)
276 handleStream := func(s *Stream) {
277 if want := "/service/foo.bar"; s.method != want {
278 t.Errorf("stream method = %q; want %q", s.method, want)
279 }
280
281 if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
282 t.Error(err)
283 }
284
285 if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
286 t.Error(err)
287 }
288
289 if err := s.SetSendCompress("gzip"); err != nil {
290 t.Error(err)
291 }
292
293 md := metadata.Pairs("custom-header", "Another custom header value")
294 if err := s.SendHeader(md); err != nil {
295 t.Error(err)
296 }
297 delete(md, "custom-header")
298
299 if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
300 t.Error("expected SetHeader call after SendHeader to fail")
301 }
302
303 if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
304 t.Error("expected second SendHeader call to fail")
305 }
306
307 if err := s.SetSendCompress("snappy"); err == nil {
308 t.Error("expected second SetSendCompress call to fail")
309 }
310
311 st.bodyw.Close()
312 st.ht.WriteStatus(s, status.New(codes.OK, ""))
313 }
314 st.ht.HandleStreams(
315 context.Background(), func(s *Stream) { go handleStream(s) },
316 )
317 wantHeader := http.Header{
318 "Date": nil,
319 "Content-Type": {"application/grpc"},
320 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
321 "Custom-Header": {"Custom header value", "Another custom header value"},
322 "Grpc-Encoding": {"gzip"},
323 }
324 wantTrailer := http.Header{
325 "Grpc-Status": {"0"},
326 "Custom-Trailer": {"Custom trailer value"},
327 }
328 checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
329 }
330
331
332 func (s) TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
333 handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
334 }
335
336
337 func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
338 handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
339 }
340
341 func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
342 st := newHandleStreamTest(t)
343
344 handleStream := func(s *Stream) {
345 st.ht.WriteStatus(s, status.New(statusCode, msg))
346 }
347 st.ht.HandleStreams(
348 context.Background(), func(s *Stream) { go handleStream(s) },
349 )
350 wantHeader := http.Header{
351 "Date": nil,
352 "Content-Type": {"application/grpc"},
353 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
354 }
355 wantTrailer := http.Header{
356 "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
357 "Grpc-Message": {encodeGrpcMessage(msg)},
358 }
359 checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
360 }
361
362 func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
363 bodyr, bodyw := io.Pipe()
364 req := &http.Request{
365 ProtoMajor: 2,
366 Method: "POST",
367 Header: http.Header{
368 "Content-Type": {"application/grpc"},
369 "Grpc-Timeout": {"200m"},
370 },
371 URL: &url.URL{
372 Path: "/service/foo.bar",
373 },
374 Body: bodyr,
375 }
376 rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
377 ht, err := NewServerHandlerTransport(rw, req, nil)
378 if err != nil {
379 t.Fatal(err)
380 }
381 runStream := func(s *Stream) {
382 defer bodyw.Close()
383 select {
384 case <-s.ctx.Done():
385 case <-time.After(5 * time.Second):
386 t.Errorf("timeout waiting for ctx.Done")
387 return
388 }
389 err := s.ctx.Err()
390 if err != context.DeadlineExceeded {
391 t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
392 return
393 }
394 ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
395 }
396 ht.HandleStreams(
397 context.Background(), func(s *Stream) { go runStream(s) },
398 )
399 wantHeader := http.Header{
400 "Date": nil,
401 "Content-Type": {"application/grpc"},
402 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
403 }
404 wantTrailer := http.Header{
405 "Grpc-Status": {"4"},
406 "Grpc-Message": {encodeGrpcMessage("too slow")},
407 }
408 checkHeaderAndTrailer(t, rw, wantHeader, wantTrailer)
409 }
410
411
412
413 func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
414 testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
415 if want := "/service/foo.bar"; s.method != want {
416 t.Errorf("stream method = %q; want %q", s.method, want)
417 }
418 st.bodyw.Close()
419
420 var wg sync.WaitGroup
421 wg.Add(5)
422 for i := 0; i < 5; i++ {
423 go func() {
424 defer wg.Done()
425 st.ht.WriteStatus(s, status.New(codes.OK, ""))
426 }()
427 }
428 wg.Wait()
429 })
430 }
431
432
433
434 func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
435 testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
436 if want := "/service/foo.bar"; s.method != want {
437 t.Errorf("stream method = %q; want %q", s.method, want)
438 }
439 st.bodyw.Close()
440
441 st.ht.WriteStatus(s, status.New(codes.OK, ""))
442 st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
443 })
444 }
445
446 func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
447 st := newHandleStreamTest(t)
448 st.ht.HandleStreams(
449 context.Background(), func(s *Stream) { go handleStream(st, s) },
450 )
451 }
452
453 func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
454 errDetails := []protoadapt.MessageV1{
455 &epb.RetryInfo{
456 RetryDelay: &durationpb.Duration{Seconds: 60},
457 },
458 &epb.ResourceInfo{
459 ResourceType: "foo bar",
460 ResourceName: "service.foo.bar",
461 Owner: "User",
462 },
463 }
464
465 statusCode := codes.ResourceExhausted
466 msg := "you are being throttled"
467 st, err := status.New(statusCode, msg).WithDetails(errDetails...)
468 if err != nil {
469 t.Fatal(err)
470 }
471
472 stBytes, err := proto.Marshal(st.Proto())
473 if err != nil {
474 t.Fatal(err)
475 }
476
477 hst := newHandleStreamTest(t)
478 handleStream := func(s *Stream) {
479 hst.ht.WriteStatus(s, st)
480 }
481 hst.ht.HandleStreams(
482 context.Background(), func(s *Stream) { go handleStream(s) },
483 )
484 wantHeader := http.Header{
485 "Date": nil,
486 "Content-Type": {"application/grpc"},
487 "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
488 }
489 wantTrailer := http.Header{
490 "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
491 "Grpc-Message": {encodeGrpcMessage(msg)},
492 "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
493 }
494
495 checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
496 }
497
498
499
500 func (s) TestHandlerTransport_Drain(t *testing.T) {
501 defer func() { recover() }()
502 st := newHandleStreamTest(t)
503 st.ht.Drain("whatever")
504 t.Errorf("serverHandlerTransport.Drain() should have panicked")
505 }
506
507
508 func checkHeaderAndTrailer(t *testing.T, rw testHandlerResponseWriter, wantHeader, wantTrailer http.Header) {
509
510
511 actualHeader := rw.Result().Header.Clone()
512 for _, trailerKey := range actualHeader["Trailer"] {
513 actualHeader.Del(trailerKey)
514 }
515
516 if !reflect.DeepEqual(actualHeader, wantHeader) {
517 t.Errorf("Header mismatch.\n got: %#v\n want: %#v", actualHeader, wantHeader)
518 }
519 if actualTrailer := rw.Result().Trailer; !reflect.DeepEqual(actualTrailer, wantTrailer) {
520 t.Errorf("Trailer mismatch.\n got: %#v\n want: %#v", actualTrailer, wantTrailer)
521 }
522 }
523
View as plain text