1 package runtime_test
2
3 import (
4 "context"
5 "io"
6 "io/ioutil"
7 "net/http"
8 "net/http/httptest"
9 "testing"
10
11 "github.com/golang/protobuf/proto"
12 "github.com/grpc-ecosystem/grpc-gateway/internal"
13 "github.com/grpc-ecosystem/grpc-gateway/runtime"
14 pb "github.com/grpc-ecosystem/grpc-gateway/runtime/internal/examplepb"
15 "google.golang.org/grpc/codes"
16 "google.golang.org/grpc/status"
17 )
18
19 type fakeReponseBodyWrapper struct {
20 proto.Message
21 }
22
23
24 func (r fakeReponseBodyWrapper) XXX_ResponseBody() interface{} {
25 resp := r.Message.(*pb.SimpleMessage)
26 return resp.Id
27 }
28
29 func TestForwardResponseStream(t *testing.T) {
30 type msg struct {
31 pb proto.Message
32 err error
33 }
34 tests := []struct {
35 name string
36 msgs []msg
37 statusCode int
38 responseBody bool
39 }{{
40 name: "encoding",
41 msgs: []msg{
42 {&pb.SimpleMessage{Id: "One"}, nil},
43 {&pb.SimpleMessage{Id: "Two"}, nil},
44 },
45 statusCode: http.StatusOK,
46 }, {
47 name: "empty",
48 statusCode: http.StatusOK,
49 }, {
50 name: "error",
51 msgs: []msg{{nil, status.Errorf(codes.OutOfRange, "400")}},
52 statusCode: http.StatusBadRequest,
53 }, {
54 name: "stream_error",
55 msgs: []msg{
56 {&pb.SimpleMessage{Id: "One"}, nil},
57 {nil, status.Errorf(codes.OutOfRange, "400")},
58 },
59 statusCode: http.StatusOK,
60 }, {
61 name: "response body stream case",
62 msgs: []msg{
63 {fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "One"}}, nil},
64 {fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "Two"}}, nil},
65 },
66 responseBody: true,
67 statusCode: http.StatusOK,
68 }, {
69 name: "response body stream error case",
70 msgs: []msg{
71 {fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "One"}}, nil},
72 {nil, status.Errorf(codes.OutOfRange, "400")},
73 },
74 responseBody: true,
75 statusCode: http.StatusOK,
76 }}
77
78 newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
79 var count int
80 return func() (proto.Message, error) {
81 if count == len(msgs) {
82 return nil, io.EOF
83 } else if count > len(msgs) {
84 t.Errorf("recv() called %d times for %d messages", count, len(msgs))
85 }
86 count++
87 msg := msgs[count-1]
88 return msg.pb, msg.err
89 }
90 }
91 ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
92 marshaler := &runtime.JSONPb{}
93 for _, tt := range tests {
94 t.Run(tt.name, func(t *testing.T) {
95 recv := newTestRecv(t, tt.msgs)
96 req := httptest.NewRequest("GET", "http://example.com/foo", nil)
97 resp := httptest.NewRecorder()
98
99 runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
100
101 w := resp.Result()
102 if w.StatusCode != tt.statusCode {
103 t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
104 }
105 if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
106 t.Errorf("ForwardResponseStream missing header chunked")
107 }
108 body, err := ioutil.ReadAll(w.Body)
109 if err != nil {
110 t.Errorf("Failed to read response body with %v", err)
111 }
112 w.Body.Close()
113
114 var want []byte
115 for i, msg := range tt.msgs {
116 if msg.err != nil {
117 if i == 0 {
118
119 t.Skip("checking error encodings")
120 }
121 st, _ := status.FromError(msg.err)
122 httpCode := runtime.HTTPStatusFromCode(st.Code())
123 b, err := marshaler.Marshal(map[string]proto.Message{
124 "error": &internal.StreamError{
125 GrpcCode: int32(st.Code()),
126 HttpCode: int32(httpCode),
127 Message: st.Message(),
128 HttpStatus: http.StatusText(httpCode),
129 Details: st.Proto().GetDetails(),
130 },
131 })
132 if err != nil {
133 t.Errorf("marshaler.Marshal() failed %v", err)
134 }
135 errBytes := body[len(want):]
136 if string(errBytes) != string(b) {
137 t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", errBytes, b)
138 }
139
140 return
141 }
142
143 var b []byte
144
145 if tt.responseBody {
146
147
148 rb, ok := msg.pb.(fakeReponseBodyWrapper)
149 if !ok {
150 t.Errorf("stream responseBody failed %v", err)
151 }
152
153 b, err = marshaler.Marshal(map[string]interface{}{"result": rb.XXX_ResponseBody()})
154 } else {
155 b, err = marshaler.Marshal(map[string]interface{}{"result": msg.pb})
156 }
157
158 if err != nil {
159 t.Errorf("marshaler.Marshal() failed %v", err)
160 }
161 want = append(want, b...)
162 want = append(want, marshaler.Delimiter()...)
163 }
164
165 if string(body) != string(want) {
166 t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
167 }
168 })
169 }
170 }
171
172
173 type CustomMarshaler struct {
174 m *runtime.JSONPb
175 }
176
177 func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error) { return c.m.Marshal(v) }
178 func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error { return c.m.Unmarshal(data, v) }
179 func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c.m.NewDecoder(r) }
180 func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) }
181 func (c *CustomMarshaler) ContentType() string { return c.m.ContentType() }
182 func (c *CustomMarshaler) ContentTypeFromMessage(v interface{}) string { return "Custom-Content-Type" }
183
184 func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
185 type msg struct {
186 pb proto.Message
187 err error
188 }
189 tests := []struct {
190 name string
191 msgs []msg
192 statusCode int
193 }{{
194 name: "encoding",
195 msgs: []msg{
196 {&pb.SimpleMessage{Id: "One"}, nil},
197 {&pb.SimpleMessage{Id: "Two"}, nil},
198 },
199 statusCode: http.StatusOK,
200 }, {
201 name: "empty",
202 statusCode: http.StatusOK,
203 }, {
204 name: "error",
205 msgs: []msg{{nil, status.Errorf(codes.OutOfRange, "400")}},
206 statusCode: http.StatusBadRequest,
207 }, {
208 name: "stream_error",
209 msgs: []msg{
210 {&pb.SimpleMessage{Id: "One"}, nil},
211 {nil, status.Errorf(codes.OutOfRange, "400")},
212 },
213 statusCode: http.StatusOK,
214 }}
215
216 newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
217 var count int
218 return func() (proto.Message, error) {
219 if count == len(msgs) {
220 return nil, io.EOF
221 } else if count > len(msgs) {
222 t.Errorf("recv() called %d times for %d messages", count, len(msgs))
223 }
224 count++
225 msg := msgs[count-1]
226 return msg.pb, msg.err
227 }
228 }
229 ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
230 marshaler := &CustomMarshaler{&runtime.JSONPb{}}
231 for _, tt := range tests {
232 t.Run(tt.name, func(t *testing.T) {
233 recv := newTestRecv(t, tt.msgs)
234 req := httptest.NewRequest("GET", "http://example.com/foo", nil)
235 resp := httptest.NewRecorder()
236
237 runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
238
239 w := resp.Result()
240 if w.StatusCode != tt.statusCode {
241 t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
242 }
243 if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
244 t.Errorf("ForwardResponseStream missing header chunked")
245 }
246 body, err := ioutil.ReadAll(w.Body)
247 if err != nil {
248 t.Errorf("Failed to read response body with %v", err)
249 }
250 w.Body.Close()
251
252 var want []byte
253 for _, msg := range tt.msgs {
254 if msg.err != nil {
255 t.Skip("checking erorr encodings")
256 }
257 b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
258 if err != nil {
259 t.Errorf("marshaler.Marshal() failed %v", err)
260 }
261 want = append(want, b...)
262 want = append(want, "\n"...)
263 }
264
265 if string(body) != string(want) {
266 t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
267 }
268 })
269 }
270 }
271
272 func TestForwardResponseMessage(t *testing.T) {
273 msg := &pb.SimpleMessage{Id: "One"}
274 tests := []struct {
275 name string
276 marshaler runtime.Marshaler
277 contentType string
278 }{{
279 name: "standard marshaler",
280 marshaler: &runtime.JSONPb{},
281 contentType: "application/json",
282 }, {
283 name: "httpbody marshaler",
284 marshaler: &runtime.HTTPBodyMarshaler{&runtime.JSONPb{}},
285 contentType: "application/json",
286 }, {
287 name: "custom marshaler",
288 marshaler: &CustomMarshaler{&runtime.JSONPb{}},
289 contentType: "Custom-Content-Type",
290 }}
291
292 ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
293 for _, tt := range tests {
294 t.Run(tt.name, func(t *testing.T) {
295 req := httptest.NewRequest("GET", "http://example.com/foo", nil)
296 resp := httptest.NewRecorder()
297
298 runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(), tt.marshaler, resp, req, msg)
299
300 w := resp.Result()
301 if w.StatusCode != http.StatusOK {
302 t.Errorf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
303 }
304 if h := w.Header.Get("Content-Type"); h != tt.contentType {
305 t.Errorf("Content-Type %v want %v", h, tt.contentType)
306 }
307 body, err := ioutil.ReadAll(w.Body)
308 if err != nil {
309 t.Errorf("Failed to read response body with %v", err)
310 }
311 w.Body.Close()
312
313 want, err := tt.marshaler.Marshal(msg)
314 if err != nil {
315 t.Errorf("marshaler.Marshal() failed %v", err)
316 }
317
318 if string(body) != string(want) {
319 t.Errorf("ForwardResponseMessage() = \"%s\" want \"%s\"", body, want)
320 }
321 })
322 }
323 }
324
View as plain text