...

Source file src/github.com/grpc-ecosystem/grpc-gateway/runtime/handler_test.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/runtime

     1  package runtime_test
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"testing"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"github.com/grpc-ecosystem/grpc-gateway/internal"
    13  	"github.com/grpc-ecosystem/grpc-gateway/runtime"
    14  	pb "github.com/grpc-ecosystem/grpc-gateway/runtime/internal/examplepb"
    15  	"google.golang.org/grpc/codes"
    16  	"google.golang.org/grpc/status"
    17  )
    18  
    19  type fakeReponseBodyWrapper struct {
    20  	proto.Message
    21  }
    22  
    23  // XXX_ResponseBody returns id of SimpleMessage
    24  func (r fakeReponseBodyWrapper) XXX_ResponseBody() interface{} {
    25  	resp := r.Message.(*pb.SimpleMessage)
    26  	return resp.Id
    27  }
    28  
    29  func TestForwardResponseStream(t *testing.T) {
    30  	type msg struct {
    31  		pb  proto.Message
    32  		err error
    33  	}
    34  	tests := []struct {
    35  		name         string
    36  		msgs         []msg
    37  		statusCode   int
    38  		responseBody bool
    39  	}{{
    40  		name: "encoding",
    41  		msgs: []msg{
    42  			{&pb.SimpleMessage{Id: "One"}, nil},
    43  			{&pb.SimpleMessage{Id: "Two"}, nil},
    44  		},
    45  		statusCode: http.StatusOK,
    46  	}, {
    47  		name:       "empty",
    48  		statusCode: http.StatusOK,
    49  	}, {
    50  		name:       "error",
    51  		msgs:       []msg{{nil, status.Errorf(codes.OutOfRange, "400")}},
    52  		statusCode: http.StatusBadRequest,
    53  	}, {
    54  		name: "stream_error",
    55  		msgs: []msg{
    56  			{&pb.SimpleMessage{Id: "One"}, nil},
    57  			{nil, status.Errorf(codes.OutOfRange, "400")},
    58  		},
    59  		statusCode: http.StatusOK,
    60  	}, {
    61  		name: "response body stream case",
    62  		msgs: []msg{
    63  			{fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "One"}}, nil},
    64  			{fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "Two"}}, nil},
    65  		},
    66  		responseBody: true,
    67  		statusCode:   http.StatusOK,
    68  	}, {
    69  		name: "response body stream error case",
    70  		msgs: []msg{
    71  			{fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "One"}}, nil},
    72  			{nil, status.Errorf(codes.OutOfRange, "400")},
    73  		},
    74  		responseBody: true,
    75  		statusCode:   http.StatusOK,
    76  	}}
    77  
    78  	newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
    79  		var count int
    80  		return func() (proto.Message, error) {
    81  			if count == len(msgs) {
    82  				return nil, io.EOF
    83  			} else if count > len(msgs) {
    84  				t.Errorf("recv() called %d times for %d messages", count, len(msgs))
    85  			}
    86  			count++
    87  			msg := msgs[count-1]
    88  			return msg.pb, msg.err
    89  		}
    90  	}
    91  	ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
    92  	marshaler := &runtime.JSONPb{}
    93  	for _, tt := range tests {
    94  		t.Run(tt.name, func(t *testing.T) {
    95  			recv := newTestRecv(t, tt.msgs)
    96  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
    97  			resp := httptest.NewRecorder()
    98  
    99  			runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
   100  
   101  			w := resp.Result()
   102  			if w.StatusCode != tt.statusCode {
   103  				t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
   104  			}
   105  			if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
   106  				t.Errorf("ForwardResponseStream missing header chunked")
   107  			}
   108  			body, err := ioutil.ReadAll(w.Body)
   109  			if err != nil {
   110  				t.Errorf("Failed to read response body with %v", err)
   111  			}
   112  			w.Body.Close()
   113  
   114  			var want []byte
   115  			for i, msg := range tt.msgs {
   116  				if msg.err != nil {
   117  					if i == 0 {
   118  						// Skip non-stream errors
   119  						t.Skip("checking error encodings")
   120  					}
   121  					st, _ := status.FromError(msg.err)
   122  					httpCode := runtime.HTTPStatusFromCode(st.Code())
   123  					b, err := marshaler.Marshal(map[string]proto.Message{
   124  						"error": &internal.StreamError{
   125  							GrpcCode:   int32(st.Code()),
   126  							HttpCode:   int32(httpCode),
   127  							Message:    st.Message(),
   128  							HttpStatus: http.StatusText(httpCode),
   129  							Details:    st.Proto().GetDetails(),
   130  						},
   131  					})
   132  					if err != nil {
   133  						t.Errorf("marshaler.Marshal() failed %v", err)
   134  					}
   135  					errBytes := body[len(want):]
   136  					if string(errBytes) != string(b) {
   137  						t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", errBytes, b)
   138  					}
   139  
   140  					return
   141  				}
   142  
   143  				var b []byte
   144  
   145  				if tt.responseBody {
   146  					// responseBody interface is in runtime package and test is in runtime_test package. hence can't use responseBody directly
   147  					// So type casting to fakeReponseBodyWrapper struct to verify the data.
   148  					rb, ok := msg.pb.(fakeReponseBodyWrapper)
   149  					if !ok {
   150  						t.Errorf("stream responseBody failed %v", err)
   151  					}
   152  
   153  					b, err = marshaler.Marshal(map[string]interface{}{"result": rb.XXX_ResponseBody()})
   154  				} else {
   155  					b, err = marshaler.Marshal(map[string]interface{}{"result": msg.pb})
   156  				}
   157  
   158  				if err != nil {
   159  					t.Errorf("marshaler.Marshal() failed %v", err)
   160  				}
   161  				want = append(want, b...)
   162  				want = append(want, marshaler.Delimiter()...)
   163  			}
   164  
   165  			if string(body) != string(want) {
   166  				t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
   167  			}
   168  		})
   169  	}
   170  }
   171  
   172  // A custom marshaler implementation, that doesn't implement the delimited interface
   173  type CustomMarshaler struct {
   174  	m *runtime.JSONPb
   175  }
   176  
   177  func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error)       { return c.m.Marshal(v) }
   178  func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error  { return c.m.Unmarshal(data, v) }
   179  func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder      { return c.m.NewDecoder(r) }
   180  func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder      { return c.m.NewEncoder(w) }
   181  func (c *CustomMarshaler) ContentType() string                         { return c.m.ContentType() }
   182  func (c *CustomMarshaler) ContentTypeFromMessage(v interface{}) string { return "Custom-Content-Type" }
   183  
   184  func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
   185  	type msg struct {
   186  		pb  proto.Message
   187  		err error
   188  	}
   189  	tests := []struct {
   190  		name       string
   191  		msgs       []msg
   192  		statusCode int
   193  	}{{
   194  		name: "encoding",
   195  		msgs: []msg{
   196  			{&pb.SimpleMessage{Id: "One"}, nil},
   197  			{&pb.SimpleMessage{Id: "Two"}, nil},
   198  		},
   199  		statusCode: http.StatusOK,
   200  	}, {
   201  		name:       "empty",
   202  		statusCode: http.StatusOK,
   203  	}, {
   204  		name:       "error",
   205  		msgs:       []msg{{nil, status.Errorf(codes.OutOfRange, "400")}},
   206  		statusCode: http.StatusBadRequest,
   207  	}, {
   208  		name: "stream_error",
   209  		msgs: []msg{
   210  			{&pb.SimpleMessage{Id: "One"}, nil},
   211  			{nil, status.Errorf(codes.OutOfRange, "400")},
   212  		},
   213  		statusCode: http.StatusOK,
   214  	}}
   215  
   216  	newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
   217  		var count int
   218  		return func() (proto.Message, error) {
   219  			if count == len(msgs) {
   220  				return nil, io.EOF
   221  			} else if count > len(msgs) {
   222  				t.Errorf("recv() called %d times for %d messages", count, len(msgs))
   223  			}
   224  			count++
   225  			msg := msgs[count-1]
   226  			return msg.pb, msg.err
   227  		}
   228  	}
   229  	ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
   230  	marshaler := &CustomMarshaler{&runtime.JSONPb{}}
   231  	for _, tt := range tests {
   232  		t.Run(tt.name, func(t *testing.T) {
   233  			recv := newTestRecv(t, tt.msgs)
   234  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
   235  			resp := httptest.NewRecorder()
   236  
   237  			runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
   238  
   239  			w := resp.Result()
   240  			if w.StatusCode != tt.statusCode {
   241  				t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
   242  			}
   243  			if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
   244  				t.Errorf("ForwardResponseStream missing header chunked")
   245  			}
   246  			body, err := ioutil.ReadAll(w.Body)
   247  			if err != nil {
   248  				t.Errorf("Failed to read response body with %v", err)
   249  			}
   250  			w.Body.Close()
   251  
   252  			var want []byte
   253  			for _, msg := range tt.msgs {
   254  				if msg.err != nil {
   255  					t.Skip("checking erorr encodings")
   256  				}
   257  				b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
   258  				if err != nil {
   259  					t.Errorf("marshaler.Marshal() failed %v", err)
   260  				}
   261  				want = append(want, b...)
   262  				want = append(want, "\n"...)
   263  			}
   264  
   265  			if string(body) != string(want) {
   266  				t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
   267  			}
   268  		})
   269  	}
   270  }
   271  
   272  func TestForwardResponseMessage(t *testing.T) {
   273  	msg := &pb.SimpleMessage{Id: "One"}
   274  	tests := []struct {
   275  		name        string
   276  		marshaler   runtime.Marshaler
   277  		contentType string
   278  	}{{
   279  		name:        "standard marshaler",
   280  		marshaler:   &runtime.JSONPb{},
   281  		contentType: "application/json",
   282  	}, {
   283  		name:        "httpbody marshaler",
   284  		marshaler:   &runtime.HTTPBodyMarshaler{&runtime.JSONPb{}},
   285  		contentType: "application/json",
   286  	}, {
   287  		name:        "custom marshaler",
   288  		marshaler:   &CustomMarshaler{&runtime.JSONPb{}},
   289  		contentType: "Custom-Content-Type",
   290  	}}
   291  
   292  	ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
   293  	for _, tt := range tests {
   294  		t.Run(tt.name, func(t *testing.T) {
   295  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
   296  			resp := httptest.NewRecorder()
   297  
   298  			runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(), tt.marshaler, resp, req, msg)
   299  
   300  			w := resp.Result()
   301  			if w.StatusCode != http.StatusOK {
   302  				t.Errorf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
   303  			}
   304  			if h := w.Header.Get("Content-Type"); h != tt.contentType {
   305  				t.Errorf("Content-Type %v want %v", h, tt.contentType)
   306  			}
   307  			body, err := ioutil.ReadAll(w.Body)
   308  			if err != nil {
   309  				t.Errorf("Failed to read response body with %v", err)
   310  			}
   311  			w.Body.Close()
   312  
   313  			want, err := tt.marshaler.Marshal(msg)
   314  			if err != nil {
   315  				t.Errorf("marshaler.Marshal() failed %v", err)
   316  			}
   317  
   318  			if string(body) != string(want) {
   319  				t.Errorf("ForwardResponseMessage() = \"%s\" want \"%s\"", body, want)
   320  			}
   321  		})
   322  	}
   323  }
   324  

View as plain text