1 package runtime
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "net/http"
9 "net/textproto"
10
11 "github.com/golang/protobuf/proto"
12 "github.com/grpc-ecosystem/grpc-gateway/internal"
13 "google.golang.org/grpc/grpclog"
14 )
15
16 var errEmptyResponse = errors.New("empty response")
17
18
19 func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
20 f, ok := w.(http.Flusher)
21 if !ok {
22 grpclog.Infof("Flush not supported in %T", w)
23 http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
24 return
25 }
26
27 md, ok := ServerMetadataFromContext(ctx)
28 if !ok {
29 grpclog.Infof("Failed to extract ServerMetadata from context")
30 http.Error(w, "unexpected error", http.StatusInternalServerError)
31 return
32 }
33 handleForwardResponseServerMetadata(w, mux, md)
34
35 w.Header().Set("Transfer-Encoding", "chunked")
36 w.Header().Set("Content-Type", marshaler.ContentType())
37 if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
38 HTTPError(ctx, mux, marshaler, w, req, err)
39 return
40 }
41
42 var delimiter []byte
43 if d, ok := marshaler.(Delimited); ok {
44 delimiter = d.Delimiter()
45 } else {
46 delimiter = []byte("\n")
47 }
48
49 var wroteHeader bool
50 for {
51 resp, err := recv()
52 if err == io.EOF {
53 return
54 }
55 if err != nil {
56 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
57 return
58 }
59 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
60 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
61 return
62 }
63
64 var buf []byte
65 switch {
66 case resp == nil:
67 buf, err = marshaler.Marshal(errorChunk(streamError(ctx, mux.streamErrorHandler, errEmptyResponse)))
68 default:
69 result := map[string]interface{}{"result": resp}
70 if rb, ok := resp.(responseBody); ok {
71 result["result"] = rb.XXX_ResponseBody()
72 }
73
74 buf, err = marshaler.Marshal(result)
75 }
76
77 if err != nil {
78 grpclog.Infof("Failed to marshal response chunk: %v", err)
79 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
80 return
81 }
82 if _, err = w.Write(buf); err != nil {
83 grpclog.Infof("Failed to send response chunk: %v", err)
84 return
85 }
86 wroteHeader = true
87 if _, err = w.Write(delimiter); err != nil {
88 grpclog.Infof("Failed to send delimiter chunk: %v", err)
89 return
90 }
91 f.Flush()
92 }
93 }
94
95 func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
96 for k, vs := range md.HeaderMD {
97 if h, ok := mux.outgoingHeaderMatcher(k); ok {
98 for _, v := range vs {
99 w.Header().Add(h, v)
100 }
101 }
102 }
103 }
104
105 func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
106 for k := range md.TrailerMD {
107 tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
108 w.Header().Add("Trailer", tKey)
109 }
110 }
111
112 func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
113 for k, vs := range md.TrailerMD {
114 tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
115 for _, v := range vs {
116 w.Header().Add(tKey, v)
117 }
118 }
119 }
120
121
122
123 type responseBody interface {
124 XXX_ResponseBody() interface{}
125 }
126
127
128 func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
129 md, ok := ServerMetadataFromContext(ctx)
130 if !ok {
131 grpclog.Infof("Failed to extract ServerMetadata from context")
132 }
133
134 handleForwardResponseServerMetadata(w, mux, md)
135 handleForwardResponseTrailerHeader(w, md)
136
137 contentType := marshaler.ContentType()
138
139
140
141 if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok {
142 contentType = typeMarshaler.ContentTypeFromMessage(resp)
143 }
144 w.Header().Set("Content-Type", contentType)
145
146 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
147 HTTPError(ctx, mux, marshaler, w, req, err)
148 return
149 }
150 var buf []byte
151 var err error
152 if rb, ok := resp.(responseBody); ok {
153 buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
154 } else {
155 buf, err = marshaler.Marshal(resp)
156 }
157 if err != nil {
158 grpclog.Infof("Marshal error: %v", err)
159 HTTPError(ctx, mux, marshaler, w, req, err)
160 return
161 }
162
163 if _, err = w.Write(buf); err != nil {
164 grpclog.Infof("Failed to write response: %v", err)
165 }
166
167 handleForwardResponseTrailer(w, md)
168 }
169
170 func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
171 if len(opts) == 0 {
172 return nil
173 }
174 for _, opt := range opts {
175 if err := opt(ctx, w, resp); err != nil {
176 grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
177 return err
178 }
179 }
180 return nil
181 }
182
183 func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
184 serr := streamError(ctx, mux.streamErrorHandler, err)
185 if !wroteHeader {
186 w.WriteHeader(int(serr.HttpCode))
187 }
188 buf, merr := marshaler.Marshal(errorChunk(serr))
189 if merr != nil {
190 grpclog.Infof("Failed to marshal an error: %v", merr)
191 return
192 }
193 if _, werr := w.Write(buf); werr != nil {
194 grpclog.Infof("Failed to notify error to client: %v", werr)
195 return
196 }
197 }
198
199
200
201 func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError {
202 serr := errHandler(ctx, err)
203 if serr != nil {
204 return serr
205 }
206
207 return DefaultHTTPStreamErrorHandler(ctx, err)
208 }
209
210 func errorChunk(err *StreamError) map[string]proto.Message {
211 return map[string]proto.Message{"error": (*internal.StreamError)(err)}
212 }
213
View as plain text